Spaces:
Sleeping
Sleeping
File size: 2,383 Bytes
f69e608 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 | """
Unit tests for the drift detection module.
"""
import pytest
from modules.drift import DriftDetector
@pytest.fixture
def detector():
return DriftDetector()
class TestDriftDetector:
"""Tests for semantic drift detection."""
def test_normal_query_no_drift(self, detector):
drift, scores = detector.analyze_drift("I need a good water bottle.")
assert drift == "normal", f"Expected 'normal', got '{drift}'"
assert all(isinstance(v, float) for v in scores.values())
def test_price_sensitive_detection(self, detector):
# Feed multiple budget-oriented queries to build up EWMA
for q in ["cheapest option", "budget under $20", "show me the cheapest"]:
drift, _ = detector.analyze_drift(q)
assert drift == "price_sensitive", f"Expected 'price_sensitive' after budget queries, got '{drift}'"
def test_eco_trend_detection(self, detector):
for q in ["sustainable organic products", "eco-friendly recycled", "I want plant-based items"]:
drift, _ = detector.analyze_drift(q)
assert drift == "eco_trend", f"Expected 'eco_trend' after eco queries, got '{drift}'"
def test_summer_shift_detection(self, detector):
for q in ["summer beach sandals", "hot weather lightweight", "UV protection for sun"]:
drift, _ = detector.analyze_drift(q)
assert drift == "summer_shift", f"Expected 'summer_shift' after summer queries, got '{drift}'"
def test_scores_have_all_concepts(self, detector):
_, scores = detector.analyze_drift("test query")
expected = {"price_sensitive", "summer_shift", "eco_trend"}
assert set(scores.keys()) == expected
def test_history_accumulates(self, detector):
for i in range(5):
detector.analyze_drift(f"query {i}")
assert len(detector.history) == 5
def test_ewma_scores_available(self, detector):
detector.analyze_drift("some query")
ewma = detector.get_ewma_scores()
assert isinstance(ewma, dict)
assert len(ewma) == 3
def test_history_series_length_matches(self, detector):
for i in range(10):
detector.analyze_drift(f"query {i}")
series = detector.get_history_series()
for concept, data in series.items():
assert len(data) == 10, f"{concept} series length {len(data)} != 10"
|