Spaces:
Sleeping
Sleeping
File size: 5,957 Bytes
f69e608 d624b44 f69e608 d624b44 f69e608 d624b44 f69e608 d624b44 f69e608 d624b44 f69e608 d624b44 f69e608 778855c f69e608 778855c f69e608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | """
Semantic drift detector for RetailMind.
Tracks the rolling semantic similarity of incoming user queries against
predefined *concept anchors* (e.g., price-sensitivity, seasonal shift,
eco-trend). When the exponentially-weighted moving average for any concept
exceeds a configurable threshold the system flags an active drift β which
triggers the self-healing adapter to rewrite the LLM system prompt.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from typing import Any
import numpy as np
from modules.shared import get_embedding_model
logger = logging.getLogger(__name__)
@dataclass
class DriftEvent:
"""Immutable record of a single drift measurement."""
timestamp: float
query: str
scores: dict[str, float]
dominant: str
@dataclass
class DriftDetector:
"""
Monitors semantic drift across configurable concept anchors.
Uses EWMA (exponentially weighted moving average) to smooth noisy
single-query scores into stable trend signals.
"""
threshold: float = 0.38
ewma_alpha: float = 0.35 # smoothing factor (higher = more reactive)
history: list[DriftEvent] = field(default_factory=list)
_ewma: dict[str, float] = field(default_factory=dict)
_concept_embs: dict[str, Any] = field(default_factory=dict, repr=False)
def __post_init__(self) -> None:
model = get_embedding_model()
# Multiple anchor phrases per concept β averaged embedding for robustness
concept_phrases = {
"price_sensitive": [
"cheap budget discount low price clearance sale savings affordable",
"what is the cheapest option under twenty dollars bargain deal",
"I only have a limited budget, show me value picks",
],
"summer_shift": [
"summer heat warm weather sandals shorts sunscreen beach",
"lightweight breathable sun protection hot climate UV",
"vacation tropical poolside outdoor warm temperature",
],
"eco_trend": [
"eco-friendly sustainable organic recycled environment green",
"plant-based carbon-neutral zero waste biodegradable vegan",
"responsible sourcing ethical production renewable materials",
],
}
for concept, phrases in concept_phrases.items():
embs = model.encode(phrases, show_progress_bar=False)
self._concept_embs[concept] = np.mean(embs, axis=0)
self._ewma[concept] = 0.0
logger.info("DriftDetector initialized with %d concept anchors.", len(concept_phrases))
# ββ Public API ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
def analyze_drift(
self, query: str, query_emb=None
) -> tuple[str, dict[str, float]]:
"""
Score *query* against all concept anchors and return
``(dominant_concept, raw_scores)``.
Pass *query_emb* to skip re-encoding when the caller already has it.
"""
if query_emb is None:
query_emb = get_embedding_model().encode([query], show_progress_bar=False)[0]
raw_scores: dict[str, float] = {}
for concept, ref_emb in self._concept_embs.items():
sim = float(
np.dot(query_emb, ref_emb)
/ (np.linalg.norm(query_emb) * np.linalg.norm(ref_emb) + 1e-10)
)
raw_scores[concept] = sim
# Update EWMA
prev = self._ewma[concept]
self._ewma[concept] = self.ewma_alpha * sim + (1 - self.ewma_alpha) * prev
# Determine dominant drift from smoothed signal
detected = "normal"
max_smoothed = 0.0
for concept, smoothed in self._ewma.items():
if smoothed > self.threshold and smoothed > max_smoothed:
max_smoothed = smoothed
detected = concept
event = DriftEvent(
timestamp=time.time(),
query=query,
scores=raw_scores,
dominant=detected,
)
self.history.append(event)
if len(self.history) > 200:
self.history = self.history[-200:]
logger.debug("Drift analysis: %s | scores=%s | ewma=%s", detected, raw_scores, self._ewma)
return detected, raw_scores
def get_ewma_scores(self) -> dict[str, float]:
"""Return current EWMA-smoothed scores for dashboard display."""
return dict(self._ewma)
def get_recent_stats(self) -> dict[str, float] | None:
"""Return averaged raw scores from last N queries."""
if not self.history:
return None
recent = self.history[-5:]
concepts = list(self._concept_embs.keys())
return {
c: float(np.mean([e.scores[c] for e in recent]))
for c in concepts
}
def get_history_series(self) -> dict[str, list[float]]:
"""Return full EWMA time-series for each concept (for charts).
Pads with baseline values when fewer than 5 real events exist so the
chart renders a smooth baseline line on first load.
"""
series: dict[str, list[float]] = {c: [] for c in self._concept_embs}
ewma_state = {c: 0.0 for c in self._concept_embs}
# Pad with neutral baseline so chart always has something to show
padding = max(0, 5 - len(self.history))
for _ in range(padding):
for c in self._concept_embs:
series[c].append(0.15)
for event in self.history:
for c in self._concept_embs:
ewma_state[c] = self.ewma_alpha * event.scores[c] + (1 - self.ewma_alpha) * ewma_state[c]
series[c].append(ewma_state[c])
return series
|