File size: 989 Bytes
c65e908 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 | import unittest
import torch
from inference import StreamingInference
class DummyModel(torch.nn.Module):
def __init__(self):
super().__init__()
self.vocab_size = 128
def forward(self, x):
batch_size, seq_len = x.shape
return torch.randn(batch_size, seq_len, self.vocab_size)
class DummyTokenizer:
def __call__(self, text):
return [ord(c) % 100 + 1 for c in text]
def decode(self, token_ids):
return "".join([chr((tid - 1) % 100 + 32) for tid in token_ids])
class InferenceTest(unittest.TestCase):
def test_streaming_inference(self):
model = DummyModel()
tokenizer = DummyTokenizer()
infer = StreamingInference(model=model, tokenizer=tokenizer, max_context_length=20, batch_size=1)
infer.start()
infer.submit_input("Hello")
out = infer.get_response(timeout=5)
infer.stop()
self.assertIsInstance(out, str)
if __name__ == "__main__":
unittest.main()
|