Spaces:
Sleeping
Sleeping
| from typing import List, Dict, Any | |
| import json | |
| import re | |
| from src.domain.models.conversation import QueryClassification, ConversationContext, Entity | |
| from src.infrastructure.providers.llm_provider import LLMClient | |
| from src.application.services.dynamic_pattern_manager import dynamic_pattern_manager | |
| from config import settings | |
| from config.conversation_config import conversation_config | |
| class QueryClassifier: | |
| """Classifies user queries to determine if they need new tool data or can be answered from history.""" | |
| def __init__(self): | |
| self.llm_client = LLMClient() | |
| async def classify_query(self, user_message: str, conversation_context: ConversationContext) -> QueryClassification: | |
| """Classify a user query using LLM assistance and rule-based logic.""" | |
| # First, try rule-based classification for obvious cases | |
| rule_based_result = self._rule_based_classification(user_message, conversation_context) | |
| if rule_based_result.confidence > 0.8: | |
| return rule_based_result | |
| # Use LLM for more nuanced classification | |
| return await self._llm_classification(user_message, conversation_context, rule_based_result) | |
| def _rule_based_classification(self, user_message: str, conversation_context: ConversationContext) -> QueryClassification: | |
| """Fast rule-based classification for obvious cases.""" | |
| message_lower = user_message.lower() | |
| # Get patterns dynamically from pattern manager | |
| dynamic_pattern_manager.refresh_if_needed() | |
| history_indicators = dynamic_pattern_manager.get_history_patterns() | |
| tool_indicators = dynamic_pattern_manager.get_tool_patterns() | |
| pronoun_patterns = dynamic_pattern_manager.get_pronoun_patterns() | |
| # Check for pronouns without clear antecedents (suggesting reference to history) | |
| has_pronouns = any(bool(re.search(pattern, message_lower)) for pattern in pronoun_patterns) | |
| # Calculate scores | |
| history_score = sum(1 for pattern in history_indicators if re.search(pattern, message_lower)) | |
| tool_score = sum(1 for pattern in tool_indicators if re.search(pattern, message_lower)) | |
| # Get configuration | |
| config = conversation_config.query_classification | |
| # Check if there's relevant history | |
| has_history = len(conversation_context.turns) > 0 | |
| recent_tool_usage = any(turn.tool_used for turn in conversation_context.get_recent_turns(config.recent_turns_check)) | |
| # Classification logic using configurable thresholds | |
| if history_score > 0 and has_history: | |
| confidence = min(config.max_confidence, config.min_history_confidence + (history_score * config.confidence_increment)) | |
| return QueryClassification( | |
| query_type="history", | |
| confidence=confidence, | |
| reasoning=f"Found {history_score} dynamic history indicators and conversation history exists" | |
| ) | |
| if tool_score > 0 or not has_history: | |
| confidence = min(config.max_confidence, config.min_history_confidence + 0.1 + (tool_score * config.confidence_increment)) | |
| return QueryClassification( | |
| query_type="tool", | |
| confidence=confidence, | |
| reasoning=f"Found {tool_score} dynamic tool indicators or no conversation history", | |
| needs_tool_data=True | |
| ) | |
| if has_pronouns and recent_tool_usage: | |
| return QueryClassification( | |
| query_type="history", | |
| confidence=config.pronoun_confidence, | |
| reasoning="Contains pronouns and recent tool usage suggests reference to history" | |
| ) | |
| # Default to tool query with low confidence | |
| return QueryClassification( | |
| query_type="tool", | |
| confidence=config.default_confidence, | |
| reasoning="Uncertain classification, defaulting to tool query", | |
| needs_tool_data=True | |
| ) | |
| async def _llm_classification(self, user_message: str, conversation_context: ConversationContext, | |
| rule_based_result: QueryClassification) -> QueryClassification: | |
| """Use LLM for sophisticated query classification.""" | |
| if not settings.HF_TOKEN: | |
| # Fallback to rule-based if no LLM available | |
| return rule_based_result | |
| # Prepare context for LLM | |
| recent_context = self._prepare_context_for_llm(conversation_context) | |
| classification_prompt = self._build_classification_prompt(user_message, recent_context, rule_based_result) | |
| try: | |
| # Get LLM decision | |
| response = await self.llm_client.complete_json(classification_prompt) | |
| return QueryClassification( | |
| query_type=response.get("query_type", rule_based_result.query_type), | |
| confidence=response.get("confidence", rule_based_result.confidence), | |
| reasoning=response.get("reasoning", rule_based_result.reasoning), | |
| referenced_entities=response.get("referenced_entities", []), | |
| needs_tool_data=response.get("needs_tool_data", rule_based_result.needs_tool_data) | |
| ) | |
| except Exception as e: | |
| print(f"LLM classification failed: {e}") | |
| # Fallback to rule-based result | |
| return rule_based_result | |
| def _prepare_context_for_llm(self, conversation_context: ConversationContext) -> str: | |
| """Prepare conversation context for LLM analysis.""" | |
| recent_turns = conversation_context.get_recent_turns(5) | |
| if not recent_turns: | |
| return "No previous conversation history." | |
| context_parts = [] | |
| for i, turn in enumerate(recent_turns, 1): | |
| turn_summary = f"Turn {i}:\n" | |
| turn_summary += f" User: {turn.user_message}\n" | |
| if turn.tool_used: | |
| turn_summary += f" Tool used: {turn.tool_used}\n" | |
| if turn.tool_params: | |
| turn_summary += f" Parameters: {json.dumps(turn.tool_params, separators=(',', ':'))}\n" | |
| turn_summary += f" Result: {turn.response_summary}\n" | |
| else: | |
| turn_summary += f" Response: {turn.full_response[:100]}...\n" | |
| if turn.extracted_entities: | |
| entities = [f"{e.type}:{e.name}" for e in turn.extracted_entities] | |
| turn_summary += f" Entities: {', '.join(entities)}\n" | |
| context_parts.append(turn_summary) | |
| return "\n".join(context_parts) | |
| def _build_classification_prompt(self, user_message: str, context: str, rule_based_result: QueryClassification) -> str: | |
| """Build the LLM prompt for query classification.""" | |
| return f"""You are a query classification assistant for a Topcoder MCP agent. | |
| Your task is to determine if a user query can be answered from conversation history or needs new external data. | |
| CONVERSATION CONTEXT: | |
| {context} | |
| USER'S NEW MESSAGE: "{user_message}" | |
| RULE-BASED ANALYSIS: | |
| - Preliminary classification: {rule_based_result.query_type} | |
| - Confidence: {rule_based_result.confidence} | |
| - Reasoning: {rule_based_result.reasoning} | |
| CLASSIFICATION TYPES: | |
| 1. "history" - Can be answered from conversation context (previous results, mentioned entities, etc.) | |
| 2. "tool" - Needs new data from external APIs/tools | |
| 3. "chat" - General conversation, greetings, or explanations | |
| ANALYZE THE QUERY FOR: | |
| - References to previous results ("the last challenge", "that user", "it") | |
| - Temporal references ("what you just told me", "from before") | |
| - Pronouns without clear new antecedents | |
| - Requests for new/different/more data vs clarification of existing data | |
| - Entity references that were mentioned in previous conversation | |
| Respond ONLY with a JSON object: | |
| {{ | |
| "query_type": "history|tool|chat", | |
| "confidence": 0.0-1.0, | |
| "reasoning": "Brief explanation of your decision", | |
| "referenced_entities": ["entity1", "entity2"], | |
| "needs_tool_data": true/false | |
| }} | |
| Be especially careful to identify when users are asking about previously retrieved data vs requesting new data.""" | |
| def extract_entities(self, text: str, previous_entities: List[Entity] = None) -> List[Entity]: | |
| """Extract entities from text that might be referenced later.""" | |
| entities = [] | |
| if not conversation_config.entity_extraction.use_dynamic_patterns: | |
| return entities | |
| # Get patterns dynamically from pattern manager | |
| dynamic_pattern_manager.refresh_if_needed() | |
| patterns = dynamic_pattern_manager.get_entity_patterns() | |
| config = conversation_config.entity_extraction | |
| for pattern_name, pattern in patterns.items(): | |
| try: | |
| matches = re.finditer(pattern, text, re.IGNORECASE) | |
| for match in matches: | |
| # Get the first non-None group | |
| value = next((group for group in match.groups() if group), None) | |
| if value: | |
| # Determine entity type based on pattern name | |
| entity_type = self._determine_entity_type(pattern_name, value) | |
| entities.append(Entity( | |
| name=value.strip(), | |
| type=entity_type, | |
| value=value.strip(), | |
| confidence=config.default_entity_confidence | |
| )) | |
| except Exception as e: | |
| print(f"Error in pattern {pattern_name}: {e}") | |
| continue | |
| return entities | |
| def _determine_entity_type(self, pattern_name: str, value: str) -> str: | |
| """Determine entity type from pattern name and value.""" | |
| pattern_lower = pattern_name.lower() | |
| if "id" in pattern_lower: | |
| if "challenge" in pattern_lower: | |
| return "challenge_id" | |
| elif "skill" in pattern_lower: | |
| return "skill_id" | |
| elif "member" in pattern_lower or "user" in pattern_lower: | |
| return "user_id" | |
| else: | |
| return "entity_id" | |
| elif "name" in pattern_lower: | |
| if "challenge" in pattern_lower: | |
| return "challenge_name" | |
| elif "skill" in pattern_lower: | |
| return "skill_name" | |
| else: | |
| return "entity_name" | |
| elif "handle" in pattern_lower: | |
| return "user_handle" | |
| else: | |
| # Try to infer from value pattern | |
| if value.isdigit(): | |
| return "entity_id" | |
| elif len(value) <= 30 and "_" in value: | |
| return "user_handle" | |
| else: | |
| return "entity_name" |