niloydebbarma commited on
Commit
9e5bc69
·
verified ·
1 Parent(s): c9efcfc

Upload 8 files

Browse files
query_graph_functions/__init__.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Graph Functions Package
3
+
4
+ Core modules for graph-based retrieval augmentation implementation.
5
+
6
+ Package contents:
7
+ - setup.py: Initialization and connection functionality (Phase A: Steps 1-2)
8
+ - query_preprocessing.py: Query analysis, routing, and vectorization (Phase B: Steps 3-5)
9
+ - knowledge_retrieval.py: Community search and data extraction (Phase C: Steps 6-8)
10
+ - follow_up_search.py: Follow-up search and entity extraction (Phase D: Steps 9-12)
11
+ - vector_augmentation.py: Vector search enhancement (Phase E: Steps 13-14)
12
+ - answer_synthesis.py: Final answer generation (Phase F: Steps 15-16)
13
+ - response_management.py: Metadata generation and file persistence (Phase G: Steps 17-20)
14
+ """
15
+
16
+ # Phase A: Initialization (Steps 1-2)
17
+ from .setup import GraphRAGSetup, create_graphrag_setup
18
+
19
+ # Phase B: Query Preprocessing (Steps 3-5)
20
+ from .query_preprocessing import (
21
+ QueryAnalyzer,
22
+ DriftRouter,
23
+ QueryVectorizer,
24
+ QueryPreprocessor,
25
+ create_query_preprocessor,
26
+ preprocess_query_pipeline,
27
+ QueryAnalysis,
28
+ DriftRoutingResult,
29
+ VectorizedQuery,
30
+ QueryType,
31
+ SearchStrategy
32
+ )
33
+
34
+ # Phase C: Knowledge Retrieval (Steps 6-8)
35
+ from .knowledge_retrieval import (
36
+ CommunitySearchEngine,
37
+ CommunityResult,
38
+ EntityResult,
39
+ RelationshipResult
40
+ )
41
+
42
+ # Phase D: Follow-up Search (Steps 9-12)
43
+ from .follow_up_search import (
44
+ FollowUpSearch,
45
+ FollowUpQuestion,
46
+ LocalSearchResult,
47
+ IntermediateAnswer
48
+ )
49
+
50
+ # Phase E: Vector Search Augmentation (Steps 13-14)
51
+ from .vector_augmentation import (
52
+ VectorAugmentationEngine,
53
+ VectorSearchResult,
54
+ AugmentationResult
55
+ )
56
+
57
+ # Phase F: Answer Synthesis (Steps 15-16)
58
+ from .answer_synthesis import (
59
+ AnswerSynthesisEngine,
60
+ SynthesisResult,
61
+ SourceEvidence
62
+ )
63
+
64
+ # Phase G: Response Management (Steps 17-20)
65
+ from .response_management import (
66
+ ResponseManager,
67
+ ResponseMetadata
68
+ )
69
+
70
+ __version__ = "1.3.0"
71
+ __author__ = "AllyCat GraphRAG Team"
72
+ __description__ = "Graph-based retrieval augmentation implementation for AllyCat"
73
+
74
+ # Export main classes and functions
75
+ __all__ = [
76
+ # Phase A: Initialization
77
+ "GraphRAGSetup",
78
+ "create_graphrag_setup",
79
+ # Phase B: Query Preprocessing
80
+ "QueryAnalyzer",
81
+ "DriftRouter",
82
+ "QueryVectorizer",
83
+ "QueryPreprocessor",
84
+ "create_query_preprocessor",
85
+ "preprocess_query_pipeline",
86
+ "QueryAnalysis",
87
+ "DriftRoutingResult",
88
+ "VectorizedQuery",
89
+ "QueryType",
90
+ "SearchStrategy",
91
+ # Phase C: Knowledge Retrieval
92
+ "CommunitySearchEngine",
93
+ "CommunityResult",
94
+ "EntityResult",
95
+ "RelationshipResult",
96
+ # Phase D: Follow-up Search
97
+ "FollowUpSearch",
98
+ "FollowUpQuestion",
99
+ "LocalSearchResult",
100
+ "IntermediateAnswer",
101
+ # Phase E: Vector Augmentation
102
+ "VectorAugmentationEngine",
103
+ "VectorSearchResult",
104
+ "AugmentationResult",
105
+ # Phase F: Answer Synthesis
106
+ "AnswerSynthesisEngine",
107
+ "SynthesisResult",
108
+ "SourceEvidence",
109
+ # Phase G: Response Management
110
+ "ResponseManager",
111
+ "ResponseMetadata"
112
+ ]
query_graph_functions/answer_synthesis.py ADDED
@@ -0,0 +1,408 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Answer synthesis module for final response generation. - Phase F (Steps 15-16)"""
2
+
3
+ import logging
4
+ import json
5
+ from typing import Dict, List, Any, Optional
6
+ from dataclasses import dataclass
7
+ from datetime import datetime
8
+
9
+ from .setup import GraphRAGSetup
10
+ from .query_preprocessing import DriftRoutingResult, QueryAnalysis
11
+ from .vector_augmentation import AugmentationResult
12
+
13
+
14
+ @dataclass
15
+ class SourceEvidence:
16
+ """Evidence source with attribution and confidence."""
17
+ source_type: str # 'community', 'entity', 'relationship', 'vector_doc'
18
+ source_id: str
19
+ content: str
20
+ confidence: float
21
+ phase: str # 'C', 'D', 'E'
22
+
23
+
24
+ @dataclass
25
+ class SynthesisResult:
26
+ """Phase F synthesis result with comprehensive answer."""
27
+ final_answer: str
28
+ confidence_score: float
29
+ source_evidence: List[SourceEvidence]
30
+ synthesis_strategy: str
31
+ coverage_assessment: Dict[str, float]
32
+ execution_time: float
33
+ metadata: Dict[str, Any]
34
+
35
+
36
+ class AnswerSynthesisEngine:
37
+ """
38
+ Answer synthesis engine implementing Phase F (Steps 15-16).
39
+
40
+ Handles final answer generation process:
41
+ - Context assembly and evidence ranking (Step 15)
42
+ - Final answer generation with confidence scoring (Step 16)
43
+ """
44
+
45
+ def __init__(self, setup: GraphRAGSetup):
46
+ self.setup = setup
47
+ self.llm = setup.llm
48
+ self.config = setup.config
49
+ self.logger = logging.getLogger(self.__class__.__name__)
50
+
51
+ # Synthesis parameters
52
+ self.min_confidence_threshold = 0.7
53
+ self.max_synthesis_length = 2000
54
+
55
+ async def execute_answer_synthesis_phase(self,
56
+ analysis: QueryAnalysis,
57
+ routing: DriftRoutingResult,
58
+ community_results: Dict[str, Any],
59
+ follow_up_results: Dict[str, Any],
60
+ augmentation_results: AugmentationResult) -> SynthesisResult:
61
+ """
62
+ Execute answer synthesis phase with comprehensive integration.
63
+
64
+ Args:
65
+ analysis: Query analysis results
66
+ routing: Routing decision parameters
67
+ community_results: Community search results
68
+ follow_up_results: Follow-up search results
69
+ augmentation_results: Vector augmentation results
70
+
71
+ Returns:
72
+ Synthesis result with final answer
73
+ """
74
+ start_time = datetime.now()
75
+
76
+ try:
77
+ # Context assembly
78
+ self.logger.info("Starting Step 15: Context Assembly and Ranking")
79
+ assembled_context = await self._assemble_and_rank_context(
80
+ analysis, community_results, follow_up_results, augmentation_results
81
+ )
82
+
83
+ # Final answer generation
84
+ self.logger.info("Starting Step 16: Final Answer Generation")
85
+ final_answer, confidence = await self._generate_final_answer(
86
+ analysis, routing, assembled_context
87
+ )
88
+
89
+ execution_time = (datetime.now() - start_time).total_seconds()
90
+
91
+ synthesis_result = SynthesisResult(
92
+ final_answer=final_answer,
93
+ confidence_score=confidence,
94
+ source_evidence=assembled_context['evidence'],
95
+ synthesis_strategy='comprehensive_drift',
96
+ coverage_assessment=assembled_context['coverage'],
97
+ execution_time=execution_time,
98
+ metadata={
99
+ 'sources_integrated': len(assembled_context['evidence']),
100
+ 'phase_coverage': assembled_context['phase_coverage'],
101
+ 'synthesis_method': 'llm_guided',
102
+ 'phase': 'answer_synthesis',
103
+ 'step_range': '15-16'
104
+ }
105
+ )
106
+
107
+ self.logger.info(f"Phase F completed: confidence {confidence:.3f}, {len(assembled_context['evidence'])} sources integrated")
108
+ return synthesis_result
109
+
110
+ except Exception as e:
111
+ self.logger.error(f"Answer synthesis phase failed: {e}")
112
+ # Return fallback synthesis on failure
113
+ return self._create_fallback_synthesis(
114
+ community_results, follow_up_results,
115
+ (datetime.now() - start_time).total_seconds(), str(e)
116
+ )
117
+
118
+ async def _assemble_and_rank_context(self,
119
+ analysis: QueryAnalysis,
120
+ community_results: Dict[str, Any],
121
+ follow_up_results: Dict[str, Any],
122
+ augmentation_results: AugmentationResult) -> Dict[str, Any]:
123
+ """
124
+ Step 15: Assemble and rank all context from Phases C, D, and E.
125
+
126
+ Prioritizes information by relevance, confidence, and source diversity.
127
+ """
128
+ evidence_sources = []
129
+
130
+ # Extract community evidence
131
+ if 'communities' in community_results:
132
+ for community in community_results['communities']:
133
+ evidence_sources.append(SourceEvidence(
134
+ source_type='community',
135
+ source_id=community.community_id,
136
+ content=community.summary,
137
+ confidence=community.similarity_score,
138
+ phase='C'
139
+ ))
140
+
141
+ # Extract follow-up evidence
142
+ if 'intermediate_answers' in follow_up_results:
143
+ for answer in follow_up_results['intermediate_answers']:
144
+ evidence_sources.append(SourceEvidence(
145
+ source_type='entity_search',
146
+ source_id=f"followup_{len(evidence_sources)}",
147
+ content=f"Q: {answer.question}\nA: {answer.answer}",
148
+ confidence=answer.confidence,
149
+ phase='D'
150
+ ))
151
+
152
+ # Extract vector evidence
153
+ if augmentation_results and augmentation_results.vector_results:
154
+ for i, vector_result in enumerate(augmentation_results.vector_results):
155
+ evidence_sources.append(SourceEvidence(
156
+ source_type='vector_doc',
157
+ source_id=f"vector_{i}",
158
+ content=vector_result.content,
159
+ confidence=vector_result.similarity_score,
160
+ phase='E'
161
+ ))
162
+
163
+ # Rank evidence
164
+ ranked_evidence = sorted(evidence_sources, key=lambda x: x.confidence, reverse=True)
165
+
166
+ # Calculate coverage
167
+ coverage = {
168
+ 'community_coverage': len([e for e in ranked_evidence if e.phase == 'C']) / max(1, len(community_results.get('communities', []))),
169
+ 'entity_coverage': len([e for e in ranked_evidence if e.phase == 'D']) / max(1, len(follow_up_results.get('intermediate_answers', []))),
170
+ 'vector_coverage': len([e for e in ranked_evidence if e.phase == 'E']) / max(1, len(augmentation_results.vector_results) if augmentation_results else 1),
171
+ 'overall_confidence': sum(e.confidence for e in ranked_evidence) / max(1, len(ranked_evidence))
172
+ }
173
+
174
+ phase_coverage = {
175
+ 'phase_c': len([e for e in ranked_evidence if e.phase == 'C']),
176
+ 'phase_d': len([e for e in ranked_evidence if e.phase == 'D']),
177
+ 'phase_e': len([e for e in ranked_evidence if e.phase == 'E'])
178
+ }
179
+
180
+ return {
181
+ 'evidence': ranked_evidence[:15], # Top 15 pieces of evidence
182
+ 'coverage': coverage,
183
+ 'phase_coverage': phase_coverage
184
+ }
185
+
186
+ async def _generate_final_answer(self,
187
+ analysis: QueryAnalysis,
188
+ routing: DriftRoutingResult,
189
+ assembled_context: Dict[str, Any]) -> tuple[str, float]:
190
+ """
191
+ Step 16: Generate comprehensive final answer using LLM synthesis.
192
+
193
+ Creates structured, comprehensive response with proper source attribution.
194
+ """
195
+ try:
196
+ # Prepare prompt
197
+ synthesis_prompt = self._create_synthesis_prompt(
198
+ routing.original_query,
199
+ assembled_context['evidence']
200
+ )
201
+
202
+ # Generate answer
203
+ response = self.llm.complete(synthesis_prompt)
204
+ final_answer = str(response).strip()
205
+
206
+ # Calculate confidence
207
+ synthesis_confidence = self._calculate_synthesis_confidence(
208
+ assembled_context['evidence'], assembled_context['coverage']
209
+ )
210
+
211
+ # Format final answer
212
+ formatted_answer = self._format_final_answer(
213
+ final_answer, assembled_context['evidence'], synthesis_confidence
214
+ )
215
+
216
+ return formatted_answer, synthesis_confidence
217
+
218
+ except Exception as e:
219
+ self.logger.error(f"Final answer generation failed: {e}")
220
+ return self._create_fallback_answer(assembled_context['evidence']), 0.5
221
+
222
+ def _create_synthesis_prompt(self, original_query: str, evidence: List[SourceEvidence]) -> str:
223
+ """Create comprehensive synthesis prompt for LLM."""
224
+ prompt_parts = [
225
+ f"# Query: {original_query}",
226
+ "",
227
+ "You are an expert synthesizing information from multiple sources.",
228
+ "Create a comprehensive, accurate answer using the following evidence:",
229
+ "",
230
+ "## Evidence Sources:",
231
+ ""
232
+ ]
233
+
234
+ for i, source in enumerate(evidence[:10], 1): # Top 10 sources
235
+ prompt_parts.extend([
236
+ f"### Source {i} ({source.phase} - {source.source_type}, confidence: {source.confidence:.3f})",
237
+ source.content[:500] + ("..." if len(source.content) > 500 else ""),
238
+ ""
239
+ ])
240
+
241
+ prompt_parts.extend([
242
+ "## Instructions:",
243
+ "1. Synthesize a comprehensive answer addressing the original query",
244
+ "2. Prioritize high-confidence sources (>0.8)",
245
+ "3. Include specific details and examples from the evidence",
246
+ "4. Structure the response clearly with sections if appropriate",
247
+ "5. Do not mention source IDs or technical details",
248
+ "6. Focus on factual accuracy and completeness",
249
+ "",
250
+ "## Comprehensive Answer:"
251
+ ])
252
+
253
+ return "\n".join(prompt_parts)
254
+
255
+ def _calculate_synthesis_confidence(self, evidence: List[SourceEvidence], coverage: Dict[str, float]) -> float:
256
+ """Calculate overall synthesis confidence based on evidence quality and coverage."""
257
+ if not evidence:
258
+ return 0.0
259
+
260
+ # Weight evidence
261
+ evidence_confidence = sum(e.confidence for e in evidence) / len(evidence)
262
+ coverage_score = sum(coverage.values()) / len(coverage)
263
+
264
+ # Coverage bonus
265
+ phase_diversity = len(set(e.phase for e in evidence)) / 3.0 # 3 phases max
266
+
267
+ # Combined score
268
+ synthesis_confidence = (evidence_confidence * 0.5) + (coverage_score * 0.3) + (phase_diversity * 0.2)
269
+
270
+ return min(synthesis_confidence, 1.0)
271
+
272
+ def _format_final_answer(self, answer: str, evidence: List[SourceEvidence], confidence: float) -> str:
273
+ """Format the final answer with proper structure and attribution."""
274
+ formatted_parts = [
275
+ "# Comprehensive Answer",
276
+ "",
277
+ answer,
278
+ "",
279
+ "---",
280
+ "",
281
+ f"**Answer Confidence**: {confidence:.1%}",
282
+ f"**Sources Integrated**: {len(evidence)} evidence sources",
283
+ f"**Multi-Phase Coverage**: {len(set(e.phase for e in evidence))} phases (C: Community, D: Entity, E: Vector)",
284
+ ""
285
+ ]
286
+
287
+ return "\n".join(formatted_parts)
288
+
289
+ def _create_fallback_answer(self, evidence: List[SourceEvidence]) -> str:
290
+ """Create fallback answer when LLM synthesis fails."""
291
+ if not evidence:
292
+ return "Unable to generate answer due to insufficient evidence."
293
+
294
+ # Simple concatenation of top evidence
295
+ fallback_parts = [
296
+ "# Answer Summary",
297
+ "",
298
+ "Based on available evidence:",
299
+ ""
300
+ ]
301
+
302
+ for i, source in enumerate(evidence[:3], 1):
303
+ fallback_parts.extend([
304
+ f"## Source {i} (Confidence: {source.confidence:.2f})",
305
+ source.content[:300] + ("..." if len(source.content) > 300 else ""),
306
+ ""
307
+ ])
308
+
309
+ return "\n".join(fallback_parts)
310
+
311
+ def _create_fallback_synthesis(self, community_results: Dict, follow_up_results: Dict,
312
+ execution_time: float, error: str) -> SynthesisResult:
313
+ """Create fallback synthesis result when phase fails."""
314
+ return SynthesisResult(
315
+ final_answer=" Response failed due to technical error. Please try again.",
316
+ confidence_score=0.0,
317
+ source_evidence=[],
318
+ synthesis_strategy='fallback',
319
+ coverage_assessment={'overall_confidence': 0.0},
320
+ execution_time=execution_time,
321
+ metadata={'error': error, 'fallback': True}
322
+ )
323
+
324
+ def combine_phase_results(self,
325
+ phase_c_answer: str,
326
+ follow_up_results: Dict[str, Any],
327
+ augmentation_results=None) -> str:
328
+ """
329
+ Combine Phase C, D, and E results into enhanced answer.
330
+
331
+ Creates comprehensive response by integrating results from multiple phases.
332
+ """
333
+ try:
334
+ intermediate_answers = follow_up_results.get('intermediate_answers', [])
335
+
336
+ if not intermediate_answers:
337
+ return phase_c_answer
338
+
339
+ # Start with Phase C answer
340
+ enhanced_parts = [
341
+ "## Global Context (Phase C)",
342
+ phase_c_answer.strip(),
343
+ "",
344
+ "## Detailed Information (Phase D)"
345
+ ]
346
+
347
+ # Add intermediate answers from Phase D
348
+ for i, answer in enumerate(intermediate_answers, 1):
349
+ enhanced_parts.extend([
350
+ f"**{i}. {answer.question}**",
351
+ answer.answer,
352
+ f"*Confidence: {answer.confidence:.2f}*",
353
+ ""
354
+ ])
355
+
356
+ # Add Phase E vector augmentation if available
357
+ if augmentation_results and hasattr(augmentation_results, 'vector_results') and augmentation_results.vector_results:
358
+ enhanced_parts.extend([
359
+ "## Vector Augmentation (Phase E)",
360
+ f"**Semantic Enhancement** (Confidence: {augmentation_results.augmentation_confidence:.2f})",
361
+ ""
362
+ ])
363
+
364
+ # Add top vector results
365
+ for i, vector_result in enumerate(augmentation_results.vector_results[:3], 1):
366
+ enhanced_parts.extend([
367
+ f"**Vector Result {i}** (Similarity: {vector_result.similarity_score:.3f})",
368
+ vector_result.content, # Show full content without truncation
369
+ ""
370
+ ])
371
+
372
+ # Add supporting evidence if available
373
+ if intermediate_answers:
374
+ enhanced_parts.extend([
375
+ "## Supporting Evidence",
376
+ "**Key Entities Found:** " + ", ".join(
377
+ set(entity for answer in intermediate_answers
378
+ for entity in answer.supporting_entities[:3])
379
+ ),
380
+ ""
381
+ ])
382
+
383
+ return "\n".join(enhanced_parts)
384
+
385
+ except Exception as e:
386
+ self.logger.error(f"Failed to combine phase results: {e}")
387
+ return phase_c_answer
388
+
389
+ def generate_error_response(self, error_message: str) -> Dict[str, Any]:
390
+ """
391
+ Generate standardized error response.
392
+
393
+ Creates consistent error format for failed synthesis operations.
394
+ """
395
+ return {
396
+ "answer": f"Sorry, I encountered an error during answer synthesis: {error_message}",
397
+ "metadata": {
398
+ "status": "synthesis_error",
399
+ "error_message": error_message,
400
+ "synthesis_stage": "failed",
401
+ "confidence_score": 0.0,
402
+ "timestamp": datetime.now().isoformat()
403
+ }
404
+ }
405
+
406
+
407
+ # Exports
408
+ __all__ = ['AnswerSynthesisEngine', 'SynthesisResult', 'SourceEvidence']
query_graph_functions/follow_up_search.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Follow-up search module for local graph traversal. - Phase D (Steps 9-12)"""
2
+
3
+ import logging
4
+ from typing import Dict, List, Any
5
+ from dataclasses import dataclass
6
+ import re
7
+ from datetime import datetime
8
+
9
+ # Project imports
10
+ from .setup import GraphRAGSetup
11
+ from .query_preprocessing import DriftRoutingResult
12
+ from .knowledge_retrieval import CommunityResult, EntityResult, RelationshipResult
13
+
14
+
15
+ @dataclass
16
+ class FollowUpQuestion:
17
+ """Represents a follow-up question from Phase C."""
18
+ question: str
19
+ question_id: int
20
+ extracted_entities: List[str]
21
+ query_type: str
22
+ confidence: float
23
+
24
+
25
+ @dataclass
26
+ class LocalSearchResult:
27
+ """Results from local graph traversal."""
28
+ seed_entities: List[EntityResult]
29
+ traversed_entities: List[EntityResult]
30
+ traversed_relationships: List[RelationshipResult]
31
+ search_depth: int
32
+ total_nodes_visited: int
33
+
34
+
35
+ @dataclass
36
+ class IntermediateAnswer:
37
+ """Intermediate answer for a follow-up question."""
38
+ question_id: int
39
+ question: str
40
+ answer: str
41
+ confidence: float
42
+ reasoning: str
43
+ supporting_entities: List[str]
44
+ supporting_evidence: List[str]
45
+
46
+
47
+ class FollowUpSearch:
48
+ """Follow-up search module for local graph traversal."""
49
+ def __init__(self, setup: GraphRAGSetup):
50
+ self.setup = setup
51
+ self.neo4j_conn = setup.neo4j_conn
52
+ self.logger = logging.getLogger(__name__)
53
+
54
+ # Configuration
55
+ self.max_traversal_depth = 2
56
+ self.max_entities_per_hop = 20
57
+ self.min_entity_confidence = 0.7
58
+ self.min_relationship_confidence = 0.6
59
+
60
+ async def execute_follow_up_phase(self,
61
+ phase_c_results: Dict[str, Any],
62
+ routing_result: DriftRoutingResult) -> Dict[str, Any]:
63
+ """
64
+ Execute follow-up search pipeline based on initial results.
65
+
66
+ Args:
67
+ phase_c_results: Results from community search with follow-up questions
68
+ routing_result: Routing configuration parameters
69
+
70
+ Returns:
71
+ Dictionary with intermediate answers and entity information
72
+ """
73
+ try:
74
+ self.logger.info("Starting Follow-up Search (Steps 9-12)")
75
+
76
+ # Process follow-up questions
77
+ self.logger.info("Starting Step 9: Follow-up Question Processing")
78
+ follow_up_questions = await self._process_follow_up_questions(
79
+ phase_c_results.get('initial_answer', {}).get('follow_up_questions', []),
80
+ routing_result
81
+ )
82
+ self.logger.info(f"Step 9 completed: {len(follow_up_questions)} questions processed")
83
+
84
+ # Local graph traversal
85
+ self.logger.info("Starting Step 10: Local Graph Traversal")
86
+ local_search_results = await self._execute_local_traversal(
87
+ follow_up_questions,
88
+ phase_c_results.get('communities', []),
89
+ routing_result
90
+ )
91
+ self.logger.info(f"Step 10 completed: {len(local_search_results)} searches performed")
92
+
93
+ # Entity extraction
94
+ self.logger.info("Starting Step 11: Detailed Entity Extraction")
95
+ detailed_entities = await self._extract_detailed_entities(
96
+ local_search_results,
97
+ routing_result
98
+ )
99
+ self.logger.info(f"Step 11 completed: {len(detailed_entities)} detailed entities extracted")
100
+
101
+ # Generate intermediate answers
102
+ self.logger.info("Starting Step 12: Intermediate Answer Generation")
103
+ intermediate_answers = await self._generate_intermediate_answers(
104
+ follow_up_questions,
105
+ local_search_results,
106
+ detailed_entities,
107
+ routing_result
108
+ )
109
+ self.logger.info(f"Step 12 completed: {len(intermediate_answers)} intermediate answers generated")
110
+
111
+ # Compile results
112
+ phase_d_results = {
113
+ 'follow_up_questions': follow_up_questions,
114
+ 'local_search_results': local_search_results,
115
+ 'detailed_entities': detailed_entities,
116
+ 'intermediate_answers': intermediate_answers,
117
+ 'execution_stats': {
118
+ 'questions_processed': len(follow_up_questions),
119
+ 'local_searches_executed': len(local_search_results),
120
+ 'entities_extracted': len(detailed_entities),
121
+ 'answers_generated': len(intermediate_answers),
122
+ 'timestamp': datetime.now().isoformat()
123
+ }
124
+ }
125
+
126
+ self.logger.info(f"Phase D completed: {len(intermediate_answers)} detailed answers generated")
127
+ return phase_d_results
128
+
129
+ except Exception as e:
130
+ self.logger.error(f"Phase D execution failed: {e}")
131
+ return {'error': str(e), 'intermediate_answers': []}
132
+
133
+ async def _process_follow_up_questions(self,
134
+ questions: List[str],
135
+ routing_result: DriftRoutingResult) -> List[FollowUpQuestion]:
136
+ """Simple: just wrap questions in FollowUpQuestion objects."""
137
+ processed_questions = []
138
+
139
+ for i, question in enumerate(questions):
140
+ # Extract keywords
141
+ keywords = re.findall(r'\b[A-Z][a-z]+\b|\b[A-Z]{2,}\b', question)
142
+ keywords = [k for k in keywords if k not in ['What', 'Which', 'Who', 'How', 'Are', 'The']]
143
+
144
+ follow_up = FollowUpQuestion(
145
+ question=question,
146
+ question_id=i + 1,
147
+ extracted_entities=keywords[:3], # Top 3 keywords
148
+ query_type='search',
149
+ confidence=0.8
150
+ )
151
+
152
+ processed_questions.append(follow_up)
153
+ self.logger.info(f"Question {i+1}: {question} -> Keywords: {keywords[:3]}")
154
+
155
+ return processed_questions
156
+
157
+
158
+
159
+ async def _execute_local_traversal(self,
160
+ questions: List[FollowUpQuestion],
161
+ communities: List[CommunityResult],
162
+ routing_result: DriftRoutingResult) -> List[LocalSearchResult]:
163
+ """
164
+ Step 10: Execute local graph traversal for each follow-up question.
165
+
166
+ Performs multi-hop traversal from seed entities to find detailed information.
167
+ """
168
+ local_results = []
169
+
170
+ for question in questions:
171
+ try:
172
+ # Find seed entities
173
+ seed_entities = await self._find_seed_entities(
174
+ question.extracted_entities,
175
+ communities
176
+ )
177
+
178
+ if not seed_entities:
179
+ self.logger.warning(f"No seed entities found for question: {question.question}")
180
+ continue
181
+
182
+ # Multi-hop traversal
183
+ traversal_result = await self._multi_hop_traversal(
184
+ seed_entities,
185
+ question,
186
+ routing_result
187
+ )
188
+
189
+ local_results.append(traversal_result)
190
+ self.logger.info(f" Traversal for Q{question.question_id}: {traversal_result.total_nodes_visited} nodes visited")
191
+
192
+ except Exception as e:
193
+ self.logger.error(f"Local traversal failed for question {question.question_id}: {e}")
194
+
195
+ return local_results
196
+
197
+ async def _find_seed_entities(self,
198
+ entity_names: List[str],
199
+ communities: List[CommunityResult]) -> List[EntityResult]:
200
+ """Just search the graph for entities matching the keywords."""
201
+ if not entity_names:
202
+ return []
203
+
204
+ # Search query
205
+ conditions = " OR ".join([f"n.name CONTAINS '{name}'" for name in entity_names])
206
+ query = f"""
207
+ MATCH (n)
208
+ WHERE n.name IS NOT NULL AND ({conditions})
209
+ RETURN n.id as entity_id, n.name as name, n.content as content,
210
+ n.confidence as confidence,
211
+ n.degree_centrality as degree_centrality,
212
+ n.betweenness_centrality as betweenness_centrality,
213
+ n.closeness_centrality as closeness_centrality,
214
+ labels(n) as node_types
215
+ ORDER BY n.degree_centrality DESC
216
+ LIMIT 20
217
+ """
218
+
219
+ try:
220
+ results = self.neo4j_conn.execute_query(query, {})
221
+ entities = []
222
+
223
+ for record in results:
224
+ entity = EntityResult(
225
+ entity_id=record['entity_id'],
226
+ name=record['name'],
227
+ content=record['content'],
228
+ confidence=record['confidence'],
229
+ degree_centrality=record['degree_centrality'],
230
+ betweenness_centrality=record['betweenness_centrality'],
231
+ closeness_centrality=record['closeness_centrality'],
232
+ # Set community info
233
+ community_id='found',
234
+ node_type=', '.join(record['node_types']) if record['node_types'] else 'Entity'
235
+ )
236
+ entities.append(entity)
237
+
238
+ return entities
239
+
240
+ except Exception as e:
241
+ self.logger.error(f"Search failed: {e}")
242
+ return []
243
+
244
+ async def _multi_hop_traversal(self,
245
+ seed_entities: List[EntityResult],
246
+ question: FollowUpQuestion,
247
+ routing_result: DriftRoutingResult) -> LocalSearchResult:
248
+ """Execute multi-hop graph traversal from seed entities."""
249
+
250
+ all_entities = list(seed_entities)
251
+ all_relationships = []
252
+ visited_node_ids = {entity.entity_id for entity in seed_entities}
253
+
254
+ current_entities = seed_entities
255
+
256
+ for hop in range(self.max_traversal_depth):
257
+ if not current_entities:
258
+ break
259
+
260
+ # Get entity IDs for this hop
261
+ current_ids = [entity.entity_id for entity in current_entities]
262
+
263
+ # Multi-hop traversal query
264
+ traversal_query = """
265
+ MATCH (seed)-[r]-(neighbor)
266
+ WHERE seed.id IN $current_ids
267
+ AND NOT (neighbor.id IN $visited_ids)
268
+ AND r.confidence >= $min_rel_confidence
269
+ AND neighbor.confidence >= $min_entity_confidence
270
+ AND neighbor.name IS NOT NULL
271
+ AND neighbor.content IS NOT NULL
272
+ RETURN DISTINCT
273
+ seed.id as seed_id,
274
+ neighbor.id as neighbor_id,
275
+ neighbor.name as neighbor_name,
276
+ neighbor.content as neighbor_content,
277
+ neighbor.confidence as neighbor_confidence,
278
+ neighbor.degree_centrality as degree_centrality,
279
+ neighbor.betweenness_centrality as betweenness_centrality,
280
+ neighbor.closeness_centrality as closeness_centrality,
281
+ labels(neighbor) as neighbor_types,
282
+ type(r) as relationship_type,
283
+ r.confidence as relationship_confidence
284
+ ORDER BY neighbor.degree_centrality DESC, r.confidence DESC
285
+ LIMIT $max_results
286
+ """
287
+
288
+ try:
289
+ results = self.neo4j_conn.execute_query(
290
+ traversal_query,
291
+ {
292
+ 'current_ids': current_ids,
293
+ 'visited_ids': list(visited_node_ids),
294
+ 'min_rel_confidence': self.min_relationship_confidence,
295
+ 'min_entity_confidence': self.min_entity_confidence,
296
+ 'max_results': self.max_entities_per_hop
297
+ }
298
+ )
299
+
300
+ next_hop_entities = []
301
+
302
+ for record in results:
303
+ neighbor_id = record['neighbor_id']
304
+
305
+ if neighbor_id not in visited_node_ids:
306
+ # Create entity result
307
+ entity = EntityResult(
308
+ entity_id=neighbor_id,
309
+ name=record['neighbor_name'],
310
+ content=record['neighbor_content'],
311
+ confidence=record['neighbor_confidence'],
312
+ degree_centrality=record['degree_centrality'] or 0.0,
313
+ betweenness_centrality=record['betweenness_centrality'] or 0.0,
314
+ closeness_centrality=record['closeness_centrality'] or 0.0,
315
+ # Set community info
316
+ community_id='unknown',
317
+ node_type=', '.join(record['neighbor_types']) if record['neighbor_types'] else 'Entity'
318
+ )
319
+
320
+ all_entities.append(entity)
321
+ next_hop_entities.append(entity)
322
+ visited_node_ids.add(neighbor_id)
323
+
324
+ # Create relationship result using REAL schema attributes
325
+ relationship = RelationshipResult(
326
+ start_node=record['seed_id'],
327
+ end_node=neighbor_id,
328
+ relationship_type=record['relationship_type'],
329
+ confidence=record['relationship_confidence']
330
+ # Using REAL schema: startNode, endNode
331
+ )
332
+
333
+ all_relationships.append(relationship)
334
+
335
+ current_entities = next_hop_entities
336
+ self.logger.info(f" Hop {hop + 1}: Found {len(next_hop_entities)} new entities")
337
+
338
+ except Exception as e:
339
+ self.logger.error(f"Multi-hop traversal failed at hop {hop + 1}: {e}")
340
+ break
341
+
342
+ return LocalSearchResult(
343
+ seed_entities=seed_entities,
344
+ traversed_entities=all_entities,
345
+ traversed_relationships=all_relationships,
346
+ search_depth=min(hop + 1, self.max_traversal_depth),
347
+ total_nodes_visited=len(visited_node_ids)
348
+ )
349
+
350
+ async def _extract_detailed_entities(self,
351
+ local_results: List[LocalSearchResult],
352
+ routing_result: DriftRoutingResult) -> List[EntityResult]:
353
+ """
354
+ Step 11: Extract detailed entity information from local search results.
355
+
356
+ Combines and ranks entities from all local searches.
357
+ """
358
+ all_entities = []
359
+ entity_scores = {}
360
+
361
+ # Collect all entities and calculate importance scores
362
+ for search_result in local_results:
363
+ for entity in search_result.traversed_entities:
364
+ if entity.entity_id not in entity_scores:
365
+ # Calculate entity importance score
366
+ importance_score = (
367
+ 0.4 * entity.confidence +
368
+ 0.3 * entity.degree_centrality +
369
+ 0.2 * entity.betweenness_centrality +
370
+ 0.1 * entity.closeness_centrality
371
+ )
372
+
373
+ entity_scores[entity.entity_id] = {
374
+ 'entity': entity,
375
+ 'importance_score': importance_score,
376
+ 'appearance_count': 1
377
+ }
378
+ all_entities.append(entity)
379
+ else:
380
+ # Increment appearance count for entities found in multiple searches
381
+ entity_scores[entity.entity_id]['appearance_count'] += 1
382
+
383
+ # Sort entities by importance score and appearance frequency
384
+ sorted_entities = sorted(
385
+ entity_scores.values(),
386
+ key=lambda x: (x['appearance_count'], x['importance_score']),
387
+ reverse=True
388
+ )
389
+
390
+ # Return top entities
391
+ max_entities = routing_result.parameters.get('max_detailed_entities', 50)
392
+ detailed_entities = [item['entity'] for item in sorted_entities[:max_entities]]
393
+
394
+ self.logger.info(f"Extracted {len(detailed_entities)} detailed entities from {len(all_entities)} total")
395
+ return detailed_entities
396
+
397
+ async def _generate_intermediate_answers(self,
398
+ questions: List[FollowUpQuestion],
399
+ local_results: List[LocalSearchResult],
400
+ detailed_entities: List[EntityResult],
401
+ routing_result: DriftRoutingResult) -> List[IntermediateAnswer]:
402
+ """Simple: just list the entity names we found."""
403
+ answers = []
404
+
405
+ for i, question in enumerate(questions):
406
+ # Get entities from search result
407
+ entities = local_results[i].traversed_entities if i < len(local_results) else []
408
+ entity_names = [e.name for e in entities[:10]]
409
+
410
+ # Simple answer with entity names
411
+ answer_text = f"Found entities: {', '.join(entity_names)}" if entity_names else "No specific entities found."
412
+
413
+ answer = IntermediateAnswer(
414
+ question_id=question.question_id,
415
+ question=question.question,
416
+ answer=answer_text,
417
+ confidence=0.8,
418
+ reasoning=f"Found {len(entity_names)} entities matching the search criteria.",
419
+ supporting_entities=entity_names,
420
+ supporting_evidence=[]
421
+ )
422
+ answers.append(answer)
423
+
424
+ return answers
425
+
426
+
427
+ # Exports
428
+ __all__ = ['FollowUpSearch', 'FollowUpQuestion', 'LocalSearchResult', 'IntermediateAnswer']
429
+
query_graph_functions/knowledge_retrieval.py ADDED
@@ -0,0 +1,843 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Knowledge Retrieval Module - Phase C (Steps 6-8)
3
+
4
+ Performs community search and data extraction using graph database structures.
5
+ Handles community retrieval, data extraction, and initial answer generation.
6
+ """
7
+
8
+ import logging
9
+ import numpy as np
10
+ import json
11
+ from typing import Dict, List, Tuple, Any
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+
15
+ from .setup import GraphRAGSetup
16
+ from .query_preprocessing import DriftRoutingResult
17
+
18
+
19
+ @dataclass
20
+ class CommunityResult:
21
+ """Enhanced community result with comprehensive properties."""
22
+ community_id: str
23
+ similarity_score: float
24
+ summary: str
25
+ key_entities: List[str]
26
+ member_ids: List[str] # Direct member access
27
+ modularity_score: float # Community quality
28
+ level: int
29
+ internal_edges: int
30
+ member_count: int
31
+ centrality_stats: Dict[str, float] # Aggregated centrality measures
32
+ confidence_score: float
33
+ search_index: str # Optimized search key
34
+ termination_criteria: Dict[str, Any]
35
+
36
+
37
+ @dataclass
38
+ class EntityResult:
39
+ """Entity result with attributes from graph database."""
40
+ entity_id: str
41
+ name: str
42
+ content: str
43
+ confidence: float
44
+ degree_centrality: float
45
+ betweenness_centrality: float
46
+ closeness_centrality: float
47
+ community_id: str
48
+ node_type: str
49
+
50
+
51
+ @dataclass
52
+ class RelationshipResult:
53
+ """Relationship result with graph database attributes."""
54
+ start_node: str
55
+ end_node: str
56
+ relationship_type: str
57
+ confidence: float
58
+
59
+
60
+ class CommunitySearchEngine:
61
+ """Knowledge retrieval engine for community search and entity extraction."""
62
+
63
+ def __init__(self, setup: GraphRAGSetup):
64
+ self.setup = setup
65
+ self.neo4j_conn = setup.neo4j_conn
66
+ self.config = setup.config
67
+ self.logger = logging.getLogger(self.__class__.__name__)
68
+
69
+ # Initialize search optimization
70
+ self.community_search_index = {}
71
+ self.centrality_cache = {}
72
+
73
+ async def execute_primer_phase(self,
74
+ query_embedding: List[float],
75
+ routing_result: DriftRoutingResult) -> Dict[str, Any]:
76
+ """Execute community search and knowledge retrieval."""
77
+ start_time = datetime.now()
78
+
79
+ try:
80
+ # Community retrieval
81
+ self.logger.info("Starting community retrieval")
82
+ communities = await self._retrieve_communities_enhanced(
83
+ query_embedding, routing_result
84
+ )
85
+
86
+ # Data extraction
87
+ self.logger.info("Starting data extraction")
88
+ extracted_data = await self._extract_community_data_enhanced(communities)
89
+
90
+ # Answer generation
91
+ self.logger.info("Starting answer generation")
92
+ initial_answer = await self._generate_initial_answer_enhanced(
93
+ extracted_data, routing_result
94
+ )
95
+
96
+ execution_time = (datetime.now() - start_time).total_seconds()
97
+
98
+ return {
99
+ 'communities': communities,
100
+ 'extracted_data': extracted_data,
101
+ 'initial_answer': initial_answer,
102
+ 'execution_time': execution_time,
103
+ 'metadata': {
104
+ 'communities_retrieved': len(communities),
105
+ 'entities_extracted': len(extracted_data.get('entities', [])),
106
+ 'relationships_extracted': len(extracted_data.get('relationships', [])),
107
+ 'phase': 'primer',
108
+ 'step_range': '6-8'
109
+ }
110
+ }
111
+
112
+ except Exception as e:
113
+ self.logger.error(f"Primer phase execution failed: {e}")
114
+ raise
115
+
116
+ async def _retrieve_communities_enhanced(self,
117
+ query_embedding: List[float],
118
+ routing_result: DriftRoutingResult) -> List[CommunityResult]:
119
+ """
120
+ Step 6: Enhanced community retrieval using comprehensive properties.
121
+
122
+ Retrieves relevant communities based on query embedding similarity.
123
+ """
124
+ try:
125
+ # Retrieve HyDE embeddings
126
+ hyde_embeddings = await self._retrieve_hyde_embeddings_enhanced()
127
+
128
+ if not hyde_embeddings:
129
+ self.logger.warning("No HyDE embeddings found")
130
+ return []
131
+
132
+ # Compute similarities
133
+ similarities = self._compute_hyde_similarities_enhanced(
134
+ query_embedding, hyde_embeddings
135
+ )
136
+
137
+ # Rank communities
138
+ ranked_communities = self._rank_communities_enhanced(
139
+ similarities, routing_result
140
+ )
141
+
142
+ # Apply criteria
143
+ filtered_communities = self._apply_termination_criteria(
144
+ ranked_communities, routing_result
145
+ )
146
+
147
+ # Fetch community details
148
+ community_results = await self._fetch_community_details_enhanced(
149
+ filtered_communities
150
+ )
151
+
152
+ self.logger.info(f"Retrieved {len(community_results)} enhanced communities")
153
+ return community_results
154
+
155
+ except Exception as e:
156
+ self.logger.error(f"Enhanced community retrieval failed: {e}")
157
+ return []
158
+
159
+ async def _load_community_search_index(self):
160
+ """Load optimized community search index from Neo4j."""
161
+ try:
162
+ query = """
163
+ MATCH (meta:DriftMetadata)
164
+ WHERE meta.community_search_index IS NOT NULL
165
+ RETURN meta.community_search_index as search_index,
166
+ meta.total_communities as total_communities
167
+ """
168
+
169
+ results = self.neo4j_conn.execute_query(query)
170
+
171
+ for record in results:
172
+ # The search index is a nested JSON structure with community IDs as keys
173
+ search_index_data = record['search_index']
174
+ if isinstance(search_index_data, dict):
175
+ # Each community in the search index
176
+ for community_id, community_data in search_index_data.items():
177
+ self.community_search_index[community_id] = community_data
178
+ else:
179
+ self.logger.warning(f"Unexpected search index format: {type(search_index_data)}")
180
+
181
+ self.logger.info(f"Loaded search index for {len(self.community_search_index)} communities")
182
+
183
+ except Exception as e:
184
+ self.logger.error(f"Failed to load community search index: {e}")
185
+
186
+ async def _retrieve_hyde_embeddings_enhanced(self) -> Dict[str, Dict[str, Any]]:
187
+ """Retrieve HyDE embeddings and metadata."""
188
+ try:
189
+ # Retrieve community embeddings
190
+ query = """
191
+ MATCH (c:Community)
192
+ WHERE c.hyde_embeddings IS NOT NULL
193
+ OPTIONAL MATCH (meta:CommunitiesMetadata)
194
+ RETURN c.id as community_id,
195
+ c.hyde_embeddings as hyde_embeddings,
196
+ c.summary as summary,
197
+ c.key_entities as key_entities,
198
+ c.member_ids as member_ids,
199
+ size(c.hyde_embeddings) as embedding_size,
200
+ meta.modularity_score as global_modularity_score
201
+ """
202
+
203
+ results = self.neo4j_conn.execute_query(query)
204
+ hyde_embeddings = {}
205
+
206
+ for record in results:
207
+ community_id = record['community_id']
208
+ embeddings_data = record.get('hyde_embeddings')
209
+
210
+ if embeddings_data and community_id:
211
+ hyde_embeddings[community_id] = {
212
+ 'embeddings': embeddings_data,
213
+ 'summary': record.get('summary', ''),
214
+ 'key_entities': record.get('key_entities', []),
215
+ 'member_ids': record.get('member_ids', []),
216
+ 'embedding_size': record.get('embedding_size', 0),
217
+ 'global_modularity_score': record.get('global_modularity_score', 0.0),
218
+ 'embedding_type': 'hyde'
219
+ }
220
+
221
+ self.logger.info(f"Retrieved enhanced HyDE embeddings for {len(hyde_embeddings)} communities")
222
+ return hyde_embeddings
223
+
224
+ except Exception as e:
225
+ self.logger.error(f"Failed to retrieve enhanced HyDE embeddings: {e}")
226
+ # Retry logic for embeddings
227
+ self.logger.info("Attempting retry for HyDE embeddings...")
228
+ try:
229
+ import time
230
+ time.sleep(2) # Brief delay before retry
231
+ results = self.neo4j_conn.execute_query(query)
232
+ hyde_embeddings = {}
233
+
234
+ for record in results:
235
+ community_id = record['community_id']
236
+ embeddings_data = record.get('hyde_embeddings')
237
+
238
+ if embeddings_data and community_id:
239
+ hyde_embeddings[community_id] = {
240
+ 'embeddings': embeddings_data,
241
+ 'summary': record.get('summary', ''),
242
+ 'key_entities': record.get('key_entities', []),
243
+ 'member_ids': record.get('member_ids', []),
244
+ 'embedding_size': record.get('embedding_size', 0),
245
+ 'global_modularity': record.get('global_modularity_score', 0.0)
246
+ }
247
+
248
+ self.logger.info(f"Retry successful: Retrieved enhanced HyDE embeddings for {len(hyde_embeddings)} communities")
249
+ return hyde_embeddings
250
+
251
+ except Exception as retry_error:
252
+ self.logger.error(f"Retry also failed: {retry_error}")
253
+ return {}
254
+
255
+ def _compute_hyde_similarities_enhanced(self,
256
+ query_embedding: List[float],
257
+ hyde_embeddings: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, float]]:
258
+ """
259
+ Enhanced similarity computation with global modularity weighting.
260
+
261
+ Calculates similarity scores between query embedding and community embeddings.
262
+ """
263
+ similarities = {}
264
+ query_vec = np.array(query_embedding)
265
+ query_norm = np.linalg.norm(query_vec)
266
+
267
+ if query_norm == 0:
268
+ self.logger.warning("Query embedding has zero norm")
269
+ return similarities
270
+
271
+ for community_id, embedding_data in hyde_embeddings.items():
272
+ embeddings_list = embedding_data['embeddings']
273
+ global_modularity = embedding_data.get('global_modularity_score', 0.0)
274
+
275
+ max_similarity = 0.0
276
+
277
+ # Compute similarity
278
+ try:
279
+ # Parse embedding string
280
+ if isinstance(embeddings_list, str):
281
+ embeddings_list = json.loads(embeddings_list)
282
+
283
+ # Process embeddings
284
+ if isinstance(embeddings_list, list) and len(embeddings_list) > 0:
285
+ # Use first embedding
286
+ hyde_vec = np.array(embeddings_list[0] if isinstance(embeddings_list[0], list) else embeddings_list)
287
+ else:
288
+ hyde_vec = np.array(embeddings_list)
289
+
290
+ hyde_norm = np.linalg.norm(hyde_vec)
291
+
292
+ if hyde_norm > 0:
293
+ # Calculate similarity
294
+ base_similarity = np.dot(query_vec, hyde_vec) / (query_norm * hyde_norm)
295
+
296
+ # Apply weighting
297
+ weighted_similarity = base_similarity * (1 + 0.2 * global_modularity)
298
+ max_similarity = weighted_similarity
299
+
300
+ except Exception as e:
301
+ self.logger.warning(f"Error computing similarity for community {community_id}: {e}")
302
+ continue
303
+
304
+ similarities[community_id] = {
305
+ 'similarity': max_similarity,
306
+ 'global_modularity_score': global_modularity,
307
+ 'embedding_size': embedding_data.get('embedding_size', 0)
308
+ }
309
+
310
+ self.logger.info(f"Computed enhanced similarities for {len(similarities)} communities")
311
+ return similarities
312
+
313
+ def _rank_communities_enhanced(self,
314
+ similarities: Dict[str, Dict[str, float]],
315
+ routing_result: DriftRoutingResult) -> List[Tuple[str, Dict[str, float]]]:
316
+ """
317
+ Enhanced ranking using global modularity and similarity.
318
+
319
+ Ranks communities based on a weighted combination of similarity score and modularity.
320
+ """
321
+
322
+ # Rank primarily by similarity, with modularity as secondary factor
323
+
324
+ def ranking_score(item):
325
+ _, scores = item
326
+ similarity = scores['similarity']
327
+ global_modularity = scores['global_modularity_score']
328
+
329
+ # Weighted combination (similarity is primary)
330
+ return 0.8 * similarity + 0.2 * global_modularity
331
+
332
+ # Sort by combined ranking score
333
+ ranked = sorted(similarities.items(), key=ranking_score, reverse=True)
334
+
335
+ # Apply similarity threshold
336
+ similarity_threshold = routing_result.parameters.get('similarity_threshold', 0.7)
337
+ filtered_ranked = [
338
+ (cid, scores) for cid, scores in ranked
339
+ if scores['similarity'] >= similarity_threshold
340
+ ]
341
+
342
+ self.logger.info(f"Enhanced ranking: {len(filtered_ranked)} communities above threshold {similarity_threshold}")
343
+ return filtered_ranked
344
+
345
+ def _apply_termination_criteria(self,
346
+ ranked_communities: List[Tuple[str, Dict[str, float]]],
347
+ routing_result: DriftRoutingResult) -> List[Tuple[str, Dict[str, float]]]:
348
+ """
349
+ Apply termination criteria for community selection.
350
+
351
+ Limits the number of communities selected based on threshold parameters.
352
+ """
353
+
354
+ # Get termination criteria from routing or defaults
355
+ max_communities = routing_result.parameters.get('max_communities', 3)
356
+ min_global_modularity = routing_result.parameters.get('min_global_modularity', 0.3)
357
+
358
+ # Apply criteria
359
+ filtered = []
360
+ for community_id, scores in ranked_communities:
361
+ if len(filtered) >= max_communities:
362
+ break
363
+
364
+ # Check global modularity threshold
365
+ if scores['global_modularity_score'] >= min_global_modularity:
366
+ filtered.append((community_id, scores))
367
+
368
+ self.logger.info(f"Applied termination criteria: {len(filtered)} communities selected")
369
+ return filtered
370
+
371
+ async def _fetch_community_details_enhanced(self,
372
+ ranked_communities: List[Tuple[str, Dict[str, float]]]) -> List[CommunityResult]:
373
+ """
374
+ Fetch comprehensive community details with all properties.
375
+
376
+ Retrieves detailed information about selected communities including summaries,
377
+ key entities, and member IDs.
378
+ """
379
+ community_results = []
380
+
381
+ for community_id, scores in ranked_communities:
382
+ try:
383
+ # Query the Community node directly by ID (since embedding communities have id=community_id)
384
+ detail_query = """
385
+ MATCH (c:Community)
386
+ WHERE c.id = $community_id AND c.hyde_embeddings IS NOT NULL
387
+ OPTIONAL MATCH (meta:CommunitiesMetadata)
388
+ RETURN c.summary as summary,
389
+ c.key_entities as key_entities,
390
+ c.member_ids as member_ids,
391
+ c.internal_edges as internal_edges,
392
+ c.density as density,
393
+ c.avg_degree as avg_degree,
394
+ c.level as level,
395
+ meta.modularity_score as modularity_score,
396
+ CASE WHEN c.member_ids IS NOT NULL THEN size(c.member_ids) ELSE 0 END as member_count,
397
+ c.id as id
398
+ LIMIT 1
399
+ """
400
+
401
+ results = self.neo4j_conn.execute_query(
402
+ detail_query,
403
+ {'community_id': community_id}
404
+ )
405
+
406
+ if results:
407
+ record = results[0]
408
+
409
+ # Create enhanced community result with actual available data from Neo4j
410
+ community_result = CommunityResult(
411
+ community_id=community_id,
412
+ similarity_score=scores['similarity'],
413
+ summary=record.get('summary', ''),
414
+ key_entities=record.get('key_entities', []),
415
+ member_ids=record.get('member_ids', []),
416
+ modularity_score=record.get('modularity_score', 0.0),
417
+ level=record.get('level', 1),
418
+ internal_edges=record.get('internal_edges', 0),
419
+ member_count=record.get('member_count', 0),
420
+ confidence_score=scores.get('confidence_score', 0.5),
421
+ search_index='',
422
+ termination_criteria={},
423
+ centrality_stats={
424
+ 'avg_degree': record.get('avg_degree', 0.0),
425
+ 'density': record.get('density', 0.0)
426
+ }
427
+ )
428
+
429
+ community_results.append(community_result)
430
+
431
+ except Exception as e:
432
+ self.logger.error(f"Failed to fetch details for community {community_id}: {e}")
433
+ continue
434
+
435
+ self.logger.info(f"Fetched enhanced details for {len(community_results)} communities")
436
+ return community_results
437
+
438
+ async def _extract_community_data_enhanced(self,
439
+ communities: List[CommunityResult]) -> Dict[str, Any]:
440
+ """
441
+ Step 7: Enhanced data extraction with centrality measures.
442
+
443
+ Extracts:
444
+ - Entities with degree/betweenness/closeness centrality
445
+ - Relationships with confidence scores
446
+ - Community statistics and properties
447
+ """
448
+ try:
449
+ all_entities = []
450
+ all_relationships = []
451
+ community_stats = []
452
+
453
+ for community in communities:
454
+ # Extract entities with centrality measures
455
+ entities = await self._extract_entities_with_centrality(community)
456
+ all_entities.extend(entities)
457
+
458
+ # Extract relationships with properties
459
+ relationships = await self._extract_relationships_enhanced(community)
460
+ all_relationships.extend(relationships)
461
+
462
+ # Collect community statistics
463
+ community_stats.append({
464
+ 'community_id': community.community_id,
465
+ 'member_count': community.member_count,
466
+ 'modularity_score': community.modularity_score,
467
+ 'confidence_score': community.confidence_score,
468
+ 'centrality_stats': community.centrality_stats
469
+ })
470
+
471
+ extracted_data = {
472
+ 'entities': all_entities,
473
+ 'relationships': all_relationships,
474
+ 'community_stats': community_stats,
475
+ 'extraction_metadata': {
476
+ 'communities_processed': len(communities),
477
+ 'entities_extracted': len(all_entities),
478
+ 'relationships_extracted': len(all_relationships),
479
+ 'timestamp': datetime.now().isoformat()
480
+ }
481
+ }
482
+
483
+ self.logger.info(f"Enhanced extraction completed: {len(all_entities)} entities, {len(all_relationships)} relationships")
484
+ return extracted_data
485
+
486
+ except Exception as e:
487
+ self.logger.error(f"Enhanced data extraction failed: {e}")
488
+ return {'entities': [], 'relationships': [], 'community_stats': []}
489
+
490
+ async def _extract_entities_with_centrality(self,
491
+ community: CommunityResult) -> List[EntityResult]:
492
+ """
493
+ Extract entities with comprehensive centrality measures.
494
+
495
+ Retrieves entities from the community with their associated centrality metrics.
496
+ """
497
+ try:
498
+ # Use member_ids for direct access if available
499
+ member_ids = community.member_ids if community.member_ids else []
500
+
501
+ if member_ids:
502
+ # Direct member access query based on actual schema
503
+ entity_query = """
504
+ MATCH (n)
505
+ WHERE n.id IN $member_ids
506
+ AND n.name IS NOT NULL
507
+ AND n.content IS NOT NULL
508
+ RETURN n.id as entity_id,
509
+ n.name as name,
510
+ n.content as content,
511
+ n.confidence as confidence,
512
+ n.degree_centrality as degree_centrality,
513
+ n.betweenness_centrality as betweenness_centrality,
514
+ n.closeness_centrality as closeness_centrality,
515
+ labels(n) as node_types
516
+ ORDER BY n.degree_centrality DESC
517
+ """
518
+
519
+ results = self.neo4j_conn.execute_query(
520
+ entity_query,
521
+ {'member_ids': member_ids}
522
+ )
523
+ else:
524
+ # Fallback: find entities using community_id pattern matching
525
+ entity_query = """
526
+ MATCH (n)
527
+ WHERE n.community_id IS NOT NULL
528
+ AND n.name IS NOT NULL
529
+ AND n.content IS NOT NULL
530
+ RETURN n.id as entity_id,
531
+ n.name as name,
532
+ n.content as content,
533
+ n.confidence as confidence,
534
+ n.degree_centrality as degree_centrality,
535
+ n.betweenness_centrality as betweenness_centrality,
536
+ n.closeness_centrality as closeness_centrality,
537
+ labels(n) as node_types
538
+ ORDER BY n.degree_centrality DESC
539
+ LIMIT 20
540
+ """
541
+
542
+ results = self.neo4j_conn.execute_query(entity_query)
543
+
544
+ entities = []
545
+ for record in results:
546
+ entity = EntityResult(
547
+ entity_id=record['entity_id'],
548
+ name=record.get('name', ''),
549
+ content=record.get('content', ''),
550
+ confidence=record.get('confidence', 0.0),
551
+ degree_centrality=record.get('degree_centrality', 0.0),
552
+ betweenness_centrality=record.get('betweenness_centrality', 0.0),
553
+ closeness_centrality=record.get('closeness_centrality', 0.0),
554
+ community_id=community.community_id,
555
+ node_type=record.get('node_types', ['Unknown'])[0] if record.get('node_types') else 'Unknown'
556
+ )
557
+ entities.append(entity)
558
+
559
+ return entities
560
+
561
+ except Exception as e:
562
+ self.logger.error(f"Failed to extract entities for community {community.community_id}: {e}")
563
+ return []
564
+
565
+ async def _extract_relationships_enhanced(self,
566
+ community: CommunityResult) -> List[RelationshipResult]:
567
+ """
568
+ Extract relationships with enhanced properties.
569
+
570
+ Retrieves relationship data between entities within the specified community.
571
+ """
572
+ try:
573
+ relationship_query = """
574
+ MATCH (a)-[r]->(b)
575
+ WHERE a.community_id = $community_id
576
+ AND b.community_id = $community_id
577
+ AND r.confidence > 0.5
578
+ RETURN startNode(r).id as start_node,
579
+ endNode(r).id as end_node,
580
+ type(r) as relationship_type,
581
+ r.confidence as confidence
582
+ ORDER BY r.confidence DESC
583
+ LIMIT 50
584
+ """
585
+
586
+ results = self.neo4j_conn.execute_query(
587
+ relationship_query,
588
+ {'community_id': community.community_id}
589
+ )
590
+
591
+ relationships = []
592
+ for record in results:
593
+ relationship = RelationshipResult(
594
+ start_node=record['start_node'],
595
+ end_node=record['end_node'],
596
+ relationship_type=record['relationship_type'],
597
+ confidence=record.get('confidence', 0.0)
598
+ )
599
+ relationships.append(relationship)
600
+
601
+ return relationships
602
+
603
+ except Exception as e:
604
+ self.logger.error(f"Failed to extract relationships for community {community.community_id}: {e}")
605
+ return []
606
+
607
+ async def _generate_initial_answer_enhanced(self,
608
+ extracted_data: Dict[str, Any],
609
+ routing_result: DriftRoutingResult) -> Dict[str, Any]:
610
+ """
611
+ Step 8: Context-aware initial answer generation.
612
+
613
+ Uses:
614
+ - Entity importance from centrality measures
615
+ - Relationship confidence for evidence strength
616
+ - Community statistics for context sizing
617
+ """
618
+ try:
619
+ entities = extracted_data['entities']
620
+ relationships = extracted_data['relationships']
621
+ community_stats = extracted_data['community_stats']
622
+
623
+ # Rank entities by importance (centrality measures)
624
+ important_entities = sorted(
625
+ entities,
626
+ key=lambda e: (e.degree_centrality + e.betweenness_centrality) / 2,
627
+ reverse=True
628
+ )[:10]
629
+
630
+ # Select high-confidence relationships
631
+ strong_relationships = [
632
+ r for r in relationships
633
+ if r.confidence >= 0.7
634
+ ]
635
+
636
+ # Prepare context for LLM
637
+ llm_context = self._prepare_llm_context_enhanced(
638
+ important_entities, strong_relationships, community_stats, routing_result
639
+ )
640
+
641
+ # Generate initial answer using configured LLM
642
+ llm_response = await self._generate_llm_answer(llm_context, routing_result)
643
+
644
+ initial_answer = {
645
+ 'content': llm_response['answer'],
646
+ 'llm_context': llm_context,
647
+ 'context_used': {
648
+ 'important_entities': len(important_entities),
649
+ 'strong_relationships': len(strong_relationships),
650
+ 'communities_analyzed': len(community_stats)
651
+ },
652
+ 'confidence_metrics': {
653
+ 'avg_entity_centrality': np.mean([e.degree_centrality for e in important_entities]) if important_entities else 0,
654
+ 'avg_relationship_confidence': np.mean([r.confidence for r in strong_relationships]) if strong_relationships else 0,
655
+ 'avg_community_modularity': np.mean([c['modularity_score'] for c in community_stats]) if community_stats else 0,
656
+ 'llm_confidence': llm_response['confidence']
657
+ },
658
+ 'follow_up_questions': llm_response['follow_up_questions'],
659
+ 'reasoning': llm_response['reasoning']
660
+ }
661
+
662
+ self.logger.info("Enhanced initial answer generated with comprehensive context")
663
+ return initial_answer
664
+
665
+ except Exception as e:
666
+ self.logger.error(f"Enhanced answer generation failed: {e}")
667
+ return {'content': 'Error generating initial answer', 'error': str(e)}
668
+
669
+ def _prepare_llm_context_enhanced(self,
670
+ entities: List[EntityResult],
671
+ relationships: List[RelationshipResult],
672
+ community_stats: List[Dict[str, Any]],
673
+ routing_result: DriftRoutingResult) -> str:
674
+ """Prepare enhanced context for LLM with comprehensive information."""
675
+
676
+ context_parts = [
677
+ f"Query: {routing_result.original_query}",
678
+ f"Search Strategy: {routing_result.search_strategy.value}",
679
+ "",
680
+ "=== IMPORTANT ENTITIES (Use these specific names in your answer) ===",
681
+ ]
682
+
683
+ for i, entity in enumerate(entities[:10], 1): # Show more entities
684
+ context_parts.append(
685
+ f"{i}. NAME: '{entity.name}' | Description: {entity.content[:100]}... "
686
+ f"| Centrality: {entity.degree_centrality:.3f} | Confidence: {entity.confidence:.3f}"
687
+ )
688
+
689
+ context_parts.extend([
690
+ "",
691
+ "=== KEY RELATIONSHIPS (Use these connections in your answer) ===",
692
+ ])
693
+
694
+ for i, rel in enumerate(relationships[:8], 1): # Show more relationships
695
+ context_parts.append(
696
+ f"{i}. '{rel.start_node}' --[{rel.relationship_type}]--> '{rel.end_node}' "
697
+ f"| Confidence: {rel.confidence:.3f}"
698
+ )
699
+
700
+ # Add quick reference list of all entity names
701
+ entity_names = [entity.name for entity in entities[:15]]
702
+ context_parts.extend([
703
+ "",
704
+ "=== ENTITY NAMES FOR REFERENCE ===",
705
+ f"Available entities: {', '.join(entity_names)}",
706
+ "",
707
+ "=== COMMUNITY STATISTICS ===",
708
+ ])
709
+
710
+ for stat in community_stats:
711
+ context_parts.append(
712
+ f"Community {stat['community_id']}: {stat['member_count']} members, "
713
+ f"modularity: {stat['modularity_score']:.3f}"
714
+ )
715
+
716
+ context_parts.extend([
717
+ "",
718
+ "REMEMBER: Use the specific entity names listed above in your answer!"
719
+ ])
720
+
721
+ return "\n".join(context_parts)
722
+
723
+ async def _generate_llm_answer(self,
724
+ context: str,
725
+ routing_result: DriftRoutingResult) -> Dict[str, Any]:
726
+ """
727
+ Generate actual LLM response using the configured LLM.
728
+
729
+ Uses the LLM from GraphRAGSetup to generate answers with follow-up questions.
730
+ """
731
+ try:
732
+ # Construct comprehensive prompt for LLM
733
+ prompt = f"""
734
+ You are an expert knowledge analyst. Answer the user's query using SPECIFIC NAMES and information from the graph data provided below.
735
+
736
+ IMPORTANT: Use the actual entity names, organization names, and relationship details from the graph data. Do not give generic answers.
737
+
738
+ GRAPH DATA CONTEXT:
739
+ {context}
740
+
741
+ INSTRUCTIONS:
742
+ 1. Answer using SPECIFIC ENTITY NAMES from the "IMPORTANT ENTITIES" section above
743
+ 2. Reference actual relationships and organizations mentioned in the graph data
744
+ 3. If the query asks for members/organizations, LIST THE ACTUAL NAMES from the entities
745
+ 4. Use confidence scores and centrality measures as evidence strength indicators
746
+ 5. Generate follow-up questions based on the specific entities found
747
+
748
+ RESPONSE FORMAT:
749
+ Answer: [Use specific names and details from the graph data above]
750
+ Confidence: [0.0-1.0]
751
+ Reasoning: [Why these specific entities answer the query]
752
+ Follow-up Questions:
753
+ 1. [Specific question about entities found]
754
+ 2. [Question about relationships discovered]
755
+ 3. [Question about community connections]
756
+ 4. [Question for deeper exploration]
757
+ 5. [Question about related entities]
758
+ """
759
+
760
+ # Call the configured LLM
761
+ llm_response = await self.setup.llm.acomplete(prompt)
762
+ response_text = llm_response.text
763
+
764
+ # Parse LLM response
765
+ parsed_response = self._parse_llm_response(response_text)
766
+
767
+ self.logger.info(f"LLM generated answer with confidence: {parsed_response['confidence']}")
768
+ return parsed_response
769
+
770
+ except Exception as e:
771
+ self.logger.error(f"LLM answer generation failed: {e}")
772
+ # Fallback response
773
+ return {
774
+ 'answer': f"Based on the graph analysis, I found relevant information but encountered an issue generating the full response: {str(e)}",
775
+ 'confidence': 0.3,
776
+ 'reasoning': "LLM generation encountered an error, providing basic analysis from graph data.",
777
+ 'follow_up_questions': [
778
+ "What specific aspects would you like me to explore further?",
779
+ "Are there particular entities or relationships of interest?",
780
+ "Should I focus on a specific community or time period?"
781
+ ]
782
+ }
783
+
784
+ def _parse_llm_response(self, response_text: str) -> Dict[str, Any]:
785
+ """Parse structured LLM response into components."""
786
+ try:
787
+ lines = response_text.strip().split('\n')
788
+
789
+ answer = ""
790
+ confidence = 0.5
791
+ reasoning = ""
792
+ follow_up_questions = []
793
+
794
+ current_section = None
795
+
796
+ for line in lines:
797
+ line = line.strip()
798
+
799
+ if line.startswith("Answer:"):
800
+ current_section = "answer"
801
+ answer = line.replace("Answer:", "").strip()
802
+ elif line.startswith("Confidence:"):
803
+ confidence_text = line.replace("Confidence:", "").strip()
804
+ try:
805
+ confidence = float(confidence_text)
806
+ except (ValueError, TypeError):
807
+ confidence = 0.5
808
+ elif line.startswith("Reasoning:"):
809
+ current_section = "reasoning"
810
+ reasoning = line.replace("Reasoning:", "").strip()
811
+ elif line.startswith("Follow-up Questions:"):
812
+ current_section = "questions"
813
+ elif current_section == "answer" and line:
814
+ answer += " " + line
815
+ elif current_section == "reasoning" and line:
816
+ reasoning += " " + line
817
+ elif current_section == "questions" and line.startswith(("1.", "2.", "3.", "4.", "5.")):
818
+ question = line[2:].strip() # Remove "1. " etc.
819
+ follow_up_questions.append(question)
820
+
821
+ return {
822
+ 'answer': answer.strip() if answer else "Unable to generate answer from available context.",
823
+ 'confidence': max(0.0, min(1.0, confidence)), # Clamp between 0-1
824
+ 'reasoning': reasoning.strip() if reasoning else "Analysis based on graph structure and entity relationships.",
825
+ 'follow_up_questions': follow_up_questions if follow_up_questions else [
826
+ "What additional information would be helpful?",
827
+ "Are there specific aspects to explore further?",
828
+ "Should I analyze different communities or relationships?"
829
+ ]
830
+ }
831
+
832
+ except Exception as e:
833
+ self.logger.error(f"Failed to parse LLM response: {e}")
834
+ return {
835
+ 'answer': response_text[:500] if response_text else "No response generated.",
836
+ 'confidence': 0.4,
837
+ 'reasoning': "Direct LLM output due to parsing issues.",
838
+ 'follow_up_questions': ["What would you like to know more about?"]
839
+ }
840
+
841
+
842
+ # Exports
843
+ __all__ = ['CommunitySearchEngine', 'CommunityResult', 'EntityResult', 'RelationshipResult']
query_graph_functions/query_preprocessing.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Query preprocessing for analysis, routing, and vectorization - Phase B (Steps 3-5)."""
2
+
3
+ import logging
4
+ from typing import Dict, List, Any, Tuple, Optional
5
+ from dataclasses import dataclass
6
+ from enum import Enum
7
+ import re
8
+
9
+ # System imports
10
+ import sys
11
+ import os
12
+ sys.path.append(os.path.dirname(os.path.dirname(__file__)))
13
+
14
+ from my_config import MY_CONFIG
15
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
16
+
17
+
18
+ class QueryType(Enum):
19
+ """Query type classifications for DRIFT routing."""
20
+ SPECIFIC_ENTITY = "specific_entity"
21
+ RELATIONSHIP_QUERY = "relationship_query"
22
+ BROAD_THEMATIC = "broad_thematic"
23
+ COMPARATIVE = "comparative"
24
+ COMPLEX_REASONING = "complex_reasoning"
25
+ FACTUAL_LOOKUP = "factual_lookup"
26
+
27
+
28
+ class SearchStrategy(Enum):
29
+ """Search strategy determined by DRIFT routing."""
30
+ LOCAL_SEARCH = "local_search"
31
+ GLOBAL_SEARCH = "global_search"
32
+ HYBRID_SEARCH = "hybrid_search"
33
+
34
+
35
+ @dataclass
36
+ class QueryAnalysis:
37
+ """Results of query analysis step."""
38
+ query_type: QueryType
39
+ complexity_score: float # 0.0 to 1.0
40
+ entities_mentioned: List[str]
41
+ key_concepts: List[str]
42
+ intent_description: str
43
+ context_requirements: Dict[str, Any]
44
+ estimated_scope: str # "narrow", "moderate", "broad"
45
+
46
+
47
+ @dataclass
48
+ @dataclass
49
+ class DriftRoutingResult:
50
+ """Results of DRIFT routing decision."""
51
+ search_strategy: SearchStrategy
52
+ reasoning: str
53
+ confidence: float # 0.0 to 1.0
54
+ parameters: Dict[str, Any]
55
+ original_query: str # Added to fix answer generation
56
+ fallback_strategy: Optional[SearchStrategy] = None
57
+
58
+
59
+ @dataclass
60
+ class VectorizedQuery:
61
+ """Results of query vectorization."""
62
+ embedding: List[float]
63
+ embedding_model: str
64
+ normalized_query: str
65
+ semantic_keywords: List[str]
66
+ similarity_threshold: float
67
+
68
+
69
+ class QueryAnalyzer:
70
+ """Handles Step 3: Query Analysis with intent detection and complexity assessment."""
71
+
72
+ def __init__(self, config: Any):
73
+ self.config = config
74
+ self.logger = logging.getLogger('graphrag_query')
75
+
76
+ # Entity extraction patterns
77
+ self.entity_patterns = [
78
+ r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', # Proper nouns
79
+ r'\b(?:company|organization|person|place|event)\s+(?:named|called)?\s*["\']?([^"\']+)["\']?',
80
+ r'\bwho\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
81
+ r'\bwhat\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
82
+ ]
83
+
84
+ # Complexity indicators
85
+ self.complexity_indicators = {
86
+ 'high': ['compare', 'analyze', 'evaluate', 'relationship', 'impact', 'why', 'how'],
87
+ 'medium': ['describe', 'explain', 'summarize', 'list', 'identify'],
88
+ 'low': ['who', 'what', 'when', 'where', 'is', 'are']
89
+ }
90
+
91
+ self.logger.info("QueryAnalyzer initialized for Step 3 processing")
92
+
93
+ async def analyze_query(self, query: str) -> QueryAnalysis:
94
+ """Analyze query for intent, complexity, and entities."""
95
+ self.logger.info(f"Starting Step 3: Query Analysis for: {query[:100]}...")
96
+
97
+ try:
98
+ # Extract entities and concepts
99
+ entities = self._extract_entities(query)
100
+ concepts = self._extract_key_concepts(query)
101
+ query_type = self._classify_query_type(query, entities, concepts)
102
+ complexity = self._calculate_complexity(query, query_type)
103
+ intent = self._determine_intent(query, query_type)
104
+ scope = self._estimate_scope(query, entities, concepts, complexity)
105
+
106
+ # Build context
107
+ context_reqs = self._analyze_context_requirements(query, query_type, entities)
108
+
109
+ analysis = QueryAnalysis(
110
+ query_type=query_type,
111
+ complexity_score=complexity,
112
+ entities_mentioned=entities,
113
+ key_concepts=concepts,
114
+ intent_description=intent,
115
+ context_requirements=context_reqs,
116
+ estimated_scope=scope
117
+ )
118
+
119
+ self.logger.info(f"Step 3 completed: Query type={query_type.value}, "
120
+ f"complexity={complexity:.2f}, entities={len(entities)}, scope={scope}")
121
+
122
+ return analysis
123
+
124
+ except Exception as e:
125
+ self.logger.error(f"Step 3 Query Analysis failed: {e}")
126
+ raise
127
+
128
+ def _extract_entities(self, query: str) -> List[str]:
129
+ """Extract named entities from query text."""
130
+ entities = set()
131
+
132
+ for pattern in self.entity_patterns:
133
+ matches = re.findall(pattern, query, re.IGNORECASE)
134
+ entities.update(matches)
135
+
136
+ # Filter entities
137
+ filtered_entities = [
138
+ entity.strip() for entity in entities
139
+ if len(entity.strip()) > 2 and entity.lower() not in
140
+ {'the', 'and', 'are', 'is', 'was', 'were', 'this', 'that', 'what', 'who', 'how'}
141
+ ]
142
+
143
+ return list(set(filtered_entities))
144
+
145
+ def _extract_key_concepts(self, query: str) -> List[str]:
146
+ """Extract key conceptual terms from query."""
147
+ # Extract concepts
148
+ concepts = []
149
+
150
+ # Find domain terms
151
+ domain_terms = [
152
+ 'revenue', 'profit', 'growth', 'market', 'strategy', 'technology',
153
+ 'product', 'service', 'customer', 'partnership', 'acquisition',
154
+ 'investment', 'research', 'development', 'innovation', 'competition'
155
+ ]
156
+
157
+ query_lower = query.lower()
158
+ for term in domain_terms:
159
+ if term in query_lower:
160
+ concepts.append(term)
161
+
162
+ return concepts
163
+
164
+ def _classify_query_type(self, query: str, entities: List[str], concepts: List[str]) -> QueryType:
165
+ """Classify the type of query for routing decisions."""
166
+ query_lower = query.lower()
167
+
168
+ # Check patterns
169
+ if any(word in query_lower for word in ['compare', 'versus', 'vs', 'difference']):
170
+ return QueryType.COMPARATIVE
171
+
172
+ if any(word in query_lower for word in ['relationship', 'connect', 'related', 'between']):
173
+ return QueryType.RELATIONSHIP_QUERY
174
+
175
+ if len(entities) > 0 and any(word in query_lower for word in ['who is', 'what is', 'about']):
176
+ return QueryType.SPECIFIC_ENTITY
177
+
178
+ if any(word in query_lower for word in ['analyze', 'evaluate', 'why', 'how', 'impact']):
179
+ return QueryType.COMPLEX_REASONING
180
+
181
+ if len(concepts) > 2 or any(word in query_lower for word in ['overall', 'general', 'trend']):
182
+ return QueryType.BROAD_THEMATIC
183
+
184
+ return QueryType.FACTUAL_LOOKUP
185
+
186
+ def _calculate_complexity(self, query: str, query_type: QueryType) -> float:
187
+ """Calculate query complexity score (0.0 to 1.0)."""
188
+ base_score = 0.3
189
+ query_lower = query.lower()
190
+
191
+ # Base complexity
192
+ type_scores = {
193
+ QueryType.FACTUAL_LOOKUP: 0.2,
194
+ QueryType.SPECIFIC_ENTITY: 0.3,
195
+ QueryType.RELATIONSHIP_QUERY: 0.6,
196
+ QueryType.BROAD_THEMATIC: 0.7,
197
+ QueryType.COMPARATIVE: 0.8,
198
+ QueryType.COMPLEX_REASONING: 0.9
199
+ }
200
+
201
+ base_score = type_scores.get(query_type, 0.5)
202
+
203
+ # Adjust complexity
204
+ for level, indicators in self.complexity_indicators.items():
205
+ count = sum(1 for indicator in indicators if indicator in query_lower)
206
+ if level == 'high':
207
+ base_score += count * 0.2
208
+ elif level == 'medium':
209
+ base_score += count * 0.1
210
+ else:
211
+ base_score -= count * 0.05
212
+
213
+ # Query length and structure
214
+ if len(query.split()) > 15:
215
+ base_score += 0.1
216
+ if '?' in query and len(query.split('?')) > 2:
217
+ base_score += 0.15
218
+
219
+ return min(1.0, max(0.0, base_score))
220
+
221
+ def _determine_intent(self, query: str, query_type: QueryType) -> str:
222
+ """Determine the user's intent based on query analysis."""
223
+ intent_map = {
224
+ QueryType.FACTUAL_LOOKUP: "Seeking specific factual information",
225
+ QueryType.SPECIFIC_ENTITY: "Requesting details about a particular entity",
226
+ QueryType.RELATIONSHIP_QUERY: "Exploring connections and relationships",
227
+ QueryType.BROAD_THEMATIC: "Understanding broad themes or patterns",
228
+ QueryType.COMPARATIVE: "Comparing entities or concepts",
229
+ QueryType.COMPLEX_REASONING: "Requiring analytical reasoning and insights"
230
+ }
231
+
232
+ return intent_map.get(query_type, "General information seeking")
233
+
234
+ def _estimate_scope(self, query: str, entities: List[str], concepts: List[str], complexity: float) -> str:
235
+ """Estimate the scope of information needed."""
236
+ if len(entities) == 1 and complexity < 0.4:
237
+ return "narrow"
238
+ elif len(entities) > 3 or len(concepts) > 3 or complexity > 0.7:
239
+ return "broad"
240
+ else:
241
+ return "moderate"
242
+
243
+ def _analyze_context_requirements(self, query: str, query_type: QueryType, entities: List[str]) -> Dict[str, Any]:
244
+ """Analyze what context information is needed."""
245
+ return {
246
+ "requires_entity_details": len(entities) > 0,
247
+ "requires_relationships": query_type in [QueryType.RELATIONSHIP_QUERY, QueryType.COMPARATIVE],
248
+ "requires_historical_context": any(word in query.lower() for word in ['history', 'past', 'previous', 'before']),
249
+ "requires_quantitative_data": any(word in query.lower() for word in ['number', 'amount', 'count', 'revenue', 'profit']),
250
+ "primary_entities": entities[:3] # Focus on top 3 entities
251
+ }
252
+
253
+
254
+ class DriftRouter:
255
+ """Handles Step 4: DRIFT Routing for optimal search strategy selection."""
256
+
257
+ def __init__(self, config: Any, graph_stats: Dict[str, Any]):
258
+ self.config = config
259
+ self.graph_stats = graph_stats
260
+ self.logger = logging.getLogger('graphrag_query')
261
+
262
+ # Routing thresholds
263
+ self.local_search_threshold = 0.4
264
+ self.global_search_threshold = 0.7
265
+ self.entity_count_threshold = 10 # Based on graph size
266
+
267
+ self.logger.info("DriftRouter initialized for Step 4 processing")
268
+
269
+ async def determine_search_strategy(self, query_analysis: QueryAnalysis, original_query: str) -> DriftRoutingResult:
270
+ """
271
+ Determine optimal search strategy using DRIFT methodology (Step 4).
272
+
273
+ Args:
274
+ query_analysis: Results from Step 3 query analysis
275
+ original_query: The original user query
276
+
277
+ Returns:
278
+ DriftRoutingResult with search strategy and parameters
279
+ """
280
+ self.logger.info(f"Starting Step 4: DRIFT Routing for {query_analysis.query_type.value}")
281
+
282
+ try:
283
+ # Apply routing logic
284
+ strategy, reasoning, confidence, params = self._apply_drift_logic(query_analysis)
285
+
286
+ # Fallback strategy
287
+ fallback = self._determine_fallback_strategy(strategy)
288
+
289
+ result = DriftRoutingResult(
290
+ search_strategy=strategy,
291
+ reasoning=reasoning,
292
+ confidence=confidence,
293
+ parameters=params,
294
+ original_query=original_query,
295
+ fallback_strategy=fallback
296
+ )
297
+
298
+ self.logger.info(f"Step 4 completed: Strategy={strategy.value}, "
299
+ f"confidence={confidence:.2f}, reasoning={reasoning[:50]}...")
300
+
301
+ return result
302
+
303
+ except Exception as e:
304
+ self.logger.error(f"Step 4 DRIFT Routing failed: {e}")
305
+ raise
306
+
307
+ def _apply_drift_logic(self, analysis: QueryAnalysis) -> Tuple[SearchStrategy, str, float, Dict[str, Any]]:
308
+ """Apply DRIFT (Distributed Retrieval and Information Filtering Technique) logic."""
309
+
310
+ # Decision factors
311
+ complexity = analysis.complexity_score
312
+ entity_count = len(analysis.entities_mentioned)
313
+ scope = analysis.estimated_scope
314
+ query_type = analysis.query_type
315
+
316
+ # Local search conditions
317
+ if (query_type == QueryType.SPECIFIC_ENTITY and
318
+ entity_count <= 2 and
319
+ complexity < self.local_search_threshold):
320
+
321
+ return (
322
+ SearchStrategy.LOCAL_SEARCH,
323
+ f"Specific entity query with low complexity ({complexity:.2f})",
324
+ 0.9,
325
+ {
326
+ "max_depth": 2,
327
+ "entity_focus": analysis.entities_mentioned,
328
+ "include_neighbors": True,
329
+ "max_results": 20
330
+ }
331
+ )
332
+
333
+ # Global search conditions
334
+ if (complexity > self.global_search_threshold or
335
+ scope == "broad" or
336
+ query_type in [QueryType.BROAD_THEMATIC, QueryType.COMPLEX_REASONING]):
337
+
338
+ return (
339
+ SearchStrategy.GLOBAL_SEARCH,
340
+ f"High complexity ({complexity:.2f}) or broad scope requiring global context",
341
+ 0.85,
342
+ {
343
+ "community_level": "high",
344
+ "max_communities": 10,
345
+ "include_summary": True,
346
+ "max_results": 50
347
+ }
348
+ )
349
+
350
+ # Hybrid search for intermediate cases
351
+ if (query_type == QueryType.RELATIONSHIP_QUERY or
352
+ query_type == QueryType.COMPARATIVE or
353
+ entity_count > 2):
354
+
355
+ return (
356
+ SearchStrategy.HYBRID_SEARCH,
357
+ f"Relationship/comparative query or multiple entities ({entity_count})",
358
+ 0.75,
359
+ {
360
+ "local_depth": 2,
361
+ "global_communities": 5,
362
+ "balance_weight": 0.6, # Favor local over global
363
+ "max_results": 35
364
+ }
365
+ )
366
+
367
+ # Default to local search with moderate confidence
368
+ return (
369
+ SearchStrategy.LOCAL_SEARCH,
370
+ "Default local search for moderate complexity query",
371
+ 0.6,
372
+ {
373
+ "max_depth": 3,
374
+ "entity_focus": analysis.entities_mentioned,
375
+ "include_neighbors": True,
376
+ "max_results": 25
377
+ }
378
+ )
379
+
380
+ def _determine_fallback_strategy(self, primary_strategy: SearchStrategy) -> Optional[SearchStrategy]:
381
+ """Determine fallback strategy if primary fails."""
382
+ fallback_map = {
383
+ SearchStrategy.LOCAL_SEARCH: SearchStrategy.GLOBAL_SEARCH,
384
+ SearchStrategy.GLOBAL_SEARCH: SearchStrategy.LOCAL_SEARCH,
385
+ SearchStrategy.HYBRID_SEARCH: SearchStrategy.LOCAL_SEARCH
386
+ }
387
+
388
+ return fallback_map.get(primary_strategy)
389
+
390
+
391
+ class QueryVectorizer:
392
+ """Handles Step 5: Query Vectorization for semantic similarity matching."""
393
+
394
+ def __init__(self, config: Any):
395
+ self.config = config
396
+ self.logger = logging.getLogger('graphrag_query')
397
+
398
+ # Initialize embedding model using same pattern as other files
399
+ self.embedding_model = HuggingFaceEmbedding(
400
+ model_name=MY_CONFIG.EMBEDDING_MODEL
401
+ )
402
+
403
+ self.model_name = MY_CONFIG.EMBEDDING_MODEL
404
+ self.embedding_dimension = MY_CONFIG.EMBEDDING_LENGTH
405
+
406
+ self.logger.info(f"QueryVectorizer initialized with {self.model_name}")
407
+
408
+ async def vectorize_query(self, query: str, query_analysis: QueryAnalysis) -> VectorizedQuery:
409
+ """
410
+ Generate query embeddings for similarity matching (Step 5).
411
+
412
+ Args:
413
+ query: Original query text
414
+ query_analysis: Results from Step 3
415
+
416
+ Returns:
417
+ VectorizedQuery with embeddings and metadata
418
+ """
419
+ self.logger.info(f"Starting Step 5: Query Vectorization for: {query[:100]}...")
420
+
421
+ try:
422
+ # Normalize query
423
+ normalized_query = self._normalize_query(query, query_analysis)
424
+
425
+ # Generate embedding
426
+ embedding = await self._generate_embedding(normalized_query)
427
+
428
+ # Extract keywords
429
+ semantic_keywords = self._extract_semantic_keywords(query, query_analysis)
430
+
431
+ # Set similarity threshold
432
+ similarity_threshold = self._calculate_similarity_threshold(query_analysis)
433
+
434
+ result = VectorizedQuery(
435
+ embedding=embedding,
436
+ embedding_model=self.model_name,
437
+ normalized_query=normalized_query,
438
+ semantic_keywords=semantic_keywords,
439
+ similarity_threshold=similarity_threshold
440
+ )
441
+
442
+ self.logger.info(f"Step 5 completed: Embedding dimension={len(embedding)}, "
443
+ f"threshold={similarity_threshold:.3f}, keywords={len(semantic_keywords)}")
444
+
445
+ return result
446
+
447
+ except Exception as e:
448
+ self.logger.error(f"Step 5 Query Vectorization failed: {e}")
449
+ raise
450
+
451
+ def _normalize_query(self, query: str, analysis: QueryAnalysis) -> str:
452
+ """Normalize query text for better embedding quality."""
453
+ # Start with original query
454
+ normalized = query.strip()
455
+
456
+ # Add important entities and concepts for context
457
+ if analysis.entities_mentioned:
458
+ entity_context = " ".join(analysis.entities_mentioned[:3])
459
+ normalized = f"{normalized} [Entities: {entity_context}]"
460
+
461
+ if analysis.key_concepts:
462
+ concept_context = " ".join(analysis.key_concepts[:3])
463
+ normalized = f"{normalized} [Concepts: {concept_context}]"
464
+
465
+ return normalized
466
+
467
+ async def _generate_embedding(self, text: str) -> List[float]:
468
+ """Generate embedding for text using configured model."""
469
+ try:
470
+ embedding = await self.embedding_model.aget_text_embedding(text)
471
+ return embedding
472
+ except Exception as e:
473
+ self.logger.error(f"Embedding generation failed: {e}")
474
+ # Fallback to synchronous call if async fails
475
+ return self.embedding_model.get_text_embedding(text)
476
+
477
+ def _extract_semantic_keywords(self, query: str, analysis: QueryAnalysis) -> List[str]:
478
+ """Extract semantic keywords for additional matching."""
479
+ keywords = set()
480
+
481
+ # Add entities and concepts
482
+ keywords.update(analysis.entities_mentioned)
483
+ keywords.update(analysis.key_concepts)
484
+
485
+ # Add query-specific terms based on type
486
+ if analysis.query_type == QueryType.RELATIONSHIP_QUERY:
487
+ keywords.update(['relationship', 'connection', 'related', 'linked'])
488
+ elif analysis.query_type == QueryType.COMPARATIVE:
489
+ keywords.update(['comparison', 'versus', 'difference', 'similar'])
490
+ elif analysis.query_type == QueryType.BROAD_THEMATIC:
491
+ keywords.update(['theme', 'pattern', 'trend', 'overview'])
492
+
493
+ # Filter and return as list
494
+ return [kw for kw in keywords if len(kw) > 2]
495
+
496
+ def _calculate_similarity_threshold(self, analysis: QueryAnalysis) -> float:
497
+ """Calculate appropriate similarity threshold based on query characteristics."""
498
+ base_threshold = 0.7
499
+
500
+ # Adjust based on query complexity
501
+ if analysis.complexity_score > 0.7:
502
+ base_threshold -= 0.1 # Lower threshold for complex queries
503
+ elif analysis.complexity_score < 0.3:
504
+ base_threshold += 0.1 # Higher threshold for simple queries
505
+
506
+ # Adjust based on scope
507
+ if analysis.estimated_scope == "narrow":
508
+ base_threshold += 0.05
509
+ elif analysis.estimated_scope == "broad":
510
+ base_threshold -= 0.05
511
+
512
+ # Ensure reasonable bounds
513
+ return max(0.5, min(0.9, base_threshold))
514
+
515
+
516
+ class QueryPreprocessor:
517
+ """Main class coordinating all query preprocessing steps (Steps 3-5)."""
518
+
519
+ def __init__(self, config: Any, graph_stats: Dict[str, Any]):
520
+ self.config = config
521
+ self.graph_stats = graph_stats
522
+ self.logger = logging.getLogger('graphrag_query')
523
+
524
+ # Initialize component processors
525
+ self.analyzer = QueryAnalyzer(config)
526
+ self.router = DriftRouter(config, graph_stats)
527
+ self.vectorizer = QueryVectorizer(config)
528
+
529
+ self.logger.info("QueryPreprocessor initialized for Steps 3-5")
530
+
531
+ async def preprocess_query(self, query: str) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]:
532
+ """
533
+ Execute complete query preprocessing pipeline (Steps 3-5).
534
+
535
+ Args:
536
+ query: User's natural language query
537
+
538
+ Returns:
539
+ Tuple of (analysis, routing, vectorization) results
540
+ """
541
+ self.logger.info(f"Starting Phase B: Query Preprocessing Pipeline for: {query[:100]}...")
542
+
543
+ try:
544
+ # Query analysis
545
+ analysis = await self.analyzer.analyze_query(query)
546
+
547
+ # Query routing
548
+ routing = await self.router.determine_search_strategy(analysis, query)
549
+
550
+ # Query vectorization
551
+ vectorization = await self.vectorizer.vectorize_query(query, analysis)
552
+
553
+ self.logger.info(f"Phase B completed successfully: "
554
+ f"Type={analysis.query_type.value}, "
555
+ f"Strategy={routing.search_strategy.value}, "
556
+ f"Embedding_dim={len(vectorization.embedding)}")
557
+
558
+ return analysis, routing, vectorization
559
+
560
+ except Exception as e:
561
+ self.logger.error(f"Query preprocessing pipeline failed: {e}")
562
+ raise
563
+
564
+
565
+ # Exports
566
+ async def create_query_preprocessor(config: Any, graph_stats: Dict[str, Any]) -> QueryPreprocessor:
567
+ """Create and initialize QueryPreprocessor."""
568
+ return QueryPreprocessor(config, graph_stats)
569
+
570
+
571
+ async def preprocess_query_pipeline(query: str, config: Any, graph_stats: Dict[str, Any]) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]:
572
+ """
573
+ Convenience function for complete query preprocessing.
574
+
575
+ Args:
576
+ query: User's natural language query
577
+ config: Application configuration
578
+ graph_stats: Graph database statistics
579
+
580
+ Returns:
581
+ Complete preprocessing results
582
+ """
583
+ preprocessor = await create_query_preprocessor(config, graph_stats)
584
+ return await preprocessor.preprocess_query(query)
585
+
586
+
587
+ __all__ = [
588
+ 'QueryAnalyzer', 'DriftRouter', 'QueryVectorizer', 'QueryPreprocessor',
589
+ 'create_query_preprocessor', 'preprocess_query_pipeline',
590
+ 'QueryAnalysis', 'DriftRoutingResult', 'VectorizedQuery',
591
+ 'QueryType', 'SearchStrategy'
592
+ ]
query_graph_functions/response_management.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """Response management module for metadata generation and file I/O operations - Phase G (Steps 17-20)."""
3
+
4
+ import time
5
+ import json
6
+ import logging
7
+ from typing import Dict, List, Any
8
+ from dataclasses import dataclass
9
+ from datetime import datetime
10
+
11
+ from .setup import GraphRAGSetup
12
+ from .query_preprocessing import QueryAnalysis, DriftRoutingResult, VectorizedQuery
13
+ from .answer_synthesis import SynthesisResult
14
+
15
+
16
+ @dataclass
17
+ class ResponseMetadata:
18
+ """Complete response metadata structure."""
19
+ query_type: str
20
+ search_strategy: str
21
+ complexity_score: float
22
+ total_time_seconds: float
23
+ phases_completed: List[str]
24
+ status: str
25
+ phase_details: Dict[str, Any]
26
+ database_stats: Dict[str, Any]
27
+
28
+
29
+ class ResponseManager:
30
+ def __init__(self, setup: GraphRAGSetup):
31
+ self.setup = setup
32
+ self.config = setup.config
33
+ self.logger = logging.getLogger(self.__class__.__name__)
34
+
35
+ def generate_comprehensive_metadata(self,
36
+ analysis: QueryAnalysis,
37
+ routing: DriftRoutingResult,
38
+ vectorization: VectorizedQuery,
39
+ community_results: Dict[str, Any],
40
+ follow_up_results: Dict[str, Any],
41
+ augmentation_results: Any,
42
+ synthesis_results: SynthesisResult,
43
+ total_time: float) -> Dict[str, Any]:
44
+ """
45
+ Generate comprehensive metadata for query response.
46
+
47
+ Consolidates all phase results into structured metadata format.
48
+ """
49
+ try:
50
+ communities = community_results.get('communities', [])
51
+
52
+ metadata = {
53
+ # Execution Summary
54
+ "query_type": analysis.query_type.value,
55
+ "search_strategy": routing.search_strategy.value,
56
+ "complexity_score": analysis.complexity_score,
57
+ "total_time_seconds": round(total_time, 2),
58
+ "phases_completed": ["A-Init", "B-Preprocess", "C-Communities", "D-Followup", "E-Vector", "F-Synthesis"],
59
+ "status": "success",
60
+
61
+ # Phase A: Initialization
62
+ "phase_a": self._generate_phase_a_metadata(),
63
+
64
+ # Phase B: Query Preprocessing
65
+ "phase_b": self._generate_phase_b_metadata(analysis, vectorization, routing),
66
+
67
+ # Phase C: Community Search
68
+ "phase_c": self._generate_phase_c_metadata(communities, community_results),
69
+
70
+ # Phase D: Follow-up Search
71
+ "phase_d": self._generate_phase_d_metadata(follow_up_results),
72
+
73
+ # Phase E: Vector Augmentation
74
+ "phase_e": self._generate_phase_e_metadata(augmentation_results),
75
+
76
+ # Phase F: Answer Synthesis
77
+ "phase_f": self._generate_phase_f_metadata(synthesis_results),
78
+
79
+ # Database Statistics
80
+ "database_stats": self._generate_database_stats(follow_up_results, communities, augmentation_results)
81
+ }
82
+
83
+ self.logger.info("Generated comprehensive metadata with all phase details")
84
+ return metadata
85
+
86
+ except Exception as e:
87
+ self.logger.error(f"Failed to generate metadata: {e}")
88
+ return self._generate_fallback_metadata(str(e))
89
+
90
+ def _generate_phase_a_metadata(self) -> Dict[str, Any]:
91
+ """Generate Phase A initialization metadata."""
92
+ from my_config import MY_CONFIG
93
+
94
+ return {
95
+ "neo4j_connected": bool(self.setup.neo4j_conn),
96
+ "vector_db_ready": bool(self.setup.query_engine),
97
+ "llm_model": getattr(MY_CONFIG, 'LLM_MODEL', 'unknown'),
98
+ "embedding_model": getattr(MY_CONFIG, 'EMBEDDING_MODEL', 'unknown'),
99
+ "drift_config_loaded": bool(self.setup.drift_config)
100
+ }
101
+
102
+ def _generate_phase_b_metadata(self, analysis: QueryAnalysis, vectorization: VectorizedQuery, routing: DriftRoutingResult) -> Dict[str, Any]:
103
+ """Generate Phase B query preprocessing metadata."""
104
+ return {
105
+ "entities_extracted": len(analysis.entities_mentioned),
106
+ "semantic_keywords": len(vectorization.semantic_keywords),
107
+ "embedding_dimensions": len(vectorization.embedding),
108
+ "similarity_threshold": vectorization.similarity_threshold,
109
+ "routing_confidence": round(routing.confidence, 3)
110
+ }
111
+
112
+ def _generate_phase_c_metadata(self, communities: List[Any], community_results: Dict[str, Any]) -> Dict[str, Any]:
113
+ """Generate Phase C community search metadata."""
114
+ return {
115
+ "communities_found": len(communities),
116
+ "community_ids": [c.community_id for c in communities[:5]],
117
+ "similarities": [round(c.similarity_score, 3) for c in communities[:5]],
118
+ "entities_extracted": len(community_results.get('extracted_data', {}).get('entities', [])),
119
+ "relationships_extracted": len(community_results.get('extracted_data', {}).get('relationships', []))
120
+ }
121
+
122
+ def _generate_phase_d_metadata(self, follow_up_results: Dict[str, Any]) -> Dict[str, Any]:
123
+ """Generate Phase D follow-up search metadata."""
124
+ intermediate_answers = follow_up_results.get('intermediate_answers', [])
125
+ avg_confidence = 0.0
126
+ if intermediate_answers:
127
+ avg_confidence = sum(a.confidence for a in intermediate_answers) / len(intermediate_answers)
128
+
129
+ return {
130
+ "questions_generated": len(follow_up_results.get('follow_up_questions', [])),
131
+ "graph_traversals": len(follow_up_results.get('local_search_results', [])),
132
+ "entities_found": len(follow_up_results.get('detailed_entities', [])),
133
+ "intermediate_answers": len(intermediate_answers),
134
+ "avg_confidence": round(avg_confidence, 3)
135
+ }
136
+
137
+ def _generate_phase_e_metadata(self, augmentation_results: Any) -> Dict[str, Any]:
138
+ """Generate Phase E vector augmentation metadata."""
139
+ if not augmentation_results:
140
+ return {"vector_results_count": 0, "augmentation_confidence": 0.0}
141
+
142
+ vector_files = []
143
+ if hasattr(augmentation_results, 'vector_results'):
144
+ for i, result in enumerate(augmentation_results.vector_results):
145
+ file_info = {
146
+ "file_id": i + 1,
147
+ "file_path": getattr(result, 'file_path', 'unknown'),
148
+ "similarity": round(result.similarity_score, 3),
149
+ "content_length": len(result.content),
150
+ "relevance": round(getattr(result, 'relevance_score', 0.0), 3)
151
+ }
152
+ vector_files.append(file_info)
153
+
154
+ return {
155
+ "vector_results_count": len(augmentation_results.vector_results) if hasattr(augmentation_results, 'vector_results') else 0,
156
+ "augmentation_confidence": round(augmentation_results.augmentation_confidence, 3) if hasattr(augmentation_results, 'augmentation_confidence') else 0.0,
157
+ "execution_time": round(augmentation_results.execution_time, 2) if hasattr(augmentation_results, 'execution_time') else 0.0,
158
+ "similarity_threshold": 0.75,
159
+ "vector_files": vector_files
160
+ }
161
+
162
+ def _generate_phase_f_metadata(self, synthesis_results: SynthesisResult) -> Dict[str, Any]:
163
+ """Generate Phase F answer synthesis metadata."""
164
+ return {
165
+ "synthesis_confidence": round(synthesis_results.confidence_score, 3),
166
+ "sources_integrated": len(synthesis_results.source_evidence),
167
+ "final_answer_length": len(synthesis_results.final_answer),
168
+ "synthesis_method": getattr(synthesis_results, 'synthesis_method', 'comprehensive_fusion')
169
+ }
170
+
171
+ def _generate_database_stats(self, follow_up_results: Dict[str, Any], communities: List[Any], augmentation_results: Any) -> Dict[str, Any]:
172
+ """Generate database statistics metadata."""
173
+ vector_docs_used = 0
174
+ if augmentation_results and hasattr(augmentation_results, 'vector_results'):
175
+ vector_docs_used = len(augmentation_results.vector_results)
176
+
177
+ return {
178
+ "total_nodes": self.setup.graph_stats.get('node_count', 0),
179
+ "total_relationships": self.setup.graph_stats.get('relationship_count', 0),
180
+ "total_communities": self.setup.graph_stats.get('community_count', 0),
181
+ "nodes_accessed": len(follow_up_results.get('detailed_entities', [])),
182
+ "communities_searched": len(communities),
183
+ "vector_docs_used": vector_docs_used
184
+ }
185
+
186
+ def _generate_fallback_metadata(self, error: str) -> Dict[str, Any]:
187
+ """Generate minimal metadata when full generation fails."""
188
+ return {
189
+ "status": "metadata_generation_error",
190
+ "error": error,
191
+ "phases_completed": "incomplete",
192
+ "total_time_seconds": 0.0
193
+ }
194
+
195
+ def save_response_to_files(self, user_query: str, result: Dict[str, Any]) -> None:
196
+ """
197
+ Save query response and metadata to separate files.
198
+
199
+ Handles file I/O operations for response persistence.
200
+ """
201
+ try:
202
+ timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
203
+
204
+ # Save response to response file
205
+ self._save_response_file(user_query, result, timestamp)
206
+
207
+ # Save metadata to metadata file
208
+ self._save_metadata_file(user_query, result, timestamp)
209
+
210
+ self.logger.info(f"Saved response and metadata for query: {user_query[:50]}...")
211
+
212
+ except Exception as e:
213
+ self.logger.error(f"Failed to save response files: {e}")
214
+
215
+ def _save_response_file(self, user_query: str, result: Dict[str, Any], timestamp: str) -> None:
216
+ """Save response content to response file."""
217
+ try:
218
+ with open('logs/graphrag_query/graphrag_responses.txt', 'a', encoding='utf-8') as f:
219
+ f.write(f"\n{'='*80}\n")
220
+ f.write(f"QUERY [{timestamp}]: {user_query}\n")
221
+ f.write(f"{'='*80}\n")
222
+ f.write(f"RESPONSE: {result['answer']}\n")
223
+ f.write(f"{'='*80}\n\n")
224
+ except Exception as e:
225
+ self.logger.error(f"Failed to save response file: {e}")
226
+
227
+ def _save_metadata_file(self, user_query: str, result: Dict[str, Any], timestamp: str) -> None:
228
+ """Save metadata to metadata file."""
229
+ try:
230
+ with open('logs/graphrag_query/graphrag_metadata.txt', 'a', encoding='utf-8') as f:
231
+ f.write(f"\n{'='*80}\n")
232
+ f.write(f"METADATA [{timestamp}]: {user_query}\n")
233
+ f.write(f"{'='*80}\n")
234
+ f.write(json.dumps(result['metadata'], indent=2, default=str))
235
+ f.write(f"\n{'='*80}\n\n")
236
+ except Exception as e:
237
+ self.logger.error(f"Failed to save metadata file: {e}")
238
+
239
+ def format_error_response(self, error_message: str) -> Dict[str, Any]:
240
+ """
241
+ Generate standardized error response with metadata.
242
+
243
+ Creates consistent error format for failed queries.
244
+ """
245
+ return {
246
+ "answer": f"Sorry, I encountered an error: {error_message}",
247
+ "metadata": {
248
+ "status": "error",
249
+ "error_message": error_message,
250
+ "phases_completed": "incomplete",
251
+ "neo4j_connected": bool(self.setup.neo4j_conn) if self.setup.neo4j_conn else False,
252
+ "vector_engine_ready": bool(self.setup.query_engine) if self.setup.query_engine else False,
253
+ "timestamp": datetime.now().isoformat()
254
+ }
255
+ }
256
+
257
+
258
+ # Exports
259
+ __all__ = ['ResponseManager', 'ResponseMetadata']
query_graph_functions/setup.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Graph setup module for database and model initialization. Phase A (Steps 1-2)"""
2
+
3
+ import os
4
+ import logging
5
+ from typing import Dict, Optional, Any
6
+ import sys
7
+ sys.path.append('..') # Add parent directory to path for imports
8
+
9
+ from my_config import MY_CONFIG
10
+ from neo4j import GraphDatabase
11
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
+ from llama_index.core import Settings, VectorStoreIndex, StorageContext
13
+ from llama_index.vector_stores.milvus import MilvusVectorStore
14
+ from llama_index.llms.litellm import LiteLLM
15
+
16
+ # Set up environment
17
+ os.environ['HF_ENDPOINT'] = MY_CONFIG.HF_ENDPOINT
18
+
19
+ # Configure logging
20
+ logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
21
+ logger = logging.getLogger(__name__)
22
+ logger.setLevel(logging.INFO)
23
+
24
+
25
+ class Neo4jConnection:
26
+ """
27
+ Neo4j database connection manager.
28
+ """
29
+
30
+ def __init__(self):
31
+ self.uri = MY_CONFIG.NEO4J_URI
32
+ self.username = MY_CONFIG.NEO4J_USER
33
+ self.password = MY_CONFIG.NEO4J_PASSWORD
34
+ self.database = getattr(MY_CONFIG, "NEO4J_DATABASE", None)
35
+
36
+ # Validate required configuration
37
+ if not self.uri:
38
+ raise ValueError("NEO4J_URI config is required")
39
+ if not self.username:
40
+ raise ValueError("NEO4J_USERNAME config is required")
41
+ if not self.password:
42
+ raise ValueError("NEO4J_PASSWORD config is required")
43
+ if not self.database:
44
+ raise ValueError("NEO4J_DATABASE config is required")
45
+
46
+ self.driver: Optional[GraphDatabase.driver] = None
47
+
48
+ def connect(self):
49
+ """STEP 1.2: Initialize Neo4j driver with verification"""
50
+ if self.driver is None:
51
+ try:
52
+ self.driver = GraphDatabase.driver(
53
+ self.uri,
54
+ auth=(self.username, self.password)
55
+ )
56
+ self.driver.verify_connectivity()
57
+ logger.info(f"Connected to Neo4j at {self.uri}")
58
+ except Exception as e:
59
+ logger.error(f"❌ STEP 1.2 FAILED: Neo4j connection error: {e}")
60
+ self.driver = None
61
+
62
+ def disconnect(self):
63
+ """Clean up Neo4j connection"""
64
+ if self.driver:
65
+ self.driver.close()
66
+ self.driver = None
67
+ logger.info("Neo4j connection closed")
68
+
69
+ def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None):
70
+ """Execute Cypher query with error handling"""
71
+ if not self.driver:
72
+ raise ConnectionError("Not connected to Neo4j database")
73
+
74
+ with self.driver.session(database=self.database) as session:
75
+ result = session.run(query, parameters or {})
76
+ records = [record.data() for record in result]
77
+ return records
78
+
79
+
80
+ class GraphRAGSetup:
81
+ """
82
+ Main setup class for graph-based retrieval system.
83
+
84
+ Handles core initialization and configuration:
85
+ - Database connections (Neo4j and vector database)
86
+ - Model initialization and configuration
87
+ - Graph statistics and validation
88
+ - Search configuration loading
89
+ """
90
+
91
+ def __init__(self):
92
+ logger.info("Starting graph system initialization")
93
+
94
+ # Initialize core components
95
+ self.config = MY_CONFIG # Add config attribute for GraphQueryEngine
96
+ self.neo4j_conn = None
97
+ self.query_engine = None
98
+ self.graph_stats = {}
99
+ self.drift_config = {}
100
+ self.llm = None
101
+ self.embedding_model = None
102
+
103
+ # Execute Step 1 initialization sequence
104
+ self._execute_step1_sequence()
105
+
106
+ logger.info("Graph system initialization complete")
107
+
108
+ def _execute_step1_sequence(self):
109
+ """Execute complete Step 1 initialization sequence"""
110
+ # STEP 1.1-1.6: Initialize all components
111
+ self._setup_neo4j() # STEP 1.2
112
+ self._setup_vector_search() # STEP 1.3-1.6
113
+ self._load_graph_statistics() # STEP 2.1-2.4
114
+ self._load_drift_configuration() # STEP 2.5
115
+
116
+ def _setup_neo4j(self):
117
+ """STEP 1.2: Initialize Neo4j driver with verification"""
118
+ try:
119
+ logger.info("Initializing Neo4j connection...")
120
+ self.neo4j_conn = Neo4jConnection()
121
+ self.neo4j_conn.connect()
122
+
123
+ # Verify connection with test query
124
+ if self.neo4j_conn.driver:
125
+ test_result = self.neo4j_conn.execute_query("MATCH (n) RETURN count(n) as total_nodes LIMIT 1")
126
+ node_count = test_result[0]['total_nodes'] if test_result else 0
127
+ logger.info(f"Neo4j connected - {node_count} nodes found")
128
+
129
+ except Exception as e:
130
+ logger.error(f"Neo4j connection error: {e}")
131
+ self.neo4j_conn = None
132
+
133
+ def _setup_vector_search(self):
134
+ """STEP 1.3-1.5: Initialize vector database and LLM components"""
135
+ try:
136
+ logger.info("Setting up vector search and LLM...")
137
+
138
+ # STEP 1.5: Load embedding model
139
+ self.embedding_model = HuggingFaceEmbedding(
140
+ model_name=MY_CONFIG.EMBEDDING_MODEL
141
+ )
142
+ Settings.embed_model = self.embedding_model
143
+ logger.info(f"Embedding model loaded: {MY_CONFIG.EMBEDDING_MODEL}")
144
+
145
+ # STEP 1.6: Connect to vector database based on configuration
146
+ if MY_CONFIG.VECTOR_DB_TYPE == "cloud_zilliz":
147
+ if not MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT or not MY_CONFIG.ZILLIZ_TOKEN:
148
+ raise ValueError("Cloud database configuration missing. Set ZILLIZ_CLUSTER_ENDPOINT and ZILLIZ_TOKEN in .env")
149
+
150
+ vector_store = MilvusVectorStore(
151
+ uri=MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT,
152
+ token=MY_CONFIG.ZILLIZ_TOKEN,
153
+ dim=MY_CONFIG.EMBEDDING_LENGTH,
154
+ collection_name=MY_CONFIG.COLLECTION_NAME,
155
+ overwrite=False
156
+ )
157
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
158
+ logger.info("Connected to cloud vector database")
159
+ else:
160
+ vector_store = MilvusVectorStore(
161
+ uri=MY_CONFIG.MILVUS_URI_HYBRID_GRAPH,
162
+ dim=MY_CONFIG.EMBEDDING_LENGTH,
163
+ collection_name=MY_CONFIG.COLLECTION_NAME,
164
+ overwrite=False
165
+ )
166
+ storage_context = StorageContext.from_defaults(vector_store=vector_store)
167
+ logger.info("Connected to local vector database")
168
+
169
+ index = VectorStoreIndex.from_vector_store(
170
+ vector_store=vector_store, storage_context=storage_context)
171
+ logger.info("Vector index loaded successfully")
172
+
173
+ # STEP 1.4: Initialize LLM provider
174
+ llm_model = MY_CONFIG.LLM_MODEL
175
+ self.llm = LiteLLM(model=llm_model)
176
+ Settings.llm = self.llm
177
+ logger.info(f"LLM initialized: {llm_model}")
178
+
179
+ self.query_engine = index.as_query_engine()
180
+
181
+ except Exception as e:
182
+ logger.error(f"Vector setup error: {e}")
183
+ self.query_engine = None
184
+
185
+ def _load_graph_statistics(self):
186
+ """STEP 2.1-2.4: Load and validate graph data structure"""
187
+ try:
188
+ logger.info("Loading graph statistics and validation...")
189
+
190
+ if not self.neo4j_conn or not self.neo4j_conn.driver:
191
+ logger.warning("No Neo4j connection for statistics")
192
+ return
193
+
194
+ # STEP 2.1: Get node and relationship counts
195
+ stats_query = """
196
+ MATCH (n)
197
+ OPTIONAL MATCH ()-[r]-()
198
+ RETURN count(DISTINCT n) as node_count,
199
+ count(DISTINCT r) as relationship_count,
200
+ count(DISTINCT n.community_id) as community_count
201
+ """
202
+
203
+ result = self.neo4j_conn.execute_query(stats_query)
204
+ if result:
205
+ stats = result[0]
206
+ self.graph_stats = {
207
+ 'node_count': stats.get('node_count', 0),
208
+ 'relationship_count': stats.get('relationship_count', 0),
209
+ 'community_count': stats.get('community_count', 0)
210
+ }
211
+
212
+ logger.info(f"Graph validated - {self.graph_stats['node_count']} nodes, "
213
+ f"{self.graph_stats['relationship_count']} relationships, "
214
+ f"{self.graph_stats['community_count']} communities")
215
+
216
+ except Exception as e:
217
+ logger.error(f"Graph statistics error: {e}")
218
+ self.graph_stats = {}
219
+
220
+ def _load_drift_configuration(self):
221
+ """STEP 2.5: Load DRIFT search metadata and configuration"""
222
+ logger.info("Loading search configuration...")
223
+
224
+ if not self.neo4j_conn or not self.neo4j_conn.driver:
225
+ logger.warning("No Neo4j connection for search configuration")
226
+ self.drift_config = {}
227
+ return
228
+
229
+ # Query for all DRIFT-related nodes
230
+ drift_metadata_query = """
231
+ OPTIONAL MATCH (dm:DriftMetadata)
232
+ OPTIONAL MATCH (dc:DriftConfiguration)
233
+ OPTIONAL MATCH (csi:CommunitySearchIndex)
234
+ OPTIONAL MATCH (gm:GraphMetadata)
235
+ OPTIONAL MATCH (cm:CommunitiesMetadata)
236
+ RETURN dm, dc, csi, gm, cm
237
+ """
238
+
239
+ result = self.neo4j_conn.execute_query(drift_metadata_query)
240
+ if result and result[0]:
241
+ record = result[0]
242
+ drift_config = {}
243
+
244
+ # Extract DriftMetadata properties
245
+ if record.get('dm'):
246
+ dm_props = dict(record['dm'])
247
+ drift_config.update(dm_props)
248
+ logger.info("DriftMetadata node found")
249
+
250
+ # Extract DriftConfiguration properties
251
+ if record.get('dc'):
252
+ dc_props = dict(record['dc'])
253
+ drift_config['configuration'] = dc_props
254
+ logger.info("DriftConfiguration node found")
255
+
256
+ # Extract CommunitySearchIndex properties
257
+ if record.get('csi'):
258
+ csi_props = dict(record['csi'])
259
+ drift_config['community_search_index'] = csi_props
260
+ logger.info("CommunitySearchIndex node found")
261
+
262
+ # Extract GraphMetadata properties
263
+ if record.get('gm'):
264
+ gm_props = dict(record['gm'])
265
+ drift_config['graph_metadata'] = gm_props
266
+ logger.info("GraphMetadata node found")
267
+
268
+ # Extract CommunitiesMetadata properties
269
+ if record.get('cm'):
270
+ cm_props = dict(record['cm'])
271
+ drift_config['communities_metadata'] = cm_props
272
+ logger.info("CommunitiesMetadata node found")
273
+
274
+ self.drift_config = drift_config
275
+ logger.info("Search configuration loaded from Neo4j nodes")
276
+
277
+ else:
278
+ logger.warning("No metadata nodes found in Neo4j")
279
+ self.drift_config = {}
280
+
281
+ def validate_system_readiness(self):
282
+ """Validate all required components are initialized"""
283
+ ready = True
284
+
285
+ if not self.neo4j_conn or not self.neo4j_conn.driver:
286
+ logger.error("Neo4j connection not available")
287
+ ready = False
288
+
289
+ if not self.query_engine:
290
+ logger.error("Vector query engine not available")
291
+ ready = False
292
+
293
+ if not self.graph_stats:
294
+ logger.warning("Graph statistics not loaded")
295
+
296
+ if ready:
297
+ logger.info("System readiness validated")
298
+
299
+ return ready
300
+
301
+ def get_system_status(self):
302
+ """Get detailed system status information"""
303
+ return {
304
+ "neo4j_connected": bool(self.neo4j_conn and self.neo4j_conn.driver),
305
+ "vector_engine_ready": bool(self.query_engine),
306
+ "graph_stats_loaded": bool(self.graph_stats),
307
+ "drift_config_loaded": bool(self.drift_config),
308
+ "llm_ready": bool(self.llm),
309
+ "graph_stats": self.graph_stats,
310
+ "drift_config": self.drift_config
311
+ }
312
+
313
+ async def cleanup_async_tasks(self, timeout: float = 2.0) -> None:
314
+ """
315
+ Clean up async tasks and pending operations.
316
+
317
+ Handles proper cleanup of LiteLLM and other async tasks to prevent
318
+ 'Task was destroyed but it is pending!' warnings.
319
+ """
320
+ try:
321
+ import asyncio
322
+
323
+ # Import cleanup function if available
324
+ try:
325
+ from litellm_patch import cleanup_all_async_tasks
326
+ await cleanup_all_async_tasks(timeout=timeout)
327
+ logger.info(f"Cleaned up async tasks with timeout {timeout}s")
328
+ except ImportError:
329
+ # Fallback: Cancel pending tasks manually
330
+ pending_tasks = [task for task in asyncio.all_tasks() if not task.done()]
331
+ if pending_tasks:
332
+ logger.info(f"Cancelling {len(pending_tasks)} pending tasks")
333
+ for task in pending_tasks:
334
+ task.cancel()
335
+
336
+ # Wait for cancellation with timeout
337
+ try:
338
+ await asyncio.wait_for(
339
+ asyncio.gather(*pending_tasks, return_exceptions=True),
340
+ timeout=timeout
341
+ )
342
+ except asyncio.TimeoutError:
343
+ logger.warning("Some tasks did not complete within timeout")
344
+
345
+ except Exception as e:
346
+ logger.error(f"Error during async cleanup: {e}")
347
+
348
+ def close(self):
349
+ """Clean up all connections"""
350
+ if self.neo4j_conn:
351
+ self.neo4j_conn.disconnect()
352
+ logger.info("Setup cleanup complete")
353
+
354
+
355
+ def create_graphrag_setup():
356
+ """Factory function to create GraphRAG setup instance"""
357
+ return GraphRAGSetup()
358
+
359
+
360
+ # Exports
361
+ __all__ = ['GraphRAGSetup', 'create_graphrag_setup']
query_graph_functions/vector_augmentation.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Vector augmentation engine implementing Phase E (Steps 13-14).
3
+
4
+ Handles vector search operations and result fusion:
5
+ - Vector similarity search for additional context (Step 13)
6
+ - Result fusion strategy for enhanced answers (Step 14)
7
+ """
8
+
9
+ import logging
10
+ import numpy as np
11
+ from typing import Dict, List, Any, Tuple
12
+ from dataclasses import dataclass
13
+ from datetime import datetime
14
+
15
+ from .setup import GraphRAGSetup
16
+ from .query_preprocessing import DriftRoutingResult
17
+
18
+
19
+ @dataclass
20
+ class VectorSearchResult:
21
+ """Vector search result with similarity and content."""
22
+ document_id: str
23
+ content: str
24
+ similarity_score: float
25
+ metadata: Dict[str, Any]
26
+ source_type: str # 'vector_db', 'semantic_search'
27
+ relevance_score: float
28
+
29
+
30
+ @dataclass
31
+ class AugmentationResult:
32
+ """Phase E augmentation result with enhanced context."""
33
+ vector_results: List[VectorSearchResult]
34
+ enhanced_context: str
35
+ fusion_strategy: str
36
+ augmentation_confidence: float
37
+ execution_time: float
38
+ metadata: Dict[str, Any]
39
+
40
+
41
+ class VectorAugmentationEngine:
42
+ def __init__(self, setup: GraphRAGSetup):
43
+ self.setup = setup
44
+ self.vector_engine = setup.query_engine # Milvus vector engine
45
+ self.embedding_model = setup.embedding_model
46
+ self.config = setup.config
47
+ self.logger = logging.getLogger(self.__class__.__name__)
48
+
49
+ # Vector search parameters
50
+ self.similarity_threshold = 0.75
51
+ self.max_vector_results = 10
52
+
53
+ async def execute_vector_augmentation_phase(self,
54
+ query_embedding: List[float],
55
+ graph_results: Dict[str, Any],
56
+ routing_result: DriftRoutingResult) -> AugmentationResult:
57
+ """
58
+ Execute vector augmentation phase with similarity search.
59
+
60
+ Args:
61
+ query_embedding: Query vector for similarity matching
62
+ graph_results: Results from graph-based search
63
+ routing_result: Routing decision parameters
64
+
65
+ Returns:
66
+ Augmentation results with vector context
67
+ """
68
+ start_time = datetime.now()
69
+
70
+ try:
71
+ # Step 13: Vector Similarity Search
72
+ self.logger.info("Starting Step 13: Vector Similarity Search")
73
+ vector_results = await self._perform_vector_search(
74
+ query_embedding, routing_result
75
+ )
76
+
77
+ # Step 14: Result Fusion and Enhancement
78
+ self.logger.info("Starting Step 14: Result Fusion and Enhancement")
79
+ enhanced_context = await self._fuse_results(
80
+ vector_results, graph_results, routing_result
81
+ )
82
+
83
+ execution_time = (datetime.now() - start_time).total_seconds()
84
+
85
+ augmentation_result = AugmentationResult(
86
+ vector_results=vector_results,
87
+ enhanced_context=enhanced_context,
88
+ fusion_strategy='graph_vector_hybrid',
89
+ augmentation_confidence=self._calculate_augmentation_confidence(vector_results),
90
+ execution_time=execution_time,
91
+ metadata={
92
+ 'vector_results_count': len(vector_results),
93
+ 'avg_similarity': np.mean([r.similarity_score for r in vector_results]) if vector_results else 0,
94
+ 'phase': 'vector_augmentation',
95
+ 'step_range': '13-14'
96
+ }
97
+ )
98
+
99
+ self.logger.info(f"Phase E completed: {len(vector_results)} vector results, augmentation confidence: {augmentation_result.augmentation_confidence:.3f}")
100
+ return augmentation_result
101
+
102
+ except Exception as e:
103
+ self.logger.error(f"Vector augmentation phase failed: {e}")
104
+ # Return empty augmentation on failure
105
+ return AugmentationResult(
106
+ vector_results=[],
107
+ enhanced_context="",
108
+ fusion_strategy='graph_only',
109
+ augmentation_confidence=0.0,
110
+ execution_time=(datetime.now() - start_time).total_seconds(),
111
+ metadata={'error': str(e), 'fallback': True}
112
+ )
113
+
114
+ async def _perform_vector_search(self,
115
+ query_embedding: List[float],
116
+ routing_result: DriftRoutingResult) -> List[VectorSearchResult]:
117
+ """
118
+ Step 13: Perform comprehensive vector similarity search.
119
+
120
+ Uses the Milvus vector database to find semantically similar content.
121
+ """
122
+ try:
123
+ vector_results = []
124
+
125
+ # Use the existing vector query engine for similarity search
126
+ if self.vector_engine:
127
+ # Query the vector database with the embedding
128
+ search_results = self.vector_engine.query(routing_result.original_query)
129
+
130
+ # Extract vector search results from the response
131
+ if hasattr(search_results, 'source_nodes') and search_results.source_nodes:
132
+ for i, node in enumerate(search_results.source_nodes[:self.max_vector_results]):
133
+ # Calculate similarity score (handle different node types)
134
+ similarity_score = 0.8 # Default similarity
135
+ if hasattr(node, 'score'):
136
+ similarity_score = node.score
137
+ elif hasattr(node, 'similarity'):
138
+ similarity_score = node.similarity
139
+ elif hasattr(node, 'metadata') and 'score' in node.metadata:
140
+ similarity_score = node.metadata['score']
141
+
142
+ # Extract content (handle different node types)
143
+ content = ""
144
+ if hasattr(node, 'text'):
145
+ content = node.text
146
+ elif hasattr(node, 'content'):
147
+ content = node.content
148
+ elif hasattr(node, 'get_content'):
149
+ content = node.get_content()
150
+ else:
151
+ content = str(node)
152
+
153
+ # Extract metadata safely
154
+ node_metadata = {}
155
+ if hasattr(node, 'metadata') and node.metadata:
156
+ node_metadata = node.metadata
157
+ elif hasattr(node, 'extra_info') and node.extra_info:
158
+ node_metadata = node.extra_info
159
+
160
+ vector_result = VectorSearchResult(
161
+ document_id=node_metadata.get('doc_id', f"doc_{i}"),
162
+ content=content,
163
+ similarity_score=similarity_score,
164
+ metadata=node_metadata,
165
+ source_type='vector_db',
166
+ relevance_score=similarity_score * 0.9 # Slightly weighted down
167
+ )
168
+
169
+ # Only include results above similarity threshold
170
+ if similarity_score >= self.similarity_threshold:
171
+ vector_results.append(vector_result)
172
+
173
+ self.logger.info(f"Vector search completed: {len(vector_results)} results above threshold {self.similarity_threshold}")
174
+ else:
175
+ self.logger.warning("Vector engine not available, skipping vector search")
176
+
177
+ return vector_results
178
+
179
+ except Exception as e:
180
+ self.logger.error(f"Vector search failed: {e}")
181
+ return []
182
+
183
+ async def _fuse_results(self,
184
+ vector_results: List[VectorSearchResult],
185
+ graph_results: Dict[str, Any],
186
+ routing_result: DriftRoutingResult) -> str:
187
+ """
188
+ Step 14: Fuse vector and graph results for enhanced context.
189
+
190
+ Combines graph-based entity relationships with vector similarity content.
191
+ """
192
+ try:
193
+ fusion_parts = []
194
+
195
+ # Start with graph-based context (Phase C & D results)
196
+ if 'initial_answer' in graph_results:
197
+ initial_answer = graph_results['initial_answer']
198
+ if isinstance(initial_answer, dict) and 'content' in initial_answer:
199
+ fusion_parts.extend([
200
+ "=== GRAPH-BASED KNOWLEDGE ===",
201
+ initial_answer['content'],
202
+ ""
203
+ ])
204
+
205
+ # Add vector-based augmentation
206
+ if vector_results:
207
+ fusion_parts.extend([
208
+ "=== SEMANTIC AUGMENTATION ===",
209
+ "Additional relevant information from vector similarity search:",
210
+ ""
211
+ ])
212
+
213
+ for i, result in enumerate(vector_results[:5], 1): # Top 5 vector results
214
+ fusion_parts.extend([
215
+ f"**{i}. Vector Result (Similarity: {result.similarity_score:.3f})**",
216
+ result.content, # Show full content without truncation
217
+ ""
218
+ ])
219
+
220
+ # Add fusion methodology explanation
221
+ fusion_parts.extend([
222
+ "=== FUSION METHODOLOGY ===",
223
+ "This enhanced answer combines graph-based entity relationships with vector semantic similarity search.",
224
+ "Graph results provide structured knowledge connections, while vector search adds contextual depth.",
225
+ ""
226
+ ])
227
+
228
+ enhanced_context = "\n".join(fusion_parts)
229
+
230
+ self.logger.info(f"Result fusion completed: {len(fusion_parts)} context sections")
231
+ return enhanced_context
232
+
233
+ except Exception as e:
234
+ self.logger.error(f"Result fusion failed: {e}")
235
+ return "Graph-based results only (vector fusion failed)"
236
+
237
+ def _calculate_augmentation_confidence(self, vector_results: List[VectorSearchResult]) -> float:
238
+ """Calculate confidence score for the augmentation results."""
239
+ if not vector_results:
240
+ return 0.0
241
+
242
+ # Base confidence on average similarity and result count
243
+ avg_similarity = np.mean([r.similarity_score for r in vector_results])
244
+ count_factor = min(len(vector_results) / 10, 1.0) # Normalize to max 10 results
245
+
246
+ # Combined confidence
247
+ confidence = (avg_similarity * 0.7) + (count_factor * 0.3)
248
+
249
+ return min(confidence, 1.0)
250
+
251
+ def get_augmentation_stats(self) -> Dict[str, Any]:
252
+ """Get statistics about vector augmentation performance."""
253
+ return {
254
+ 'similarity_threshold': self.similarity_threshold,
255
+ 'max_vector_results': self.max_vector_results,
256
+ 'vector_engine_ready': bool(self.vector_engine),
257
+ 'embedding_model': str(self.embedding_model) if self.embedding_model else None
258
+ }
259
+
260
+
261
+ # Export main class
262
+ __all__ = ['VectorAugmentationEngine', 'VectorSearchResult', 'AugmentationResult']