import unittest import torch from inference import StreamingInference class DummyModel(torch.nn.Module): def __init__(self): super().__init__() self.vocab_size = 128 def forward(self, x): batch_size, seq_len = x.shape return torch.randn(batch_size, seq_len, self.vocab_size) class DummyTokenizer: def __call__(self, text): return [ord(c) % 100 + 1 for c in text] def decode(self, token_ids): return "".join([chr((tid - 1) % 100 + 32) for tid in token_ids]) class InferenceTest(unittest.TestCase): def test_streaming_inference(self): model = DummyModel() tokenizer = DummyTokenizer() infer = StreamingInference(model=model, tokenizer=tokenizer, max_context_length=20, batch_size=1) infer.start() infer.submit_input("Hello") out = infer.get_response(timeout=5) infer.stop() self.assertIsInstance(out, str) if __name__ == "__main__": unittest.main()