flint-1.2B / evaluate.py
tekkmaven's picture
Add evaluation script for benchmarking
945e942 verified
"""
Flint-1.2B Evaluation
======================
Lightweight evaluation during and after training.
Supports: perplexity, few-shot benchmarks, tool-use accuracy.
Usage:
# During training (called automatically every eval_interval steps)
python evaluate.py --checkpoint /path/to/step_XXXX --quick
# Full evaluation after training
python evaluate.py --checkpoint /path/to/step_XXXX --full
# Evaluate specific capability
python evaluate.py --checkpoint /path/to/step_XXXX --task reasoning
python evaluate.py --checkpoint /path/to/step_XXXX --task tool_use
"""
import os
import json
import argparse
import time
from pathlib import Path
from typing import Dict, List, Any
import numpy as np
import jax
import jax.numpy as jnp
# ============================================================
# EVALUATION DATASETS (lightweight, included)
# ============================================================
REASONING_PROBLEMS = [
{
"prompt": "What is 247 + 389? Think step by step.\n\n<think>\n",
"expected_contains": ["636"],
"category": "arithmetic"
},
{
"prompt": "If a shirt costs $45 and is 20% off, what do you pay? Think step by step.\n\n<think>\n",
"expected_contains": ["36", "$36"],
"category": "word_problem"
},
{
"prompt": "Complete the pattern: 2, 6, 12, 20, ?\n\n<think>\n",
"expected_contains": ["30"],
"category": "pattern"
},
{
"prompt": "A train travels 120 miles in 2 hours. What is its speed in mph?\n\n<think>\n",
"expected_contains": ["60"],
"category": "word_problem"
},
{
"prompt": "What is the next prime number after 23?\n\n<think>\n",
"expected_contains": ["29"],
"category": "number_theory"
},
]
TOOL_USE_PROBLEMS = [
{
"prompt": "User: What's the weather in London?\n\n<think>\nI need to check the weather API.\n</think>\n\n<tool_call>\n",
"expected_format": '{"name"', # Should produce valid JSON tool call
"category": "weather"
},
{
"prompt": "User: Search for recent papers about quantum computing\n\n<think>\nI should use a search function.\n</think>\n\n<tool_call>\n",
"expected_format": '{"name"',
"category": "search"
},
{
"prompt": "User: Calculate the compound interest on $1000 at 5% for 3 years\n\n<think>\nI'll use the calculator tool for precision.\n</think>\n\n<tool_call>\n",
"expected_format": '{"name"',
"category": "calculation"
},
]
HELLASWAG_SAMPLES = [
{
"context": "A woman is sitting at a table. She picks up a glass of water and",
"choices": [
" takes a sip from it.",
" throws it at the wall.",
" puts it in her pocket.",
" starts singing to it.",
],
"answer": 0,
},
{
"context": "The chef places the raw chicken in the oven and",
"choices": [
" sets the timer for 45 minutes.",
" immediately eats it raw.",
" puts on ice skates.",
" plants a tree inside the oven.",
],
"answer": 0,
},
]
# ============================================================
# EVALUATION FUNCTIONS
# ============================================================
def compute_perplexity(
forward_fn,
params,
texts: List[str],
tokenizer,
max_length: int = 512,
) -> float:
"""Compute perplexity on a list of texts."""
total_loss = 0.0
total_tokens = 0
for text in texts:
tokens = tokenizer.encode(text, max_length=max_length, truncation=True)
if len(tokens) < 2:
continue
input_ids = jnp.array([tokens], dtype=jnp.int32)
logits = forward_fn(params, input_ids)
# Compute cross-entropy
shift_logits = logits[0, :-1, :]
shift_labels = jnp.array(tokens[1:])
log_probs = jax.nn.log_softmax(shift_logits, axis=-1)
token_losses = -log_probs[jnp.arange(len(shift_labels)), shift_labels]
total_loss += float(token_losses.sum())
total_tokens += len(shift_labels)
if total_tokens == 0:
return float('inf')
avg_loss = total_loss / total_tokens
perplexity = np.exp(avg_loss)
return perplexity
def evaluate_reasoning(
generate_fn,
params,
tokenizer,
problems: List[Dict] = None,
max_new_tokens: int = 256,
) -> Dict[str, Any]:
"""
Evaluate reasoning capability.
Checks if the model:
1. Uses <think> tags
2. Produces correct answers
3. Shows step-by-step reasoning
"""
if problems is None:
problems = REASONING_PROBLEMS
results = {
"total": len(problems),
"correct": 0,
"used_think_tags": 0,
"showed_steps": 0,
"details": [],
}
for problem in problems:
# Generate
input_ids = tokenizer.encode(problem["prompt"], return_tensors="np")
output = generate_fn(params, input_ids, max_new_tokens=max_new_tokens)
generated = tokenizer.decode(output[0], skip_special_tokens=False)
# Check metrics
detail = {
"category": problem["category"],
"has_think": "</think>" in generated,
"correct": any(exp in generated for exp in problem["expected_contains"]),
"length": len(generated),
}
if detail["has_think"]:
results["used_think_tags"] += 1
if detail["correct"]:
results["correct"] += 1
if "\n" in generated and len(generated.split("\n")) > 2:
results["showed_steps"] += 1
detail["has_steps"] = True
results["details"].append(detail)
results["accuracy"] = results["correct"] / max(results["total"], 1)
results["think_rate"] = results["used_think_tags"] / max(results["total"], 1)
results["step_rate"] = results["showed_steps"] / max(results["total"], 1)
return results
def evaluate_tool_use(
generate_fn,
params,
tokenizer,
problems: List[Dict] = None,
max_new_tokens: int = 128,
) -> Dict[str, Any]:
"""
Evaluate tool-calling capability.
Checks if the model:
1. Produces valid JSON tool calls
2. Uses appropriate tool names
3. Includes required arguments
"""
if problems is None:
problems = TOOL_USE_PROBLEMS
results = {
"total": len(problems),
"valid_json": 0,
"has_name": 0,
"has_args": 0,
"details": [],
}
for problem in problems:
input_ids = tokenizer.encode(problem["prompt"], return_tensors="np")
output = generate_fn(params, input_ids, max_new_tokens=max_new_tokens)
generated = tokenizer.decode(output[0], skip_special_tokens=False)
detail = {"category": problem["category"]}
# Try to parse as JSON
try:
# Extract content between <tool_call> tags or just the generated text
tool_text = generated
if "</tool_call>" in tool_text:
tool_text = tool_text.split("</tool_call>")[0]
tool_call = json.loads(tool_text.strip())
detail["valid_json"] = True
results["valid_json"] += 1
if "name" in tool_call:
detail["has_name"] = True
results["has_name"] += 1
if "arguments" in tool_call or "args" in tool_call:
detail["has_args"] = True
results["has_args"] += 1
except (json.JSONDecodeError, Exception):
detail["valid_json"] = False
results["details"].append(detail)
results["json_rate"] = results["valid_json"] / max(results["total"], 1)
results["complete_rate"] = min(results["has_name"], results["has_args"]) / max(results["total"], 1)
return results
def quick_eval(
forward_fn,
generate_fn,
params,
tokenizer,
step: int,
) -> Dict[str, Any]:
"""
Quick evaluation for during-training monitoring.
Runs in <30 seconds.
"""
print(f"\n[Eval] Quick evaluation at step {step}...")
start = time.time()
results = {
"step": step,
"timestamp": time.time(),
}
# Perplexity on a few samples
eval_texts = [
"The quick brown fox jumps over the lazy dog.",
"In mathematics, a prime number is a natural number greater than 1.",
"To solve this equation, we first isolate the variable on one side.",
]
ppl = compute_perplexity(forward_fn, params, eval_texts, tokenizer)
results["perplexity"] = ppl
# Quick reasoning check (just 2 problems)
reasoning = evaluate_reasoning(generate_fn, params, tokenizer, REASONING_PROBLEMS[:2])
results["reasoning_accuracy"] = reasoning["accuracy"]
results["reasoning_think_rate"] = reasoning["think_rate"]
# Quick tool use check (just 1 problem)
tool_use = evaluate_tool_use(generate_fn, params, tokenizer, TOOL_USE_PROBLEMS[:1])
results["tool_use_json_rate"] = tool_use["json_rate"]
elapsed = time.time() - start
results["eval_time_seconds"] = elapsed
print(f"[Eval] PPL={ppl:.2f} | Reasoning={reasoning['accuracy']:.0%} | "
f"Think={reasoning['think_rate']:.0%} | Tools={tool_use['json_rate']:.0%} | "
f"Time={elapsed:.1f}s")
return results
def full_eval(
forward_fn,
generate_fn,
params,
tokenizer,
) -> Dict[str, Any]:
"""Full evaluation after training."""
print("\n[Eval] Running full evaluation...")
results = {}
# Reasoning (all problems)
print("[Eval] Reasoning...")
results["reasoning"] = evaluate_reasoning(generate_fn, params, tokenizer)
# Tool use (all problems)
print("[Eval] Tool use...")
results["tool_use"] = evaluate_tool_use(generate_fn, params, tokenizer)
# Summary
print("\n" + "=" * 50)
print(" FLINT-1.2B EVALUATION RESULTS")
print("=" * 50)
print(f" Reasoning accuracy: {results['reasoning']['accuracy']:.1%}")
print(f" Think tag usage: {results['reasoning']['think_rate']:.1%}")
print(f" Step-by-step rate: {results['reasoning']['step_rate']:.1%}")
print(f" Tool JSON valid: {results['tool_use']['json_rate']:.1%}")
print(f" Tool complete rate: {results['tool_use']['complete_rate']:.1%}")
print("=" * 50)
return results
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--checkpoint", required=True)
parser.add_argument("--quick", action="store_true")
parser.add_argument("--full", action="store_true")
parser.add_argument("--task", choices=["reasoning", "tool_use", "perplexity"])
args = parser.parse_args()
print(f"[Eval] Loading checkpoint: {args.checkpoint}")
# In production: load checkpoint and create forward/generate functions
print("[Eval] TODO: Implement checkpoint loading for standalone eval")