| from dataclasses import dataclass |
|
|
| from chat_state import ConversationState |
| import sql_tools |
|
|
|
|
| SMALLTALK = "smalltalk" |
| CREATE_TABLE = "create_table" |
| EDIT_TABLE = "edit_table" |
| SQL_QUERY = "sql_query" |
| UNKNOWN = "unknown" |
|
|
| @dataclass(frozen=True) |
| class IntentResult: |
| intent: str |
| confidence: float |
| reason: str |
|
|
|
|
| def _has_active_schema(state): |
| return bool((getattr(state, "active_schema", "") or "").strip()) |
|
|
|
|
| def _is_smalltalk(message): |
| normalized = sql_tools.normalize_text(message) |
| exact = { |
| "oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite", |
| "obrigado", "obrigada", "valeu", "thanks", "thank you", |
| "tudo bem", "tudo bom", "tudo", "tchau", "xau", "ate mais", "ate logo", |
| "de nada", "por nada", "imagina", |
| "como voce esta", "como voce esta hoje", "qual seu nome", |
| "me conte uma piada", "conte uma piada", "vamos conversar", |
| "o que voce faz", "como voce funciona", "como funciona", |
| } |
| if normalized in exact: |
| return True |
| smalltalk_fragments = ( |
| "como voce esta", |
| "qual seu nome", |
| "conte uma piada", |
| "vamos conversar", |
| "obrigado", |
| "tudo bem", |
| "tudo bom", |
| ) |
| return any(fragment in normalized for fragment in smalltalk_fragments) |
|
|
|
|
| def classify_intent(message, state=None, chat_history=None): |
| state = ConversationState.from_value(state) |
| normalized = sql_tools.normalize_text(message) |
| if not normalized: |
| return IntentResult(UNKNOWN, 0.0, "empty_message") |
|
|
| if _is_smalltalk(message): |
| return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase") |
|
|
| schema_context = sql_tools.last_create_table_from_history(chat_history) or state.active_schema |
| if sql_tools.is_unsupported_data_mutation_for_schema(message, schema_context): |
| return IntentResult(UNKNOWN, 0.9, "unsupported_data_mutation") |
|
|
| edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema) |
| if edited_table or sql_tools.is_table_edit_intent(message): |
| return IntentResult( |
| EDIT_TABLE, |
| 0.9 if (edited_table or _has_active_schema(state)) else 0.7, |
| "table_edit_terms", |
| ) |
|
|
| if sql_tools.is_create_table_intent(message): |
| return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table") |
|
|
| if sql_tools.is_sql_intent(message, state.active_schema): |
| return IntentResult(SQL_QUERY, 0.86, "sql_query_terms") |
|
|
| return IntentResult(UNKNOWN, 0.25, "no_intent_match") |
|
|