File size: 7,309 Bytes
6524169 81b01a7 40a90bb 81b01a7 6524169 81b01a7 6524169 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | """Shared helpers for skill classification output."""
from __future__ import annotations
import json
import re
VALID_SKILLS: frozenset[str] = frozenset(
{
"bluetooth_enable",
"calendar_create_event",
"camera_take_photo",
"contacts_search",
"create_alarm",
"gmail_send_email",
"linkedin_search_person",
"slack_open_channel",
"spotify_pause",
"spotify_play_playlist",
"spotify_search_play",
"uber_request_ride",
"whatsapp_send_message",
"wifi_enable",
"youtube_search",
}
)
# Model sometimes invents close-but-wrong skill names.
SKILL_ALIASES: dict[str, str] = {
"alarm_set": "create_alarm",
"set_alarm": "create_alarm",
"alarm_create": "create_alarm",
"create_alarms": "create_alarm",
"wake_up_alarm": "create_alarm",
"send_message": "whatsapp_send_message",
"send_whatsapp": "whatsapp_send_message",
"whatsapp_message": "whatsapp_send_message",
"send_email": "gmail_send_email",
"compose_email": "gmail_send_email",
"open_slack": "slack_open_channel",
"slack_channel": "slack_open_channel",
"search_contacts": "contacts_search",
"contact_search": "contacts_search",
"spotify_create_playlist": "spotify_play_playlist",
}
def extract_skill(text: str) -> str | None:
"""Extract the skill name from model output JSON."""
text = text.strip()
if not text:
return None
match = re.search(r'\{[^{}]*"skill"\s*:\s*"([^"]+)"[^{}]*\}', text)
if match:
return match.group(1)
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
try:
payload = json.loads(text[start : end + 1])
except json.JSONDecodeError:
return None
skill = payload.get("skill")
return skill if isinstance(skill, str) and skill else None
def normalize_skill(raw_skill: str | None) -> str | None:
"""Map a raw model skill label to a known skill, if possible."""
if not raw_skill or not isinstance(raw_skill, str):
return None
skill = raw_skill.strip().lower().replace("-", "_").replace(" ", "_")
if skill in VALID_SKILLS:
return skill
if skill in SKILL_ALIASES:
return SKILL_ALIASES[skill]
return None
def infer_skill_from_prompt(prompt: str) -> str | None:
"""Keyword fallback when the model returns an unknown skill."""
text = prompt.lower().strip()
if not text:
return None
if re.search(r"gmail|@\w+\.\w+|send\s+(an?\s+)?email|compose\s+email|write\s+mail", text):
return "gmail_send_email"
if "whatsapp" in text or re.search(
r"\b(send|message|text)\s+\w+\s+a\s+message\b", text
):
return "whatsapp_send_message"
rules: list[tuple[str, re.Pattern[str]]] = [
("contacts_search", re.compile(r"\bcontacts\b|phone book|address book")),
("slack_open_channel", re.compile(r"slack.*channel|channel.*slack|#\w+.*slack")),
("youtube_search", re.compile(r"youtube")),
("linkedin_search_person", re.compile(r"linkedin")),
(
"spotify_search_play",
re.compile(r"spotify.*(search|find).*(play|and play)|search.*spotify.*play"),
),
(
"spotify_play_playlist",
re.compile(r"playlist|liked songs|discover weekly|daily mix"),
),
("spotify_pause", re.compile(r"pause.*spotify|stop.*spotify|hold spotify")),
("uber_request_ride", re.compile(r"\buber\b|book.*ride|get.*ride")),
("create_alarm", re.compile(r"\balarm\b|wake me up|wake up at|wake up tomorrow")),
(
"calendar_create_event",
re.compile(r"calendar|appointment|schedule.*meeting|book.*appointment"),
),
("bluetooth_enable", re.compile(r"bluetooth")),
("wifi_enable", re.compile(r"\bwifi\b|wlan")),
("camera_take_photo", re.compile(r"camera|take a photo|take a picture|selfie|snap a")),
]
for skill, pattern in rules:
if pattern.search(text):
return skill
if "spotify" in text:
return "spotify_play_playlist"
return None
def resolve_skill(raw_skill: str | None, prompt: str) -> str | None:
"""Return a valid skill from model output, aliases, or prompt keywords."""
inferred = infer_skill_from_prompt(prompt)
normalized = normalize_skill(raw_skill)
text = prompt.lower()
bare_message = re.search(r"\b(send|message|text)\s+\w+\s+a\s+message\b", text)
if (
inferred == "whatsapp_send_message"
and bare_message
and "gmail" not in text
and "email" not in text
and "@" not in prompt
and normalized == "gmail_send_email"
):
return inferred
if normalized:
return normalized
return inferred
def _parse_json_payload(text: str) -> dict | None:
text = text.strip()
if not text:
return None
start = text.find("{")
end = text.rfind("}")
if start == -1 or end == -1 or end <= start:
return None
try:
payload = json.loads(text[start : end + 1])
except json.JSONDecodeError:
return None
return payload if isinstance(payload, dict) else None
def extract_intent(text: str) -> dict | None:
"""Extract skill and parameters from model output JSON."""
payload = _parse_json_payload(text)
if not payload:
return None
skill = payload.get("skill")
if not isinstance(skill, str) or not skill:
return None
parameters = payload.get("parameters", {})
if parameters is None:
parameters = {}
if not isinstance(parameters, dict):
return None
return {"skill": skill, "parameters": parameters}
def format_intent_json(skill: str, parameters: dict | None = None) -> str:
payload: dict = {"skill": skill, "parameters": parameters or {}}
return json.dumps(payload, separators=(",", ":"))
def normalize_param(value: str) -> str:
return " ".join(value.lower().strip().split())
def parameters_match(
predicted: dict,
expected: dict,
*,
required_only: bool = True,
) -> bool:
"""Compare parameter dicts with normalized lowercase string matching."""
for key, expected_value in expected.items():
if required_only and expected_value is None:
continue
predicted_value = predicted.get(key)
if predicted_value is None:
return False
if normalize_param(str(predicted_value)) != normalize_param(str(expected_value)):
return False
return True
def intent_matches(
predicted: dict | None,
expected_skill: str,
expected_parameters: dict | None = None,
) -> bool:
"""Check if predicted intent matches expected skill and parameters."""
if not predicted:
return False
predicted_skill = normalize_skill(predicted.get("skill")) or predicted.get("skill")
if predicted_skill != expected_skill:
return False
if not expected_parameters:
return True
return parameters_match(
predicted.get("parameters", {}),
expected_parameters,
)
def format_skill_json(skill: str) -> str:
return json.dumps({"skill": skill}, separators=(",", ":"))
|