Aetheris-Inference / tests /test_inference.py
Pomilon
Deploy Aetheris to HF Space
1df0e33
import pytest
from unittest.mock import MagicMock, patch
from aetheris.inference import InferenceEngine
@pytest.fixture
def mock_model():
with patch("aetheris.inference.HybridMambaMoE") as MockModel:
mock_instance = MockModel.return_value
# Mock model output
mock_instance.to.return_value = mock_instance
mock_instance.eval.return_value = None
# Mock forward pass
mock_output = MagicMock()
# Shape: (batch_size, seq_len, vocab_size)
mock_output.__getitem__.return_value = torch.randn(1, 1, 50257)
# Actually we need 'logits' key access
mock_instance.return_value = {'logits': torch.randn(1, 10, 50257)}
yield mock_instance
@pytest.fixture
def mock_tokenizer():
with patch("aetheris.inference.get_tokenizer") as mock_get_tokenizer:
mock_tok = MagicMock()
mock_tok.encode.return_value = torch.tensor([[1, 2, 3]])
mock_tok.decode.return_value = "token"
mock_tok.eos_token_id = 50256
mock_get_tokenizer.return_value = mock_tok
yield mock_tok
@pytest.fixture
def mock_utils():
with patch("aetheris.inference.load_latest_checkpoint") as mock_load:
yield mock_load
import torch
def test_inference_initialization(mock_model, mock_tokenizer, mock_utils):
engine = InferenceEngine(config_path="configs/default.yaml")
assert engine.model is not None
assert engine.tokenizer is not None
mock_utils.assert_called_once()
def test_generate_full(mock_model, mock_tokenizer, mock_utils):
engine = InferenceEngine()
# Mock model output for generation loop
# We need to ensure the model returns logits of correct shape
# The loop calls model(generated_ids)
# Let's mock the actual model call inside generate
engine.model.config.torch_dtype = torch.float32
# We need to return a dict with logits
# Shape: (batch, seq_len, vocab_size)
engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)}
output = engine.generate_full("test prompt", max_new_tokens=5)
assert isinstance(output, str)
assert len(output) > 0
def test_generate_stream(mock_model, mock_tokenizer, mock_utils):
engine = InferenceEngine()
engine.model.config.torch_dtype = torch.float32
engine.model.side_effect = lambda x: {'logits': torch.randn(1, x.shape[1], 50257)}
generator = engine.generate("test prompt", max_new_tokens=5, stream=True)
tokens = list(generator)
assert len(tokens) == 5
assert all(isinstance(t, str) for t in tokens)