abhishekrn's picture
history, follow-up
9a6a5aa
from typing import List, Dict, Any
import json
import re
from src.domain.models.conversation import QueryClassification, ConversationContext, Entity
from src.infrastructure.providers.llm_provider import LLMClient
from src.application.services.dynamic_pattern_manager import dynamic_pattern_manager
from config import settings
from config.conversation_config import conversation_config
class QueryClassifier:
"""Classifies user queries to determine if they need new tool data or can be answered from history."""
def __init__(self):
self.llm_client = LLMClient()
async def classify_query(self, user_message: str, conversation_context: ConversationContext) -> QueryClassification:
"""Classify a user query using LLM assistance and rule-based logic."""
# First, try rule-based classification for obvious cases
rule_based_result = self._rule_based_classification(user_message, conversation_context)
if rule_based_result.confidence > 0.8:
return rule_based_result
# Use LLM for more nuanced classification
return await self._llm_classification(user_message, conversation_context, rule_based_result)
def _rule_based_classification(self, user_message: str, conversation_context: ConversationContext) -> QueryClassification:
"""Fast rule-based classification for obvious cases."""
message_lower = user_message.lower()
# Get patterns dynamically from pattern manager
dynamic_pattern_manager.refresh_if_needed()
history_indicators = dynamic_pattern_manager.get_history_patterns()
tool_indicators = dynamic_pattern_manager.get_tool_patterns()
pronoun_patterns = dynamic_pattern_manager.get_pronoun_patterns()
# Check for pronouns without clear antecedents (suggesting reference to history)
has_pronouns = any(bool(re.search(pattern, message_lower)) for pattern in pronoun_patterns)
# Calculate scores
history_score = sum(1 for pattern in history_indicators if re.search(pattern, message_lower))
tool_score = sum(1 for pattern in tool_indicators if re.search(pattern, message_lower))
# Get configuration
config = conversation_config.query_classification
# Check if there's relevant history
has_history = len(conversation_context.turns) > 0
recent_tool_usage = any(turn.tool_used for turn in conversation_context.get_recent_turns(config.recent_turns_check))
# Classification logic using configurable thresholds
if history_score > 0 and has_history:
confidence = min(config.max_confidence, config.min_history_confidence + (history_score * config.confidence_increment))
return QueryClassification(
query_type="history",
confidence=confidence,
reasoning=f"Found {history_score} dynamic history indicators and conversation history exists"
)
if tool_score > 0 or not has_history:
confidence = min(config.max_confidence, config.min_history_confidence + 0.1 + (tool_score * config.confidence_increment))
return QueryClassification(
query_type="tool",
confidence=confidence,
reasoning=f"Found {tool_score} dynamic tool indicators or no conversation history",
needs_tool_data=True
)
if has_pronouns and recent_tool_usage:
return QueryClassification(
query_type="history",
confidence=config.pronoun_confidence,
reasoning="Contains pronouns and recent tool usage suggests reference to history"
)
# Default to tool query with low confidence
return QueryClassification(
query_type="tool",
confidence=config.default_confidence,
reasoning="Uncertain classification, defaulting to tool query",
needs_tool_data=True
)
async def _llm_classification(self, user_message: str, conversation_context: ConversationContext,
rule_based_result: QueryClassification) -> QueryClassification:
"""Use LLM for sophisticated query classification."""
if not settings.HF_TOKEN:
# Fallback to rule-based if no LLM available
return rule_based_result
# Prepare context for LLM
recent_context = self._prepare_context_for_llm(conversation_context)
classification_prompt = self._build_classification_prompt(user_message, recent_context, rule_based_result)
try:
# Get LLM decision
response = await self.llm_client.complete_json(classification_prompt)
return QueryClassification(
query_type=response.get("query_type", rule_based_result.query_type),
confidence=response.get("confidence", rule_based_result.confidence),
reasoning=response.get("reasoning", rule_based_result.reasoning),
referenced_entities=response.get("referenced_entities", []),
needs_tool_data=response.get("needs_tool_data", rule_based_result.needs_tool_data)
)
except Exception as e:
print(f"LLM classification failed: {e}")
# Fallback to rule-based result
return rule_based_result
def _prepare_context_for_llm(self, conversation_context: ConversationContext) -> str:
"""Prepare conversation context for LLM analysis."""
recent_turns = conversation_context.get_recent_turns(5)
if not recent_turns:
return "No previous conversation history."
context_parts = []
for i, turn in enumerate(recent_turns, 1):
turn_summary = f"Turn {i}:\n"
turn_summary += f" User: {turn.user_message}\n"
if turn.tool_used:
turn_summary += f" Tool used: {turn.tool_used}\n"
if turn.tool_params:
turn_summary += f" Parameters: {json.dumps(turn.tool_params, separators=(',', ':'))}\n"
turn_summary += f" Result: {turn.response_summary}\n"
else:
turn_summary += f" Response: {turn.full_response[:100]}...\n"
if turn.extracted_entities:
entities = [f"{e.type}:{e.name}" for e in turn.extracted_entities]
turn_summary += f" Entities: {', '.join(entities)}\n"
context_parts.append(turn_summary)
return "\n".join(context_parts)
def _build_classification_prompt(self, user_message: str, context: str, rule_based_result: QueryClassification) -> str:
"""Build the LLM prompt for query classification."""
return f"""You are a query classification assistant for a Topcoder MCP agent.
Your task is to determine if a user query can be answered from conversation history or needs new external data.
CONVERSATION CONTEXT:
{context}
USER'S NEW MESSAGE: "{user_message}"
RULE-BASED ANALYSIS:
- Preliminary classification: {rule_based_result.query_type}
- Confidence: {rule_based_result.confidence}
- Reasoning: {rule_based_result.reasoning}
CLASSIFICATION TYPES:
1. "history" - Can be answered from conversation context (previous results, mentioned entities, etc.)
2. "tool" - Needs new data from external APIs/tools
3. "chat" - General conversation, greetings, or explanations
ANALYZE THE QUERY FOR:
- References to previous results ("the last challenge", "that user", "it")
- Temporal references ("what you just told me", "from before")
- Pronouns without clear new antecedents
- Requests for new/different/more data vs clarification of existing data
- Entity references that were mentioned in previous conversation
Respond ONLY with a JSON object:
{{
"query_type": "history|tool|chat",
"confidence": 0.0-1.0,
"reasoning": "Brief explanation of your decision",
"referenced_entities": ["entity1", "entity2"],
"needs_tool_data": true/false
}}
Be especially careful to identify when users are asking about previously retrieved data vs requesting new data."""
def extract_entities(self, text: str, previous_entities: List[Entity] = None) -> List[Entity]:
"""Extract entities from text that might be referenced later."""
entities = []
if not conversation_config.entity_extraction.use_dynamic_patterns:
return entities
# Get patterns dynamically from pattern manager
dynamic_pattern_manager.refresh_if_needed()
patterns = dynamic_pattern_manager.get_entity_patterns()
config = conversation_config.entity_extraction
for pattern_name, pattern in patterns.items():
try:
matches = re.finditer(pattern, text, re.IGNORECASE)
for match in matches:
# Get the first non-None group
value = next((group for group in match.groups() if group), None)
if value:
# Determine entity type based on pattern name
entity_type = self._determine_entity_type(pattern_name, value)
entities.append(Entity(
name=value.strip(),
type=entity_type,
value=value.strip(),
confidence=config.default_entity_confidence
))
except Exception as e:
print(f"Error in pattern {pattern_name}: {e}")
continue
return entities
def _determine_entity_type(self, pattern_name: str, value: str) -> str:
"""Determine entity type from pattern name and value."""
pattern_lower = pattern_name.lower()
if "id" in pattern_lower:
if "challenge" in pattern_lower:
return "challenge_id"
elif "skill" in pattern_lower:
return "skill_id"
elif "member" in pattern_lower or "user" in pattern_lower:
return "user_id"
else:
return "entity_id"
elif "name" in pattern_lower:
if "challenge" in pattern_lower:
return "challenge_name"
elif "skill" in pattern_lower:
return "skill_name"
else:
return "entity_name"
elif "handle" in pattern_lower:
return "user_handle"
else:
# Try to infer from value pattern
if value.isdigit():
return "entity_id"
elif len(value) <= 30 and "_" in value:
return "user_handle"
else:
return "entity_name"