File size: 6,046 Bytes
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24492a8
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24492a8
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
"""
Evaluate the fine-tuned skill-classification model.

Loads the trained model, runs 50 held-out prompts, and reports per-example
PASS/FAIL plus overall accuracy.

Uses transformers + PEFT so evaluation works on Mac/CPU without Unsloth.

Run:
    python -m src.evaluate
    python -m src.evaluate --model-path ./trained_model/adapter
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.paths import DATA_DIR, TRAINED_MODEL_DIR
from src.skill_utils import extract_skill
from src.classifier_prompt import build_classifier_messages

BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
MAX_SEQ_LENGTH = 2048
DEFAULT_MODEL_PATH = TRAINED_MODEL_DIR / "adapter"
FALLBACK_MODEL_PATH = TRAINED_MODEL_DIR / "merged"
EVAL_PROMPTS_PATH = DATA_DIR / "eval_prompts.json"


def load_eval_prompts() -> list[dict[str, str]]:
    with EVAL_PROMPTS_PATH.open(encoding="utf-8") as handle:
        prompts = json.load(handle)
    if len(prompts) != 50:
        raise ValueError(f"Expected 50 eval prompts in {EVAL_PROMPTS_PATH}, got {len(prompts)}")
    return prompts


def pick_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def is_adapter_path(model_path: Path) -> bool:
    return (model_path / "adapter_config.json").exists()


def is_complete_merged_model(model_path: Path) -> bool:
    config_path = model_path / "config.json"
    if not config_path.exists() or config_path.stat().st_size == 0:
        return False

    index_path = model_path / "model.safetensors.index.json"
    if index_path.exists():
        index = json.loads(index_path.read_text(encoding="utf-8"))
        shard_names = set(index.get("weight_map", {}).values())
        return all(
            (model_path / shard).exists() and (model_path / shard).stat().st_size > 0
            for shard in shard_names
        )

    single_shard = model_path / "model.safetensors"
    return single_shard.exists() and single_shard.stat().st_size > 0


def resolve_model_path(path: str) -> Path:
    model_path = Path(path)
    if model_path.exists():
        if is_adapter_path(model_path):
            return model_path
        if is_complete_merged_model(model_path):
            return model_path
        print(f"Warning: {model_path} looks incomplete; trying adapter fallback.")

    adapter_path = DEFAULT_MODEL_PATH
    if adapter_path.exists() and is_adapter_path(adapter_path):
        print(f"Using LoRA adapter at {adapter_path}")
        return adapter_path

    merged_path = FALLBACK_MODEL_PATH
    if merged_path.exists() and is_complete_merged_model(merged_path):
        print(f"Using merged model at {merged_path}")
        return merged_path

    raise FileNotFoundError(
        "No usable trained model found. Expected a complete merged model or "
        f"LoRA adapter at {DEFAULT_MODEL_PATH}."
    )


def load_model(model_path: Path, device: str):
    dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32

    if is_adapter_path(model_path):
        print(f"Loading base model: {BASE_MODEL}")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            dtype=dtype,
            low_cpu_mem_usage=True,
        )
        model = PeftModel.from_pretrained(base_model, str(model_path))
    else:
        print("Loading merged model weights")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            dtype=dtype,
            low_cpu_mem_usage=True,
        )

    model.to(device)
    model.eval()
    return model, tokenizer


def generate_skill(model, tokenizer, prompt: str, device: str) -> str:
    messages = build_classifier_messages(prompt)
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(device)

    with torch.inference_mode():
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=64,
            use_cache=True,
            do_sample=False,
        )

    generated = outputs[0][inputs.shape[1] :]
    return tokenizer.decode(generated, skip_special_tokens=True).strip()


def evaluate(model, tokenizer, test_prompts: list[dict[str, str]], device: str) -> float:
    correct = 0

    for index, case in enumerate(test_prompts, start=1):
        prompt = case["prompt"]
        expected = case["expected"]

        raw_output = generate_skill(model, tokenizer, prompt, device)
        predicted = extract_skill(raw_output)
        passed = predicted == expected
        correct += int(passed)

        print(f"--- [{index}/{len(test_prompts)}] ---")
        print(f"Prompt:    {prompt}")
        print(f"Expected:  {expected}")
        print(f"Predicted: {predicted if predicted is not None else raw_output}")
        print(f"Result:    {'PASS' if passed else 'FAIL'}")
        print()

    accuracy = correct / len(test_prompts)
    print(f"Accuracy: {correct}/{len(test_prompts)} ({accuracy:.1%})")
    return accuracy


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate the skill model.")
    parser.add_argument(
        "--model-path",
        default=str(DEFAULT_MODEL_PATH),
        help=f"Path to LoRA adapter or merged model (default: {DEFAULT_MODEL_PATH})",
    )
    args = parser.parse_args()

    test_prompts = load_eval_prompts()

    device = pick_device()
    model_path = resolve_model_path(args.model_path)
    print(f"Device: {device}")
    print(f"Loading model from {model_path.resolve()}")
    model, tokenizer = load_model(model_path, device)
    print(f"Running evaluation on {len(test_prompts)} prompts...\n")
    evaluate(model, tokenizer, test_prompts, device)


if __name__ == "__main__":
    main()