#!/usr/bin/env python3 """ Evaluate a fine-tuned Gemma 4 model. Usage: python scripts/evaluate.py --model checkpoints/finetuned/lora_adapter \ --eval-data data/processed/train_eval.jsonl """ import argparse import json import time import torch from unsloth import FastModel def parse_args(): parser = argparse.ArgumentParser(description="Evaluate fine-tuned Gemma 4") parser.add_argument("--model", type=str, required=True, help="Path to fine-tuned LoRA adapter or model name") parser.add_argument("--base-model", type=str, default=None, help="Base model name (if loading LoRA adapter separately)") parser.add_argument("--eval-data", type=str, required=True, help="Path to evaluation JSONL file") parser.add_argument("--max-samples", type=int, default=100) parser.add_argument("--max-new-tokens", type=int, default=512) parser.add_argument("--max-seq-length", type=int, default=2048) parser.add_argument("--temperature", type=float, default=0.0, help="0.0 for greedy (deterministic)") return parser.parse_args() def load_eval_data(path, max_samples): """Load evaluation data from JSONL.""" data = [] with open(path) as f: for line in f: item = json.loads(line) if "messages" in item and len(item["messages"]) >= 2: data.append(item) if len(data) >= max_samples: break return data def exact_match(prediction, expected): """Simple exact match after normalization.""" pred_clean = prediction.strip().lower() exp_clean = expected.strip().lower() return pred_clean == exp_clean def contains_match(prediction, expected): """Check if expected answer is contained in prediction.""" pred_clean = prediction.strip().lower() exp_clean = expected.strip().lower() return exp_clean in pred_clean def main(): args = parse_args() print("=" * 60) print("Gemma 4 Evaluation") print("=" * 60) print(f"Model: {args.model}") print(f"Eval data: {args.eval_data}") print(f"Max samples: {args.max_samples}") print("=" * 60) # Load model print("\nLoading model...") model, tokenizer = FastModel.from_pretrained( model_name=args.model, max_seq_length=args.max_seq_length, load_in_4bit=True, ) FastModel.for_inference(model) # Load eval data eval_data = load_eval_data(args.eval_data, args.max_samples) print(f"Loaded {len(eval_data)} evaluation examples") # Evaluate results = [] total_tokens = 0 start_time = time.time() for i, sample in enumerate(eval_data): messages = sample["messages"] # Use all messages except the last (expected answer) as input input_messages = messages[:-1] expected = messages[-1]["content"] # Tokenize inputs = tokenizer.apply_chat_template( input_messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(model.device) # Generate with torch.no_grad(): outputs = model.generate( input_ids=inputs, max_new_tokens=args.max_new_tokens, temperature=args.temperature if args.temperature > 0 else None, do_sample=args.temperature > 0, ) # Decode only the new tokens new_tokens = outputs[0][inputs.shape[1]:] prediction = tokenizer.decode(new_tokens, skip_special_tokens=True) total_tokens += len(new_tokens) # Score em = exact_match(prediction, expected) cm = contains_match(prediction, expected) results.append({ "idx": i, "exact_match": em, "contains_match": cm, "prediction_len": len(prediction), "expected_len": len(expected), }) if i < 3: print(f"\n--- Example {i+1} ---") print(f"Input: {input_messages[-1]['content'][:100]}...") print(f"Expected: {expected[:100]}...") print(f"Got: {prediction[:100]}...") print(f"EM: {em} | Contains: {cm}") if (i + 1) % 10 == 0: print(f" Evaluated {i+1}/{len(eval_data)}...") # Compute metrics elapsed = time.time() - start_time n = len(results) exact_match_acc = sum(r["exact_match"] for r in results) / n if n else 0 contains_match_acc = sum(r["contains_match"] for r in results) / n if n else 0 avg_pred_len = sum(r["prediction_len"] for r in results) / n if n else 0 tokens_per_sec = total_tokens / elapsed if elapsed > 0 else 0 print("\n" + "=" * 60) print("Results") print("=" * 60) print(f" Samples evaluated: {n}") print(f" Exact match: {exact_match_acc:.4f} ({sum(r['exact_match'] for r in results)}/{n})") print(f" Contains match: {contains_match_acc:.4f} ({sum(r['contains_match'] for r in results)}/{n})") print(f" Avg prediction len: {avg_pred_len:.0f} chars") print(f" Inference speed: {tokens_per_sec:.1f} tokens/sec") print(f" Total time: {elapsed:.1f}s") # Parseable metrics line for AutoResearch print(f"\nMETRICS: exact_match={exact_match_acc:.4f} " f"contains_match={contains_match_acc:.4f} " f"tokens_per_sec={tokens_per_sec:.1f} " f"eval_time={elapsed:.1f}") if __name__ == "__main__": main()