File size: 5,559 Bytes
4942b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
#!/usr/bin/env python3
"""
Evaluate a fine-tuned Gemma 4 model.

Usage:
    python scripts/evaluate.py --model checkpoints/finetuned/lora_adapter \
        --eval-data data/processed/train_eval.jsonl
"""

import argparse
import json
import time

import torch
from unsloth import FastModel


def parse_args():
    parser = argparse.ArgumentParser(description="Evaluate fine-tuned Gemma 4")
    parser.add_argument("--model", type=str, required=True,
                        help="Path to fine-tuned LoRA adapter or model name")
    parser.add_argument("--base-model", type=str, default=None,
                        help="Base model name (if loading LoRA adapter separately)")
    parser.add_argument("--eval-data", type=str, required=True,
                        help="Path to evaluation JSONL file")
    parser.add_argument("--max-samples", type=int, default=100)
    parser.add_argument("--max-new-tokens", type=int, default=512)
    parser.add_argument("--max-seq-length", type=int, default=2048)
    parser.add_argument("--temperature", type=float, default=0.0,
                        help="0.0 for greedy (deterministic)")
    return parser.parse_args()


def load_eval_data(path, max_samples):
    """Load evaluation data from JSONL."""
    data = []
    with open(path) as f:
        for line in f:
            item = json.loads(line)
            if "messages" in item and len(item["messages"]) >= 2:
                data.append(item)
            if len(data) >= max_samples:
                break
    return data


def exact_match(prediction, expected):
    """Simple exact match after normalization."""
    pred_clean = prediction.strip().lower()
    exp_clean = expected.strip().lower()
    return pred_clean == exp_clean


def contains_match(prediction, expected):
    """Check if expected answer is contained in prediction."""
    pred_clean = prediction.strip().lower()
    exp_clean = expected.strip().lower()
    return exp_clean in pred_clean


def main():
    args = parse_args()

    print("=" * 60)
    print("Gemma 4 Evaluation")
    print("=" * 60)
    print(f"Model:      {args.model}")
    print(f"Eval data:  {args.eval_data}")
    print(f"Max samples: {args.max_samples}")
    print("=" * 60)

    # Load model
    print("\nLoading model...")
    model, tokenizer = FastModel.from_pretrained(
        model_name=args.model,
        max_seq_length=args.max_seq_length,
        load_in_4bit=True,
    )
    FastModel.for_inference(model)

    # Load eval data
    eval_data = load_eval_data(args.eval_data, args.max_samples)
    print(f"Loaded {len(eval_data)} evaluation examples")

    # Evaluate
    results = []
    total_tokens = 0
    start_time = time.time()

    for i, sample in enumerate(eval_data):
        messages = sample["messages"]

        # Use all messages except the last (expected answer) as input
        input_messages = messages[:-1]
        expected = messages[-1]["content"]

        # Tokenize
        inputs = tokenizer.apply_chat_template(
            input_messages,
            tokenize=True,
            add_generation_prompt=True,
            return_tensors="pt",
        ).to(model.device)

        # Generate
        with torch.no_grad():
            outputs = model.generate(
                input_ids=inputs,
                max_new_tokens=args.max_new_tokens,
                temperature=args.temperature if args.temperature > 0 else None,
                do_sample=args.temperature > 0,
            )

        # Decode only the new tokens
        new_tokens = outputs[0][inputs.shape[1]:]
        prediction = tokenizer.decode(new_tokens, skip_special_tokens=True)
        total_tokens += len(new_tokens)

        # Score
        em = exact_match(prediction, expected)
        cm = contains_match(prediction, expected)

        results.append({
            "idx": i,
            "exact_match": em,
            "contains_match": cm,
            "prediction_len": len(prediction),
            "expected_len": len(expected),
        })

        if i < 3:
            print(f"\n--- Example {i+1} ---")
            print(f"Input:    {input_messages[-1]['content'][:100]}...")
            print(f"Expected: {expected[:100]}...")
            print(f"Got:      {prediction[:100]}...")
            print(f"EM: {em} | Contains: {cm}")

        if (i + 1) % 10 == 0:
            print(f"  Evaluated {i+1}/{len(eval_data)}...")

    # Compute metrics
    elapsed = time.time() - start_time
    n = len(results)
    exact_match_acc = sum(r["exact_match"] for r in results) / n if n else 0
    contains_match_acc = sum(r["contains_match"] for r in results) / n if n else 0
    avg_pred_len = sum(r["prediction_len"] for r in results) / n if n else 0
    tokens_per_sec = total_tokens / elapsed if elapsed > 0 else 0

    print("\n" + "=" * 60)
    print("Results")
    print("=" * 60)
    print(f"  Samples evaluated:  {n}")
    print(f"  Exact match:        {exact_match_acc:.4f} ({sum(r['exact_match'] for r in results)}/{n})")
    print(f"  Contains match:     {contains_match_acc:.4f} ({sum(r['contains_match'] for r in results)}/{n})")
    print(f"  Avg prediction len: {avg_pred_len:.0f} chars")
    print(f"  Inference speed:    {tokens_per_sec:.1f} tokens/sec")
    print(f"  Total time:         {elapsed:.1f}s")

    # Parseable metrics line for AutoResearch
    print(f"\nMETRICS: exact_match={exact_match_acc:.4f} "
          f"contains_match={contains_match_acc:.4f} "
          f"tokens_per_sec={tokens_per_sec:.1f} "
          f"eval_time={elapsed:.1f}")


if __name__ == "__main__":
    main()