""" Flint-1.2B Evaluation ====================== Lightweight evaluation during and after training. Supports: perplexity, few-shot benchmarks, tool-use accuracy. Usage: # During training (called automatically every eval_interval steps) python evaluate.py --checkpoint /path/to/step_XXXX --quick # Full evaluation after training python evaluate.py --checkpoint /path/to/step_XXXX --full # Evaluate specific capability python evaluate.py --checkpoint /path/to/step_XXXX --task reasoning python evaluate.py --checkpoint /path/to/step_XXXX --task tool_use """ import os import json import argparse import time from pathlib import Path from typing import Dict, List, Any import numpy as np import jax import jax.numpy as jnp # ============================================================ # EVALUATION DATASETS (lightweight, included) # ============================================================ REASONING_PROBLEMS = [ { "prompt": "What is 247 + 389? Think step by step.\n\n\n", "expected_contains": ["636"], "category": "arithmetic" }, { "prompt": "If a shirt costs $45 and is 20% off, what do you pay? Think step by step.\n\n\n", "expected_contains": ["36", "$36"], "category": "word_problem" }, { "prompt": "Complete the pattern: 2, 6, 12, 20, ?\n\n\n", "expected_contains": ["30"], "category": "pattern" }, { "prompt": "A train travels 120 miles in 2 hours. What is its speed in mph?\n\n\n", "expected_contains": ["60"], "category": "word_problem" }, { "prompt": "What is the next prime number after 23?\n\n\n", "expected_contains": ["29"], "category": "number_theory" }, ] TOOL_USE_PROBLEMS = [ { "prompt": "User: What's the weather in London?\n\n\nI need to check the weather API.\n\n\n\n", "expected_format": '{"name"', # Should produce valid JSON tool call "category": "weather" }, { "prompt": "User: Search for recent papers about quantum computing\n\n\nI should use a search function.\n\n\n\n", "expected_format": '{"name"', "category": "search" }, { "prompt": "User: Calculate the compound interest on $1000 at 5% for 3 years\n\n\nI'll use the calculator tool for precision.\n\n\n\n", "expected_format": '{"name"', "category": "calculation" }, ] HELLASWAG_SAMPLES = [ { "context": "A woman is sitting at a table. She picks up a glass of water and", "choices": [ " takes a sip from it.", " throws it at the wall.", " puts it in her pocket.", " starts singing to it.", ], "answer": 0, }, { "context": "The chef places the raw chicken in the oven and", "choices": [ " sets the timer for 45 minutes.", " immediately eats it raw.", " puts on ice skates.", " plants a tree inside the oven.", ], "answer": 0, }, ] # ============================================================ # EVALUATION FUNCTIONS # ============================================================ def compute_perplexity( forward_fn, params, texts: List[str], tokenizer, max_length: int = 512, ) -> float: """Compute perplexity on a list of texts.""" total_loss = 0.0 total_tokens = 0 for text in texts: tokens = tokenizer.encode(text, max_length=max_length, truncation=True) if len(tokens) < 2: continue input_ids = jnp.array([tokens], dtype=jnp.int32) logits = forward_fn(params, input_ids) # Compute cross-entropy shift_logits = logits[0, :-1, :] shift_labels = jnp.array(tokens[1:]) log_probs = jax.nn.log_softmax(shift_logits, axis=-1) token_losses = -log_probs[jnp.arange(len(shift_labels)), shift_labels] total_loss += float(token_losses.sum()) total_tokens += len(shift_labels) if total_tokens == 0: return float('inf') avg_loss = total_loss / total_tokens perplexity = np.exp(avg_loss) return perplexity def evaluate_reasoning( generate_fn, params, tokenizer, problems: List[Dict] = None, max_new_tokens: int = 256, ) -> Dict[str, Any]: """ Evaluate reasoning capability. Checks if the model: 1. Uses tags 2. Produces correct answers 3. Shows step-by-step reasoning """ if problems is None: problems = REASONING_PROBLEMS results = { "total": len(problems), "correct": 0, "used_think_tags": 0, "showed_steps": 0, "details": [], } for problem in problems: # Generate input_ids = tokenizer.encode(problem["prompt"], return_tensors="np") output = generate_fn(params, input_ids, max_new_tokens=max_new_tokens) generated = tokenizer.decode(output[0], skip_special_tokens=False) # Check metrics detail = { "category": problem["category"], "has_think": "" in generated, "correct": any(exp in generated for exp in problem["expected_contains"]), "length": len(generated), } if detail["has_think"]: results["used_think_tags"] += 1 if detail["correct"]: results["correct"] += 1 if "\n" in generated and len(generated.split("\n")) > 2: results["showed_steps"] += 1 detail["has_steps"] = True results["details"].append(detail) results["accuracy"] = results["correct"] / max(results["total"], 1) results["think_rate"] = results["used_think_tags"] / max(results["total"], 1) results["step_rate"] = results["showed_steps"] / max(results["total"], 1) return results def evaluate_tool_use( generate_fn, params, tokenizer, problems: List[Dict] = None, max_new_tokens: int = 128, ) -> Dict[str, Any]: """ Evaluate tool-calling capability. Checks if the model: 1. Produces valid JSON tool calls 2. Uses appropriate tool names 3. Includes required arguments """ if problems is None: problems = TOOL_USE_PROBLEMS results = { "total": len(problems), "valid_json": 0, "has_name": 0, "has_args": 0, "details": [], } for problem in problems: input_ids = tokenizer.encode(problem["prompt"], return_tensors="np") output = generate_fn(params, input_ids, max_new_tokens=max_new_tokens) generated = tokenizer.decode(output[0], skip_special_tokens=False) detail = {"category": problem["category"]} # Try to parse as JSON try: # Extract content between tags or just the generated text tool_text = generated if "" in tool_text: tool_text = tool_text.split("")[0] tool_call = json.loads(tool_text.strip()) detail["valid_json"] = True results["valid_json"] += 1 if "name" in tool_call: detail["has_name"] = True results["has_name"] += 1 if "arguments" in tool_call or "args" in tool_call: detail["has_args"] = True results["has_args"] += 1 except (json.JSONDecodeError, Exception): detail["valid_json"] = False results["details"].append(detail) results["json_rate"] = results["valid_json"] / max(results["total"], 1) results["complete_rate"] = min(results["has_name"], results["has_args"]) / max(results["total"], 1) return results def quick_eval( forward_fn, generate_fn, params, tokenizer, step: int, ) -> Dict[str, Any]: """ Quick evaluation for during-training monitoring. Runs in <30 seconds. """ print(f"\n[Eval] Quick evaluation at step {step}...") start = time.time() results = { "step": step, "timestamp": time.time(), } # Perplexity on a few samples eval_texts = [ "The quick brown fox jumps over the lazy dog.", "In mathematics, a prime number is a natural number greater than 1.", "To solve this equation, we first isolate the variable on one side.", ] ppl = compute_perplexity(forward_fn, params, eval_texts, tokenizer) results["perplexity"] = ppl # Quick reasoning check (just 2 problems) reasoning = evaluate_reasoning(generate_fn, params, tokenizer, REASONING_PROBLEMS[:2]) results["reasoning_accuracy"] = reasoning["accuracy"] results["reasoning_think_rate"] = reasoning["think_rate"] # Quick tool use check (just 1 problem) tool_use = evaluate_tool_use(generate_fn, params, tokenizer, TOOL_USE_PROBLEMS[:1]) results["tool_use_json_rate"] = tool_use["json_rate"] elapsed = time.time() - start results["eval_time_seconds"] = elapsed print(f"[Eval] PPL={ppl:.2f} | Reasoning={reasoning['accuracy']:.0%} | " f"Think={reasoning['think_rate']:.0%} | Tools={tool_use['json_rate']:.0%} | " f"Time={elapsed:.1f}s") return results def full_eval( forward_fn, generate_fn, params, tokenizer, ) -> Dict[str, Any]: """Full evaluation after training.""" print("\n[Eval] Running full evaluation...") results = {} # Reasoning (all problems) print("[Eval] Reasoning...") results["reasoning"] = evaluate_reasoning(generate_fn, params, tokenizer) # Tool use (all problems) print("[Eval] Tool use...") results["tool_use"] = evaluate_tool_use(generate_fn, params, tokenizer) # Summary print("\n" + "=" * 50) print(" FLINT-1.2B EVALUATION RESULTS") print("=" * 50) print(f" Reasoning accuracy: {results['reasoning']['accuracy']:.1%}") print(f" Think tag usage: {results['reasoning']['think_rate']:.1%}") print(f" Step-by-step rate: {results['reasoning']['step_rate']:.1%}") print(f" Tool JSON valid: {results['tool_use']['json_rate']:.1%}") print(f" Tool complete rate: {results['tool_use']['complete_rate']:.1%}") print("=" * 50) return results if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--checkpoint", required=True) parser.add_argument("--quick", action="store_true") parser.add_argument("--full", action="store_true") parser.add_argument("--task", choices=["reasoning", "tool_use", "perplexity"]) args = parser.parse_args() print(f"[Eval] Loading checkpoint: {args.checkpoint}") # In production: load checkpoint and create forward/generate functions print("[Eval] TODO: Implement checkpoint loading for standalone eval")