File size: 5,957 Bytes
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d624b44
f69e608
d624b44
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d624b44
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d624b44
 
 
f69e608
 
 
d624b44
 
f69e608
d624b44
 
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
778855c
 
 
 
 
f69e608
 
778855c
 
 
 
 
 
 
f69e608
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""
Semantic drift detector for RetailMind.

Tracks the rolling semantic similarity of incoming user queries against
predefined *concept anchors* (e.g., price-sensitivity, seasonal shift,
eco-trend).  When the exponentially-weighted moving average for any concept
exceeds a configurable threshold the system flags an active drift β€” which
triggers the self-healing adapter to rewrite the LLM system prompt.
"""

from __future__ import annotations

import logging
import time
from dataclasses import dataclass, field
from typing import Any

import numpy as np

from modules.shared import get_embedding_model

logger = logging.getLogger(__name__)


@dataclass
class DriftEvent:
    """Immutable record of a single drift measurement."""

    timestamp: float
    query: str
    scores: dict[str, float]
    dominant: str


@dataclass
class DriftDetector:
    """
    Monitors semantic drift across configurable concept anchors.

    Uses EWMA (exponentially weighted moving average) to smooth noisy
    single-query scores into stable trend signals.
    """

    threshold: float = 0.38
    ewma_alpha: float = 0.35          # smoothing factor (higher = more reactive)
    history: list[DriftEvent] = field(default_factory=list)
    _ewma: dict[str, float] = field(default_factory=dict)
    _concept_embs: dict[str, Any] = field(default_factory=dict, repr=False)

    def __post_init__(self) -> None:
        model = get_embedding_model()
        # Multiple anchor phrases per concept β†’ averaged embedding for robustness
        concept_phrases = {
            "price_sensitive": [
                "cheap budget discount low price clearance sale savings affordable",
                "what is the cheapest option under twenty dollars bargain deal",
                "I only have a limited budget, show me value picks",
            ],
            "summer_shift": [
                "summer heat warm weather sandals shorts sunscreen beach",
                "lightweight breathable sun protection hot climate UV",
                "vacation tropical poolside outdoor warm temperature",
            ],
            "eco_trend": [
                "eco-friendly sustainable organic recycled environment green",
                "plant-based carbon-neutral zero waste biodegradable vegan",
                "responsible sourcing ethical production renewable materials",
            ],
        }
        for concept, phrases in concept_phrases.items():
            embs = model.encode(phrases, show_progress_bar=False)
            self._concept_embs[concept] = np.mean(embs, axis=0)
            self._ewma[concept] = 0.0

        logger.info("DriftDetector initialized with %d concept anchors.", len(concept_phrases))
    # ── Public API ──────────────────────────────────────────────────────────

    def analyze_drift(
        self, query: str, query_emb=None
    ) -> tuple[str, dict[str, float]]:
        """
        Score *query* against all concept anchors and return
        ``(dominant_concept, raw_scores)``.

        Pass *query_emb* to skip re-encoding when the caller already has it.
        """
        if query_emb is None:
            query_emb = get_embedding_model().encode([query], show_progress_bar=False)[0]

        raw_scores: dict[str, float] = {}
        for concept, ref_emb in self._concept_embs.items():
            sim = float(
                np.dot(query_emb, ref_emb)
                / (np.linalg.norm(query_emb) * np.linalg.norm(ref_emb) + 1e-10)
            )
            raw_scores[concept] = sim

            # Update EWMA
            prev = self._ewma[concept]
            self._ewma[concept] = self.ewma_alpha * sim + (1 - self.ewma_alpha) * prev

        # Determine dominant drift from smoothed signal
        detected = "normal"
        max_smoothed = 0.0
        for concept, smoothed in self._ewma.items():
            if smoothed > self.threshold and smoothed > max_smoothed:
                max_smoothed = smoothed
                detected = concept

        event = DriftEvent(
            timestamp=time.time(),
            query=query,
            scores=raw_scores,
            dominant=detected,
        )
        self.history.append(event)
        if len(self.history) > 200:
            self.history = self.history[-200:]

        logger.debug("Drift analysis: %s | scores=%s | ewma=%s", detected, raw_scores, self._ewma)
        return detected, raw_scores

    def get_ewma_scores(self) -> dict[str, float]:
        """Return current EWMA-smoothed scores for dashboard display."""
        return dict(self._ewma)

    def get_recent_stats(self) -> dict[str, float] | None:
        """Return averaged raw scores from last N queries."""
        if not self.history:
            return None
        recent = self.history[-5:]
        concepts = list(self._concept_embs.keys())
        return {
            c: float(np.mean([e.scores[c] for e in recent]))
            for c in concepts
        }

    def get_history_series(self) -> dict[str, list[float]]:
        """Return full EWMA time-series for each concept (for charts).

        Pads with baseline values when fewer than 5 real events exist so the
        chart renders a smooth baseline line on first load.
        """
        series: dict[str, list[float]] = {c: [] for c in self._concept_embs}
        ewma_state = {c: 0.0 for c in self._concept_embs}

        # Pad with neutral baseline so chart always has something to show
        padding = max(0, 5 - len(self.history))
        for _ in range(padding):
            for c in self._concept_embs:
                series[c].append(0.15)

        for event in self.history:
            for c in self._concept_embs:
                ewma_state[c] = self.ewma_alpha * event.scores[c] + (1 - self.ewma_alpha) * ewma_state[c]
                series[c].append(ewma_state[c])
        return series