android-skill-router / src /skill_utils.py
kriyanshi's picture
Ship v2 intent extraction with API, demo UI, eval, and benchmark suite.
40a90bb
Raw
History Blame Contribute Delete
7.31 kB
"""Shared helpers for skill classification output."""
from __future__ import annotations
import json
import re
VALID_SKILLS: frozenset[str] = frozenset(
{
"bluetooth_enable",
"calendar_create_event",
"camera_take_photo",
"contacts_search",
"create_alarm",
"gmail_send_email",
"linkedin_search_person",
"slack_open_channel",
"spotify_pause",
"spotify_play_playlist",
"spotify_search_play",
"uber_request_ride",
"whatsapp_send_message",
"wifi_enable",
"youtube_search",
}
)
# Model sometimes invents close-but-wrong skill names.
SKILL_ALIASES: dict[str, str] = {
"alarm_set": "create_alarm",
"set_alarm": "create_alarm",
"alarm_create": "create_alarm",
"create_alarms": "create_alarm",
"wake_up_alarm": "create_alarm",
"send_message": "whatsapp_send_message",
"send_whatsapp": "whatsapp_send_message",
"whatsapp_message": "whatsapp_send_message",
"send_email": "gmail_send_email",
"compose_email": "gmail_send_email",
"open_slack": "slack_open_channel",
"slack_channel": "slack_open_channel",
"search_contacts": "contacts_search",
"contact_search": "contacts_search",
"spotify_create_playlist": "spotify_play_playlist",
}
def extract_skill(text: str) -> str | None:
"""Extract the skill name from model output JSON."""
text = text.strip()
if not text:
return None
match = re.search(r'\{[^{}]*"skill"\s*:\s*"([^"]+)"[^{}]*\}', text)
if match:
return match.group(1)
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
try:
payload = json.loads(text[start : end + 1])
except json.JSONDecodeError:
return None
skill = payload.get("skill")
return skill if isinstance(skill, str) and skill else None
def normalize_skill(raw_skill: str | None) -> str | None:
"""Map a raw model skill label to a known skill, if possible."""
if not raw_skill or not isinstance(raw_skill, str):
return None
skill = raw_skill.strip().lower().replace("-", "_").replace(" ", "_")
if skill in VALID_SKILLS:
return skill
if skill in SKILL_ALIASES:
return SKILL_ALIASES[skill]
return None
def infer_skill_from_prompt(prompt: str) -> str | None:
"""Keyword fallback when the model returns an unknown skill."""
text = prompt.lower().strip()
if not text:
return None
if re.search(r"gmail|@\w+\.\w+|send\s+(an?\s+)?email|compose\s+email|write\s+mail", text):
return "gmail_send_email"
if "whatsapp" in text or re.search(
r"\b(send|message|text)\s+\w+\s+a\s+message\b", text
):
return "whatsapp_send_message"
rules: list[tuple[str, re.Pattern[str]]] = [
("contacts_search", re.compile(r"\bcontacts\b|phone book|address book")),
("slack_open_channel", re.compile(r"slack.*channel|channel.*slack|#\w+.*slack")),
("youtube_search", re.compile(r"youtube")),
("linkedin_search_person", re.compile(r"linkedin")),
(
"spotify_search_play",
re.compile(r"spotify.*(search|find).*(play|and play)|search.*spotify.*play"),
),
(
"spotify_play_playlist",
re.compile(r"playlist|liked songs|discover weekly|daily mix"),
),
("spotify_pause", re.compile(r"pause.*spotify|stop.*spotify|hold spotify")),
("uber_request_ride", re.compile(r"\buber\b|book.*ride|get.*ride")),
("create_alarm", re.compile(r"\balarm\b|wake me up|wake up at|wake up tomorrow")),
(
"calendar_create_event",
re.compile(r"calendar|appointment|schedule.*meeting|book.*appointment"),
),
("bluetooth_enable", re.compile(r"bluetooth")),
("wifi_enable", re.compile(r"\bwifi\b|wlan")),
("camera_take_photo", re.compile(r"camera|take a photo|take a picture|selfie|snap a")),
]
for skill, pattern in rules:
if pattern.search(text):
return skill
if "spotify" in text:
return "spotify_play_playlist"
return None
def resolve_skill(raw_skill: str | None, prompt: str) -> str | None:
"""Return a valid skill from model output, aliases, or prompt keywords."""
inferred = infer_skill_from_prompt(prompt)
normalized = normalize_skill(raw_skill)
text = prompt.lower()
bare_message = re.search(r"\b(send|message|text)\s+\w+\s+a\s+message\b", text)
if (
inferred == "whatsapp_send_message"
and bare_message
and "gmail" not in text
and "email" not in text
and "@" not in prompt
and normalized == "gmail_send_email"
):
return inferred
if normalized:
return normalized
return inferred
def _parse_json_payload(text: str) -> dict | None:
text = text.strip()
if not text:
return None
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
try:
payload = json.loads(text[start : end + 1])
except json.JSONDecodeError:
return None
return payload if isinstance(payload, dict) else None
def extract_intent(text: str) -> dict | None:
"""Extract skill and parameters from model output JSON."""
payload = _parse_json_payload(text)
if not payload:
return None
skill = payload.get("skill")
if not isinstance(skill, str) or not skill:
return None
parameters = payload.get("parameters", {})
if parameters is None:
parameters = {}
if not isinstance(parameters, dict):
return None
return {"skill": skill, "parameters": parameters}
def format_intent_json(skill: str, parameters: dict | None = None) -> str:
payload: dict = {"skill": skill, "parameters": parameters or {}}
return json.dumps(payload, separators=(",", ":"))
def normalize_param(value: str) -> str:
return " ".join(value.lower().strip().split())
def parameters_match(
predicted: dict,
expected: dict,
*,
required_only: bool = True,
) -> bool:
"""Compare parameter dicts with normalized lowercase string matching."""
for key, expected_value in expected.items():
if required_only and expected_value is None:
continue
predicted_value = predicted.get(key)
if predicted_value is None:
return False
if normalize_param(str(predicted_value)) != normalize_param(str(expected_value)):
return False
return True
def intent_matches(
predicted: dict | None,
expected_skill: str,
expected_parameters: dict | None = None,
) -> bool:
"""Check if predicted intent matches expected skill and parameters."""
if not predicted:
return False
predicted_skill = normalize_skill(predicted.get("skill")) or predicted.get("skill")
if predicted_skill != expected_skill:
return False
if not expected_parameters:
return True
return parameters_match(
predicted.get("parameters", {}),
expected_parameters,
)
def format_skill_json(skill: str) -> str:
return json.dumps({"skill": skill}, separators=(",", ":"))