mozzic commited on
Commit
5a1d3ae
·
verified ·
1 Parent(s): 7b56cbf

Upload src\reasoning.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. src//reasoning.py +360 -0
src//reasoning.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLM reasoning layer for answering questions with citations.
3
+ Ensures all responses are grounded in retrieved context.
4
+ """
5
+
6
+ import os
7
+ import re
8
+ from typing import List, Dict, Optional
9
+
10
+ from src.models import QueryRequest, AgentResponse, Citation, ContextUnit
11
+ from src.retrieval import RetrievalResult, ContextBuilder
12
+ from src.groq_integration import GroqReasoningEngine
13
+
14
+
15
+ class CitationExtractor:
16
+ """Extract cell references from LLM responses."""
17
+
18
+ @staticmethod
19
+ def extract_citations(response_text: str, retrieved_units: List[ContextUnit]) -> List[Citation]:
20
+ """
21
+ Extract cell citations from LLM response.
22
+ Looks for patterns like "Cell X", "cell_X", "(X)", etc.
23
+ """
24
+ citations = []
25
+ cell_ids = {u.cell.cell_id for u in retrieved_units}
26
+
27
+ # Pattern 1: "Cell X" or "cell X"
28
+ pattern1 = r'[Cc]ell\s+([a-zA-Z_0-9]+)'
29
+ matches1 = re.findall(pattern1, response_text)
30
+
31
+ # Pattern 2: "(cell_X)" or similar
32
+ pattern2 = r'\(([a-zA-Z_0-9]+)\)'
33
+ matches2 = re.findall(pattern2, response_text)
34
+
35
+ # Combine and deduplicate
36
+ potential_cells = set(matches1 + matches2)
37
+
38
+ # Validate against retrieved cells
39
+ for cell_id in potential_cells:
40
+ if cell_id in cell_ids:
41
+ # Find the unit
42
+ unit = next((u for u in retrieved_units if u.cell.cell_id == cell_id), None)
43
+ if unit:
44
+ citation = Citation(
45
+ cell_id=cell_id,
46
+ cell_type=unit.cell.cell_type,
47
+ content_snippet=CitationExtractor._get_snippet(unit),
48
+ intent=unit.intent if unit.intent != "[Pending intent inference]" else None
49
+ )
50
+ citations.append(citation)
51
+
52
+ return citations
53
+
54
+ @staticmethod
55
+ def _get_snippet(unit: ContextUnit, max_length: int = 100) -> str:
56
+ """Get content snippet from unit."""
57
+ return unit.cell.source[:max_length]
58
+
59
+
60
+ class HallucinationDetector:
61
+ """Detect potential hallucinations in responses."""
62
+
63
+ @staticmethod
64
+ def check_for_unsupported_claims(response: str, context: str) -> bool:
65
+ """Check if response makes claims not supported by context."""
66
+ # Simplified check - in real implementation, use more sophisticated methods
67
+ response_lower = response.lower()
68
+ context_lower = context.lower()
69
+
70
+ # Check for common hallucination indicators
71
+ hallucination_indicators = [
72
+ "according to", "experts say", "research shows",
73
+ "it's known that", "generally", "typically"
74
+ ]
75
+
76
+ for indicator in hallucination_indicators:
77
+ if indicator in response_lower and indicator not in context_lower:
78
+ return True
79
+
80
+ return False
81
+
82
+
83
+ class ReasoningEngine:
84
+ """LLM reasoning engine for answering questions."""
85
+
86
+ def __init__(self):
87
+ self.groq_client = self._init_groq()
88
+ self.openai_client = self._init_openai()
89
+
90
+ def _init_groq(self):
91
+ """Initialize Groq client (preferred for speed and cost)."""
92
+ try:
93
+ return GroqReasoningEngine()
94
+ except Exception:
95
+ return None
96
+
97
+ def _init_openai(self):
98
+ """Initialize OpenAI client (fallback)."""
99
+ try:
100
+ api_key = os.getenv("OPENAI_API_KEY")
101
+ if api_key and not api_key.startswith("sk-placeholder"):
102
+ from openai import OpenAI
103
+ return OpenAI(api_key=api_key)
104
+ except Exception:
105
+ pass
106
+ return None
107
+
108
+ def reason(self, query: str, retrieval_result: RetrievalResult, conversation_history: Optional[List[Dict]] = None) -> AgentResponse:
109
+ """
110
+ Reason about a question given retrieved context.
111
+ Returns response with citations.
112
+ """
113
+ # Build context for LLM
114
+ context = ContextBuilder.build_context_for_llm(
115
+ retrieval_result.units,
116
+ query
117
+ )
118
+
119
+ # Try to use Groq first (fast and free)
120
+ if self.groq_client:
121
+ try:
122
+ groq_result = self.groq_client.reason_with_context(query, context, conversation_history=conversation_history)
123
+ answer = groq_result["answer"]
124
+ except Exception as e:
125
+ print(f"Groq query failed: {e}. Using OpenAI fallback.")
126
+ answer = self._query_openai(query, context) if self.openai_client else self._generate_answer_fallback(query, retrieval_result)
127
+ elif self.openai_client:
128
+ try:
129
+ answer = self._query_openai(query, context)
130
+ except Exception as e:
131
+ print(f"OpenAI query failed: {e}. Using fallback.")
132
+ answer = self._generate_answer_fallback(query, retrieval_result)
133
+ else:
134
+ # Use fallback reasoning
135
+ answer = self._generate_answer_fallback(query, retrieval_result)
136
+
137
+ # Extract citations
138
+ citations = CitationExtractor.extract_citations(answer, retrieval_result.units)
139
+
140
+ # Check for hallucination risk
141
+ has_hallucination_risk = HallucinationDetector.check_for_unsupported_claims(
142
+ answer, context
143
+ )
144
+
145
+ # Calculate confidence
146
+ confidence = self._calculate_confidence(
147
+ len(citations),
148
+ len(retrieval_result.units),
149
+ )
150
+
151
+ return AgentResponse(
152
+ answer=answer,
153
+ citations=citations,
154
+ confidence=confidence,
155
+ has_hallucination_risk=has_hallucination_risk,
156
+ retrieved_units=retrieval_result.units
157
+ )
158
+
159
+ def _query_openai(self, query: str, context: str) -> str:
160
+ """Query OpenAI API."""
161
+ if not self.openai_client:
162
+ raise Exception("OpenAI client not available")
163
+
164
+ prompt = f"""
165
+ Based on the following notebook context, answer the question.
166
+ Cite specific cells when referencing information.
167
+
168
+ Context:
169
+ {context}
170
+
171
+ Question: {query}
172
+
173
+ Answer:"""
174
+
175
+ response = self.openai_client.chat.completions.create(
176
+ model="gpt-4",
177
+ messages=[{"role": "user", "content": prompt}],
178
+ max_tokens=500
179
+ )
180
+
181
+ return response.choices[0].message.content
182
+
183
+ def _generate_answer_fallback(self, query: str, retrieval_result: RetrievalResult) -> str:
184
+ """Generate answer using simple fallback logic."""
185
+ query_lower = query.lower()
186
+
187
+ # Handle specific question types
188
+ # Match: "what is this notebook about", "whats the notebook about", "what's this about", etc.
189
+ if any(phrase in query_lower for phrase in [
190
+ "what is this notebook about", "what does this notebook",
191
+ "whats the notebook", "what's this", "what about this notebook",
192
+ "what is this about", "describe this notebook", "tell me about this"
193
+ ]):
194
+ return self._summarize_notebook(retrieval_result)
195
+
196
+ if "why" in query_lower:
197
+ return self._explain_decision(retrieval_result, query)
198
+
199
+ # Default: find relevant code snippets and summarize all units
200
+ if retrieval_result.units:
201
+ return self._summarize_notebook(retrieval_result)
202
+
203
+ return "I couldn't find specific information about that in the notebook context."
204
+
205
+ def _summarize_notebook(self, retrieval_result: RetrievalResult) -> str:
206
+ """Generate a comprehensive summary of what the notebook is about."""
207
+ data_sources = []
208
+ data_operations = []
209
+ models = []
210
+ metrics = []
211
+ visualizations = []
212
+ code_cells = 0
213
+ markdown_cells = 0
214
+
215
+ for unit in retrieval_result.units:
216
+ intent = unit.intent.lower() if unit.intent else ""
217
+ source = unit.cell.source.lower()
218
+
219
+ if unit.cell.cell_type == "code":
220
+ code_cells += 1
221
+ elif unit.cell.cell_type == "markdown":
222
+ markdown_cells += 1
223
+
224
+ # Data loading/sources
225
+ if "load data" in intent or "read" in source or "dataset" in source:
226
+ if "iris" in source:
227
+ data_sources.append("Iris dataset")
228
+ elif "csv" in source or "pd.read_csv" in source:
229
+ data_sources.append("CSV data files")
230
+ elif "excel" in source or "xlsx" in source:
231
+ data_sources.append("Excel spreadsheets")
232
+ else:
233
+ data_sources.append("external datasets")
234
+
235
+ # Data operations
236
+ if "preprocess" in intent or "clean" in source or "drop" in source:
237
+ data_operations.append("data cleaning and preprocessing")
238
+ if "filter" in source or "select" in source:
239
+ data_operations.append("data filtering")
240
+ if "merge" in source or "join" in source:
241
+ data_operations.append("data merging")
242
+
243
+ # Models
244
+ if "model" in intent or "fit" in source or "train" in source:
245
+ if "randomforest" in source:
246
+ models.append("Random Forest classifier")
247
+ elif "regression" in source:
248
+ models.append("regression model")
249
+ elif "neural" in source or "nn" in source:
250
+ models.append("neural network")
251
+ else:
252
+ models.append("machine learning model")
253
+
254
+ # Evaluation metrics
255
+ if "accuracy" in source or "precision" in source or "recall" in source or "f1" in source:
256
+ metrics.append("classification metrics")
257
+ if "rmse" in source or "mse" in source:
258
+ metrics.append("regression metrics")
259
+ if "auc" in source or "roc" in source:
260
+ metrics.append("ROC/AUC analysis")
261
+
262
+ # Visualizations
263
+ if "visualize" in intent or "plot" in source or "matplotlib" in source or "seaborn" in source:
264
+ if "scatter" in source:
265
+ visualizations.append("scatter plots")
266
+ elif "hist" in source:
267
+ visualizations.append("histograms")
268
+ elif "bar" in source:
269
+ visualizations.append("bar charts")
270
+ else:
271
+ visualizations.append("data visualizations")
272
+
273
+ # Build comprehensive summary
274
+ summary = []
275
+
276
+ # Main purpose
277
+ if data_sources and models:
278
+ summary.append(f"This is a machine learning notebook that analyzes {', '.join(set(data_sources))}")
279
+ elif data_sources:
280
+ summary.append(f"This notebook analyzes {', '.join(set(data_sources))}")
281
+ elif models:
282
+ summary.append("This notebook demonstrates machine learning model development and evaluation")
283
+ else:
284
+ summary.append("This is a data analysis notebook")
285
+
286
+ # Data operations
287
+ if data_operations:
288
+ summary.append(f"It includes {', '.join(set(data_operations))}")
289
+
290
+ # Models and evaluation
291
+ if models or metrics:
292
+ model_desc = f"Uses {', '.join(set(models))}" if models else "Includes model training"
293
+ if metrics:
294
+ model_desc += f" with {', '.join(set(metrics))}"
295
+ summary.append(model_desc)
296
+
297
+ # Visualizations
298
+ if visualizations:
299
+ summary.append(f"Includes {', '.join(set(visualizations))} for data exploration and results visualization")
300
+
301
+ # Notebook structure
302
+ total_cells = code_cells + markdown_cells
303
+ if total_cells > 0:
304
+ summary.append(f"\n**Notebook Structure:** {code_cells} code cells, {markdown_cells} documentation cells")
305
+
306
+ return ". ".join(summary) + "."
307
+
308
+ def _explain_decision(self, retrieval_result: RetrievalResult, query: str) -> str:
309
+ """Explain why certain decisions were made."""
310
+ query_lower = query.lower()
311
+
312
+ # Look for common decisions
313
+ if "remove" in query_lower or "drop" in query_lower:
314
+ for unit in retrieval_result.units:
315
+ if "drop" in unit.cell.source.lower() or "remove" in unit.cell.source.lower():
316
+ return f"Data was removed/cleaned as shown in: {unit.cell.source[:150]}"
317
+
318
+ return "The notebook shows standard data preprocessing and modeling steps."
319
+
320
+ def _calculate_confidence(self, num_citations: int, num_units: int) -> float:
321
+ """Calculate confidence score."""
322
+ if num_units == 0:
323
+ return 0.0
324
+
325
+ # If we have units but no explicit citations, give baseline confidence (0.7)
326
+ if num_citations == 0 and num_units > 0:
327
+ return 0.7
328
+
329
+ # With citations, confidence increases
330
+ base_confidence = min(num_citations / max(num_units, 1), 1.0)
331
+ return max(base_confidence, 0.7)
332
+
333
+
334
+ class ContextualAnsweringSystem:
335
+ """End-to-end system for context-aware question answering."""
336
+
337
+ def __init__(self, retrieval_engine, use_llm: bool = True):
338
+ self.retrieval_engine = retrieval_engine
339
+ self.reasoning_engine = ReasoningEngine()
340
+ self.use_llm = use_llm
341
+
342
+ def answer_question(self, query: str, top_k: int = 5, conversation_history: Optional[List[Dict]] = None) -> AgentResponse:
343
+ """
344
+ Answer a question about the notebook context.
345
+
346
+ Args:
347
+ query: User's natural language question
348
+ top_k: Number of cells to retrieve for context
349
+ conversation_history: Previous conversation for context
350
+
351
+ Returns:
352
+ AgentResponse with answer, citations, and context
353
+ """
354
+ # Step 1: Retrieve relevant context
355
+ retrieval_result = self.retrieval_engine.retrieve(query, top_k=top_k)
356
+
357
+ # Step 2: Reason and generate answer with conversation context
358
+ response = self.reasoning_engine.reason(query, retrieval_result, conversation_history)
359
+
360
+ return response