"""Tests for FusionLayer, WeightCalculator, and AutoTurnaround.""" from __future__ import annotations import time import pytest from unittest.mock import AsyncMock, MagicMock from memorybridge.core.schemas import ( AffectSignal, GestureSignal, ModalityWeight, SignalBufferState, VisionSnapshot, ) from memorybridge.fusion.weight_calculator import WeightCalculator from memorybridge.fusion.auto_turnaround import AutoTurnaround class TestWeightCalculator: def test_normalize_uniform(self): calc = WeightCalculator(threshold=0.3) weights = {"a": 0.5, "b": 0.5} normalized = calc._normalize(weights) assert abs(sum(normalized.values()) - 1.0) < 1e-6 def test_normalize_zero_total_returns_all_zeros(self): """Zero total means no modality cleared threshold — all weights must be 0, not uniform.""" calc = WeightCalculator(threshold=0.3) weights = {"a": 0.0, "b": 0.0} normalized = calc._normalize(weights) assert normalized["a"] == 0.0 assert normalized["b"] == 0.0 def test_compute_all_below_threshold_excluded(self): """When all confidences are below threshold, all ModalityWeights have included=False.""" calc = WeightCalculator(threshold=0.3) result = calc.compute({ "text_retrieval": 0.0, "kg_facts": 0.0, "gesture": 0.0, "affect": 0.0, "air_sign": 0.0 }) assert isinstance(result, list) assert all(isinstance(w, ModalityWeight) for w in result) assert all(not w.included for w in result) assert all(w.weight == 0.0 for w in result) def test_compute_single_modality_above_threshold(self): """Single modality above threshold gets weight=1.0.""" calc = WeightCalculator(threshold=0.3) result = calc.compute({ "text_retrieval": 0.8, "kg_facts": 0.0, "gesture": 0.0, "affect": 0.0, "air_sign": 0.0 }) by_name = {w.modality: w for w in result} assert by_name["text_retrieval"].included is True assert abs(by_name["text_retrieval"].weight - 1.0) < 1e-6 assert by_name["gesture"].included is False class TestAutoTurnaround: def _make_frustrated_buffer(self, confidence: float = 0.9, n_windows: int = 2) -> SignalBufferState: ts = time.time() snapshots = [ VisionSnapshot( timestamp=ts + i, affect=AffectSignal(affect_class="frustrated", confidence=confidence, timestamp=ts + i) ) for i in range(n_windows) ] return SignalBufferState(snapshots=snapshots) def test_check_empty_buffer_returns_false(self): """Empty buffer → no dissatisfaction signal → False.""" turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) assert turnaround.check(SignalBufferState()) is False def test_check_frustrated_affect_triggers(self): """2 consecutive frustrated snapshots at 0.9 conf with threshold 0.85 → True.""" turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) buffer = self._make_frustrated_buffer(confidence=0.9, n_windows=2) assert turnaround._check_frustrated_affect(buffer) is True def test_check_frustrated_affect_below_threshold(self): """Frustrated affect below confidence threshold → False.""" turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) buffer = self._make_frustrated_buffer(confidence=0.5, n_windows=2) assert turnaround._check_frustrated_affect(buffer) is False def test_check_thumbs_down_empty_buffer_returns_false(self): """No snapshots → _check_thumbs_down returns False.""" turnaround = AutoTurnaround() assert turnaround._check_thumbs_down(SignalBufferState()) is False def test_check_thumbs_down_triggers(self): """thumbs_down with confidence above threshold in latest snapshot → True.""" ts = time.time() snap = VisionSnapshot( timestamp=ts, gestures=[GestureSignal(gesture_class="thumbs_down", confidence=0.95, timestamp=ts)], ) buffer = SignalBufferState(snapshots=[snap]) turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2) assert turnaround._check_thumbs_down(buffer) is True