File size: 9,577 Bytes
0f07ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
import unittest
import subprocess
import time

import grpc
import backend_pb2
import backend_pb2_grpc

class TestBackendServicer(unittest.TestCase):
    """
    TestBackendServicer is the class that tests the gRPC service.

    This class contains methods to test the startup and shutdown of the gRPC service.
    """
    def setUp(self):
        self.service = subprocess.Popen(["python", "backend.py", "--addr", "localhost:50051"])
        time.sleep(10)

    def tearDown(self) -> None:
        self.service.terminate()
        self.service.wait()

    def test_server_startup(self):
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.Health(backend_pb2.HealthMessage())
                self.assertEqual(response.message, b'OK')
        except Exception as err:
            print(err)
            self.fail("Server failed to start")
        finally:
            self.tearDown()
    def test_load_model(self):
        """
        This method tests if the model is loaded successfully
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
                self.assertTrue(response.success)
                self.assertEqual(response.message, "MLX model loaded successfully")
        except Exception as err:
            print(err)
            self.fail("LoadModel service failed")
        finally:
            self.tearDown()

    def test_text(self):
        """
        This method tests if the embeddings are generated successfully
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
                self.assertTrue(response.success)
                req = backend_pb2.PredictOptions(Prompt="The capital of France is")
                resp = stub.Predict(req)
                self.assertIsNotNone(resp.message)
        except Exception as err:
            print(err)
            self.fail("text service failed")
        finally:
            self.tearDown()

    def test_sampling_params(self):
        """
        This method tests if all sampling parameters are correctly processed
        NOTE: this does NOT test for correctness, just that we received a compatible response
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
                self.assertTrue(response.success)

                req = backend_pb2.PredictOptions(
                    Prompt="The capital of France is",
                    TopP=0.8,
                    Tokens=50,
                    Temperature=0.7,
                    TopK=40,
                    PresencePenalty=0.1,
                    FrequencyPenalty=0.2,
                    MinP=0.05,
                    Seed=42,
                    StopPrompts=["\n"],
                    IgnoreEOS=True,
                )
                resp = stub.Predict(req)
                self.assertIsNotNone(resp.message)
        except Exception as err:
            print(err)
            self.fail("sampling params service failed")
        finally:
            self.tearDown()


    def test_embedding(self):
        """
        This method tests if the embeddings are generated successfully
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="intfloat/e5-mistral-7b-instruct"))
                self.assertTrue(response.success)
                embedding_request = backend_pb2.PredictOptions(Embeddings="This is a test sentence.")
                embedding_response = stub.Embedding(embedding_request)
                self.assertIsNotNone(embedding_response.embeddings)
                # assert that is a list of floats
                self.assertIsInstance(embedding_response.embeddings, list)
                # assert that the list is not empty
                self.assertTrue(len(embedding_response.embeddings) > 0)
        except Exception as err:
            print(err)
            self.fail("Embedding service failed")
        finally:
            self.tearDown()

    def test_concurrent_requests(self):
        """
        This method tests that concurrent requests don't corrupt each other's cache state.
        This is a regression test for the race condition in the original implementation.
        """
        import concurrent.futures

        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
                self.assertTrue(response.success)

                def make_request(prompt):
                    req = backend_pb2.PredictOptions(Prompt=prompt, Tokens=20)
                    return stub.Predict(req)

                # Run 5 concurrent requests with different prompts
                prompts = [
                    "The capital of France is",
                    "The capital of Germany is",
                    "The capital of Italy is",
                    "The capital of Spain is",
                    "The capital of Portugal is",
                ]

                with concurrent.futures.ThreadPoolExecutor(max_workers=5) as executor:
                    futures = [executor.submit(make_request, p) for p in prompts]
                    results = [f.result() for f in concurrent.futures.as_completed(futures)]

                # All results should be non-empty
                messages = [r.message for r in results]
                self.assertTrue(all(len(m) > 0 for m in messages), "All requests should return non-empty responses")
                print(f"Concurrent test passed: {len(messages)} responses received")

        except Exception as err:
            print(err)
            self.fail("Concurrent requests test failed")
        finally:
            self.tearDown()

    def test_cache_reuse(self):
        """
        This method tests that repeated prompts reuse cached KV states.
        The second request should benefit from the cached prompt processing.
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
                self.assertTrue(response.success)

                prompt = "The quick brown fox jumps over the lazy dog. "

                # First request - populates cache
                req1 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
                resp1 = stub.Predict(req1)
                self.assertIsNotNone(resp1.message)

                # Second request with same prompt - should reuse cache
                req2 = backend_pb2.PredictOptions(Prompt=prompt, Tokens=10)
                resp2 = stub.Predict(req2)
                self.assertIsNotNone(resp2.message)

                print(f"Cache reuse test passed: first={len(resp1.message)} bytes, second={len(resp2.message)} bytes")

        except Exception as err:
            print(err)
            self.fail("Cache reuse test failed")
        finally:
            self.tearDown()

    def test_prefix_cache_reuse(self):
        """
        This method tests that prompts sharing a common prefix benefit from cached KV states.
        """
        try:
            self.setUp()
            with grpc.insecure_channel("localhost:50051") as channel:
                stub = backend_pb2_grpc.BackendStub(channel)
                response = stub.LoadModel(backend_pb2.ModelOptions(Model="mlx-community/Llama-3.2-1B-Instruct-4bit"))
                self.assertTrue(response.success)

                # First request with base prompt
                prompt_base = "Once upon a time in a land far away, "
                req1 = backend_pb2.PredictOptions(Prompt=prompt_base, Tokens=10)
                resp1 = stub.Predict(req1)
                self.assertIsNotNone(resp1.message)

                # Second request with extended prompt (same prefix)
                prompt_extended = prompt_base + "there lived a brave knight who "
                req2 = backend_pb2.PredictOptions(Prompt=prompt_extended, Tokens=10)
                resp2 = stub.Predict(req2)
                self.assertIsNotNone(resp2.message)

                print(f"Prefix cache test passed: base={len(resp1.message)} bytes, extended={len(resp2.message)} bytes")

        except Exception as err:
            print(err)
            self.fail("Prefix cache reuse test failed")
        finally:
            self.tearDown()


# Unit tests for ThreadSafeLRUPromptCache are in test_mlx_cache.py