| """ |
| Route skill names from the classifier model to trajectory JSON files. |
| |
| Flow: |
| prompt -> model -> {"skill": "spotify_play_playlist"} -> trajectories/spotify_play_playlist.json |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from pathlib import Path |
|
|
| from src.paths import PROJECT_ROOT |
|
|
| ROOT_DIR = PROJECT_ROOT |
|
|
| SKILL_TO_TRAJECTORY: dict[str, str] = { |
| "bluetooth_enable": "trajectories/bluetooth_enable.json", |
| "calendar_create_event": "trajectories/calendar_create_event.json", |
| "camera_take_photo": "trajectories/camera_take_photo.json", |
| "contacts_search": "trajectories/contacts_search.json", |
| "create_alarm": "trajectories/create_alarm.json", |
| "gmail_send_email": "trajectories/gmail_send_email.json", |
| "linkedin_search_person": "trajectories/linkedin_search_person.json", |
| "slack_open_channel": "trajectories/slack_open_channel.json", |
| "spotify_pause": "trajectories/spotify_pause.json", |
| "spotify_play_playlist": "trajectories/spotify_play_playlist.json", |
| "spotify_search_play": "trajectories/spotify_search_play.json", |
| "uber_request_ride": "trajectories/uber_request_ride.json", |
| "whatsapp_send_message": "trajectories/whatsapp_send_message.json", |
| "wifi_enable": "trajectories/wifi_enable.json", |
| "youtube_search": "trajectories/youtube_search.json", |
| } |
|
|
|
|
| class SkillNotFoundError(KeyError): |
| """Raised when the skill name is not in SKILL_TO_TRAJECTORY.""" |
|
|
|
|
| class TrajectoryNotFoundError(FileNotFoundError): |
| """Raised when the mapped trajectory file does not exist on disk.""" |
|
|
|
|
| def route_skill(skill: str) -> Path: |
| """Return the trajectory file path for a skill name.""" |
| if skill not in SKILL_TO_TRAJECTORY: |
| raise SkillNotFoundError(f"Unknown skill: {skill!r}") |
|
|
| path = ROOT_DIR / SKILL_TO_TRAJECTORY[skill] |
| if not path.exists(): |
| raise TrajectoryNotFoundError( |
| f"Trajectory not found for skill {skill!r}: {path}" |
| ) |
| return path |
|
|
|
|
| def load_trajectory(skill: str) -> dict: |
| """Load and return the trajectory JSON for a skill.""" |
| path = route_skill(skill).resolve() |
| raw = path.read_text(encoding="utf-8") |
| try: |
| data = json.loads(raw) |
| except json.JSONDecodeError as exc: |
| raise TrajectoryNotFoundError( |
| f"Trajectory file for {skill!r} is not valid JSON ({path}). " |
| "If this file was a symlink, replace it with the actual trajectory JSON." |
| ) from exc |
| if not isinstance(data, dict) or "steps" not in data: |
| raise TrajectoryNotFoundError( |
| f"Trajectory file for {skill!r} is missing a 'steps' field: {path}" |
| ) |
| return data |
|
|
|
|
| def route_from_model_output(model_output: str) -> Path: |
| """Parse model JSON output and route to the trajectory file.""" |
| from src.skill_utils import extract_skill |
|
|
| skill = extract_skill(model_output) |
| if skill is None: |
| raise ValueError(f"Could not extract skill from model output: {model_output!r}") |
| return route_skill(skill) |
|
|
|
|
| def route_prompt( |
| prompt: str, model_path: str | Path | None = None |
| ) -> tuple[str, Path, dict]: |
| """Classify a prompt with the model, then load its trajectory.""" |
| from src.evaluate import generate_skill, load_model, pick_device, resolve_model_path |
| from src.paths import TRAINED_MODEL_DIR |
| from src.skill_utils import extract_skill |
|
|
| if model_path is None: |
| model_path = TRAINED_MODEL_DIR / "adapter" |
|
|
| device = pick_device() |
| resolved_path = resolve_model_path(str(model_path)) |
| model, tokenizer = load_model(resolved_path, device) |
|
|
| raw_output = generate_skill(model, tokenizer, prompt, device) |
| skill = extract_skill(raw_output) |
| if skill is None: |
| raise ValueError(f"Model did not return a skill for prompt: {prompt!r}") |
|
|
| trajectory_path = route_skill(skill) |
| trajectory = load_trajectory(skill) |
| return skill, trajectory_path, trajectory |
|
|
|
|
| def _main() -> None: |
| import argparse |
|
|
| from src.paths import TRAINED_MODEL_DIR |
|
|
| parser = argparse.ArgumentParser(description="Route skills to trajectory files.") |
| parser.add_argument("prompt", nargs="?", help="User prompt to classify and route") |
| parser.add_argument("--skill", help="Route a skill name directly (skip model)") |
| parser.add_argument( |
| "--model-path", |
| default=str(TRAINED_MODEL_DIR / "adapter"), |
| help="Path to trained model for prompt classification", |
| ) |
| args = parser.parse_args() |
|
|
| if args.skill: |
| skill = args.skill |
| print(f"Skill: {skill}") |
| path = route_skill(skill) |
| data = load_trajectory(skill) |
| print(f"Trajectory: {path}") |
| print(f"Task: {data['task']}") |
| print("Result: trajectory file found") |
| return |
|
|
| if not args.prompt: |
| parser.error("Provide a prompt or --skill") |
|
|
| print(f"Prompt: {args.prompt}") |
| skill, path, data = route_prompt(args.prompt, args.model_path) |
| print(f"Skill: {skill}") |
| print(f"Trajectory: {path}") |
| print(f"Task: {data['task']}") |
| print("Result: trajectory file found") |
|
|
|
|
| if __name__ == "__main__": |
| _main() |
|
|