memorybridge / tests /test_fusion_layer.py
kimandrew927's picture
Initial Space deployment
1004967
"""Tests for FusionLayer, WeightCalculator, and AutoTurnaround."""
from __future__ import annotations
import time
import pytest
from unittest.mock import AsyncMock, MagicMock
from memorybridge.core.schemas import (
AffectSignal,
GestureSignal,
ModalityWeight,
SignalBufferState,
VisionSnapshot,
)
from memorybridge.fusion.weight_calculator import WeightCalculator
from memorybridge.fusion.auto_turnaround import AutoTurnaround
class TestWeightCalculator:
def test_normalize_uniform(self):
calc = WeightCalculator(threshold=0.3)
weights = {"a": 0.5, "b": 0.5}
normalized = calc._normalize(weights)
assert abs(sum(normalized.values()) - 1.0) < 1e-6
def test_normalize_zero_total_returns_all_zeros(self):
"""Zero total means no modality cleared threshold β€” all weights must be 0, not uniform."""
calc = WeightCalculator(threshold=0.3)
weights = {"a": 0.0, "b": 0.0}
normalized = calc._normalize(weights)
assert normalized["a"] == 0.0
assert normalized["b"] == 0.0
def test_compute_all_below_threshold_excluded(self):
"""When all confidences are below threshold, all ModalityWeights have included=False."""
calc = WeightCalculator(threshold=0.3)
result = calc.compute({
"text_retrieval": 0.0, "kg_facts": 0.0, "gesture": 0.0,
"affect": 0.0, "air_sign": 0.0
})
assert isinstance(result, list)
assert all(isinstance(w, ModalityWeight) for w in result)
assert all(not w.included for w in result)
assert all(w.weight == 0.0 for w in result)
def test_compute_single_modality_above_threshold(self):
"""Single modality above threshold gets weight=1.0."""
calc = WeightCalculator(threshold=0.3)
result = calc.compute({
"text_retrieval": 0.8, "kg_facts": 0.0, "gesture": 0.0,
"affect": 0.0, "air_sign": 0.0
})
by_name = {w.modality: w for w in result}
assert by_name["text_retrieval"].included is True
assert abs(by_name["text_retrieval"].weight - 1.0) < 1e-6
assert by_name["gesture"].included is False
class TestAutoTurnaround:
def _make_frustrated_buffer(self, confidence: float = 0.9, n_windows: int = 2) -> SignalBufferState:
ts = time.time()
snapshots = [
VisionSnapshot(
timestamp=ts + i,
affect=AffectSignal(affect_class="frustrated", confidence=confidence, timestamp=ts + i)
)
for i in range(n_windows)
]
return SignalBufferState(snapshots=snapshots)
def test_check_empty_buffer_returns_false(self):
"""Empty buffer β†’ no dissatisfaction signal β†’ False."""
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
assert turnaround.check(SignalBufferState()) is False
def test_check_frustrated_affect_triggers(self):
"""2 consecutive frustrated snapshots at 0.9 conf with threshold 0.85 β†’ True."""
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
buffer = self._make_frustrated_buffer(confidence=0.9, n_windows=2)
assert turnaround._check_frustrated_affect(buffer) is True
def test_check_frustrated_affect_below_threshold(self):
"""Frustrated affect below confidence threshold β†’ False."""
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
buffer = self._make_frustrated_buffer(confidence=0.5, n_windows=2)
assert turnaround._check_frustrated_affect(buffer) is False
def test_check_thumbs_down_empty_buffer_returns_false(self):
"""No snapshots β†’ _check_thumbs_down returns False."""
turnaround = AutoTurnaround()
assert turnaround._check_thumbs_down(SignalBufferState()) is False
def test_check_thumbs_down_triggers(self):
"""thumbs_down with confidence above threshold in latest snapshot β†’ True."""
ts = time.time()
snap = VisionSnapshot(
timestamp=ts,
gestures=[GestureSignal(gesture_class="thumbs_down", confidence=0.95, timestamp=ts)],
)
buffer = SignalBufferState(snapshots=[snap])
turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
assert turnaround._check_thumbs_down(buffer) is True