Spaces:
Sleeping
Sleeping
| import pytest | |
| from unittest.mock import MagicMock, patch | |
| from aetheris.inference import InferenceEngine | |
| def mock_model(): | |
| with patch("aetheris.inference.HybridMambaMoE") as MockModel: | |
| mock_instance = MockModel.return_value | |
| # Mock model output | |
| mock_instance.to.return_value = mock_instance | |
| mock_instance.eval.return_value = None | |
| # Mock forward pass | |
| mock_output = MagicMock() | |
| # Shape: (batch_size, seq_len, vocab_size) | |
| mock_output.__getitem__.return_value = torch.randn(1, 1, 50257) | |
| # Actually we need 'logits' key access | |
| mock_instance.return_value = {'logits': torch.randn(1, 10, 50257)} | |
| yield mock_instance | |
| def mock_tokenizer(): | |
| with patch("aetheris.inference.get_tokenizer") as mock_get_tokenizer: | |
| mock_tok = MagicMock() | |
| mock_tok.encode.return_value = torch.tensor([[1, 2, 3]]) | |
| mock_tok.decode.return_value = "token" | |
| mock_tok.eos_token_id = 50256 | |
| mock_get_tokenizer.return_value = mock_tok | |
| yield mock_tok | |
| def mock_utils(): | |
| with patch("aetheris.inference.load_latest_checkpoint") as mock_load: | |
| yield mock_load | |
| import torch | |
| def test_inference_initialization(mock_model, mock_tokenizer, mock_utils): | |
| engine = InferenceEngine(config_path="configs/default.yaml") | |
| assert engine.model is not None | |
| assert engine.tokenizer is not None | |
| mock_utils.assert_called_once() | |
| def test_generate_full(mock_model, mock_tokenizer, mock_utils): | |
| engine = InferenceEngine() | |
| # Mock model output for generation loop | |
| # We need to ensure the model returns logits of correct shape | |
| # The loop calls model(generated_ids) | |
| # Let's mock the actual model call inside generate | |
| engine.model.config.torch_dtype = torch.float32 | |
| # We need to return a dict with logits | |
| # Shape: (batch, seq_len, vocab_size) | |
| engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)} | |
| output = engine.generate_full("test prompt", max_new_tokens=5) | |
| assert isinstance(output, str) | |
| assert len(output) > 0 | |
| def test_generate_stream(mock_model, mock_tokenizer, mock_utils): | |
| engine = InferenceEngine() | |
| engine.model.config.torch_dtype = torch.float32 | |
| engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)} | |
| generator = engine.generate("test prompt", max_new_tokens=5, stream=True) | |
| tokens = list(generator) | |
| assert len(tokens) == 5 | |
| assert all(isinstance(t, str) for t in tokens) | |