| """ |
| Flint-1.2B Evaluation |
| ====================== |
| |
| Lightweight evaluation during and after training. |
| Supports: perplexity, few-shot benchmarks, tool-use accuracy. |
| |
| Usage: |
| # During training (called automatically every eval_interval steps) |
| python evaluate.py --checkpoint /path/to/step_XXXX --quick |
| |
| # Full evaluation after training |
| python evaluate.py --checkpoint /path/to/step_XXXX --full |
| |
| # Evaluate specific capability |
| python evaluate.py --checkpoint /path/to/step_XXXX --task reasoning |
| python evaluate.py --checkpoint /path/to/step_XXXX --task tool_use |
| """ |
|
|
| import os |
| import json |
| import argparse |
| import time |
| from pathlib import Path |
| from typing import Dict, List, Any |
|
|
| import numpy as np |
| import jax |
| import jax.numpy as jnp |
|
|
|
|
| |
| |
| |
|
|
| REASONING_PROBLEMS = [ |
| { |
| "prompt": "What is 247 + 389? Think step by step.\n\n<think>\n", |
| "expected_contains": ["636"], |
| "category": "arithmetic" |
| }, |
| { |
| "prompt": "If a shirt costs $45 and is 20% off, what do you pay? Think step by step.\n\n<think>\n", |
| "expected_contains": ["36", "$36"], |
| "category": "word_problem" |
| }, |
| { |
| "prompt": "Complete the pattern: 2, 6, 12, 20, ?\n\n<think>\n", |
| "expected_contains": ["30"], |
| "category": "pattern" |
| }, |
| { |
| "prompt": "A train travels 120 miles in 2 hours. What is its speed in mph?\n\n<think>\n", |
| "expected_contains": ["60"], |
| "category": "word_problem" |
| }, |
| { |
| "prompt": "What is the next prime number after 23?\n\n<think>\n", |
| "expected_contains": ["29"], |
| "category": "number_theory" |
| }, |
| ] |
|
|
| TOOL_USE_PROBLEMS = [ |
| { |
| "prompt": "User: What's the weather in London?\n\n<think>\nI need to check the weather API.\n</think>\n\n<tool_call>\n", |
| "expected_format": '{"name"', |
| "category": "weather" |
| }, |
| { |
| "prompt": "User: Search for recent papers about quantum computing\n\n<think>\nI should use a search function.\n</think>\n\n<tool_call>\n", |
| "expected_format": '{"name"', |
| "category": "search" |
| }, |
| { |
| "prompt": "User: Calculate the compound interest on $1000 at 5% for 3 years\n\n<think>\nI'll use the calculator tool for precision.\n</think>\n\n<tool_call>\n", |
| "expected_format": '{"name"', |
| "category": "calculation" |
| }, |
| ] |
|
|
| HELLASWAG_SAMPLES = [ |
| { |
| "context": "A woman is sitting at a table. She picks up a glass of water and", |
| "choices": [ |
| " takes a sip from it.", |
| " throws it at the wall.", |
| " puts it in her pocket.", |
| " starts singing to it.", |
| ], |
| "answer": 0, |
| }, |
| { |
| "context": "The chef places the raw chicken in the oven and", |
| "choices": [ |
| " sets the timer for 45 minutes.", |
| " immediately eats it raw.", |
| " puts on ice skates.", |
| " plants a tree inside the oven.", |
| ], |
| "answer": 0, |
| }, |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def compute_perplexity( |
| forward_fn, |
| params, |
| texts: List[str], |
| tokenizer, |
| max_length: int = 512, |
| ) -> float: |
| """Compute perplexity on a list of texts.""" |
| total_loss = 0.0 |
| total_tokens = 0 |
| |
| for text in texts: |
| tokens = tokenizer.encode(text, max_length=max_length, truncation=True) |
| if len(tokens) < 2: |
| continue |
| |
| input_ids = jnp.array([tokens], dtype=jnp.int32) |
| logits = forward_fn(params, input_ids) |
| |
| |
| shift_logits = logits[0, :-1, :] |
| shift_labels = jnp.array(tokens[1:]) |
| |
| log_probs = jax.nn.log_softmax(shift_logits, axis=-1) |
| token_losses = -log_probs[jnp.arange(len(shift_labels)), shift_labels] |
| |
| total_loss += float(token_losses.sum()) |
| total_tokens += len(shift_labels) |
| |
| if total_tokens == 0: |
| return float('inf') |
| |
| avg_loss = total_loss / total_tokens |
| perplexity = np.exp(avg_loss) |
| return perplexity |
|
|
|
|
| def evaluate_reasoning( |
| generate_fn, |
| params, |
| tokenizer, |
| problems: List[Dict] = None, |
| max_new_tokens: int = 256, |
| ) -> Dict[str, Any]: |
| """ |
| Evaluate reasoning capability. |
| |
| Checks if the model: |
| 1. Uses <think> tags |
| 2. Produces correct answers |
| 3. Shows step-by-step reasoning |
| """ |
| if problems is None: |
| problems = REASONING_PROBLEMS |
| |
| results = { |
| "total": len(problems), |
| "correct": 0, |
| "used_think_tags": 0, |
| "showed_steps": 0, |
| "details": [], |
| } |
| |
| for problem in problems: |
| |
| input_ids = tokenizer.encode(problem["prompt"], return_tensors="np") |
| output = generate_fn(params, input_ids, max_new_tokens=max_new_tokens) |
| generated = tokenizer.decode(output[0], skip_special_tokens=False) |
| |
| |
| detail = { |
| "category": problem["category"], |
| "has_think": "</think>" in generated, |
| "correct": any(exp in generated for exp in problem["expected_contains"]), |
| "length": len(generated), |
| } |
| |
| if detail["has_think"]: |
| results["used_think_tags"] += 1 |
| if detail["correct"]: |
| results["correct"] += 1 |
| if "\n" in generated and len(generated.split("\n")) > 2: |
| results["showed_steps"] += 1 |
| detail["has_steps"] = True |
| |
| results["details"].append(detail) |
| |
| results["accuracy"] = results["correct"] / max(results["total"], 1) |
| results["think_rate"] = results["used_think_tags"] / max(results["total"], 1) |
| results["step_rate"] = results["showed_steps"] / max(results["total"], 1) |
| |
| return results |
|
|
|
|
| def evaluate_tool_use( |
| generate_fn, |
| params, |
| tokenizer, |
| problems: List[Dict] = None, |
| max_new_tokens: int = 128, |
| ) -> Dict[str, Any]: |
| """ |
| Evaluate tool-calling capability. |
| |
| Checks if the model: |
| 1. Produces valid JSON tool calls |
| 2. Uses appropriate tool names |
| 3. Includes required arguments |
| """ |
| if problems is None: |
| problems = TOOL_USE_PROBLEMS |
| |
| results = { |
| "total": len(problems), |
| "valid_json": 0, |
| "has_name": 0, |
| "has_args": 0, |
| "details": [], |
| } |
| |
| for problem in problems: |
| input_ids = tokenizer.encode(problem["prompt"], return_tensors="np") |
| output = generate_fn(params, input_ids, max_new_tokens=max_new_tokens) |
| generated = tokenizer.decode(output[0], skip_special_tokens=False) |
| |
| detail = {"category": problem["category"]} |
| |
| |
| try: |
| |
| tool_text = generated |
| if "</tool_call>" in tool_text: |
| tool_text = tool_text.split("</tool_call>")[0] |
| |
| tool_call = json.loads(tool_text.strip()) |
| detail["valid_json"] = True |
| results["valid_json"] += 1 |
| |
| if "name" in tool_call: |
| detail["has_name"] = True |
| results["has_name"] += 1 |
| |
| if "arguments" in tool_call or "args" in tool_call: |
| detail["has_args"] = True |
| results["has_args"] += 1 |
| |
| except (json.JSONDecodeError, Exception): |
| detail["valid_json"] = False |
| |
| results["details"].append(detail) |
| |
| results["json_rate"] = results["valid_json"] / max(results["total"], 1) |
| results["complete_rate"] = min(results["has_name"], results["has_args"]) / max(results["total"], 1) |
| |
| return results |
|
|
|
|
| def quick_eval( |
| forward_fn, |
| generate_fn, |
| params, |
| tokenizer, |
| step: int, |
| ) -> Dict[str, Any]: |
| """ |
| Quick evaluation for during-training monitoring. |
| Runs in <30 seconds. |
| """ |
| print(f"\n[Eval] Quick evaluation at step {step}...") |
| start = time.time() |
| |
| results = { |
| "step": step, |
| "timestamp": time.time(), |
| } |
| |
| |
| eval_texts = [ |
| "The quick brown fox jumps over the lazy dog.", |
| "In mathematics, a prime number is a natural number greater than 1.", |
| "To solve this equation, we first isolate the variable on one side.", |
| ] |
| ppl = compute_perplexity(forward_fn, params, eval_texts, tokenizer) |
| results["perplexity"] = ppl |
| |
| |
| reasoning = evaluate_reasoning(generate_fn, params, tokenizer, REASONING_PROBLEMS[:2]) |
| results["reasoning_accuracy"] = reasoning["accuracy"] |
| results["reasoning_think_rate"] = reasoning["think_rate"] |
| |
| |
| tool_use = evaluate_tool_use(generate_fn, params, tokenizer, TOOL_USE_PROBLEMS[:1]) |
| results["tool_use_json_rate"] = tool_use["json_rate"] |
| |
| elapsed = time.time() - start |
| results["eval_time_seconds"] = elapsed |
| |
| print(f"[Eval] PPL={ppl:.2f} | Reasoning={reasoning['accuracy']:.0%} | " |
| f"Think={reasoning['think_rate']:.0%} | Tools={tool_use['json_rate']:.0%} | " |
| f"Time={elapsed:.1f}s") |
| |
| return results |
|
|
|
|
| def full_eval( |
| forward_fn, |
| generate_fn, |
| params, |
| tokenizer, |
| ) -> Dict[str, Any]: |
| """Full evaluation after training.""" |
| print("\n[Eval] Running full evaluation...") |
| |
| results = {} |
| |
| |
| print("[Eval] Reasoning...") |
| results["reasoning"] = evaluate_reasoning(generate_fn, params, tokenizer) |
| |
| |
| print("[Eval] Tool use...") |
| results["tool_use"] = evaluate_tool_use(generate_fn, params, tokenizer) |
| |
| |
| print("\n" + "=" * 50) |
| print(" FLINT-1.2B EVALUATION RESULTS") |
| print("=" * 50) |
| print(f" Reasoning accuracy: {results['reasoning']['accuracy']:.1%}") |
| print(f" Think tag usage: {results['reasoning']['think_rate']:.1%}") |
| print(f" Step-by-step rate: {results['reasoning']['step_rate']:.1%}") |
| print(f" Tool JSON valid: {results['tool_use']['json_rate']:.1%}") |
| print(f" Tool complete rate: {results['tool_use']['complete_rate']:.1%}") |
| print("=" * 50) |
| |
| return results |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--checkpoint", required=True) |
| parser.add_argument("--quick", action="store_true") |
| parser.add_argument("--full", action="store_true") |
| parser.add_argument("--task", choices=["reasoning", "tool_use", "perplexity"]) |
| args = parser.parse_args() |
| |
| print(f"[Eval] Loading checkpoint: {args.checkpoint}") |
| |
| print("[Eval] TODO: Implement checkpoint loading for standalone eval") |
|
|