tcmmichaelb139 commited on
Commit
eae9ce4
·
1 Parent(s): 20066b2

updated docker file

Browse files
Files changed (2) hide show
  1. Dockerfile +2 -0
  2. tests/test_hf_api.py +305 -0
Dockerfile CHANGED
@@ -1,6 +1,8 @@
1
  FROM python:3.12-slim
2
  WORKDIR /code
3
 
 
 
4
  RUN pip install uv
5
 
6
  COPY pyproject.toml uv.lock ./
 
1
  FROM python:3.12-slim
2
  WORKDIR /code
3
 
4
+ ENV HF_HOME=/code/.cache
5
+
6
  RUN pip install uv
7
 
8
  COPY pyproject.toml uv.lock ./
tests/test_hf_api.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytest
2
+ import requests
3
+ import time
4
+ import re
5
+
6
+
7
+ def get_final_answer(text: str) -> int | None:
8
+ numbers = re.findall(r"\d+", text)
9
+ return int(numbers[-1]) if numbers else None
10
+
11
+
12
+ BASE_URL = "https://tcmmichaelb139-evolutiontransformer.hf.space"
13
+
14
+
15
+ def await_task_completion(task_id, timeout=60):
16
+ start_time = time.time()
17
+ while time.time() - start_time < timeout:
18
+ status_response = requests.get(f"{BASE_URL}/tasks/{task_id}")
19
+
20
+ print(status_response.json())
21
+
22
+ if status_response.status_code == 500:
23
+ return {"error": status_response.json().get("detail", "Unknown error")}
24
+ assert status_response.status_code == 200
25
+ status_data = status_response.json()
26
+
27
+ if status_data["status"] == "SUCCESS":
28
+ return status_data["result"]
29
+
30
+ time.sleep(2)
31
+ else:
32
+ pytest.fail(
33
+ f"Task {task_id} did not complete within the {timeout}-second timeout."
34
+ )
35
+
36
+ return None
37
+
38
+
39
+ def test_generate_endpoint_svamp():
40
+ """
41
+ Tests inference on svamp
42
+ """
43
+ response = requests.post(
44
+ f"{BASE_URL}/generate",
45
+ json={
46
+ "model_name": "svamp",
47
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
48
+ "max_new_tokens": 50,
49
+ "temperature": 0.7,
50
+ },
51
+ )
52
+
53
+ assert response.status_code == 200
54
+ data = response.json()
55
+
56
+ assert "task_id" in data
57
+ task_id = data["task_id"]
58
+
59
+ final_result = await_task_completion(task_id)
60
+
61
+ assert "response" in final_result
62
+ output_text = final_result["response"]
63
+
64
+ answer = get_final_answer(output_text)
65
+ assert answer == 14
66
+
67
+
68
+ def test_merge_then_inference_svamp_1():
69
+ """
70
+ Tests merging then inference for svamp dataset
71
+ """
72
+
73
+ merge_response = requests.post(
74
+ f"{BASE_URL}/merge",
75
+ json={
76
+ "model1_name": "svamp",
77
+ "model2_name": "tinystories",
78
+ "layer_recipe": [[(i, 0, 1.0)] for i in range(24)],
79
+ "embedding_lambdas": [1.0, 1.0],
80
+ "linear_lambdas": [1.0, 1.0],
81
+ "merged_name": "svamp_merged",
82
+ },
83
+ )
84
+
85
+ assert merge_response.status_code == 200
86
+ merge_data = merge_response.json()
87
+ assert "task_id" in merge_data
88
+ merge_task_id = merge_data["task_id"]
89
+
90
+ merge_status_data = await_task_completion(merge_task_id)
91
+ model_name = merge_status_data["response"]
92
+
93
+ time.sleep(5)
94
+
95
+ generate_response = requests.post(
96
+ f"{BASE_URL}/generate",
97
+ json={
98
+ "model_name": model_name,
99
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
100
+ "max_new_tokens": 50,
101
+ "temperature": 0.7,
102
+ },
103
+ )
104
+
105
+ assert generate_response.status_code == 200
106
+ generate_data = generate_response.json()
107
+ assert "task_id" in generate_data
108
+ generate_task_id = generate_data["task_id"]
109
+
110
+ final_result = await_task_completion(generate_task_id)
111
+
112
+ assert "response" in final_result
113
+ output_text = final_result["response"]
114
+ answer = get_final_answer(output_text)
115
+
116
+ assert answer == 14
117
+
118
+
119
+ def test_merge_then_inference_svamp_2():
120
+ """
121
+ Tests merging then inference for svamp dataset
122
+ """
123
+
124
+ merge_repsonse = requests.post(
125
+ f"{BASE_URL}/merge",
126
+ json={
127
+ "model1_name": "svamp",
128
+ "model2_name": "tinystories",
129
+ "layer_recipe": [[(i % 24, 0, 1.0 if i < 24 else 0.5)] for i in range(48)],
130
+ "embedding_lambdas": [1.0, 1.0],
131
+ "linear_lambdas": [1.0, 1.0],
132
+ "merged_name": "svamp_merged",
133
+ },
134
+ )
135
+
136
+ assert merge_repsonse.status_code == 200
137
+ merge_data = merge_repsonse.json()
138
+ assert "task_id" in merge_data
139
+ merge_task_id = merge_data["task_id"]
140
+
141
+ merge_status_data = await_task_completion(merge_task_id)
142
+
143
+ model_name = merge_status_data["response"]
144
+
145
+ merge_response2 = requests.post(
146
+ f"{BASE_URL}/merge",
147
+ json={
148
+ "model1_name": model_name,
149
+ "model2_name": "tinystories",
150
+ "layer_recipe": [[(i, 1, 0.25)] for i in range(24)],
151
+ "embedding_lambdas": [0.0, 0.0],
152
+ "linear_lambdas": [0.0, 0.0],
153
+ "merged_name": "svamp_merged",
154
+ },
155
+ )
156
+
157
+ assert merge_response2.status_code == 200
158
+ merge_data2 = merge_response2.json()
159
+ assert "task_id" in merge_data2
160
+ merge_task_id2 = merge_data2["task_id"]
161
+ merge_status_data2 = await_task_completion(merge_task_id2)
162
+ model_name2 = merge_status_data2["response"]
163
+
164
+ time.sleep(5)
165
+
166
+ generate_response = requests.post(
167
+ f"{BASE_URL}/generate",
168
+ json={
169
+ "model_name": model_name2,
170
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
171
+ "max_new_tokens": 50,
172
+ "temperature": 0.7,
173
+ },
174
+ )
175
+
176
+ assert generate_response.status_code == 200
177
+ generate_data = generate_response.json()
178
+ assert "task_id" in generate_data
179
+ generate_task_id = generate_data["task_id"]
180
+
181
+ final_result = await_task_completion(generate_task_id)
182
+
183
+ assert "response" in final_result
184
+ output_text = final_result["response"]
185
+ answer = get_final_answer(output_text)
186
+
187
+ assert answer == 14
188
+
189
+
190
+ def test_merge_two_children_then_merge():
191
+ """
192
+ Tests creating two children and merging them
193
+ """
194
+
195
+ merge_response1 = requests.post(
196
+ f"{BASE_URL}/merge",
197
+ json={
198
+ "model1_name": "svamp",
199
+ "model2_name": "tinystories",
200
+ "layer_recipe": [[(i, 0, 0.8)] for i in range(12)]
201
+ + [[(i, 1, 0.6)] for i in range(12)],
202
+ "embedding_lambdas": [0.7, 0.3],
203
+ "linear_lambdas": [0.8, 0.2],
204
+ "merged_name": "child1",
205
+ },
206
+ )
207
+
208
+ assert merge_response1.status_code == 200
209
+ merge_data1 = merge_response1.json()
210
+ assert "task_id" in merge_data1
211
+ merge_task_id1 = merge_data1["task_id"]
212
+ merge_status_data1 = await_task_completion(merge_task_id1)
213
+ child1_name = merge_status_data1["response"]
214
+
215
+ merge_response2 = requests.post(
216
+ f"{BASE_URL}/merge",
217
+ json={
218
+ "model1_name": "svamp",
219
+ "model2_name": "tinystories",
220
+ "layer_recipe": [[(i, 1, 0.9)] for i in range(8)]
221
+ + [[(i, 0, 0.4)] for i in range(16)],
222
+ "embedding_lambdas": [0.2, 0.9],
223
+ "linear_lambdas": [0.3, 0.7],
224
+ "merged_name": "child2",
225
+ },
226
+ )
227
+
228
+ assert merge_response2.status_code == 200
229
+ merge_data2 = merge_response2.json()
230
+ assert "task_id" in merge_data2
231
+ merge_task_id2 = merge_data2["task_id"]
232
+ merge_status_data2 = await_task_completion(merge_task_id2)
233
+ child2_name = merge_status_data2["response"]
234
+
235
+ merge_response3 = requests.post(
236
+ f"{BASE_URL}/merge",
237
+ json={
238
+ "model1_name": child1_name,
239
+ "model2_name": child2_name,
240
+ "layer_recipe": [[(i, 0, 0.6), (i, 1, 0.4)] for i in range(24)],
241
+ "embedding_lambdas": [0.5, 0.5],
242
+ "linear_lambdas": [0.6, 0.4],
243
+ "merged_name": "final_merged",
244
+ },
245
+ )
246
+
247
+ assert merge_response3.status_code == 200
248
+ merge_data3 = merge_response3.json()
249
+ assert "task_id" in merge_data3
250
+ merge_task_id3 = merge_data3["task_id"]
251
+ merge_status_data3 = await_task_completion(merge_task_id3)
252
+ final_model_name = merge_status_data3["response"]
253
+
254
+ time.sleep(5)
255
+
256
+ generate_response = requests.post(
257
+ f"{BASE_URL}/generate",
258
+ json={
259
+ "model_name": final_model_name,
260
+ "prompt": "A spider has 8 legs. A fly has 6 legs. How many legs do they have in total?\nAnswer:",
261
+ "max_new_tokens": 50,
262
+ "temperature": 0.7,
263
+ },
264
+ )
265
+
266
+ assert generate_response.status_code == 200
267
+ generate_data = generate_response.json()
268
+ assert "task_id" in generate_data
269
+ generate_task_id = generate_data["task_id"]
270
+
271
+ final_result = await_task_completion(generate_task_id)
272
+
273
+ assert "response" in final_result
274
+ output_text = final_result["response"]
275
+ answer = get_final_answer(output_text)
276
+
277
+ assert answer == 14
278
+
279
+
280
+ def test_merge_fail():
281
+ """
282
+ Tests merging with too many layers
283
+ """
284
+
285
+ merge_repsonse = requests.post(
286
+ f"{BASE_URL}/merge",
287
+ json={
288
+ "model1_name": "svamp",
289
+ "model2_name": "tinystories",
290
+ "layer_recipe": [[(i, 0, 1.0)] for i in range(50)],
291
+ "embedding_lambdas": [1.0, 1.0],
292
+ "linear_lambdas": [1.0, 1.0],
293
+ "merged_name": "svamp_merged",
294
+ },
295
+ )
296
+
297
+ assert merge_repsonse.status_code == 200
298
+ merge_data = merge_repsonse.json()
299
+ assert "task_id" in merge_data
300
+ merge_task_id = merge_data["task_id"]
301
+
302
+ merge_status_data = await_task_completion(merge_task_id)
303
+ assert "response" not in merge_status_data
304
+ assert "error" in merge_status_data
305
+ assert "Layer recipe too long" in merge_status_data["error"]