Shizu0n's picture
fix: unsupported data mutation routing
92923b1
from dataclasses import dataclass
from chat_state import ConversationState
import sql_tools
SMALLTALK = "smalltalk"
CREATE_TABLE = "create_table"
EDIT_TABLE = "edit_table"
SQL_QUERY = "sql_query"
UNKNOWN = "unknown"
@dataclass(frozen=True)
class IntentResult:
intent: str
confidence: float
reason: str
def _has_active_schema(state):
return bool((getattr(state, "active_schema", "") or "").strip())
def _is_smalltalk(message):
normalized = sql_tools.normalize_text(message)
exact = {
"oi", "ola", "hi", "hello", "hey", "bom dia", "boa tarde", "boa noite",
"obrigado", "obrigada", "valeu", "thanks", "thank you",
"tudo bem", "tudo bom", "tudo", "tchau", "xau", "ate mais", "ate logo",
"de nada", "por nada", "imagina",
"como voce esta", "como voce esta hoje", "qual seu nome",
"me conte uma piada", "conte uma piada", "vamos conversar",
"o que voce faz", "como voce funciona", "como funciona",
}
if normalized in exact:
return True
smalltalk_fragments = (
"como voce esta",
"qual seu nome",
"conte uma piada",
"vamos conversar",
"obrigado",
"tudo bem",
"tudo bom",
)
return any(fragment in normalized for fragment in smalltalk_fragments)
def classify_intent(message, state=None, chat_history=None):
state = ConversationState.from_value(state)
normalized = sql_tools.normalize_text(message)
if not normalized:
return IntentResult(UNKNOWN, 0.0, "empty_message")
if _is_smalltalk(message):
return IntentResult(SMALLTALK, 0.95, "smalltalk_phrase")
schema_context = sql_tools.last_create_table_from_history(chat_history) or state.active_schema
if sql_tools.is_unsupported_data_mutation_for_schema(message, schema_context):
return IntentResult(UNKNOWN, 0.9, "unsupported_data_mutation")
edited_table = sql_tools.edit_create_table_from_message(message, chat_history, state.active_schema)
if edited_table or sql_tools.is_table_edit_intent(message):
return IntentResult(
EDIT_TABLE,
0.9 if (edited_table or _has_active_schema(state)) else 0.7,
"table_edit_terms",
)
if sql_tools.is_create_table_intent(message):
return IntentResult(CREATE_TABLE, 0.9, "explicit_create_table")
if sql_tools.is_sql_intent(message, state.active_schema):
return IntentResult(SQL_QUERY, 0.86, "sql_query_terms")
return IntentResult(UNKNOWN, 0.25, "no_intent_match")