OmniCoreX / tests /test_inference.py
Kosasih's picture
Create tests/test_inference.py
c65e908 verified
import unittest
import torch
from inference import StreamingInference
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vocab_size = 128
def forward(self, x):
batch_size, seq_len = x.shape
return torch.randn(batch_size, seq_len, self.vocab_size)
class DummyTokenizer:
def __call__(self, text):
return [ord(c) % 100 + 1 for c in text]
def decode(self, token_ids):
return "".join([chr((tid - 1) % 100 + 32) for tid in token_ids])
class InferenceTest(unittest.TestCase):
def test_streaming_inference(self):
model = DummyModel()
tokenizer = DummyTokenizer()
infer = StreamingInference(model=model, tokenizer=tokenizer, max_context_length=20, batch_size=1)
infer.start()
infer.submit_input("Hello")
out = infer.get_response(timeout=5)
infer.stop()
self.assertIsInstance(out, str)
if __name__ == "__main__":
unittest.main()