Kosasih commited on
Commit
c65e908
·
verified ·
1 Parent(s): 1830953

Create tests/test_inference.py

Browse files
Files changed (1) hide show
  1. tests/test_inference.py +33 -0
tests/test_inference.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import unittest
2
+ import torch
3
+ from inference import StreamingInference
4
+
5
+ class DummyModel(torch.nn.Module):
6
+ def __init__(self):
7
+ super().__init__()
8
+ self.vocab_size = 128
9
+ def forward(self, x):
10
+ batch_size, seq_len = x.shape
11
+ return torch.randn(batch_size, seq_len, self.vocab_size)
12
+
13
+ class DummyTokenizer:
14
+ def __call__(self, text):
15
+ return [ord(c) % 100 + 1 for c in text]
16
+ def decode(self, token_ids):
17
+ return "".join([chr((tid - 1) % 100 + 32) for tid in token_ids])
18
+
19
+ class InferenceTest(unittest.TestCase):
20
+ def test_streaming_inference(self):
21
+ model = DummyModel()
22
+ tokenizer = DummyTokenizer()
23
+ infer = StreamingInference(model=model, tokenizer=tokenizer, max_context_length=20, batch_size=1)
24
+ infer.start()
25
+
26
+ infer.submit_input("Hello")
27
+ out = infer.get_response(timeout=5)
28
+ infer.stop()
29
+
30
+ self.assertIsInstance(out, str)
31
+
32
+ if __name__ == "__main__":
33
+ unittest.main()