Vortex-13b-V1 / science_bench.py
Zandy-Wandy's picture
Upload Vortex model
5c43f61 verified
"""
Science benchmarks for Vortex model.
Evaluates performance across 7 science domains.
"""
import torch
from typing import Dict, List, Tuple
from dataclasses import dataclass
@dataclass
class BenchmarkResult:
"""Results from a benchmark."""
domain: str
accuracy: float
total_questions: int
correct_answers: int
details: List[Dict]
class ScienceBenchmark:
"""
Base class for science benchmarks.
"""
def __init__(self, name: str, domain: str):
self.name = name
self.domain = domain
def load_questions(self) -> List[Dict]:
"""Load benchmark questions."""
raise NotImplementedError
def evaluate(
self,
model,
tokenizer,
device: torch.device,
max_samples: int = 100,
) -> BenchmarkResult:
"""
Evaluate model on benchmark.
Args:
model: Vortex model
tokenizer: Tokenizer
device: Torch device
max_samples: Maximum samples to evaluate
Returns:
BenchmarkResult
"""
questions = self.load_questions()[:max_samples]
correct = 0
details = []
for q in questions:
# Format prompt
prompt = self.format_prompt(q)
# Tokenize
inputs = tokenizer(prompt, return_tensors="pt").to(device)
# Generate answer
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=50,
temperature=0.0, # Greedy
do_sample=False,
)
# Decode
generated = tokenizer.decode(outputs[0], skip_special_tokens=True)
answer = self.extract_answer(generated)
# Check correctness
is_correct = self.check_answer(answer, q["answer"])
if is_correct:
correct += 1
details.append({
"question": q["question"],
"expected": q["answer"],
"generated": answer,
"correct": is_correct,
})
accuracy = correct / len(questions) if questions else 0.0
return BenchmarkResult(
domain=self.domain,
accuracy=accuracy,
total_questions=len(questions),
correct_answers=correct,
details=details,
)
def format_prompt(self, question: Dict) -> str:
"""Format question into prompt."""
raise NotImplementedError
def extract_answer(self, text: str) -> str:
"""Extract answer from generated text."""
raise NotImplementedError
def check_answer(self, predicted: str, expected: str) -> bool:
"""Check if predicted answer matches expected."""
raise NotImplementedError
class PhysicsBenchmark(ScienceBenchmark):
"""Physics benchmark (Feynman Questions style)."""
def __init__(self):
super().__init__("physics_benchmark", "physics")
def load_questions(self) -> List[Dict]:
# Placeholder - would load from dataset
return [
{
"question": "What is the formula for kinetic energy?",
"answer": "KE = 1/2 mv^2",
"type": "formula",
},
{
"question": "Explain Newton's first law of motion.",
"answer": "An object at rest stays at rest unless acted upon by a force.",
"type": "conceptual",
},
]
def format_prompt(self, question: Dict) -> str:
return f"Question: {question['question']}\nAnswer:"
def extract_answer(self, text: str) -> str:
# Extract after "Answer:"
if "Answer:" in text:
return text.split("Answer:")[-1].strip()
return text.strip()
def check_answer(self, predicted: str, expected: str) -> bool:
# Simple string match (would use more sophisticated in practice)
pred_lower = predicted.lower()
exp_lower = expected.lower()
return exp_lower in pred_lower or pred_lower in exp_lower
class MathBenchmark(ScienceBenchmark):
"""Math benchmark (MATH dataset style)."""
def __init__(self):
super().__init__("math_benchmark", "math")
def load_questions(self) -> List[Dict]:
return [
{
"question": "Solve for x: 2x + 5 = 15",
"answer": "x = 5",
"type": "algebra",
},
{
"question": "What is the derivative of x^2?",
"answer": "2x",
"type": "calculus",
},
]
def format_prompt(self, question: Dict) -> str:
return f"Problem: {question['question']}\nSolution:"
def extract_answer(self, text: str) -> str:
if "Solution:" in text:
return text.split("Solution:")[-1].strip()
return text.strip()
def check_answer(self, predicted: str, expected: str) -> bool:
# Normalize whitespace and case
pred = " ".join(predicted.lower().split())
exp = " ".join(expected.lower().split())
return pred == exp
class ChemistryBenchmark(ScienceBenchmark):
"""Chemistry benchmark."""
def __init__(self):
super().__init__("chemistry_benchmark", "chemistry")
def load_questions(self) -> List[Dict]:
return [
{
"question": "What is the chemical formula for water?",
"answer": "H2O",
"type": "factual",
},
{
"question": "How many protons does carbon have?",
"answer": "6",
"type": "factual",
},
]
def format_prompt(self, question: Dict) -> str:
return f"Chemistry question: {question['question']}\nAnswer:"
def extract_answer(self, text: str) -> str:
if "Answer:" in text:
return text.split("Answer:")[-1].strip()
return text.strip()
def check_answer(self, predicted: str, expected: str) -> bool:
pred = predicted.strip().lower()
exp = expected.strip().lower()
return exp in pred
class BiologyBenchmark(ScienceBenchmark):
"""Biology benchmark."""
def __init__(self):
super().__init__("biology_benchmark", "biology")
def load_questions(self) -> List[Dict]:
return [
{
"question": "What is the powerhouse of the cell?",
"answer": "mitochondria",
"type": "factual",
},
{
"question": "What molecule carries genetic information?",
"answer": "DNA",
"type": "factual",
},
]
def format_prompt(self, question: Dict) -> str:
return f"Biology: {question['question']}\nAnswer:"
def extract_answer(self, text: str) -> str:
if "Answer:" in text:
return text.split("Answer:")[-1].strip()
return text.strip()
def check_answer(self, predicted: str, expected: str) -> bool:
pred = predicted.strip().lower()
exp = expected.strip().lower()
return exp in pred
# Placeholder for other domains
class EarthScienceBenchmark(ScienceBenchmark):
def __init__(self):
super().__init__("earth_science_benchmark", "earth")
def load_questions(self) -> List[Dict]:
return []
def format_prompt(self, question: Dict) -> str:
return f"Earth Science: {question['question']}\nAnswer:"
def extract_answer(self, text: str) -> str:
return text.strip()
def check_answer(self, predicted: str, expected: str) -> bool:
return predicted.strip().lower() == expected.strip().lower()
class SpaceScienceBenchmark(ScienceBenchmark):
def __init__(self):
super().__init__("space_science_benchmark", "space")
def load_questions(self) -> List[Dict]:
return []
def format_prompt(self, question: Dict) -> str:
return f"Space Science: {question['question']}\nAnswer:"
def extract_answer(self, text: str) -> str:
return text.strip()
def check_answer(self, predicted: str, expected: str) -> bool:
return predicted.strip().lower() == expected.strip().lower()
class ZoologyBenchmark(ScienceBenchmark):
def __init__(self):
super().__init__("zoology_benchmark", "zoology")
def load_questions(self) -> List[Dict]:
return []
def format_prompt(self, question: Dict) -> str:
return f"Zoology: {question['question']}\nAnswer:"
def extract_answer(self, text: str) -> str:
return text.strip()
def check_answer(self, predicted: str, expected: str) -> bool:
return predicted.strip().lower() == expected.strip().lower()
def run_all_benchmarks(
model,
tokenizer,
device: torch.device,
max_samples_per_domain: int = 50,
) -> Dict[str, BenchmarkResult]:
"""
Run all benchmarks and return results.
Args:
model: Vortex model
tokenizer: Tokenizer
device: Torch device
max_samples_per_domain: Max samples per domain
Returns:
Dictionary mapping domain to results
"""
benchmarks = [
PhysicsBenchmark(),
MathBenchmark(),
ChemistryBenchmark(),
BiologyBenchmark(),
EarthScienceBenchmark(),
SpaceScienceBenchmark(),
ZoologyBenchmark(),
]
results = {}
for bench in benchmarks:
print(f"Running {bench.name}...")
result = bench.evaluate(model, tokenizer, device, max_samples=max_samples_per_domain)
results[bench.domain] = result
print(f" Accuracy: {result.accuracy:.2%} ({result.correct_answers}/{result.total_questions})")
return results
def print_summary(results: Dict[str, BenchmarkResult]):
"""Print summary of benchmark results."""
print("\n" + "="*60)
print("BENCHMARK RESULTS")
print("="*60)
for domain, result in results.items():
print(f"{domain:15} {result.accuracy:6.2%} ({result.correct_answers}/{result.total_questions})")
# Overall average
all_accuracies = [r.accuracy for r in results.values() if r.total_questions > 0]
if all_accuracies:
avg = sum(all_accuracies) / len(all_accuracies)
print(f"{'OVERALL':15} {avg:6.2%}")
print("="*60)
if __name__ == "__main__":
# Quick test
print("This script benchmarks the model across science domains.")
print("To run full benchmarks, integrate with a trained model.")