Spaces:
Build error
Build error
File size: 4,422 Bytes
1004967 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 | """Tests for FusionLayer, WeightCalculator, and AutoTurnaround."""
from __future__ import annotations
import time
import pytest
from unittest.mock import AsyncMock, MagicMock
from memorybridge.core.schemas import (
AffectSignal,
GestureSignal,
ModalityWeight,
SignalBufferState,
VisionSnapshot,
)
from memorybridge.fusion.weight_calculator import WeightCalculator
from memorybridge.fusion.auto_turnaround import AutoTurnaround
class TestWeightCalculator:
def test_normalize_uniform(self):
calc = WeightCalculator(threshold=0.3)
weights = {"a": 0.5, "b": 0.5}
normalized = calc._normalize(weights)
assert abs(sum(normalized.values()) - 1.0) < 1e-6
def test_normalize_zero_total_returns_all_zeros(self):
"""Zero total means no modality cleared threshold — all weights must be 0, not uniform."""
calc = WeightCalculator(threshold=0.3)
weights = {"a": 0.0, "b": 0.0}
normalized = calc._normalize(weights)
assert normalized["a"] == 0.0
assert normalized["b"] == 0.0
def test_compute_all_below_threshold_excluded(self):
"""When all confidences are below threshold, all ModalityWeights have included=False."""
calc = WeightCalculator(threshold=0.3)
result = calc.compute({
"text_retrieval": 0.0, "kg_facts": 0.0, "gesture": 0.0,
"affect": 0.0, "air_sign": 0.0
})
assert isinstance(result, list)
assert all(isinstance(w, ModalityWeight) for w in result)
assert all(not w.included for w in result)
assert all(w.weight == 0.0 for w in result)
def test_compute_single_modality_above_threshold(self):
"""Single modality above threshold gets weight=1.0."""
calc = WeightCalculator(threshold=0.3)
result = calc.compute({
"text_retrieval": 0.8, "kg_facts": 0.0, "gesture": 0.0,
"affect": 0.0, "air_sign": 0.0
})
by_name = {w.modality: w for w in result}
assert by_name["text_retrieval"].included is True
assert abs(by_name["text_retrieval"].weight - 1.0) < 1e-6
assert by_name["gesture"].included is False
class TestAutoTurnaround:
def _make_frustrated_buffer(self, confidence: float = 0.9, n_windows: int = 2) -> SignalBufferState:
ts = time.time()
snapshots = [
VisionSnapshot(
timestamp=ts + i,
affect=AffectSignal(affect_class="frustrated", confidence=confidence, timestamp=ts + i)
)
for i in range(n_windows)
]
return SignalBufferState(snapshots=snapshots)
def test_check_empty_buffer_returns_false(self):
"""Empty buffer → no dissatisfaction signal → False."""
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
assert turnaround.check(SignalBufferState()) is False
def test_check_frustrated_affect_triggers(self):
"""2 consecutive frustrated snapshots at 0.9 conf with threshold 0.85 → True."""
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
buffer = self._make_frustrated_buffer(confidence=0.9, n_windows=2)
assert turnaround._check_frustrated_affect(buffer) is True
def test_check_frustrated_affect_below_threshold(self):
"""Frustrated affect below confidence threshold → False."""
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
buffer = self._make_frustrated_buffer(confidence=0.5, n_windows=2)
assert turnaround._check_frustrated_affect(buffer) is False
def test_check_thumbs_down_empty_buffer_returns_false(self):
"""No snapshots → _check_thumbs_down returns False."""
turnaround = AutoTurnaround()
assert turnaround._check_thumbs_down(SignalBufferState()) is False
def test_check_thumbs_down_triggers(self):
"""thumbs_down with confidence above threshold in latest snapshot → True."""
ts = time.time()
snap = VisionSnapshot(
timestamp=ts,
gestures=[GestureSignal(gesture_class="thumbs_down", confidence=0.95, timestamp=ts)],
)
buffer = SignalBufferState(snapshots=[snap])
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
assert turnaround._check_thumbs_down(buffer) is True
|