Kirim1 commited on
Commit
e00d2ee
·
verified ·
1 Parent(s): d5d55c9

Create evaluate.py

Browse files
Files changed (1) hide show
  1. evaluate.py +366 -0
evaluate.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Kirim-1-Math Evaluation Script
3
+ Benchmark the model on mathematical reasoning tasks
4
+ """
5
+
6
+ import torch
7
+ import json
8
+ import argparse
9
+ from typing import List, Dict, Any
10
+ from tqdm import tqdm
11
+ import time
12
+ from datetime import datetime
13
+ from inference_math import KirimMath
14
+ import re
15
+
16
+
17
+ class MathEvaluator:
18
+ """Evaluate Kirim-1-Math on mathematical benchmarks"""
19
+
20
+ def __init__(self, model_path: str = "Kirim-ai/Kirim-1-Math", load_in_4bit: bool = False):
21
+ print("Loading model for evaluation...")
22
+ self.model = KirimMath(model_path=model_path, load_in_4bit=load_in_4bit)
23
+ self.results = {}
24
+
25
+ def extract_answer(self, solution: str) -> str:
26
+ """Extract the final answer from solution"""
27
+ # Look for common answer patterns
28
+ patterns = [
29
+ r'(?:final answer|answer|solution):\s*\$?([^$\n]+)\$?',
30
+ r'=\s*([^\n]+)$',
31
+ r'\\boxed{([^}]+)}',
32
+ r'therefore[,:]?\s*([^\n]+)',
33
+ ]
34
+
35
+ for pattern in patterns:
36
+ match = re.search(pattern, solution, re.IGNORECASE)
37
+ if match:
38
+ return match.group(1).strip()
39
+
40
+ # Return last line as fallback
41
+ lines = [line.strip() for line in solution.split('\n') if line.strip()]
42
+ return lines[-1] if lines else ""
43
+
44
+ def check_answer(self, predicted: str, expected: str) -> bool:
45
+ """Check if predicted answer matches expected"""
46
+ # Normalize answers
47
+ predicted = predicted.lower().strip().replace(' ', '')
48
+ expected = expected.lower().strip().replace(' ', '')
49
+
50
+ # Direct match
51
+ if predicted == expected:
52
+ return True
53
+
54
+ # Try parsing as numbers
55
+ try:
56
+ pred_num = float(predicted.replace(',', ''))
57
+ exp_num = float(expected.replace(',', ''))
58
+ return abs(pred_num - exp_num) < 1e-6
59
+ except:
60
+ pass
61
+
62
+ # Check if expected is in predicted
63
+ return expected in predicted
64
+
65
+ def evaluate_gsm8k(self, data_path: str = None, num_samples: int = 100) -> Dict:
66
+ """Evaluate on GSM8K dataset"""
67
+ print("\n" + "="*60)
68
+ print("Evaluating GSM8K (Grade School Math)")
69
+ print("="*60)
70
+
71
+ # Sample problems (in production, load from actual dataset)
72
+ sample_problems = [
73
+ {
74
+ "question": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
75
+ "answer": "18"
76
+ },
77
+ {
78
+ "question": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
79
+ "answer": "3"
80
+ },
81
+ ]
82
+
83
+ correct = 0
84
+ total = min(len(sample_problems), num_samples)
85
+ results = []
86
+
87
+ for i, problem in enumerate(tqdm(sample_problems[:num_samples], desc="GSM8K")):
88
+ solution = self.model.solve_problem(
89
+ problem["question"],
90
+ show_work=True,
91
+ use_tools=True,
92
+ temperature=0.1
93
+ )
94
+
95
+ predicted = self.extract_answer(solution)
96
+ is_correct = self.check_answer(predicted, problem["answer"])
97
+
98
+ if is_correct:
99
+ correct += 1
100
+
101
+ results.append({
102
+ "question": problem["question"],
103
+ "expected": problem["answer"],
104
+ "predicted": predicted,
105
+ "correct": is_correct
106
+ })
107
+
108
+ accuracy = correct / total if total > 0 else 0
109
+
110
+ print(f"\nGSM8K Results:")
111
+ print(f" Correct: {correct}/{total}")
112
+ print(f" Accuracy: {accuracy:.2%}")
113
+
114
+ return {
115
+ "benchmark": "GSM8K",
116
+ "accuracy": accuracy,
117
+ "correct": correct,
118
+ "total": total,
119
+ "results": results
120
+ }
121
+
122
+ def evaluate_math_benchmark(self, num_samples: int = 50) -> Dict:
123
+ """Evaluate on MATH benchmark"""
124
+ print("\n" + "="*60)
125
+ print("Evaluating MATH Benchmark")
126
+ print("="*60)
127
+
128
+ sample_problems = [
129
+ {
130
+ "problem": "Solve for x: x^2 - 5x + 6 = 0",
131
+ "answer": "x = 2 or x = 3",
132
+ "level": 2
133
+ },
134
+ {
135
+ "problem": "Find the derivative of f(x) = x^3 + 2x^2 - x + 1",
136
+ "answer": "3x^2 + 4x - 1",
137
+ "level": 3
138
+ },
139
+ ]
140
+
141
+ correct = 0
142
+ total = min(len(sample_problems), num_samples)
143
+ results = []
144
+
145
+ for problem in tqdm(sample_problems[:num_samples], desc="MATH"):
146
+ solution = self.model.solve_problem(
147
+ problem["problem"],
148
+ show_work=True,
149
+ use_tools=True,
150
+ temperature=0.1
151
+ )
152
+
153
+ predicted = self.extract_answer(solution)
154
+ is_correct = self.check_answer(predicted, problem["answer"])
155
+
156
+ if is_correct:
157
+ correct += 1
158
+
159
+ results.append({
160
+ "problem": problem["problem"],
161
+ "level": problem["level"],
162
+ "expected": problem["answer"],
163
+ "predicted": predicted,
164
+ "correct": is_correct
165
+ })
166
+
167
+ accuracy = correct / total if total > 0 else 0
168
+
169
+ print(f"\nMATH Benchmark Results:")
170
+ print(f" Correct: {correct}/{total}")
171
+ print(f" Accuracy: {accuracy:.2%}")
172
+
173
+ return {
174
+ "benchmark": "MATH",
175
+ "accuracy": accuracy,
176
+ "correct": correct,
177
+ "total": total,
178
+ "results": results
179
+ }
180
+
181
+ def evaluate_tool_calling(self, num_samples: int = 20) -> Dict:
182
+ """Evaluate tool calling accuracy"""
183
+ print("\n" + "="*60)
184
+ print("Evaluating Tool Calling")
185
+ print("="*60)
186
+
187
+ test_cases = [
188
+ {
189
+ "problem": "Calculate 2^128 exactly",
190
+ "requires_tool": "calculator",
191
+ "expected_tool_use": True
192
+ },
193
+ {
194
+ "problem": "Simplify (x^2 - 1)/(x - 1)",
195
+ "requires_tool": "symbolic_solver",
196
+ "expected_tool_use": True
197
+ },
198
+ ]
199
+
200
+ correct_tool_selection = 0
201
+ correct_execution = 0
202
+ total = min(len(test_cases), num_samples)
203
+
204
+ for test in tqdm(test_cases[:num_samples], desc="Tool Calling"):
205
+ solution = self.model.solve_problem(
206
+ test["problem"],
207
+ use_tools=True,
208
+ temperature=0.1
209
+ )
210
+
211
+ # Check if tool was called
212
+ tool_called = "<tool_call>" in solution
213
+
214
+ if tool_called == test["expected_tool_use"]:
215
+ correct_tool_selection += 1
216
+
217
+ # Check if specific tool was used
218
+ if test.get("requires_tool") and test["requires_tool"] in solution:
219
+ correct_execution += 1
220
+
221
+ selection_accuracy = correct_tool_selection / total if total > 0 else 0
222
+ execution_accuracy = correct_execution / total if total > 0 else 0
223
+
224
+ print(f"\nTool Calling Results:")
225
+ print(f" Tool Selection: {correct_tool_selection}/{total} ({selection_accuracy:.2%})")
226
+ print(f" Correct Execution: {correct_execution}/{total} ({execution_accuracy:.2%})")
227
+
228
+ return {
229
+ "benchmark": "Tool Calling",
230
+ "selection_accuracy": selection_accuracy,
231
+ "execution_accuracy": execution_accuracy,
232
+ "total": total
233
+ }
234
+
235
+ def evaluate_bilingual(self, num_samples: int = 20) -> Dict:
236
+ """Evaluate bilingual capabilities"""
237
+ print("\n" + "="*60)
238
+ print("Evaluating Bilingual Understanding")
239
+ print("="*60)
240
+
241
+ test_cases = [
242
+ {
243
+ "problem_zh": "解方程: x^2 - 4 = 0",
244
+ "problem_en": "Solve the equation: x^2 - 4 = 0",
245
+ "answer": "x = 2 or x = -2"
246
+ },
247
+ {
248
+ "problem_zh": "计算导数: f(x) = x^3",
249
+ "problem_en": "Calculate the derivative: f(x) = x^3",
250
+ "answer": "3x^2"
251
+ },
252
+ ]
253
+
254
+ correct_zh = 0
255
+ correct_en = 0
256
+ total = min(len(test_cases), num_samples)
257
+
258
+ for test in tqdm(test_cases[:num_samples], desc="Bilingual"):
259
+ # Test Chinese
260
+ solution_zh = self.model.solve_problem(test["problem_zh"], temperature=0.1)
261
+ predicted_zh = self.extract_answer(solution_zh)
262
+ if self.check_answer(predicted_zh, test["answer"]):
263
+ correct_zh += 1
264
+
265
+ # Test English
266
+ solution_en = self.model.solve_problem(test["problem_en"], temperature=0.1)
267
+ predicted_en = self.extract_answer(solution_en)
268
+ if self.check_answer(predicted_en, test["answer"]):
269
+ correct_en += 1
270
+
271
+ accuracy_zh = correct_zh / total if total > 0 else 0
272
+ accuracy_en = correct_en / total if total > 0 else 0
273
+
274
+ print(f"\nBilingual Results:")
275
+ print(f" Chinese: {correct_zh}/{total} ({accuracy_zh:.2%})")
276
+ print(f" English: {correct_en}/{total} ({accuracy_en:.2%})")
277
+
278
+ return {
279
+ "benchmark": "Bilingual",
280
+ "chinese_accuracy": accuracy_zh,
281
+ "english_accuracy": accuracy_en,
282
+ "total": total
283
+ }
284
+
285
+ def run_full_evaluation(self, output_path: str = "evaluation_results.json"):
286
+ """Run complete evaluation suite"""
287
+ print("\n" + "="*60)
288
+ print("KIRIM-1-MATH FULL EVALUATION")
289
+ print("="*60)
290
+ print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
291
+
292
+ start_time = time.time()
293
+
294
+ # Run all benchmarks
295
+ results = {
296
+ "model": "Kirim-1-Math",
297
+ "evaluation_date": datetime.now().isoformat(),
298
+ "benchmarks": {}
299
+ }
300
+
301
+ try:
302
+ results["benchmarks"]["gsm8k"] = self.evaluate_gsm8k(num_samples=10)
303
+ except Exception as e:
304
+ print(f"GSM8K evaluation failed: {e}")
305
+
306
+ try:
307
+ results["benchmarks"]["math"] = self.evaluate_math_benchmark(num_samples=10)
308
+ except Exception as e:
309
+ print(f"MATH evaluation failed: {e}")
310
+
311
+ try:
312
+ results["benchmarks"]["tool_calling"] = self.evaluate_tool_calling(num_samples=10)
313
+ except Exception as e:
314
+ print(f"Tool calling evaluation failed: {e}")
315
+
316
+ try:
317
+ results["benchmarks"]["bilingual"] = self.evaluate_bilingual(num_samples=10)
318
+ except Exception as e:
319
+ print(f"Bilingual evaluation failed: {e}")
320
+
321
+ # Calculate overall metrics
322
+ end_time = time.time()
323
+ results["total_time_seconds"] = round(end_time - start_time, 2)
324
+
325
+ # Save results
326
+ with open(output_path, 'w', encoding='utf-8') as f:
327
+ json.dump(results, f, indent=2, ensure_ascii=False)
328
+
329
+ print("\n" + "="*60)
330
+ print("EVALUATION COMPLETE")
331
+ print("="*60)
332
+ print(f"Total time: {results['total_time_seconds']:.2f}s")
333
+ print(f"Results saved to: {output_path}")
334
+
335
+ return results
336
+
337
+
338
+ def main():
339
+ parser = argparse.ArgumentParser(description="Evaluate Kirim-1-Math")
340
+ parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math")
341
+ parser.add_argument("--load_in_4bit", action="store_true")
342
+ parser.add_argument("--benchmark", type=str, choices=["gsm8k", "math", "tools", "bilingual", "all"], default="all")
343
+ parser.add_argument("--num_samples", type=int, default=10)
344
+ parser.add_argument("--output", type=str, default="evaluation_results.json")
345
+
346
+ args = parser.parse_args()
347
+
348
+ evaluator = MathEvaluator(
349
+ model_path=args.model_path,
350
+ load_in_4bit=args.load_in_4bit
351
+ )
352
+
353
+ if args.benchmark == "all":
354
+ evaluator.run_full_evaluation(output_path=args.output)
355
+ elif args.benchmark == "gsm8k":
356
+ evaluator.evaluate_gsm8k(num_samples=args.num_samples)
357
+ elif args.benchmark == "math":
358
+ evaluator.evaluate_math_benchmark(num_samples=args.num_samples)
359
+ elif args.benchmark == "tools":
360
+ evaluator.evaluate_tool_calling(num_samples=args.num_samples)
361
+ elif args.benchmark == "bilingual":
362
+ evaluator.evaluate_bilingual(num_samples=args.num_samples)
363
+
364
+
365
+ if __name__ == "__main__":
366
+ main()