banhmi-gemma4-e4b / scripts /evaluate.py
bradduy's picture
Add Unsloth training pipeline (train, evaluate, export, prepare_data, training_logger)
4942b80 verified
#!/usr/bin/env python3
"""
Evaluate a fine-tuned Gemma 4 model.
Usage:
python scripts/evaluate.py --model checkpoints/finetuned/lora_adapter \
--eval-data data/processed/train_eval.jsonl
"""
import argparse
import json
import time
import torch
from unsloth import FastModel
def parse_args():
parser = argparse.ArgumentParser(description="Evaluate fine-tuned Gemma 4")
parser.add_argument("--model", type=str, required=True,
help="Path to fine-tuned LoRA adapter or model name")
parser.add_argument("--base-model", type=str, default=None,
help="Base model name (if loading LoRA adapter separately)")
parser.add_argument("--eval-data", type=str, required=True,
help="Path to evaluation JSONL file")
parser.add_argument("--max-samples", type=int, default=100)
parser.add_argument("--max-new-tokens", type=int, default=512)
parser.add_argument("--max-seq-length", type=int, default=2048)
parser.add_argument("--temperature", type=float, default=0.0,
help="0.0 for greedy (deterministic)")
return parser.parse_args()
def load_eval_data(path, max_samples):
"""Load evaluation data from JSONL."""
data = []
with open(path) as f:
for line in f:
item = json.loads(line)
if "messages" in item and len(item["messages"]) >= 2:
data.append(item)
if len(data) >= max_samples:
break
return data
def exact_match(prediction, expected):
"""Simple exact match after normalization."""
pred_clean = prediction.strip().lower()
exp_clean = expected.strip().lower()
return pred_clean == exp_clean
def contains_match(prediction, expected):
"""Check if expected answer is contained in prediction."""
pred_clean = prediction.strip().lower()
exp_clean = expected.strip().lower()
return exp_clean in pred_clean
def main():
args = parse_args()
print("=" * 60)
print("Gemma 4 Evaluation")
print("=" * 60)
print(f"Model: {args.model}")
print(f"Eval data: {args.eval_data}")
print(f"Max samples: {args.max_samples}")
print("=" * 60)
# Load model
print("\nLoading model...")
model, tokenizer = FastModel.from_pretrained(
model_name=args.model,
max_seq_length=args.max_seq_length,
load_in_4bit=True,
)
FastModel.for_inference(model)
# Load eval data
eval_data = load_eval_data(args.eval_data, args.max_samples)
print(f"Loaded {len(eval_data)} evaluation examples")
# Evaluate
results = []
total_tokens = 0
start_time = time.time()
for i, sample in enumerate(eval_data):
messages = sample["messages"]
# Use all messages except the last (expected answer) as input
input_messages = messages[:-1]
expected = messages[-1]["content"]
# Tokenize
inputs = tokenizer.apply_chat_template(
input_messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
# Generate
with torch.no_grad():
outputs = model.generate(
input_ids=inputs,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature if args.temperature > 0 else None,
do_sample=args.temperature > 0,
)
# Decode only the new tokens
new_tokens = outputs[0][inputs.shape[1]:]
prediction = tokenizer.decode(new_tokens, skip_special_tokens=True)
total_tokens += len(new_tokens)
# Score
em = exact_match(prediction, expected)
cm = contains_match(prediction, expected)
results.append({
"idx": i,
"exact_match": em,
"contains_match": cm,
"prediction_len": len(prediction),
"expected_len": len(expected),
})
if i < 3:
print(f"\n--- Example {i+1} ---")
print(f"Input: {input_messages[-1]['content'][:100]}...")
print(f"Expected: {expected[:100]}...")
print(f"Got: {prediction[:100]}...")
print(f"EM: {em} | Contains: {cm}")
if (i + 1) % 10 == 0:
print(f" Evaluated {i+1}/{len(eval_data)}...")
# Compute metrics
elapsed = time.time() - start_time
n = len(results)
exact_match_acc = sum(r["exact_match"] for r in results) / n if n else 0
contains_match_acc = sum(r["contains_match"] for r in results) / n if n else 0
avg_pred_len = sum(r["prediction_len"] for r in results) / n if n else 0
tokens_per_sec = total_tokens / elapsed if elapsed > 0 else 0
print("\n" + "=" * 60)
print("Results")
print("=" * 60)
print(f" Samples evaluated: {n}")
print(f" Exact match: {exact_match_acc:.4f} ({sum(r['exact_match'] for r in results)}/{n})")
print(f" Contains match: {contains_match_acc:.4f} ({sum(r['contains_match'] for r in results)}/{n})")
print(f" Avg prediction len: {avg_pred_len:.0f} chars")
print(f" Inference speed: {tokens_per_sec:.1f} tokens/sec")
print(f" Total time: {elapsed:.1f}s")
# Parseable metrics line for AutoResearch
print(f"\nMETRICS: exact_match={exact_match_acc:.4f} "
f"contains_match={contains_match_acc:.4f} "
f"tokens_per_sec={tokens_per_sec:.1f} "
f"eval_time={elapsed:.1f}")
if __name__ == "__main__":
main()