File size: 4,422 Bytes
1004967
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Tests for FusionLayer, WeightCalculator, and AutoTurnaround."""

from __future__ import annotations

import time

import pytest
from unittest.mock import AsyncMock, MagicMock

from memorybridge.core.schemas import (
    AffectSignal,
    GestureSignal,
    ModalityWeight,
    SignalBufferState,
    VisionSnapshot,
)
from memorybridge.fusion.weight_calculator import WeightCalculator
from memorybridge.fusion.auto_turnaround import AutoTurnaround


class TestWeightCalculator:

    def test_normalize_uniform(self):
        calc = WeightCalculator(threshold=0.3)
        weights = {"a": 0.5, "b": 0.5}
        normalized = calc._normalize(weights)
        assert abs(sum(normalized.values()) - 1.0) < 1e-6

    def test_normalize_zero_total_returns_all_zeros(self):
        """Zero total means no modality cleared threshold — all weights must be 0, not uniform."""
        calc = WeightCalculator(threshold=0.3)
        weights = {"a": 0.0, "b": 0.0}
        normalized = calc._normalize(weights)
        assert normalized["a"] == 0.0
        assert normalized["b"] == 0.0

    def test_compute_all_below_threshold_excluded(self):
        """When all confidences are below threshold, all ModalityWeights have included=False."""
        calc = WeightCalculator(threshold=0.3)
        result = calc.compute({
            "text_retrieval": 0.0, "kg_facts": 0.0, "gesture": 0.0,
            "affect": 0.0, "air_sign": 0.0
        })
        assert isinstance(result, list)
        assert all(isinstance(w, ModalityWeight) for w in result)
        assert all(not w.included for w in result)
        assert all(w.weight == 0.0 for w in result)

    def test_compute_single_modality_above_threshold(self):
        """Single modality above threshold gets weight=1.0."""
        calc = WeightCalculator(threshold=0.3)
        result = calc.compute({
            "text_retrieval": 0.8, "kg_facts": 0.0, "gesture": 0.0,
            "affect": 0.0, "air_sign": 0.0
        })
        by_name = {w.modality: w for w in result}
        assert by_name["text_retrieval"].included is True
        assert abs(by_name["text_retrieval"].weight - 1.0) < 1e-6
        assert by_name["gesture"].included is False


class TestAutoTurnaround:

    def _make_frustrated_buffer(self, confidence: float = 0.9, n_windows: int = 2) -> SignalBufferState:
        ts = time.time()
        snapshots = [
            VisionSnapshot(
                timestamp=ts + i,
                affect=AffectSignal(affect_class="frustrated", confidence=confidence, timestamp=ts + i)
            )
            for i in range(n_windows)
        ]
        return SignalBufferState(snapshots=snapshots)

    def test_check_empty_buffer_returns_false(self):
        """Empty buffer → no dissatisfaction signal → False."""
        turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
        assert turnaround.check(SignalBufferState()) is False

    def test_check_frustrated_affect_triggers(self):
        """2 consecutive frustrated snapshots at 0.9 conf with threshold 0.85 → True."""
        turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
        buffer = self._make_frustrated_buffer(confidence=0.9, n_windows=2)
        assert turnaround._check_frustrated_affect(buffer) is True

    def test_check_frustrated_affect_below_threshold(self):
        """Frustrated affect below confidence threshold → False."""
        turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
        buffer = self._make_frustrated_buffer(confidence=0.5, n_windows=2)
        assert turnaround._check_frustrated_affect(buffer) is False

    def test_check_thumbs_down_empty_buffer_returns_false(self):
        """No snapshots → _check_thumbs_down returns False."""
        turnaround = AutoTurnaround()
        assert turnaround._check_thumbs_down(SignalBufferState()) is False

    def test_check_thumbs_down_triggers(self):
        """thumbs_down with confidence above threshold in latest snapshot → True."""
        ts = time.time()
        snap = VisionSnapshot(
            timestamp=ts,
            gestures=[GestureSignal(gesture_class="thumbs_down", confidence=0.95, timestamp=ts)],
        )
        buffer = SignalBufferState(snapshots=[snap])
        turnaround = AutoTurnaround(dissatisfaction_threshold=0.85, window_count=2)
        assert turnaround._check_thumbs_down(buffer) is True