"""Shared helpers for skill classification output.""" from __future__ import annotations import json import re VALID_SKILLS: frozenset[str] = frozenset( { "bluetooth_enable", "calendar_create_event", "camera_take_photo", "contacts_search", "create_alarm", "gmail_send_email", "linkedin_search_person", "slack_open_channel", "spotify_pause", "spotify_play_playlist", "spotify_search_play", "uber_request_ride", "whatsapp_send_message", "wifi_enable", "youtube_search", } ) # Model sometimes invents close-but-wrong skill names. SKILL_ALIASES: dict[str, str] = { "alarm_set": "create_alarm", "set_alarm": "create_alarm", "alarm_create": "create_alarm", "create_alarms": "create_alarm", "wake_up_alarm": "create_alarm", "send_message": "whatsapp_send_message", "send_whatsapp": "whatsapp_send_message", "whatsapp_message": "whatsapp_send_message", "send_email": "gmail_send_email", "compose_email": "gmail_send_email", "open_slack": "slack_open_channel", "slack_channel": "slack_open_channel", "search_contacts": "contacts_search", "contact_search": "contacts_search", "spotify_create_playlist": "spotify_play_playlist", } def extract_skill(text: str) -> str | None: """Extract the skill name from model output JSON.""" text = text.strip() if not text: return None match = re.search(r'\{[^{}]*"skill"\s*:\s*"([^"]+)"[^{}]*\}', text) if match: return match.group(1) start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end <= start: return None try: payload = json.loads(text[start : end + 1]) except json.JSONDecodeError: return None skill = payload.get("skill") return skill if isinstance(skill, str) and skill else None def normalize_skill(raw_skill: str | None) -> str | None: """Map a raw model skill label to a known skill, if possible.""" if not raw_skill or not isinstance(raw_skill, str): return None skill = raw_skill.strip().lower().replace("-", "_").replace(" ", "_") if skill in VALID_SKILLS: return skill if skill in SKILL_ALIASES: return SKILL_ALIASES[skill] return None def infer_skill_from_prompt(prompt: str) -> str | None: """Keyword fallback when the model returns an unknown skill.""" text = prompt.lower().strip() if not text: return None if re.search(r"gmail|@\w+\.\w+|send\s+(an?\s+)?email|compose\s+email|write\s+mail", text): return "gmail_send_email" if "whatsapp" in text or re.search( r"\b(send|message|text)\s+\w+\s+a\s+message\b", text ): return "whatsapp_send_message" rules: list[tuple[str, re.Pattern[str]]] = [ ("contacts_search", re.compile(r"\bcontacts\b|phone book|address book")), ("slack_open_channel", re.compile(r"slack.*channel|channel.*slack|#\w+.*slack")), ("youtube_search", re.compile(r"youtube")), ("linkedin_search_person", re.compile(r"linkedin")), ( "spotify_search_play", re.compile(r"spotify.*(search|find).*(play|and play)|search.*spotify.*play"), ), ( "spotify_play_playlist", re.compile(r"playlist|liked songs|discover weekly|daily mix"), ), ("spotify_pause", re.compile(r"pause.*spotify|stop.*spotify|hold spotify")), ("uber_request_ride", re.compile(r"\buber\b|book.*ride|get.*ride")), ("create_alarm", re.compile(r"\balarm\b|wake me up|wake up at|wake up tomorrow")), ( "calendar_create_event", re.compile(r"calendar|appointment|schedule.*meeting|book.*appointment"), ), ("bluetooth_enable", re.compile(r"bluetooth")), ("wifi_enable", re.compile(r"\bwifi\b|wlan")), ("camera_take_photo", re.compile(r"camera|take a photo|take a picture|selfie|snap a")), ] for skill, pattern in rules: if pattern.search(text): return skill if "spotify" in text: return "spotify_play_playlist" return None def resolve_skill(raw_skill: str | None, prompt: str) -> str | None: """Return a valid skill from model output, aliases, or prompt keywords.""" inferred = infer_skill_from_prompt(prompt) normalized = normalize_skill(raw_skill) text = prompt.lower() bare_message = re.search(r"\b(send|message|text)\s+\w+\s+a\s+message\b", text) if ( inferred == "whatsapp_send_message" and bare_message and "gmail" not in text and "email" not in text and "@" not in prompt and normalized == "gmail_send_email" ): return inferred if normalized: return normalized return inferred def _parse_json_payload(text: str) -> dict | None: text = text.strip() if not text: return None start = text.find("{") end = text.rfind("}") if start == -1 or end == -1 or end <= start: return None try: payload = json.loads(text[start : end + 1]) except json.JSONDecodeError: return None return payload if isinstance(payload, dict) else None def extract_intent(text: str) -> dict | None: """Extract skill and parameters from model output JSON.""" payload = _parse_json_payload(text) if not payload: return None skill = payload.get("skill") if not isinstance(skill, str) or not skill: return None parameters = payload.get("parameters", {}) if parameters is None: parameters = {} if not isinstance(parameters, dict): return None return {"skill": skill, "parameters": parameters} def format_intent_json(skill: str, parameters: dict | None = None) -> str: payload: dict = {"skill": skill, "parameters": parameters or {}} return json.dumps(payload, separators=(",", ":")) def normalize_param(value: str) -> str: return " ".join(value.lower().strip().split()) def parameters_match( predicted: dict, expected: dict, *, required_only: bool = True, ) -> bool: """Compare parameter dicts with normalized lowercase string matching.""" for key, expected_value in expected.items(): if required_only and expected_value is None: continue predicted_value = predicted.get(key) if predicted_value is None: return False if normalize_param(str(predicted_value)) != normalize_param(str(expected_value)): return False return True def intent_matches( predicted: dict | None, expected_skill: str, expected_parameters: dict | None = None, ) -> bool: """Check if predicted intent matches expected skill and parameters.""" if not predicted: return False predicted_skill = normalize_skill(predicted.get("skill")) or predicted.get("skill") if predicted_skill != expected_skill: return False if not expected_parameters: return True return parameters_match( predicted.get("parameters", {}), expected_parameters, ) def format_skill_json(skill: str) -> str: return json.dumps({"skill": skill}, separators=(",", ":"))