Kirim-1-Math / evaluate.py
Kirim1's picture
Create evaluate.py
e00d2ee verified
"""
Kirim-1-Math Evaluation Script
Benchmark the model on mathematical reasoning tasks
"""
import torch
import json
import argparse
from typing import List, Dict, Any
from tqdm import tqdm
import time
from datetime import datetime
from inference_math import KirimMath
import re
class MathEvaluator:
"""Evaluate Kirim-1-Math on mathematical benchmarks"""
def __init__(self, model_path: str = "Kirim-ai/Kirim-1-Math", load_in_4bit: bool = False):
print("Loading model for evaluation...")
self.model = KirimMath(model_path=model_path, load_in_4bit=load_in_4bit)
self.results = {}
def extract_answer(self, solution: str) -> str:
"""Extract the final answer from solution"""
# Look for common answer patterns
patterns = [
r'(?:final answer|answer|solution):\s*\$?([^$\n]+)\$?',
r'=\s*([^\n]+)$',
r'\\boxed{([^}]+)}',
r'therefore[,:]?\s*([^\n]+)',
]
for pattern in patterns:
match = re.search(pattern, solution, re.IGNORECASE)
if match:
return match.group(1).strip()
# Return last line as fallback
lines = [line.strip() for line in solution.split('\n') if line.strip()]
return lines[-1] if lines else ""
def check_answer(self, predicted: str, expected: str) -> bool:
"""Check if predicted answer matches expected"""
# Normalize answers
predicted = predicted.lower().strip().replace(' ', '')
expected = expected.lower().strip().replace(' ', '')
# Direct match
if predicted == expected:
return True
# Try parsing as numbers
try:
pred_num = float(predicted.replace(',', ''))
exp_num = float(expected.replace(',', ''))
return abs(pred_num - exp_num) < 1e-6
except:
pass
# Check if expected is in predicted
return expected in predicted
def evaluate_gsm8k(self, data_path: str = None, num_samples: int = 100) -> Dict:
"""Evaluate on GSM8K dataset"""
print("\n" + "="*60)
print("Evaluating GSM8K (Grade School Math)")
print("="*60)
# Sample problems (in production, load from actual dataset)
sample_problems = [
{
"question": "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?",
"answer": "18"
},
{
"question": "A robe takes 2 bolts of blue fiber and half that much white fiber. How many bolts in total does it take?",
"answer": "3"
},
]
correct = 0
total = min(len(sample_problems), num_samples)
results = []
for i, problem in enumerate(tqdm(sample_problems[:num_samples], desc="GSM8K")):
solution = self.model.solve_problem(
problem["question"],
show_work=True,
use_tools=True,
temperature=0.1
)
predicted = self.extract_answer(solution)
is_correct = self.check_answer(predicted, problem["answer"])
if is_correct:
correct += 1
results.append({
"question": problem["question"],
"expected": problem["answer"],
"predicted": predicted,
"correct": is_correct
})
accuracy = correct / total if total > 0 else 0
print(f"\nGSM8K Results:")
print(f" Correct: {correct}/{total}")
print(f" Accuracy: {accuracy:.2%}")
return {
"benchmark": "GSM8K",
"accuracy": accuracy,
"correct": correct,
"total": total,
"results": results
}
def evaluate_math_benchmark(self, num_samples: int = 50) -> Dict:
"""Evaluate on MATH benchmark"""
print("\n" + "="*60)
print("Evaluating MATH Benchmark")
print("="*60)
sample_problems = [
{
"problem": "Solve for x: x^2 - 5x + 6 = 0",
"answer": "x = 2 or x = 3",
"level": 2
},
{
"problem": "Find the derivative of f(x) = x^3 + 2x^2 - x + 1",
"answer": "3x^2 + 4x - 1",
"level": 3
},
]
correct = 0
total = min(len(sample_problems), num_samples)
results = []
for problem in tqdm(sample_problems[:num_samples], desc="MATH"):
solution = self.model.solve_problem(
problem["problem"],
show_work=True,
use_tools=True,
temperature=0.1
)
predicted = self.extract_answer(solution)
is_correct = self.check_answer(predicted, problem["answer"])
if is_correct:
correct += 1
results.append({
"problem": problem["problem"],
"level": problem["level"],
"expected": problem["answer"],
"predicted": predicted,
"correct": is_correct
})
accuracy = correct / total if total > 0 else 0
print(f"\nMATH Benchmark Results:")
print(f" Correct: {correct}/{total}")
print(f" Accuracy: {accuracy:.2%}")
return {
"benchmark": "MATH",
"accuracy": accuracy,
"correct": correct,
"total": total,
"results": results
}
def evaluate_tool_calling(self, num_samples: int = 20) -> Dict:
"""Evaluate tool calling accuracy"""
print("\n" + "="*60)
print("Evaluating Tool Calling")
print("="*60)
test_cases = [
{
"problem": "Calculate 2^128 exactly",
"requires_tool": "calculator",
"expected_tool_use": True
},
{
"problem": "Simplify (x^2 - 1)/(x - 1)",
"requires_tool": "symbolic_solver",
"expected_tool_use": True
},
]
correct_tool_selection = 0
correct_execution = 0
total = min(len(test_cases), num_samples)
for test in tqdm(test_cases[:num_samples], desc="Tool Calling"):
solution = self.model.solve_problem(
test["problem"],
use_tools=True,
temperature=0.1
)
# Check if tool was called
tool_called = "<tool_call>" in solution
if tool_called == test["expected_tool_use"]:
correct_tool_selection += 1
# Check if specific tool was used
if test.get("requires_tool") and test["requires_tool"] in solution:
correct_execution += 1
selection_accuracy = correct_tool_selection / total if total > 0 else 0
execution_accuracy = correct_execution / total if total > 0 else 0
print(f"\nTool Calling Results:")
print(f" Tool Selection: {correct_tool_selection}/{total} ({selection_accuracy:.2%})")
print(f" Correct Execution: {correct_execution}/{total} ({execution_accuracy:.2%})")
return {
"benchmark": "Tool Calling",
"selection_accuracy": selection_accuracy,
"execution_accuracy": execution_accuracy,
"total": total
}
def evaluate_bilingual(self, num_samples: int = 20) -> Dict:
"""Evaluate bilingual capabilities"""
print("\n" + "="*60)
print("Evaluating Bilingual Understanding")
print("="*60)
test_cases = [
{
"problem_zh": "解方程: x^2 - 4 = 0",
"problem_en": "Solve the equation: x^2 - 4 = 0",
"answer": "x = 2 or x = -2"
},
{
"problem_zh": "计算导数: f(x) = x^3",
"problem_en": "Calculate the derivative: f(x) = x^3",
"answer": "3x^2"
},
]
correct_zh = 0
correct_en = 0
total = min(len(test_cases), num_samples)
for test in tqdm(test_cases[:num_samples], desc="Bilingual"):
# Test Chinese
solution_zh = self.model.solve_problem(test["problem_zh"], temperature=0.1)
predicted_zh = self.extract_answer(solution_zh)
if self.check_answer(predicted_zh, test["answer"]):
correct_zh += 1
# Test English
solution_en = self.model.solve_problem(test["problem_en"], temperature=0.1)
predicted_en = self.extract_answer(solution_en)
if self.check_answer(predicted_en, test["answer"]):
correct_en += 1
accuracy_zh = correct_zh / total if total > 0 else 0
accuracy_en = correct_en / total if total > 0 else 0
print(f"\nBilingual Results:")
print(f" Chinese: {correct_zh}/{total} ({accuracy_zh:.2%})")
print(f" English: {correct_en}/{total} ({accuracy_en:.2%})")
return {
"benchmark": "Bilingual",
"chinese_accuracy": accuracy_zh,
"english_accuracy": accuracy_en,
"total": total
}
def run_full_evaluation(self, output_path: str = "evaluation_results.json"):
"""Run complete evaluation suite"""
print("\n" + "="*60)
print("KIRIM-1-MATH FULL EVALUATION")
print("="*60)
print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
start_time = time.time()
# Run all benchmarks
results = {
"model": "Kirim-1-Math",
"evaluation_date": datetime.now().isoformat(),
"benchmarks": {}
}
try:
results["benchmarks"]["gsm8k"] = self.evaluate_gsm8k(num_samples=10)
except Exception as e:
print(f"GSM8K evaluation failed: {e}")
try:
results["benchmarks"]["math"] = self.evaluate_math_benchmark(num_samples=10)
except Exception as e:
print(f"MATH evaluation failed: {e}")
try:
results["benchmarks"]["tool_calling"] = self.evaluate_tool_calling(num_samples=10)
except Exception as e:
print(f"Tool calling evaluation failed: {e}")
try:
results["benchmarks"]["bilingual"] = self.evaluate_bilingual(num_samples=10)
except Exception as e:
print(f"Bilingual evaluation failed: {e}")
# Calculate overall metrics
end_time = time.time()
results["total_time_seconds"] = round(end_time - start_time, 2)
# Save results
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(results, f, indent=2, ensure_ascii=False)
print("\n" + "="*60)
print("EVALUATION COMPLETE")
print("="*60)
print(f"Total time: {results['total_time_seconds']:.2f}s")
print(f"Results saved to: {output_path}")
return results
def main():
parser = argparse.ArgumentParser(description="Evaluate Kirim-1-Math")
parser.add_argument("--model_path", type=str, default="Kirim-ai/Kirim-1-Math")
parser.add_argument("--load_in_4bit", action="store_true")
parser.add_argument("--benchmark", type=str, choices=["gsm8k", "math", "tools", "bilingual", "all"], default="all")
parser.add_argument("--num_samples", type=int, default=10)
parser.add_argument("--output", type=str, default="evaluation_results.json")
args = parser.parse_args()
evaluator = MathEvaluator(
model_path=args.model_path,
load_in_4bit=args.load_in_4bit
)
if args.benchmark == "all":
evaluator.run_full_evaluation(output_path=args.output)
elif args.benchmark == "gsm8k":
evaluator.evaluate_gsm8k(num_samples=args.num_samples)
elif args.benchmark == "math":
evaluator.evaluate_math_benchmark(num_samples=args.num_samples)
elif args.benchmark == "tools":
evaluator.evaluate_tool_calling(num_samples=args.num_samples)
elif args.benchmark == "bilingual":
evaluator.evaluate_bilingual(num_samples=args.num_samples)
if __name__ == "__main__":
main()