|
|
""" |
|
|
Kirim-1-Math Evaluation Script |
|
|
Benchmark the model on mathematical reasoning tasks |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import json |
|
|
import argparse |
|
|
from typing import List, Dict, Any |
|
|
from tqdm import tqdm |
|
|
import time |
|
|
from datetime import datetime |
|
|
from inference_math import KirimMath |
|
|
import re |
|
|
|
|
|
|
|
|
class MathEvaluator: |
|
|
"""Evaluate Kirim-1-Math on mathematical benchmarks""" |
|
|
|
|
|
def __init__(self, model_path: str = "Kirim-ai/Kirim-1-Math", load_in_4bit: bool = False): |
|
|
print("Loading model for evaluation...") |
|
|
self.model = KirimMath(model_path=model_path, load_in_4bit=load_in_4bit) |
|
|
self.results = {} |
|
|
|
|
|
def extract_answer(self, solution: str) -> str: |
|
|
"""Extract the final answer from solution""" |
|
|
|
|
|
patterns = [ |
|
|
r'(?:final answer|answer|solution):\s*\$?([^$\n]+)\$?', |
|
|
r'=\s*([^\n]+)$', |
|
|
r'\\boxed{([^}]+)}', |
|
|
r'therefore[,:]?\s*([^\n]+)', |
|
|
] |
|
|
|
|
|
for pattern in patterns: |
|
|
match = re.search(pattern, solution, re.IGNORECASE) |
|
|
if match: |
|
|
return match.group(1).strip() |
|
|
|
|
|
|
|
|
lines = [line.strip() for line in solution.split('\n') if line.strip()] |
|
|
return lines[-1] if lines else "" |
|
|
|
|
|
def check_answer(self, predicted: str, expected: str) -> bool: |
|
|
"""Check if predicted answer matches expected""" |
|
|
|
|
|
predicted = predicted.lower().strip().replace(' ', '') |
|
|
expected = expected.lower().strip().replace(' ', '') |
|
|
|
|
|
|
|
|
if predicted == expected: |
|
|
return True |
|
|
|
|
|
|
|
|
try: |
|
|
pred_num = float(predicted.replace(',', '')) |
|
|
exp_num = float(expected.replace(',', '')) |
|
|
return abs(pred_num - exp_num) < 1e-6 |
|
|
except: |
|
|
pass |
|
|
|
|
|
|
|
|
return expected in predicted |
|
|
|
|
|
def evaluate_gsm8k(self, data_path: str = None, num_samples: int = 100) -> Dict: |
|
|
"""Evaluate on GSM8K dataset""" |
|
|
print("\n" + "="*60) |
|
|
print("Evaluating GSM8K (Grade School Math)") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
sample_problems = [ |
|
|
{ |
|
|
"question": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?", |
|
|
"answer": "18" |
|
|
}, |
|
|
{ |
|
|
"question": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?", |
|
|
"answer": "3" |
|
|
}, |
|
|
] |
|
|
|
|
|
correct = 0 |
|
|
total = min(len(sample_problems), num_samples) |
|
|
results = [] |
|
|
|
|
|
for i, problem in enumerate(tqdm(sample_problems[:num_samples], desc="GSM8K")): |
|
|
solution = self.model.solve_problem( |
|
|
problem["question"], |
|
|
show_work=True, |
|
|
use_tools=True, |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
predicted = self.extract_answer(solution) |
|
|
is_correct = self.check_answer(predicted, problem["answer"]) |
|
|
|
|
|
if is_correct: |
|
|
correct += 1 |
|
|
|
|
|
results.append({ |
|
|
"question": problem["question"], |
|
|
"expected": problem["answer"], |
|
|
"predicted": predicted, |
|
|
"correct": is_correct |
|
|
}) |
|
|
|
|
|
accuracy = correct / total if total > 0 else 0 |
|
|
|
|
|
print(f"\nGSM8K Results:") |
|
|
print(f" Correct: {correct}/{total}") |
|
|
print(f" Accuracy: {accuracy:.2%}") |
|
|
|
|
|
return { |
|
|
"benchmark": "GSM8K", |
|
|
"accuracy": accuracy, |
|
|
"correct": correct, |
|
|
"total": total, |
|
|
"results": results |
|
|
} |
|
|
|
|
|
def evaluate_math_benchmark(self, num_samples: int = 50) -> Dict: |
|
|
"""Evaluate on MATH benchmark""" |
|
|
print("\n" + "="*60) |
|
|
print("Evaluating MATH Benchmark") |
|
|
print("="*60) |
|
|
|
|
|
sample_problems = [ |
|
|
{ |
|
|
"problem": "Solve for x: x^2 - 5x + 6 = 0", |
|
|
"answer": "x = 2 or x = 3", |
|
|
"level": 2 |
|
|
}, |
|
|
{ |
|
|
"problem": "Find the derivative of f(x) = x^3 + 2x^2 - x + 1", |
|
|
"answer": "3x^2 + 4x - 1", |
|
|
"level": 3 |
|
|
}, |
|
|
] |
|
|
|
|
|
correct = 0 |
|
|
total = min(len(sample_problems), num_samples) |
|
|
results = [] |
|
|
|
|
|
for problem in tqdm(sample_problems[:num_samples], desc="MATH"): |
|
|
solution = self.model.solve_problem( |
|
|
problem["problem"], |
|
|
show_work=True, |
|
|
use_tools=True, |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
predicted = self.extract_answer(solution) |
|
|
is_correct = self.check_answer(predicted, problem["answer"]) |
|
|
|
|
|
if is_correct: |
|
|
correct += 1 |
|
|
|
|
|
results.append({ |
|
|
"problem": problem["problem"], |
|
|
"level": problem["level"], |
|
|
"expected": problem["answer"], |
|
|
"predicted": predicted, |
|
|
"correct": is_correct |
|
|
}) |
|
|
|
|
|
accuracy = correct / total if total > 0 else 0 |
|
|
|
|
|
print(f"\nMATH Benchmark Results:") |
|
|
print(f" Correct: {correct}/{total}") |
|
|
print(f" Accuracy: {accuracy:.2%}") |
|
|
|
|
|
return { |
|
|
"benchmark": "MATH", |
|
|
"accuracy": accuracy, |
|
|
"correct": correct, |
|
|
"total": total, |
|
|
"results": results |
|
|
} |
|
|
|
|
|
def evaluate_tool_calling(self, num_samples: int = 20) -> Dict: |
|
|
"""Evaluate tool calling accuracy""" |
|
|
print("\n" + "="*60) |
|
|
print("Evaluating Tool Calling") |
|
|
print("="*60) |
|
|
|
|
|
test_cases = [ |
|
|
{ |
|
|
"problem": "Calculate 2^128 exactly", |
|
|
"requires_tool": "calculator", |
|
|
"expected_tool_use": True |
|
|
}, |
|
|
{ |
|
|
"problem": "Simplify (x^2 - 1)/(x - 1)", |
|
|
"requires_tool": "symbolic_solver", |
|
|
"expected_tool_use": True |
|
|
}, |
|
|
] |
|
|
|
|
|
correct_tool_selection = 0 |
|
|
correct_execution = 0 |
|
|
total = min(len(test_cases), num_samples) |
|
|
|
|
|
for test in tqdm(test_cases[:num_samples], desc="Tool Calling"): |
|
|
solution = self.model.solve_problem( |
|
|
test["problem"], |
|
|
use_tools=True, |
|
|
temperature=0.1 |
|
|
) |
|
|
|
|
|
|
|
|
tool_called = "<tool_call>" in solution |
|
|
|
|
|
if tool_called == test["expected_tool_use"]: |
|
|
correct_tool_selection += 1 |
|
|
|
|
|
|
|
|
if test.get("requires_tool") and test["requires_tool"] in solution: |
|
|
correct_execution += 1 |
|
|
|
|
|
selection_accuracy = correct_tool_selection / total if total > 0 else 0 |
|
|
execution_accuracy = correct_execution / total if total > 0 else 0 |
|
|
|
|
|
print(f"\nTool Calling Results:") |
|
|
print(f" Tool Selection: {correct_tool_selection}/{total} ({selection_accuracy:.2%})") |
|
|
print(f" Correct Execution: {correct_execution}/{total} ({execution_accuracy:.2%})") |
|
|
|
|
|
return { |
|
|
"benchmark": "Tool Calling", |
|
|
"selection_accuracy": selection_accuracy, |
|
|
"execution_accuracy": execution_accuracy, |
|
|
"total": total |
|
|
} |
|
|
|
|
|
def evaluate_bilingual(self, num_samples: int = 20) -> Dict: |
|
|
"""Evaluate bilingual capabilities""" |
|
|
print("\n" + "="*60) |
|
|
print("Evaluating Bilingual Understanding") |
|
|
print("="*60) |
|
|
|
|
|
test_cases = [ |
|
|
{ |
|
|
"problem_zh": "解方程: x^2 - 4 = 0", |
|
|
"problem_en": "Solve the equation: x^2 - 4 = 0", |
|
|
"answer": "x = 2 or x = -2" |
|
|
}, |
|
|
{ |
|
|
"problem_zh": "计算导数: f(x) = x^3", |
|
|
"problem_en": "Calculate the derivative: f(x) = x^3", |
|
|
"answer": "3x^2" |
|
|
}, |
|
|
] |
|
|
|
|
|
correct_zh = 0 |
|
|
correct_en = 0 |
|
|
total = min(len(test_cases), num_samples) |
|
|
|
|
|
for test in tqdm(test_cases[:num_samples], desc="Bilingual"): |
|
|
|
|
|
solution_zh = self.model.solve_problem(test["problem_zh"], temperature=0.1) |
|
|
predicted_zh = self.extract_answer(solution_zh) |
|
|
if self.check_answer(predicted_zh, test["answer"]): |
|
|
correct_zh += 1 |
|
|
|
|
|
|
|
|
solution_en = self.model.solve_problem(test["problem_en"], temperature=0.1) |
|
|
predicted_en = self.extract_answer(solution_en) |
|
|
if self.check_answer(predicted_en, test["answer"]): |
|
|
correct_en += 1 |
|
|
|
|
|
accuracy_zh = correct_zh / total if total > 0 else 0 |
|
|
accuracy_en = correct_en / total if total > 0 else 0 |
|
|
|
|
|
print(f"\nBilingual Results:") |
|
|
print(f" Chinese: {correct_zh}/{total} ({accuracy_zh:.2%})") |
|
|
print(f" English: {correct_en}/{total} ({accuracy_en:.2%})") |
|
|
|
|
|
return { |
|
|
"benchmark": "Bilingual", |
|
|
"chinese_accuracy": accuracy_zh, |
|
|
"english_accuracy": accuracy_en, |
|
|
"total": total |
|
|
} |
|
|
|
|
|
def run_full_evaluation(self, output_path: str = "evaluation_results.json"): |
|
|
"""Run complete evaluation suite""" |
|
|
print("\n" + "="*60) |
|
|
print("KIRIM-1-MATH FULL EVALUATION") |
|
|
print("="*60) |
|
|
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
results = { |
|
|
"model": "Kirim-1-Math", |
|
|
"evaluation_date": datetime.now().isoformat(), |
|
|
"benchmarks": {} |
|
|
} |
|
|
|
|
|
try: |
|
|
results["benchmarks"]["gsm8k"] = self.evaluate_gsm8k(num_samples=10) |
|
|
except Exception as e: |
|
|
print(f"GSM8K evaluation failed: {e}") |
|
|
|
|
|
try: |
|
|
results["benchmarks"]["math"] = self.evaluate_math_benchmark(num_samples=10) |
|
|
except Exception as e: |
|
|
print(f"MATH evaluation failed: {e}") |
|
|
|
|
|
try: |
|
|
results["benchmarks"]["tool_calling"] = self.evaluate_tool_calling(num_samples=10) |
|
|
except Exception as e: |
|
|
print(f"Tool calling evaluation failed: {e}") |
|
|
|
|
|
try: |
|
|
results["benchmarks"]["bilingual"] = self.evaluate_bilingual(num_samples=10) |
|
|
except Exception as e: |
|
|
print(f"Bilingual evaluation failed: {e}") |
|
|
|
|
|
|
|
|
end_time = time.time() |
|
|
results["total_time_seconds"] = round(end_time - start_time, 2) |
|
|
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(results, f, indent=2, ensure_ascii=False) |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("EVALUATION COMPLETE") |
|
|
print("="*60) |
|
|
print(f"Total time: {results['total_time_seconds']:.2f}s") |
|
|
print(f"Results saved to: {output_path}") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Evaluate Kirim-1-Math") |
|
|
parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math") |
|
|
parser.add_argument("--load_in_4bit", action="store_true") |
|
|
parser.add_argument("--benchmark", type=str, choices=["gsm8k", "math", "tools", "bilingual", "all"], default="all") |
|
|
parser.add_argument("--num_samples", type=int, default=10) |
|
|
parser.add_argument("--output", type=str, default="evaluation_results.json") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
evaluator = MathEvaluator( |
|
|
model_path=args.model_path, |
|
|
load_in_4bit=args.load_in_4bit |
|
|
) |
|
|
|
|
|
if args.benchmark == "all": |
|
|
evaluator.run_full_evaluation(output_path=args.output) |
|
|
elif args.benchmark == "gsm8k": |
|
|
evaluator.evaluate_gsm8k(num_samples=args.num_samples) |
|
|
elif args.benchmark == "math": |
|
|
evaluator.evaluate_math_benchmark(num_samples=args.num_samples) |
|
|
elif args.benchmark == "tools": |
|
|
evaluator.evaluate_tool_calling(num_samples=args.num_samples) |
|
|
elif args.benchmark == "bilingual": |
|
|
evaluator.evaluate_bilingual(num_samples=args.num_samples) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |