| """Shared helpers for skill classification output.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import re |
|
|
| VALID_SKILLS: frozenset[str] = frozenset( |
| { |
| "bluetooth_enable", |
| "calendar_create_event", |
| "camera_take_photo", |
| "contacts_search", |
| "create_alarm", |
| "gmail_send_email", |
| "linkedin_search_person", |
| "slack_open_channel", |
| "spotify_pause", |
| "spotify_play_playlist", |
| "spotify_search_play", |
| "uber_request_ride", |
| "whatsapp_send_message", |
| "wifi_enable", |
| "youtube_search", |
| } |
| ) |
|
|
| |
| SKILL_ALIASES: dict[str, str] = { |
| "alarm_set": "create_alarm", |
| "set_alarm": "create_alarm", |
| "alarm_create": "create_alarm", |
| "create_alarms": "create_alarm", |
| "wake_up_alarm": "create_alarm", |
| "send_message": "whatsapp_send_message", |
| "send_whatsapp": "whatsapp_send_message", |
| "whatsapp_message": "whatsapp_send_message", |
| "send_email": "gmail_send_email", |
| "compose_email": "gmail_send_email", |
| "open_slack": "slack_open_channel", |
| "slack_channel": "slack_open_channel", |
| "search_contacts": "contacts_search", |
| "contact_search": "contacts_search", |
| "spotify_create_playlist": "spotify_play_playlist", |
| } |
|
|
|
|
| def extract_skill(text: str) -> str | None: |
| """Extract the skill name from model output JSON.""" |
| text = text.strip() |
| if not text: |
| return None |
|
|
| match = re.search(r'\{[^{}]*"skill"\s*:\s*"([^"]+)"[^{}]*\}', text) |
| if match: |
| return match.group(1) |
|
|
| start = text.find("{") |
| end = text.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| return None |
|
|
| try: |
| payload = json.loads(text[start : end + 1]) |
| except json.JSONDecodeError: |
| return None |
|
|
| skill = payload.get("skill") |
| return skill if isinstance(skill, str) and skill else None |
|
|
|
|
| def normalize_skill(raw_skill: str | None) -> str | None: |
| """Map a raw model skill label to a known skill, if possible.""" |
| if not raw_skill or not isinstance(raw_skill, str): |
| return None |
|
|
| skill = raw_skill.strip().lower().replace("-", "_").replace(" ", "_") |
| if skill in VALID_SKILLS: |
| return skill |
| if skill in SKILL_ALIASES: |
| return SKILL_ALIASES[skill] |
| return None |
|
|
|
|
| def infer_skill_from_prompt(prompt: str) -> str | None: |
| """Keyword fallback when the model returns an unknown skill.""" |
| text = prompt.lower().strip() |
| if not text: |
| return None |
|
|
| if re.search(r"gmail|@\w+\.\w+|send\s+(an?\s+)?email|compose\s+email|write\s+mail", text): |
| return "gmail_send_email" |
|
|
| if "whatsapp" in text or re.search( |
| r"\b(send|message|text)\s+\w+\s+a\s+message\b", text |
| ): |
| return "whatsapp_send_message" |
|
|
| rules: list[tuple[str, re.Pattern[str]]] = [ |
| ("contacts_search", re.compile(r"\bcontacts\b|phone book|address book")), |
| ("slack_open_channel", re.compile(r"slack.*channel|channel.*slack|#\w+.*slack")), |
| ("youtube_search", re.compile(r"youtube")), |
| ("linkedin_search_person", re.compile(r"linkedin")), |
| ( |
| "spotify_search_play", |
| re.compile(r"spotify.*(search|find).*(play|and play)|search.*spotify.*play"), |
| ), |
| ( |
| "spotify_play_playlist", |
| re.compile(r"playlist|liked songs|discover weekly|daily mix"), |
| ), |
| ("spotify_pause", re.compile(r"pause.*spotify|stop.*spotify|hold spotify")), |
| ("uber_request_ride", re.compile(r"\buber\b|book.*ride|get.*ride")), |
| ("create_alarm", re.compile(r"\balarm\b|wake me up|wake up at|wake up tomorrow")), |
| ( |
| "calendar_create_event", |
| re.compile(r"calendar|appointment|schedule.*meeting|book.*appointment"), |
| ), |
| ("bluetooth_enable", re.compile(r"bluetooth")), |
| ("wifi_enable", re.compile(r"\bwifi\b|wlan")), |
| ("camera_take_photo", re.compile(r"camera|take a photo|take a picture|selfie|snap a")), |
| ] |
|
|
| for skill, pattern in rules: |
| if pattern.search(text): |
| return skill |
|
|
| if "spotify" in text: |
| return "spotify_play_playlist" |
|
|
| return None |
|
|
|
|
| def resolve_skill(raw_skill: str | None, prompt: str) -> str | None: |
| """Return a valid skill from model output, aliases, or prompt keywords.""" |
| inferred = infer_skill_from_prompt(prompt) |
| normalized = normalize_skill(raw_skill) |
|
|
| text = prompt.lower() |
| bare_message = re.search(r"\b(send|message|text)\s+\w+\s+a\s+message\b", text) |
| if ( |
| inferred == "whatsapp_send_message" |
| and bare_message |
| and "gmail" not in text |
| and "email" not in text |
| and "@" not in prompt |
| and normalized == "gmail_send_email" |
| ): |
| return inferred |
|
|
| if normalized: |
| return normalized |
| return inferred |
|
|
|
|
| def _parse_json_payload(text: str) -> dict | None: |
| text = text.strip() |
| if not text: |
| return None |
|
|
| start = text.find("{") |
| end = text.rfind("}") |
| if start == -1 or end == -1 or end <= start: |
| return None |
|
|
| try: |
| payload = json.loads(text[start : end + 1]) |
| except json.JSONDecodeError: |
| return None |
|
|
| return payload if isinstance(payload, dict) else None |
|
|
|
|
| def extract_intent(text: str) -> dict | None: |
| """Extract skill and parameters from model output JSON.""" |
| payload = _parse_json_payload(text) |
| if not payload: |
| return None |
|
|
| skill = payload.get("skill") |
| if not isinstance(skill, str) or not skill: |
| return None |
|
|
| parameters = payload.get("parameters", {}) |
| if parameters is None: |
| parameters = {} |
| if not isinstance(parameters, dict): |
| return None |
|
|
| return {"skill": skill, "parameters": parameters} |
|
|
|
|
| def format_intent_json(skill: str, parameters: dict | None = None) -> str: |
| payload: dict = {"skill": skill, "parameters": parameters or {}} |
| return json.dumps(payload, separators=(",", ":")) |
|
|
|
|
| def normalize_param(value: str) -> str: |
| return " ".join(value.lower().strip().split()) |
|
|
|
|
| def parameters_match( |
| predicted: dict, |
| expected: dict, |
| *, |
| required_only: bool = True, |
| ) -> bool: |
| """Compare parameter dicts with normalized lowercase string matching.""" |
| for key, expected_value in expected.items(): |
| if required_only and expected_value is None: |
| continue |
| predicted_value = predicted.get(key) |
| if predicted_value is None: |
| return False |
| if normalize_param(str(predicted_value)) != normalize_param(str(expected_value)): |
| return False |
| return True |
|
|
|
|
| def intent_matches( |
| predicted: dict | None, |
| expected_skill: str, |
| expected_parameters: dict | None = None, |
| ) -> bool: |
| """Check if predicted intent matches expected skill and parameters.""" |
| if not predicted: |
| return False |
|
|
| predicted_skill = normalize_skill(predicted.get("skill")) or predicted.get("skill") |
| if predicted_skill != expected_skill: |
| return False |
| if not expected_parameters: |
| return True |
| return parameters_match( |
| predicted.get("parameters", {}), |
| expected_parameters, |
| ) |
|
|
|
|
| def format_skill_json(skill: str) -> str: |
| return json.dumps({"skill": skill}, separators=(",", ":")) |
|
|