Spaces:
Build error
Build error
| """Tests for FusionLayer, WeightCalculator, and AutoTurnaround.""" | |
| from __future__ import annotations | |
| import time | |
| import pytest | |
| from unittest.mock import AsyncMock, MagicMock | |
| from memorybridge.core.schemas import ( | |
| AffectSignal, | |
| GestureSignal, | |
| ModalityWeight, | |
| SignalBufferState, | |
| VisionSnapshot, | |
| ) | |
| from memorybridge.fusion.weight_calculator import WeightCalculator | |
| from memorybridge.fusion.auto_turnaround import AutoTurnaround | |
| class TestWeightCalculator: | |
| def test_normalize_uniform(self): | |
| calc = WeightCalculator(threshold=0.3) | |
| weights = {"a": 0.5, "b": 0.5} | |
| normalized = calc._normalize(weights) | |
| assert abs(sum(normalized.values()) - 1.0) < 1e-6 | |
| def test_normalize_zero_total_returns_all_zeros(self): | |
| """Zero total means no modality cleared threshold β all weights must be 0, not uniform.""" | |
| calc = WeightCalculator(threshold=0.3) | |
| weights = {"a": 0.0, "b": 0.0} | |
| normalized = calc._normalize(weights) | |
| assert normalized["a"] == 0.0 | |
| assert normalized["b"] == 0.0 | |
| def test_compute_all_below_threshold_excluded(self): | |
| """When all confidences are below threshold, all ModalityWeights have included=False.""" | |
| calc = WeightCalculator(threshold=0.3) | |
| result = calc.compute({ | |
| "text_retrieval": 0.0, "kg_facts": 0.0, "gesture": 0.0, | |
| "affect": 0.0, "air_sign": 0.0 | |
| }) | |
| assert isinstance(result, list) | |
| assert all(isinstance(w, ModalityWeight) for w in result) | |
| assert all(not w.included for w in result) | |
| assert all(w.weight == 0.0 for w in result) | |
| def test_compute_single_modality_above_threshold(self): | |
| """Single modality above threshold gets weight=1.0.""" | |
| calc = WeightCalculator(threshold=0.3) | |
| result = calc.compute({ | |
| "text_retrieval": 0.8, "kg_facts": 0.0, "gesture": 0.0, | |
| "affect": 0.0, "air_sign": 0.0 | |
| }) | |
| by_name = {w.modality: w for w in result} | |
| assert by_name["text_retrieval"].included is True | |
| assert abs(by_name["text_retrieval"].weight - 1.0) < 1e-6 | |
| assert by_name["gesture"].included is False | |
| class TestAutoTurnaround: | |
| def _make_frustrated_buffer(self, confidence: float = 0.9, n_windows: int = 2) -> SignalBufferState: | |
| ts = time.time() | |
| snapshots = [ | |
| VisionSnapshot( | |
| timestamp=ts + i, | |
| affect=AffectSignal(affect_class="frustrated", confidence=confidence, timestamp=ts + i) | |
| ) | |
| for i in range(n_windows) | |
| ] | |
| return SignalBufferState(snapshots=snapshots) | |
| def test_check_empty_buffer_returns_false(self): | |
| """Empty buffer β no dissatisfaction signal β False.""" | |
| turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) | |
| assert turnaround.check(SignalBufferState()) is False | |
| def test_check_frustrated_affect_triggers(self): | |
| """2 consecutive frustrated snapshots at 0.9 conf with threshold 0.85 β True.""" | |
| turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) | |
| buffer = self._make_frustrated_buffer(confidence=0.9, n_windows=2) | |
| assert turnaround._check_frustrated_affect(buffer) is True | |
| def test_check_frustrated_affect_below_threshold(self): | |
| """Frustrated affect below confidence threshold β False.""" | |
| turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) | |
| buffer = self._make_frustrated_buffer(confidence=0.5, n_windows=2) | |
| assert turnaround._check_frustrated_affect(buffer) is False | |
| def test_check_thumbs_down_empty_buffer_returns_false(self): | |
| """No snapshots β _check_thumbs_down returns False.""" | |
| turnaround = AutoTurnaround() | |
| assert turnaround._check_thumbs_down(SignalBufferState()) is False | |
| def test_check_thumbs_down_triggers(self): | |
| """thumbs_down with confidence above threshold in latest snapshot β True.""" | |
| ts = time.time() | |
| snap = VisionSnapshot( | |
| timestamp=ts, | |
| gestures=[GestureSignal(gesture_class="thumbs_down", confidence=0.95, timestamp=ts)], | |
| ) | |
| buffer = SignalBufferState(snapshots=[snap]) | |
| turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) | |
| assert turnaround._check_thumbs_down(buffer) is True | |