File size: 6,643 Bytes
40a90bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
"""
Evaluate the fine-tuned intent-extraction model locally.

Run:
    python -m src.evaluate_intent
    python -m src.evaluate_intent --model-path ./trained_model/adapter
"""

from __future__ import annotations

import argparse
import json
from pathlib import Path

import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer

from src.classifier_prompt import build_intent_messages
from src.paths import DATA_DIR, TRAINED_MODEL_DIR
from src.skill_utils import extract_intent, intent_matches

BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
MAX_SEQ_LENGTH = 2048
MAX_NEW_TOKENS = 128
DEFAULT_MODEL_PATH = TRAINED_MODEL_DIR / "adapter"
FALLBACK_MODEL_PATH = TRAINED_MODEL_DIR / "merged"
EVAL_PROMPTS_PATH = DATA_DIR / "eval_intent_prompts.json"


def load_eval_prompts() -> list[dict]:
    with EVAL_PROMPTS_PATH.open(encoding="utf-8") as handle:
        return json.load(handle)


def pick_device() -> str:
    if torch.cuda.is_available():
        return "cuda"
    if torch.backends.mps.is_available():
        return "mps"
    return "cpu"


def is_adapter_path(model_path: Path) -> bool:
    return (model_path / "adapter_config.json").exists()


def is_complete_merged_model(model_path: Path) -> bool:
    config_path = model_path / "config.json"
    if not config_path.exists() or config_path.stat().st_size == 0:
        return False

    index_path = model_path / "model.safetensors.index.json"
    if index_path.exists():
        index = json.loads(index_path.read_text(encoding="utf-8"))
        shard_names = set(index.get("weight_map", {}).values())
        return all(
            (model_path / shard).exists() and (model_path / shard).stat().st_size > 0
            for shard in shard_names
        )

    single_shard = model_path / "model.safetensors"
    return single_shard.exists() and single_shard.stat().st_size > 0


def resolve_model_path(path: str) -> Path:
    model_path = Path(path)
    if model_path.exists():
        if is_adapter_path(model_path):
            return model_path
        if is_complete_merged_model(model_path):
            return model_path
        print(f"Warning: {model_path} looks incomplete; trying adapter fallback.")

    adapter_path = DEFAULT_MODEL_PATH
    if adapter_path.exists() and is_adapter_path(adapter_path):
        print(f"Using LoRA adapter at {adapter_path}")
        return adapter_path

    merged_path = FALLBACK_MODEL_PATH
    if merged_path.exists() and is_complete_merged_model(merged_path):
        print(f"Using merged model at {merged_path}")
        return merged_path

    raise FileNotFoundError(
        "No usable trained model found. Expected a complete merged model or "
        f"LoRA adapter at {DEFAULT_MODEL_PATH}."
    )


def load_model(model_path: Path, device: str):
    dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32

    if is_adapter_path(model_path):
        print(f"Loading base model: {BASE_MODEL}")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        base_model = AutoModelForCausalLM.from_pretrained(
            BASE_MODEL,
            dtype=dtype,
            low_cpu_mem_usage=True,
        )
        model = PeftModel.from_pretrained(base_model, str(model_path))
    else:
        print("Loading merged model weights")
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForCausalLM.from_pretrained(
            model_path,
            dtype=dtype,
            low_cpu_mem_usage=True,
        )

    model.to(device)
    model.eval()
    return model, tokenizer


def generate_intent(model, tokenizer, prompt: str, device: str) -> str:
    messages = build_intent_messages(prompt)
    inputs = tokenizer.apply_chat_template(
        messages,
        tokenize=True,
        add_generation_prompt=True,
        return_tensors="pt",
    ).to(device)

    with torch.inference_mode():
        outputs = model.generate(
            input_ids=inputs,
            max_new_tokens=MAX_NEW_TOKENS,
            use_cache=True,
            do_sample=False,
        )

    generated = outputs[0][inputs.shape[1] :]
    return tokenizer.decode(generated, skip_special_tokens=True).strip()


def evaluate(model, tokenizer, test_prompts: list[dict], device: str) -> tuple[float, float]:
    skill_correct = 0
    full_correct = 0

    for index, case in enumerate(test_prompts, start=1):
        prompt = case["prompt"]
        expected = case["expected"]
        expected_skill = expected["skill"]
        expected_params = expected.get("parameters", {})

        raw_output = generate_intent(model, tokenizer, prompt, device)
        predicted = extract_intent(raw_output)

        skill_pass = predicted is not None and predicted.get("skill") == expected_skill
        full_pass = intent_matches(predicted, expected_skill, expected_params)

        skill_correct += int(skill_pass)
        full_correct += int(full_pass)

        print(f"--- [{index}/{len(test_prompts)}] ---")
        print(f"Prompt:    {prompt}")
        print(f"Expected:  {json.dumps(expected, separators=(',', ':'))}")
        print(f"Predicted: {json.dumps(predicted, separators=(',', ':')) if predicted else raw_output}")
        print(f"Skill:     {'PASS' if skill_pass else 'FAIL'}")
        print(f"Full:      {'PASS' if full_pass else 'FAIL'}")
        print()

    skill_accuracy = skill_correct / len(test_prompts)
    full_accuracy = full_correct / len(test_prompts)
    print(f"Skill accuracy: {skill_correct}/{len(test_prompts)} ({skill_accuracy:.1%})")
    print(f"Full intent accuracy: {full_correct}/{len(test_prompts)} ({full_accuracy:.1%})")
    return skill_accuracy, full_accuracy


def main() -> None:
    parser = argparse.ArgumentParser(description="Evaluate the intent extraction model.")
    parser.add_argument(
        "--model-path",
        default=str(DEFAULT_MODEL_PATH),
        help=f"Path to LoRA adapter or merged model (default: {DEFAULT_MODEL_PATH})",
    )
    args = parser.parse_args()

    if not EVAL_PROMPTS_PATH.exists():
        raise FileNotFoundError(
            f"Eval prompts not found at {EVAL_PROMPTS_PATH}. "
            "Run `python scripts/generate_intent_dataset.py` first."
        )

    test_prompts = load_eval_prompts()
    device = pick_device()
    model_path = resolve_model_path(args.model_path)
    print(f"Device: {device}")
    print(f"Loading model from {model_path.resolve()}")
    model, tokenizer = load_model(model_path, device)
    print(f"Running intent evaluation on {len(test_prompts)} prompts...\n")
    evaluate(model, tokenizer, test_prompts, device)


if __name__ == "__main__":
    main()