File size: 11,054 Bytes
9a6a5aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
from typing import List, Dict, Any
import json
import re
from src.domain.models.conversation import QueryClassification, ConversationContext, Entity
from src.infrastructure.providers.llm_provider import LLMClient
from src.application.services.dynamic_pattern_manager import dynamic_pattern_manager
from config import settings
from config.conversation_config import conversation_config

class QueryClassifier:
    """Classifies user queries to determine if they need new tool data or can be answered from history."""
    
    def __init__(self):
        self.llm_client = LLMClient()
    
    async def classify_query(self, user_message: str, conversation_context: ConversationContext) -> QueryClassification:
        """Classify a user query using LLM assistance and rule-based logic."""
        
        # First, try rule-based classification for obvious cases
        rule_based_result = self._rule_based_classification(user_message, conversation_context)
        if rule_based_result.confidence > 0.8:
            return rule_based_result
        
        # Use LLM for more nuanced classification
        return await self._llm_classification(user_message, conversation_context, rule_based_result)
    
    def _rule_based_classification(self, user_message: str, conversation_context: ConversationContext) -> QueryClassification:
        """Fast rule-based classification for obvious cases."""
        message_lower = user_message.lower()
        
        # Get patterns dynamically from pattern manager
        dynamic_pattern_manager.refresh_if_needed()
        history_indicators = dynamic_pattern_manager.get_history_patterns()
        tool_indicators = dynamic_pattern_manager.get_tool_patterns()
        pronoun_patterns = dynamic_pattern_manager.get_pronoun_patterns()
        
        # Check for pronouns without clear antecedents (suggesting reference to history)
        has_pronouns = any(bool(re.search(pattern, message_lower)) for pattern in pronoun_patterns)
        
        # Calculate scores
        history_score = sum(1 for pattern in history_indicators if re.search(pattern, message_lower))
        tool_score = sum(1 for pattern in tool_indicators if re.search(pattern, message_lower))
        
        # Get configuration
        config = conversation_config.query_classification
        
        # Check if there's relevant history
        has_history = len(conversation_context.turns) > 0
        recent_tool_usage = any(turn.tool_used for turn in conversation_context.get_recent_turns(config.recent_turns_check))
        
        # Classification logic using configurable thresholds
        if history_score > 0 and has_history:
            confidence = min(config.max_confidence, config.min_history_confidence + (history_score * config.confidence_increment))
            return QueryClassification(
                query_type="history",
                confidence=confidence,
                reasoning=f"Found {history_score} dynamic history indicators and conversation history exists"
            )
        
        if tool_score > 0 or not has_history:
            confidence = min(config.max_confidence, config.min_history_confidence + 0.1 + (tool_score * config.confidence_increment))
            return QueryClassification(
                query_type="tool",
                confidence=confidence,
                reasoning=f"Found {tool_score} dynamic tool indicators or no conversation history",
                needs_tool_data=True
            )
        
        if has_pronouns and recent_tool_usage:
            return QueryClassification(
                query_type="history",
                confidence=config.pronoun_confidence,
                reasoning="Contains pronouns and recent tool usage suggests reference to history"
            )
        
        # Default to tool query with low confidence
        return QueryClassification(
            query_type="tool",
            confidence=config.default_confidence,
            reasoning="Uncertain classification, defaulting to tool query",
            needs_tool_data=True
        )
    
    async def _llm_classification(self, user_message: str, conversation_context: ConversationContext, 
                                 rule_based_result: QueryClassification) -> QueryClassification:
        """Use LLM for sophisticated query classification."""
        
        if not settings.HF_TOKEN:
            # Fallback to rule-based if no LLM available
            return rule_based_result
        
        # Prepare context for LLM
        recent_context = self._prepare_context_for_llm(conversation_context)
        
        classification_prompt = self._build_classification_prompt(user_message, recent_context, rule_based_result)
        
        try:
            # Get LLM decision
            response = await self.llm_client.complete_json(classification_prompt)
            
            return QueryClassification(
                query_type=response.get("query_type", rule_based_result.query_type),
                confidence=response.get("confidence", rule_based_result.confidence),
                reasoning=response.get("reasoning", rule_based_result.reasoning),
                referenced_entities=response.get("referenced_entities", []),
                needs_tool_data=response.get("needs_tool_data", rule_based_result.needs_tool_data)
            )
            
        except Exception as e:
            print(f"LLM classification failed: {e}")
            # Fallback to rule-based result
            return rule_based_result
    
    def _prepare_context_for_llm(self, conversation_context: ConversationContext) -> str:
        """Prepare conversation context for LLM analysis."""
        recent_turns = conversation_context.get_recent_turns(5)
        
        if not recent_turns:
            return "No previous conversation history."
        
        context_parts = []
        for i, turn in enumerate(recent_turns, 1):
            turn_summary = f"Turn {i}:\n"
            turn_summary += f"  User: {turn.user_message}\n"
            
            if turn.tool_used:
                turn_summary += f"  Tool used: {turn.tool_used}\n"
                if turn.tool_params:
                    turn_summary += f"  Parameters: {json.dumps(turn.tool_params, separators=(',', ':'))}\n"
                turn_summary += f"  Result: {turn.response_summary}\n"
            else:
                turn_summary += f"  Response: {turn.full_response[:100]}...\n"
            
            if turn.extracted_entities:
                entities = [f"{e.type}:{e.name}" for e in turn.extracted_entities]
                turn_summary += f"  Entities: {', '.join(entities)}\n"
            
            context_parts.append(turn_summary)
        
        return "\n".join(context_parts)
    
    def _build_classification_prompt(self, user_message: str, context: str, rule_based_result: QueryClassification) -> str:
        """Build the LLM prompt for query classification."""
        
        return f"""You are a query classification assistant for a Topcoder MCP agent.

Your task is to determine if a user query can be answered from conversation history or needs new external data.

CONVERSATION CONTEXT:
{context}

USER'S NEW MESSAGE: "{user_message}"

RULE-BASED ANALYSIS:
- Preliminary classification: {rule_based_result.query_type}
- Confidence: {rule_based_result.confidence}
- Reasoning: {rule_based_result.reasoning}

CLASSIFICATION TYPES:
1. "history" - Can be answered from conversation context (previous results, mentioned entities, etc.)
2. "tool" - Needs new data from external APIs/tools
3. "chat" - General conversation, greetings, or explanations

ANALYZE THE QUERY FOR:
- References to previous results ("the last challenge", "that user", "it")
- Temporal references ("what you just told me", "from before")
- Pronouns without clear new antecedents
- Requests for new/different/more data vs clarification of existing data
- Entity references that were mentioned in previous conversation

Respond ONLY with a JSON object:
{{
    "query_type": "history|tool|chat",
    "confidence": 0.0-1.0,
    "reasoning": "Brief explanation of your decision",
    "referenced_entities": ["entity1", "entity2"],
    "needs_tool_data": true/false
}}

Be especially careful to identify when users are asking about previously retrieved data vs requesting new data."""
    
    def extract_entities(self, text: str, previous_entities: List[Entity] = None) -> List[Entity]:
        """Extract entities from text that might be referenced later."""
        entities = []
        
        if not conversation_config.entity_extraction.use_dynamic_patterns:
            return entities
        
        # Get patterns dynamically from pattern manager
        dynamic_pattern_manager.refresh_if_needed()
        patterns = dynamic_pattern_manager.get_entity_patterns()
        
        config = conversation_config.entity_extraction
        
        for pattern_name, pattern in patterns.items():
            try:
                matches = re.finditer(pattern, text, re.IGNORECASE)
                for match in matches:
                    # Get the first non-None group
                    value = next((group for group in match.groups() if group), None)
                    if value:
                        # Determine entity type based on pattern name
                        entity_type = self._determine_entity_type(pattern_name, value)
                        entities.append(Entity(
                            name=value.strip(),
                            type=entity_type,
                            value=value.strip(),
                            confidence=config.default_entity_confidence
                        ))
            except Exception as e:
                print(f"Error in pattern {pattern_name}: {e}")
                continue
        
        return entities
    
    def _determine_entity_type(self, pattern_name: str, value: str) -> str:
        """Determine entity type from pattern name and value."""
        pattern_lower = pattern_name.lower()
        
        if "id" in pattern_lower:
            if "challenge" in pattern_lower:
                return "challenge_id"
            elif "skill" in pattern_lower:
                return "skill_id"
            elif "member" in pattern_lower or "user" in pattern_lower:
                return "user_id"
            else:
                return "entity_id"
        
        elif "name" in pattern_lower:
            if "challenge" in pattern_lower:
                return "challenge_name"
            elif "skill" in pattern_lower:
                return "skill_name"
            else:
                return "entity_name"
        
        elif "handle" in pattern_lower:
            return "user_handle"
        
        else:
            # Try to infer from value pattern
            if value.isdigit():
                return "entity_id"
            elif len(value) <= 30 and "_" in value:
                return "user_handle"
            else:
                return "entity_name"