File size: 7,309 Bytes
6524169
 
 
 
 
 
 
81b01a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40a90bb
81b01a7
 
6524169
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81b01a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6524169
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
"""Shared helpers for skill classification output."""

from __future__ import annotations

import json
import re

VALID_SKILLS: frozenset[str] = frozenset(
    {
        "bluetooth_enable",
        "calendar_create_event",
        "camera_take_photo",
        "contacts_search",
        "create_alarm",
        "gmail_send_email",
        "linkedin_search_person",
        "slack_open_channel",
        "spotify_pause",
        "spotify_play_playlist",
        "spotify_search_play",
        "uber_request_ride",
        "whatsapp_send_message",
        "wifi_enable",
        "youtube_search",
    }
)

# Model sometimes invents close-but-wrong skill names.
SKILL_ALIASES: dict[str, str] = {
    "alarm_set": "create_alarm",
    "set_alarm": "create_alarm",
    "alarm_create": "create_alarm",
    "create_alarms": "create_alarm",
    "wake_up_alarm": "create_alarm",
    "send_message": "whatsapp_send_message",
    "send_whatsapp": "whatsapp_send_message",
    "whatsapp_message": "whatsapp_send_message",
    "send_email": "gmail_send_email",
    "compose_email": "gmail_send_email",
    "open_slack": "slack_open_channel",
    "slack_channel": "slack_open_channel",
    "search_contacts": "contacts_search",
    "contact_search": "contacts_search",
    "spotify_create_playlist": "spotify_play_playlist",
}


def extract_skill(text: str) -> str | None:
    """Extract the skill name from model output JSON."""
    text = text.strip()
    if not text:
        return None

    match = re.search(r'\{[^{}]*"skill"\s*:\s*"([^"]+)"[^{}]*\}', text)
    if match:
        return match.group(1)

    start = text.find("{")
    end = text.rfind("}")
    if start == -1 or end == -1 or end <= start:
        return None

    try:
        payload = json.loads(text[start : end + 1])
    except json.JSONDecodeError:
        return None

    skill = payload.get("skill")
    return skill if isinstance(skill, str) and skill else None


def normalize_skill(raw_skill: str | None) -> str | None:
    """Map a raw model skill label to a known skill, if possible."""
    if not raw_skill or not isinstance(raw_skill, str):
        return None

    skill = raw_skill.strip().lower().replace("-", "_").replace(" ", "_")
    if skill in VALID_SKILLS:
        return skill
    if skill in SKILL_ALIASES:
        return SKILL_ALIASES[skill]
    return None


def infer_skill_from_prompt(prompt: str) -> str | None:
    """Keyword fallback when the model returns an unknown skill."""
    text = prompt.lower().strip()
    if not text:
        return None

    if re.search(r"gmail|@\w+\.\w+|send\s+(an?\s+)?email|compose\s+email|write\s+mail", text):
        return "gmail_send_email"

    if "whatsapp" in text or re.search(
        r"\b(send|message|text)\s+\w+\s+a\s+message\b", text
    ):
        return "whatsapp_send_message"

    rules: list[tuple[str, re.Pattern[str]]] = [
        ("contacts_search", re.compile(r"\bcontacts\b|phone book|address book")),
        ("slack_open_channel", re.compile(r"slack.*channel|channel.*slack|#\w+.*slack")),
        ("youtube_search", re.compile(r"youtube")),
        ("linkedin_search_person", re.compile(r"linkedin")),
        (
            "spotify_search_play",
            re.compile(r"spotify.*(search|find).*(play|and play)|search.*spotify.*play"),
        ),
        (
            "spotify_play_playlist",
            re.compile(r"playlist|liked songs|discover weekly|daily mix"),
        ),
        ("spotify_pause", re.compile(r"pause.*spotify|stop.*spotify|hold spotify")),
        ("uber_request_ride", re.compile(r"\buber\b|book.*ride|get.*ride")),
        ("create_alarm", re.compile(r"\balarm\b|wake me up|wake up at|wake up tomorrow")),
        (
            "calendar_create_event",
            re.compile(r"calendar|appointment|schedule.*meeting|book.*appointment"),
        ),
        ("bluetooth_enable", re.compile(r"bluetooth")),
        ("wifi_enable", re.compile(r"\bwifi\b|wlan")),
        ("camera_take_photo", re.compile(r"camera|take a photo|take a picture|selfie|snap a")),
    ]

    for skill, pattern in rules:
        if pattern.search(text):
            return skill

    if "spotify" in text:
        return "spotify_play_playlist"

    return None


def resolve_skill(raw_skill: str | None, prompt: str) -> str | None:
    """Return a valid skill from model output, aliases, or prompt keywords."""
    inferred = infer_skill_from_prompt(prompt)
    normalized = normalize_skill(raw_skill)

    text = prompt.lower()
    bare_message = re.search(r"\b(send|message|text)\s+\w+\s+a\s+message\b", text)
    if (
        inferred == "whatsapp_send_message"
        and bare_message
        and "gmail" not in text
        and "email" not in text
        and "@" not in prompt
        and normalized == "gmail_send_email"
    ):
        return inferred

    if normalized:
        return normalized
    return inferred


def _parse_json_payload(text: str) -> dict | None:
    text = text.strip()
    if not text:
        return None

    start = text.find("{")
    end = text.rfind("}")
    if start == -1 or end == -1 or end <= start:
        return None

    try:
        payload = json.loads(text[start : end + 1])
    except json.JSONDecodeError:
        return None

    return payload if isinstance(payload, dict) else None


def extract_intent(text: str) -> dict | None:
    """Extract skill and parameters from model output JSON."""
    payload = _parse_json_payload(text)
    if not payload:
        return None

    skill = payload.get("skill")
    if not isinstance(skill, str) or not skill:
        return None

    parameters = payload.get("parameters", {})
    if parameters is None:
        parameters = {}
    if not isinstance(parameters, dict):
        return None

    return {"skill": skill, "parameters": parameters}


def format_intent_json(skill: str, parameters: dict | None = None) -> str:
    payload: dict = {"skill": skill, "parameters": parameters or {}}
    return json.dumps(payload, separators=(",", ":"))


def normalize_param(value: str) -> str:
    return " ".join(value.lower().strip().split())


def parameters_match(
    predicted: dict,
    expected: dict,
    *,
    required_only: bool = True,
) -> bool:
    """Compare parameter dicts with normalized lowercase string matching."""
    for key, expected_value in expected.items():
        if required_only and expected_value is None:
            continue
        predicted_value = predicted.get(key)
        if predicted_value is None:
            return False
        if normalize_param(str(predicted_value)) != normalize_param(str(expected_value)):
            return False
    return True


def intent_matches(
    predicted: dict | None,
    expected_skill: str,
    expected_parameters: dict | None = None,
) -> bool:
    """Check if predicted intent matches expected skill and parameters."""
    if not predicted:
        return False

    predicted_skill = normalize_skill(predicted.get("skill")) or predicted.get("skill")
    if predicted_skill != expected_skill:
        return False
    if not expected_parameters:
        return True
    return parameters_match(
        predicted.get("parameters", {}),
        expected_parameters,
    )


def format_skill_json(skill: str) -> str:
    return json.dumps({"skill": skill}, separators=(",", ":"))