File size: 2,549 Bytes
47affa0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad5be9b
 
47affa0
 
 
 
 
 
 
 
 
 
 
 
ad5be9b
 
47affa0
 
 
 
 
 
 
 
 
 
 
 
 
92923b1
 
 
 
47affa0
 
ad5be9b
 
 
 
 
47affa0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
from dataclasses import dataclass

from chat_state import ConversationState
import sql_tools


SMALLTALK = "smalltalk"
CREATE_TABLE = "create_table"
EDIT_TABLE = "edit_table"
SQL_QUERY = "sql_query"
UNKNOWN = "unknown"

@dataclass(frozen=True)
class IntentResult:
    intent: str
    confidence: float
    reason: str


def _has_active_schema(state):
    return bool((getattr(state, "active_schema", "") or "").strip())


def _is_smalltalk(message):
    normalized = sql_tools.normalize_text(message)
    exact = {
        "oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite",
        "obrigado", "obrigada", "valeu", "thanks", "thank you",
        "tudo bem", "tudo bom", "tudo", "tchau", "xau", "ate mais", "ate logo",
        "de nada", "por nada", "imagina",
        "como voce esta", "como voce esta hoje", "qual seu nome",
        "me conte uma piada", "conte uma piada", "vamos conversar",
        "o que voce faz", "como voce funciona", "como funciona",
    }
    if normalized in exact:
        return True
    smalltalk_fragments = (
        "como voce esta",
        "qual seu nome",
        "conte uma piada",
        "vamos conversar",
        "obrigado",
        "tudo bem",
        "tudo bom",
    )
    return any(fragment in normalized for fragment in smalltalk_fragments)


def classify_intent(message, state=None, chat_history=None):
    state = ConversationState.from_value(state)
    normalized = sql_tools.normalize_text(message)
    if not normalized:
        return IntentResult(UNKNOWN, 0.0, "empty_message")

    if _is_smalltalk(message):
        return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase")

    schema_context = sql_tools.last_create_table_from_history(chat_history) or state.active_schema
    if sql_tools.is_unsupported_data_mutation_for_schema(message, schema_context):
        return IntentResult(UNKNOWN, 0.9, "unsupported_data_mutation")

    edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema)
    if edited_table or sql_tools.is_table_edit_intent(message):
        return IntentResult(
            EDIT_TABLE,
            0.9 if (edited_table or _has_active_schema(state)) else 0.7,
            "table_edit_terms",
        )

    if sql_tools.is_create_table_intent(message):
        return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table")

    if sql_tools.is_sql_intent(message, state.active_schema):
        return IntentResult(SQL_QUERY, 0.86, "sql_query_terms")

    return IntentResult(UNKNOWN, 0.25, "no_intent_match")