android-skill-router / src /evaluate_intent.py
kriyanshi's picture
Ship v2 intent extraction with API, demo UI, eval, and benchmark suite.
40a90bb
Raw
History Blame Contribute Delete
6.64 kB
"""
Evaluate the fine-tuned intent-extraction model locally.
Run:
python -m src.evaluate_intent
python -m src.evaluate_intent --model-path ./trained_model/adapter
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.classifier_prompt import build_intent_messages
from src.paths import DATA_DIR, TRAINED_MODEL_DIR
from src.skill_utils import extract_intent, intent_matches
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
MAX_SEQ_LENGTH = 2048
MAX_NEW_TOKENS = 128
DEFAULT_MODEL_PATH = TRAINED_MODEL_DIR / "adapter"
FALLBACK_MODEL_PATH = TRAINED_MODEL_DIR / "merged"
EVAL_PROMPTS_PATH = DATA_DIR / "eval_intent_prompts.json"
def load_eval_prompts() -> list[dict]:
with EVAL_PROMPTS_PATH.open(encoding="utf-8") as handle:
return json.load(handle)
def pick_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def is_adapter_path(model_path: Path) -> bool:
return (model_path / "adapter_config.json").exists()
def is_complete_merged_model(model_path: Path) -> bool:
config_path = model_path / "config.json"
if not config_path.exists() or config_path.stat().st_size == 0:
return False
index_path = model_path / "model.safetensors.index.json"
if index_path.exists():
index = json.loads(index_path.read_text(encoding="utf-8"))
shard_names = set(index.get("weight_map", {}).values())
return all(
(model_path / shard).exists() and (model_path / shard).stat().st_size > 0
for shard in shard_names
)
single_shard = model_path / "model.safetensors"
return single_shard.exists() and single_shard.stat().st_size > 0
def resolve_model_path(path: str) -> Path:
model_path = Path(path)
if model_path.exists():
if is_adapter_path(model_path):
return model_path
if is_complete_merged_model(model_path):
return model_path
print(f"Warning: {model_path} looks incomplete; trying adapter fallback.")
adapter_path = DEFAULT_MODEL_PATH
if adapter_path.exists() and is_adapter_path(adapter_path):
print(f"Using LoRA adapter at {adapter_path}")
return adapter_path
merged_path = FALLBACK_MODEL_PATH
if merged_path.exists() and is_complete_merged_model(merged_path):
print(f"Using merged model at {merged_path}")
return merged_path
raise FileNotFoundError(
"No usable trained model found. Expected a complete merged model or "
f"LoRA adapter at {DEFAULT_MODEL_PATH}."
)
def load_model(model_path: Path, device: str):
dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32
if is_adapter_path(model_path):
print(f"Loading base model: {BASE_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=dtype,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(base_model, str(model_path))
else:
print("Loading merged model weights")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=dtype,
low_cpu_mem_usage=True,
)
model.to(device)
model.eval()
return model, tokenizer
def generate_intent(model, tokenizer, prompt: str, device: str) -> str:
messages = build_intent_messages(prompt)
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs,
max_new_tokens=MAX_NEW_TOKENS,
use_cache=True,
do_sample=False,
)
generated = outputs[0][inputs.shape[1] :]
return tokenizer.decode(generated, skip_special_tokens=True).strip()
def evaluate(model, tokenizer, test_prompts: list[dict], device: str) -> tuple[float, float]:
skill_correct = 0
full_correct = 0
for index, case in enumerate(test_prompts, start=1):
prompt = case["prompt"]
expected = case["expected"]
expected_skill = expected["skill"]
expected_params = expected.get("parameters", {})
raw_output = generate_intent(model, tokenizer, prompt, device)
predicted = extract_intent(raw_output)
skill_pass = predicted is not None and predicted.get("skill") == expected_skill
full_pass = intent_matches(predicted, expected_skill, expected_params)
skill_correct += int(skill_pass)
full_correct += int(full_pass)
print(f"--- [{index}/{len(test_prompts)}] ---")
print(f"Prompt: {prompt}")
print(f"Expected: {json.dumps(expected, separators=(',', ':'))}")
print(f"Predicted: {json.dumps(predicted, separators=(',', ':')) if predicted else raw_output}")
print(f"Skill: {'PASS' if skill_pass else 'FAIL'}")
print(f"Full: {'PASS' if full_pass else 'FAIL'}")
print()
skill_accuracy = skill_correct / len(test_prompts)
full_accuracy = full_correct / len(test_prompts)
print(f"Skill accuracy: {skill_correct}/{len(test_prompts)} ({skill_accuracy:.1%})")
print(f"Full intent accuracy: {full_correct}/{len(test_prompts)} ({full_accuracy:.1%})")
return skill_accuracy, full_accuracy
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate the intent extraction model.")
parser.add_argument(
"--model-path",
default=str(DEFAULT_MODEL_PATH),
help=f"Path to LoRA adapter or merged model (default: {DEFAULT_MODEL_PATH})",
)
args = parser.parse_args()
if not EVAL_PROMPTS_PATH.exists():
raise FileNotFoundError(
f"Eval prompts not found at {EVAL_PROMPTS_PATH}. "
"Run `python scripts/generate_intent_dataset.py` first."
)
test_prompts = load_eval_prompts()
device = pick_device()
model_path = resolve_model_path(args.model_path)
print(f"Device: {device}")
print(f"Loading model from {model_path.resolve()}")
model, tokenizer = load_model(model_path, device)
print(f"Running intent evaluation on {len(test_prompts)} prompts...\n")
evaluate(model, tokenizer, test_prompts, device)
if __name__ == "__main__":
main()