| """Tests for the streaming predictor.""" |
|
|
| from unittest.mock import MagicMock, patch |
|
|
| import numpy as np |
| import pytest |
| import torch |
|
|
|
|
| def _make_mock_model(n_vertices=100, hidden=64): |
| """Create a mock model that returns predictable outputs.""" |
| model = MagicMock() |
| model.feature_dims = {"text": (2, 32), "audio": (2, 32)} |
| model.config = MagicMock() |
| model.config.hidden = hidden |
|
|
| def fake_forward(batch, pool_outputs=True): |
| return torch.randn(1, n_vertices, 10) |
|
|
| model.__call__ = fake_forward |
| model.return_value = torch.randn(1, n_vertices, 10) |
| return model |
|
|
|
|
| class TestStreamingPredictor: |
| def test_buffer_fill(self): |
| from cortexlab.inference.streaming import StreamingPredictor |
|
|
| model = _make_mock_model() |
| sp = StreamingPredictor(model, window_trs=5, step_trs=1) |
|
|
| |
| for _ in range(4): |
| features = { |
| "text": torch.randn(2, 32), |
| "audio": torch.randn(2, 32), |
| } |
| result = sp.push_frame(features) |
| assert result is None |
|
|
| def test_prediction_emission(self): |
| from cortexlab.inference.streaming import StreamingPredictor |
|
|
| model = _make_mock_model() |
| sp = StreamingPredictor(model, window_trs=3, step_trs=1) |
|
|
| for i in range(3): |
| features = {"text": torch.randn(2, 32), "audio": torch.randn(2, 32)} |
| result = sp.push_frame(features) |
|
|
| |
| assert result is not None |
| assert isinstance(result, np.ndarray) |
| assert result.shape == (100,) |
|
|
| def test_step_control(self): |
| from cortexlab.inference.streaming import StreamingPredictor |
|
|
| model = _make_mock_model() |
| sp = StreamingPredictor(model, window_trs=3, step_trs=2) |
|
|
| results = [] |
| for _ in range(5): |
| features = {"text": torch.randn(2, 32), "audio": torch.randn(2, 32)} |
| result = sp.push_frame(features) |
| results.append(result) |
|
|
| |
| |
| |
| |
| |
| |
| assert results[0] is None |
| assert results[1] is None |
| assert results[2] is not None |
| assert results[3] is None |
| assert results[4] is not None |
|
|
| def test_missing_modality(self): |
| from cortexlab.inference.streaming import StreamingPredictor |
|
|
| model = _make_mock_model() |
| sp = StreamingPredictor(model, window_trs=2, step_trs=1) |
|
|
| |
| for _ in range(2): |
| features = {"text": torch.randn(2, 32)} |
| result = sp.push_frame(features) |
|
|
| assert result is not None |
|
|
| def test_flush(self): |
| from cortexlab.inference.streaming import StreamingPredictor |
|
|
| model = _make_mock_model() |
| sp = StreamingPredictor(model, window_trs=3, step_trs=1) |
|
|
| for _ in range(3): |
| features = {"text": torch.randn(2, 32), "audio": torch.randn(2, 32)} |
| sp.push_frame(features) |
|
|
| results = sp.flush() |
| assert len(results) >= 1 |
|
|
| def test_reset(self): |
| from cortexlab.inference.streaming import StreamingPredictor |
|
|
| model = _make_mock_model() |
| sp = StreamingPredictor(model, window_trs=3, step_trs=1) |
|
|
| for _ in range(3): |
| features = {"text": torch.randn(2, 32)} |
| sp.push_frame(features) |
|
|
| sp.reset() |
| assert len(sp._buffer) == 0 |
| assert sp._frames_since_emit == 0 |
|
|