File size: 2,383 Bytes
f69e608
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
"""
Unit tests for the drift detection module.
"""

import pytest
from modules.drift import DriftDetector


@pytest.fixture
def detector():
    return DriftDetector()


class TestDriftDetector:
    """Tests for semantic drift detection."""

    def test_normal_query_no_drift(self, detector):
        drift, scores = detector.analyze_drift("I need a good water bottle.")
        assert drift == "normal", f"Expected 'normal', got '{drift}'"
        assert all(isinstance(v, float) for v in scores.values())

    def test_price_sensitive_detection(self, detector):
        # Feed multiple budget-oriented queries to build up EWMA
        for q in ["cheapest option", "budget under $20", "show me the cheapest"]:
            drift, _ = detector.analyze_drift(q)
        assert drift == "price_sensitive", f"Expected 'price_sensitive' after budget queries, got '{drift}'"

    def test_eco_trend_detection(self, detector):
        for q in ["sustainable organic products", "eco-friendly recycled", "I want plant-based items"]:
            drift, _ = detector.analyze_drift(q)
        assert drift == "eco_trend", f"Expected 'eco_trend' after eco queries, got '{drift}'"

    def test_summer_shift_detection(self, detector):
        for q in ["summer beach sandals", "hot weather lightweight", "UV protection for sun"]:
            drift, _ = detector.analyze_drift(q)
        assert drift == "summer_shift", f"Expected 'summer_shift' after summer queries, got '{drift}'"

    def test_scores_have_all_concepts(self, detector):
        _, scores = detector.analyze_drift("test query")
        expected = {"price_sensitive", "summer_shift", "eco_trend"}
        assert set(scores.keys()) == expected

    def test_history_accumulates(self, detector):
        for i in range(5):
            detector.analyze_drift(f"query {i}")
        assert len(detector.history) == 5

    def test_ewma_scores_available(self, detector):
        detector.analyze_drift("some query")
        ewma = detector.get_ewma_scores()
        assert isinstance(ewma, dict)
        assert len(ewma) == 3

    def test_history_series_length_matches(self, detector):
        for i in range(10):
            detector.analyze_drift(f"query {i}")
        series = detector.get_history_series()
        for concept, data in series.items():
            assert len(data) == 10, f"{concept} series length {len(data)} != 10"