android-skill-router / src /evaluate.py
kriyanshi's picture
Improve skill classifier training for contacts, Gmail, and Slack.
24492a8
Raw
History Blame Contribute Delete
6.05 kB
"""
Evaluate the fine-tuned skill-classification model.
Loads the trained model, runs 50 held-out prompts, and reports per-example
PASS/FAIL plus overall accuracy.
Uses transformers + PEFT so evaluation works on Mac/CPU without Unsloth.
Run:
python -m src.evaluate
python -m src.evaluate --model-path ./trained_model/adapter
"""
from __future__ import annotations
import argparse
import json
from pathlib import Path
import torch
from peft import PeftModel
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.paths import DATA_DIR, TRAINED_MODEL_DIR
from src.skill_utils import extract_skill
from src.classifier_prompt import build_classifier_messages
BASE_MODEL = "Qwen/Qwen2.5-3B-Instruct"
MAX_SEQ_LENGTH = 2048
DEFAULT_MODEL_PATH = TRAINED_MODEL_DIR / "adapter"
FALLBACK_MODEL_PATH = TRAINED_MODEL_DIR / "merged"
EVAL_PROMPTS_PATH = DATA_DIR / "eval_prompts.json"
def load_eval_prompts() -> list[dict[str, str]]:
with EVAL_PROMPTS_PATH.open(encoding="utf-8") as handle:
prompts = json.load(handle)
if len(prompts) != 50:
raise ValueError(f"Expected 50 eval prompts in {EVAL_PROMPTS_PATH}, got {len(prompts)}")
return prompts
def pick_device() -> str:
if torch.cuda.is_available():
return "cuda"
if torch.backends.mps.is_available():
return "mps"
return "cpu"
def is_adapter_path(model_path: Path) -> bool:
return (model_path / "adapter_config.json").exists()
def is_complete_merged_model(model_path: Path) -> bool:
config_path = model_path / "config.json"
if not config_path.exists() or config_path.stat().st_size == 0:
return False
index_path = model_path / "model.safetensors.index.json"
if index_path.exists():
index = json.loads(index_path.read_text(encoding="utf-8"))
shard_names = set(index.get("weight_map", {}).values())
return all(
(model_path / shard).exists() and (model_path / shard).stat().st_size > 0
for shard in shard_names
)
single_shard = model_path / "model.safetensors"
return single_shard.exists() and single_shard.stat().st_size > 0
def resolve_model_path(path: str) -> Path:
model_path = Path(path)
if model_path.exists():
if is_adapter_path(model_path):
return model_path
if is_complete_merged_model(model_path):
return model_path
print(f"Warning: {model_path} looks incomplete; trying adapter fallback.")
adapter_path = DEFAULT_MODEL_PATH
if adapter_path.exists() and is_adapter_path(adapter_path):
print(f"Using LoRA adapter at {adapter_path}")
return adapter_path
merged_path = FALLBACK_MODEL_PATH
if merged_path.exists() and is_complete_merged_model(merged_path):
print(f"Using merged model at {merged_path}")
return merged_path
raise FileNotFoundError(
"No usable trained model found. Expected a complete merged model or "
f"LoRA adapter at {DEFAULT_MODEL_PATH}."
)
def load_model(model_path: Path, device: str):
dtype = torch.float16 if device in {"cuda", "mps"} else torch.float32
if is_adapter_path(model_path):
print(f"Loading base model: {BASE_MODEL}")
tokenizer = AutoTokenizer.from_pretrained(model_path)
base_model = AutoModelForCausalLM.from_pretrained(
BASE_MODEL,
dtype=dtype,
low_cpu_mem_usage=True,
)
model = PeftModel.from_pretrained(base_model, str(model_path))
else:
print("Loading merged model weights")
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(
model_path,
dtype=dtype,
low_cpu_mem_usage=True,
)
model.to(device)
model.eval()
return model, tokenizer
def generate_skill(model, tokenizer, prompt: str, device: str) -> str:
messages = build_classifier_messages(prompt)
inputs = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(device)
with torch.inference_mode():
outputs = model.generate(
input_ids=inputs,
max_new_tokens=64,
use_cache=True,
do_sample=False,
)
generated = outputs[0][inputs.shape[1] :]
return tokenizer.decode(generated, skip_special_tokens=True).strip()
def evaluate(model, tokenizer, test_prompts: list[dict[str, str]], device: str) -> float:
correct = 0
for index, case in enumerate(test_prompts, start=1):
prompt = case["prompt"]
expected = case["expected"]
raw_output = generate_skill(model, tokenizer, prompt, device)
predicted = extract_skill(raw_output)
passed = predicted == expected
correct += int(passed)
print(f"--- [{index}/{len(test_prompts)}] ---")
print(f"Prompt: {prompt}")
print(f"Expected: {expected}")
print(f"Predicted: {predicted if predicted is not None else raw_output}")
print(f"Result: {'PASS' if passed else 'FAIL'}")
print()
accuracy = correct / len(test_prompts)
print(f"Accuracy: {correct}/{len(test_prompts)} ({accuracy:.1%})")
return accuracy
def main() -> None:
parser = argparse.ArgumentParser(description="Evaluate the skill model.")
parser.add_argument(
"--model-path",
default=str(DEFAULT_MODEL_PATH),
help=f"Path to LoRA adapter or merged model (default: {DEFAULT_MODEL_PATH})",
)
args = parser.parse_args()
test_prompts = load_eval_prompts()
device = pick_device()
model_path = resolve_model_path(args.model_path)
print(f"Device: {device}")
print(f"Loading model from {model_path.resolve()}")
model, tokenizer = load_model(model_path, device)
print(f"Running evaluation on {len(test_prompts)} prompts...\n")
evaluate(model, tokenizer, test_prompts, device)
if __name__ == "__main__":
main()