File size: 13,674 Bytes
4e8d206
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
"""
Enhanced Reasoning Engine for Supernova AI
Provides sophisticated problem-solving capabilities through structured reasoning,
multi-tool coordination, and knowledge synthesis.
"""
import torch
import numpy as np
try:
    import sympy as sp
except ImportError:
    sp = None
import re
import json
from typing import List, Dict, Any, Optional, Tuple
from dataclasses import dataclass
from enum import Enum

from .tools import ToolOrchestrator, ToolCall


class ReasoningType(Enum):
    ANALYTICAL = "analytical"
    CREATIVE = "creative"
    COMPARATIVE = "comparative"
    CAUSAL = "causal"
    SEQUENTIAL = "sequential"
    EVALUATIVE = "evaluative"


@dataclass
class ReasoningStep:
    step_number: int
    description: str
    reasoning_type: ReasoningType
    tool_needed: Optional[str] = None
    query: Optional[str] = None
    result: Optional[str] = None
    confidence: float = 0.8


@dataclass
class KnowledgeDomain:
    domain: str
    confidence: float
    sources: List[str]
    key_facts: List[str]


class EnhancedReasoningEngine:
    """Advanced reasoning engine that mimics sophisticated AI reasoning patterns."""
    
    def __init__(self, tool_orchestrator: ToolOrchestrator):
        self.tools = tool_orchestrator
        self.conversation_context = []
        self.domain_expertise = {
            'science': ['physics', 'chemistry', 'biology', 'mathematics', 'astronomy'],
            'technology': ['programming', 'ai', 'computing', 'engineering', 'electronics'],
            'humanities': ['history', 'literature', 'philosophy', 'psychology', 'sociology'],
            'medicine': ['anatomy', 'pharmacology', 'diagnosis', 'treatment', 'research'],
            'business': ['finance', 'management', 'economics', 'marketing', 'strategy'],
            'arts': ['music', 'visual arts', 'design', 'architecture', 'performance']
        }
    
    def analyze_query_complexity(self, query: str) -> Dict[str, Any]:
        """Analyze the complexity and requirements of a user query."""
        complexity_indicators = {
            'simple': ['what is', 'define', 'who is', 'when did'],
            'moderate': ['how does', 'why does', 'explain', 'compare', 'analyze'],
            'complex': ['evaluate', 'synthesize', 'create', 'design', 'solve for multiple', 'consider all factors']
        }
        
        domains_detected = []
        for domain, keywords in self.domain_expertise.items():
            if any(keyword in query.lower() for keyword in keywords):
                domains_detected.append(domain)
        
        complexity_level = 'simple'
        for level, indicators in complexity_indicators.items():
            if any(indicator in query.lower() for indicator in indicators):
                complexity_level = level
        
        requires_multi_step = any(phrase in query.lower() for phrase in [
            'step by step', 'first...then', 'multiple', 'several', 'both', 'compare and contrast'
        ])
        
        return {
            'complexity': complexity_level,
            'domains': domains_detected,
            'multi_step_needed': requires_multi_step,
            'estimated_steps': min(5, len(domains_detected) + (2 if requires_multi_step else 1))
        }
    
    def decompose_complex_query(self, query: str, analysis: Dict[str, Any]) -> List[ReasoningStep]:
        """Break down complex queries into manageable reasoning steps."""
        steps = []
        step_num = 1
        
        # Step 1: Information Gathering
        if analysis['complexity'] in ['moderate', 'complex']:
            # Determine if we need current information
            if any(term in query.lower() for term in ['current', 'latest', 'recent', 'today', '2024', '2025']):
                steps.append(ReasoningStep(
                    step_number=step_num,
                    description="Gather current information from web sources",
                    reasoning_type=ReasoningType.ANALYTICAL,
                    tool_needed="serper",
                    query=query
                ))
                step_num += 1
            
            # Check if mathematical computation is needed
            if any(term in query.lower() for term in ['calculate', 'compute', 'solve', 'derivative', 'integral']):
                steps.append(ReasoningStep(
                    step_number=step_num,
                    description="Perform mathematical computation",
                    reasoning_type=ReasoningType.ANALYTICAL,
                    tool_needed="math_engine",
                    query=query
                ))
                step_num += 1
        
        # Step 2: Domain-specific analysis
        for domain in analysis['domains']:
            steps.append(ReasoningStep(
                step_number=step_num,
                description=f"Analyze from {domain} perspective",
                reasoning_type=ReasoningType.ANALYTICAL,
                tool_needed=None,  # Will use model generation with domain context
                query=f"From a {domain} perspective: {query}"
            ))
            step_num += 1
        
        # Step 3: Synthesis and evaluation
        if analysis['complexity'] == 'complex':
            steps.append(ReasoningStep(
                step_number=step_num,
                description="Synthesize information and provide comprehensive analysis",
                reasoning_type=ReasoningType.EVALUATIVE,
                tool_needed=None,
                query=query
            ))
        
        return steps if steps else [ReasoningStep(1, "Direct response", ReasoningType.ANALYTICAL, query=query)]
    
    def execute_reasoning_chain(self, steps: List[ReasoningStep], model, tokenizer) -> List[ReasoningStep]:
        """Execute a chain of reasoning steps, using tools and model generation as needed."""
        results = []
        context_info = []
        
        for step in steps:
            if step.tool_needed:
                # Use appropriate tool
                tool_call = ToolCall(tool=step.tool_needed, query=step.query)
                executed_call = self.tools.execute_tool_call(tool_call)
                
                if executed_call.result:
                    step.result = executed_call.result
                    step.confidence = 0.9
                    context_info.append(f"{step.description}: {executed_call.result}")
                else:
                    step.result = f"Tool execution failed: {executed_call.error}"
                    step.confidence = 0.3
            else:
                # Use model generation with enhanced context
                enhanced_context = self._build_enhanced_context(step, context_info)
                try:
                    response = self._generate_with_context(model, tokenizer, enhanced_context, step.query)
                    step.result = response
                    step.confidence = 0.7
                    context_info.append(f"{step.description}: {response}")
                except Exception as e:
                    step.result = f"Generation failed: {str(e)}"
                    step.confidence = 0.2
            
            results.append(step)
        
        return results
    
    def _build_enhanced_context(self, step: ReasoningStep, context_info: List[str]) -> str:
        """Build enhanced context for model generation."""
        context_parts = [
            "You are Supernova, an advanced AI assistant with deep expertise across multiple domains.",
            "Apply sophisticated reasoning and provide comprehensive, nuanced responses.",
            ""
        ]
        
        if context_info:
            context_parts.extend([
                "Previous analysis steps:",
                *[f"- {info}" for info in context_info],
                ""
            ])
        
        reasoning_guidance = {
            ReasoningType.ANALYTICAL: "Analyze systematically, consider multiple factors, and provide evidence-based insights.",
            ReasoningType.CREATIVE: "Think creatively, explore innovative solutions, and consider unconventional approaches.",
            ReasoningType.COMPARATIVE: "Compare different perspectives, weigh pros and cons, and identify key differences.",
            ReasoningType.CAUSAL: "Identify cause-and-effect relationships, trace underlying mechanisms, and explain why things happen.",
            ReasoningType.SEQUENTIAL: "Break down into logical steps, show progression, and maintain clear sequencing.",
            ReasoningType.EVALUATIVE: "Make judgments based on criteria, assess quality and effectiveness, and provide recommendations."
        }
        
        context_parts.extend([
            f"Reasoning approach: {reasoning_guidance.get(step.reasoning_type, 'Provide thorough analysis.')}",
            f"Focus area: {step.description}",
            ""
        ])
        
        return "\n".join(context_parts)
    
    def _generate_with_context(self, model, tokenizer, context: str, query: str, max_tokens: int = 400) -> str:
        """Generate response using the model with enhanced context."""
        full_prompt = f"{context}\nUser Query: {query}\n\nDetailed Response:"
        
        # Use the existing generate function (simplified version)
        model.eval()
        device = next(model.parameters()).device
        input_ids = tokenizer.encode(full_prompt, return_tensors="pt").to(device)
        
        with torch.no_grad():
            for _ in range(max_tokens):
                if input_ids.size(1) >= model.cfg.n_positions:
                    input_cond = input_ids[:, -model.cfg.n_positions:]
                else:
                    input_cond = input_ids
                
                logits, _ = model(input_cond)
                logits = logits[:, -1, :] / 0.8  # temperature
                
                # Top-k sampling
                v, _ = torch.topk(logits, min(50, logits.size(-1)))
                logits[logits < v[:, [-1]]] = -float("Inf")
                
                probs = torch.softmax(logits, dim=-1)
                next_id = torch.multinomial(probs, num_samples=1)
                input_ids = torch.cat([input_ids, next_id], dim=1)
        
        response = tokenizer.decode(input_ids[0].tolist())
        
        # Extract the response part
        if "Detailed Response:" in response:
            response = response.split("Detailed Response:", 1)[1].strip()
        
        return response
    
    def synthesize_final_response(self, steps: List[ReasoningStep], original_query: str) -> str:
        """Synthesize all reasoning steps into a comprehensive final response."""
        successful_steps = [step for step in steps if step.result and step.confidence > 0.5]
        
        if not successful_steps:
            return "I apologize, but I encountered difficulties processing your request. Could you please rephrase or provide more specific details?"
        
        # Build comprehensive response
        response_parts = []
        
        # Add executive summary for complex queries
        if len(successful_steps) > 2:
            response_parts.append("Here's my comprehensive analysis:")
            response_parts.append("")
        
        # Include results from each step
        for step in successful_steps:
            if step.tool_needed in ['math_engine', 'serper']:
                # Tool results are already well-formatted
                response_parts.append(step.result)
            else:
                # Model-generated responses
                response_parts.append(step.result)
            
            response_parts.append("")
        
        # Add synthesis for multi-step responses
        if len(successful_steps) > 2:
            confidence_score = sum(step.confidence for step in successful_steps) / len(successful_steps)
            
            synthesis_parts = [
                "**Key Insights:**",
                "• Multiple perspectives have been considered",
                f"• Analysis confidence: {confidence_score:.1%}",
                "• Both current information and domain expertise were utilized"
            ]
            
            response_parts.extend(synthesis_parts)
        
        return "\n".join(response_parts).strip()
    
    def process_complex_query(self, query: str, model, tokenizer) -> str:
        """Main method to process complex queries with enhanced reasoning."""
        # Analyze query complexity and requirements
        analysis = self.analyze_query_complexity(query)
        
        # For simple queries, use direct processing
        if analysis['complexity'] == 'simple' and not analysis['multi_step_needed']:
            tool_call = self.tools.route_query(query)
            if tool_call:
                executed_call = self.tools.execute_tool_call(tool_call)
                if executed_call.result:
                    return executed_call.result
            
            # Fall back to enhanced model generation
            context = self._build_enhanced_context(
                ReasoningStep(1, "Direct response", ReasoningType.ANALYTICAL),
                []
            )
            return self._generate_with_context(model, tokenizer, context, query)
        
        # For complex queries, use multi-step reasoning
        reasoning_steps = self.decompose_complex_query(query, analysis)
        executed_steps = self.execute_reasoning_chain(reasoning_steps, model, tokenizer)
        
        return self.synthesize_final_response(executed_steps, query)


# Import torch and other needed modules here to avoid import issues
import torch
try:
    import sympy as sp
    import numpy as np
except ImportError:
    pass