Spaces:
Runtime error
Runtime error
Upload 8 files
Browse files- query_graph_functions/__init__.py +112 -0
- query_graph_functions/answer_synthesis.py +408 -0
- query_graph_functions/follow_up_search.py +429 -0
- query_graph_functions/knowledge_retrieval.py +843 -0
- query_graph_functions/query_preprocessing.py +592 -0
- query_graph_functions/response_management.py +259 -0
- query_graph_functions/setup.py +361 -0
- query_graph_functions/vector_augmentation.py +262 -0
query_graph_functions/__init__.py
ADDED
|
@@ -0,0 +1,112 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Query Graph Functions Package
|
| 3 |
+
|
| 4 |
+
Core modules for graph-based retrieval augmentation implementation.
|
| 5 |
+
|
| 6 |
+
Package contents:
|
| 7 |
+
- setup.py: Initialization and connection functionality (Phase A: Steps 1-2)
|
| 8 |
+
- query_preprocessing.py: Query analysis, routing, and vectorization (Phase B: Steps 3-5)
|
| 9 |
+
- knowledge_retrieval.py: Community search and data extraction (Phase C: Steps 6-8)
|
| 10 |
+
- follow_up_search.py: Follow-up search and entity extraction (Phase D: Steps 9-12)
|
| 11 |
+
- vector_augmentation.py: Vector search enhancement (Phase E: Steps 13-14)
|
| 12 |
+
- answer_synthesis.py: Final answer generation (Phase F: Steps 15-16)
|
| 13 |
+
- response_management.py: Metadata generation and file persistence (Phase G: Steps 17-20)
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
# Phase A: Initialization (Steps 1-2)
|
| 17 |
+
from .setup import GraphRAGSetup, create_graphrag_setup
|
| 18 |
+
|
| 19 |
+
# Phase B: Query Preprocessing (Steps 3-5)
|
| 20 |
+
from .query_preprocessing import (
|
| 21 |
+
QueryAnalyzer,
|
| 22 |
+
DriftRouter,
|
| 23 |
+
QueryVectorizer,
|
| 24 |
+
QueryPreprocessor,
|
| 25 |
+
create_query_preprocessor,
|
| 26 |
+
preprocess_query_pipeline,
|
| 27 |
+
QueryAnalysis,
|
| 28 |
+
DriftRoutingResult,
|
| 29 |
+
VectorizedQuery,
|
| 30 |
+
QueryType,
|
| 31 |
+
SearchStrategy
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
# Phase C: Knowledge Retrieval (Steps 6-8)
|
| 35 |
+
from .knowledge_retrieval import (
|
| 36 |
+
CommunitySearchEngine,
|
| 37 |
+
CommunityResult,
|
| 38 |
+
EntityResult,
|
| 39 |
+
RelationshipResult
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
# Phase D: Follow-up Search (Steps 9-12)
|
| 43 |
+
from .follow_up_search import (
|
| 44 |
+
FollowUpSearch,
|
| 45 |
+
FollowUpQuestion,
|
| 46 |
+
LocalSearchResult,
|
| 47 |
+
IntermediateAnswer
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
# Phase E: Vector Search Augmentation (Steps 13-14)
|
| 51 |
+
from .vector_augmentation import (
|
| 52 |
+
VectorAugmentationEngine,
|
| 53 |
+
VectorSearchResult,
|
| 54 |
+
AugmentationResult
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
# Phase F: Answer Synthesis (Steps 15-16)
|
| 58 |
+
from .answer_synthesis import (
|
| 59 |
+
AnswerSynthesisEngine,
|
| 60 |
+
SynthesisResult,
|
| 61 |
+
SourceEvidence
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
# Phase G: Response Management (Steps 17-20)
|
| 65 |
+
from .response_management import (
|
| 66 |
+
ResponseManager,
|
| 67 |
+
ResponseMetadata
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
__version__ = "1.3.0"
|
| 71 |
+
__author__ = "AllyCat GraphRAG Team"
|
| 72 |
+
__description__ = "Graph-based retrieval augmentation implementation for AllyCat"
|
| 73 |
+
|
| 74 |
+
# Export main classes and functions
|
| 75 |
+
__all__ = [
|
| 76 |
+
# Phase A: Initialization
|
| 77 |
+
"GraphRAGSetup",
|
| 78 |
+
"create_graphrag_setup",
|
| 79 |
+
# Phase B: Query Preprocessing
|
| 80 |
+
"QueryAnalyzer",
|
| 81 |
+
"DriftRouter",
|
| 82 |
+
"QueryVectorizer",
|
| 83 |
+
"QueryPreprocessor",
|
| 84 |
+
"create_query_preprocessor",
|
| 85 |
+
"preprocess_query_pipeline",
|
| 86 |
+
"QueryAnalysis",
|
| 87 |
+
"DriftRoutingResult",
|
| 88 |
+
"VectorizedQuery",
|
| 89 |
+
"QueryType",
|
| 90 |
+
"SearchStrategy",
|
| 91 |
+
# Phase C: Knowledge Retrieval
|
| 92 |
+
"CommunitySearchEngine",
|
| 93 |
+
"CommunityResult",
|
| 94 |
+
"EntityResult",
|
| 95 |
+
"RelationshipResult",
|
| 96 |
+
# Phase D: Follow-up Search
|
| 97 |
+
"FollowUpSearch",
|
| 98 |
+
"FollowUpQuestion",
|
| 99 |
+
"LocalSearchResult",
|
| 100 |
+
"IntermediateAnswer",
|
| 101 |
+
# Phase E: Vector Augmentation
|
| 102 |
+
"VectorAugmentationEngine",
|
| 103 |
+
"VectorSearchResult",
|
| 104 |
+
"AugmentationResult",
|
| 105 |
+
# Phase F: Answer Synthesis
|
| 106 |
+
"AnswerSynthesisEngine",
|
| 107 |
+
"SynthesisResult",
|
| 108 |
+
"SourceEvidence",
|
| 109 |
+
# Phase G: Response Management
|
| 110 |
+
"ResponseManager",
|
| 111 |
+
"ResponseMetadata"
|
| 112 |
+
]
|
query_graph_functions/answer_synthesis.py
ADDED
|
@@ -0,0 +1,408 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Answer synthesis module for final response generation. - Phase F (Steps 15-16)"""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import json
|
| 5 |
+
from typing import Dict, List, Any, Optional
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
from .setup import GraphRAGSetup
|
| 10 |
+
from .query_preprocessing import DriftRoutingResult, QueryAnalysis
|
| 11 |
+
from .vector_augmentation import AugmentationResult
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@dataclass
|
| 15 |
+
class SourceEvidence:
|
| 16 |
+
"""Evidence source with attribution and confidence."""
|
| 17 |
+
source_type: str # 'community', 'entity', 'relationship', 'vector_doc'
|
| 18 |
+
source_id: str
|
| 19 |
+
content: str
|
| 20 |
+
confidence: float
|
| 21 |
+
phase: str # 'C', 'D', 'E'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class SynthesisResult:
|
| 26 |
+
"""Phase F synthesis result with comprehensive answer."""
|
| 27 |
+
final_answer: str
|
| 28 |
+
confidence_score: float
|
| 29 |
+
source_evidence: List[SourceEvidence]
|
| 30 |
+
synthesis_strategy: str
|
| 31 |
+
coverage_assessment: Dict[str, float]
|
| 32 |
+
execution_time: float
|
| 33 |
+
metadata: Dict[str, Any]
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AnswerSynthesisEngine:
|
| 37 |
+
"""
|
| 38 |
+
Answer synthesis engine implementing Phase F (Steps 15-16).
|
| 39 |
+
|
| 40 |
+
Handles final answer generation process:
|
| 41 |
+
- Context assembly and evidence ranking (Step 15)
|
| 42 |
+
- Final answer generation with confidence scoring (Step 16)
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, setup: GraphRAGSetup):
|
| 46 |
+
self.setup = setup
|
| 47 |
+
self.llm = setup.llm
|
| 48 |
+
self.config = setup.config
|
| 49 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 50 |
+
|
| 51 |
+
# Synthesis parameters
|
| 52 |
+
self.min_confidence_threshold = 0.7
|
| 53 |
+
self.max_synthesis_length = 2000
|
| 54 |
+
|
| 55 |
+
async def execute_answer_synthesis_phase(self,
|
| 56 |
+
analysis: QueryAnalysis,
|
| 57 |
+
routing: DriftRoutingResult,
|
| 58 |
+
community_results: Dict[str, Any],
|
| 59 |
+
follow_up_results: Dict[str, Any],
|
| 60 |
+
augmentation_results: AugmentationResult) -> SynthesisResult:
|
| 61 |
+
"""
|
| 62 |
+
Execute answer synthesis phase with comprehensive integration.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
analysis: Query analysis results
|
| 66 |
+
routing: Routing decision parameters
|
| 67 |
+
community_results: Community search results
|
| 68 |
+
follow_up_results: Follow-up search results
|
| 69 |
+
augmentation_results: Vector augmentation results
|
| 70 |
+
|
| 71 |
+
Returns:
|
| 72 |
+
Synthesis result with final answer
|
| 73 |
+
"""
|
| 74 |
+
start_time = datetime.now()
|
| 75 |
+
|
| 76 |
+
try:
|
| 77 |
+
# Context assembly
|
| 78 |
+
self.logger.info("Starting Step 15: Context Assembly and Ranking")
|
| 79 |
+
assembled_context = await self._assemble_and_rank_context(
|
| 80 |
+
analysis, community_results, follow_up_results, augmentation_results
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Final answer generation
|
| 84 |
+
self.logger.info("Starting Step 16: Final Answer Generation")
|
| 85 |
+
final_answer, confidence = await self._generate_final_answer(
|
| 86 |
+
analysis, routing, assembled_context
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 90 |
+
|
| 91 |
+
synthesis_result = SynthesisResult(
|
| 92 |
+
final_answer=final_answer,
|
| 93 |
+
confidence_score=confidence,
|
| 94 |
+
source_evidence=assembled_context['evidence'],
|
| 95 |
+
synthesis_strategy='comprehensive_drift',
|
| 96 |
+
coverage_assessment=assembled_context['coverage'],
|
| 97 |
+
execution_time=execution_time,
|
| 98 |
+
metadata={
|
| 99 |
+
'sources_integrated': len(assembled_context['evidence']),
|
| 100 |
+
'phase_coverage': assembled_context['phase_coverage'],
|
| 101 |
+
'synthesis_method': 'llm_guided',
|
| 102 |
+
'phase': 'answer_synthesis',
|
| 103 |
+
'step_range': '15-16'
|
| 104 |
+
}
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
self.logger.info(f"Phase F completed: confidence {confidence:.3f}, {len(assembled_context['evidence'])} sources integrated")
|
| 108 |
+
return synthesis_result
|
| 109 |
+
|
| 110 |
+
except Exception as e:
|
| 111 |
+
self.logger.error(f"Answer synthesis phase failed: {e}")
|
| 112 |
+
# Return fallback synthesis on failure
|
| 113 |
+
return self._create_fallback_synthesis(
|
| 114 |
+
community_results, follow_up_results,
|
| 115 |
+
(datetime.now() - start_time).total_seconds(), str(e)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
async def _assemble_and_rank_context(self,
|
| 119 |
+
analysis: QueryAnalysis,
|
| 120 |
+
community_results: Dict[str, Any],
|
| 121 |
+
follow_up_results: Dict[str, Any],
|
| 122 |
+
augmentation_results: AugmentationResult) -> Dict[str, Any]:
|
| 123 |
+
"""
|
| 124 |
+
Step 15: Assemble and rank all context from Phases C, D, and E.
|
| 125 |
+
|
| 126 |
+
Prioritizes information by relevance, confidence, and source diversity.
|
| 127 |
+
"""
|
| 128 |
+
evidence_sources = []
|
| 129 |
+
|
| 130 |
+
# Extract community evidence
|
| 131 |
+
if 'communities' in community_results:
|
| 132 |
+
for community in community_results['communities']:
|
| 133 |
+
evidence_sources.append(SourceEvidence(
|
| 134 |
+
source_type='community',
|
| 135 |
+
source_id=community.community_id,
|
| 136 |
+
content=community.summary,
|
| 137 |
+
confidence=community.similarity_score,
|
| 138 |
+
phase='C'
|
| 139 |
+
))
|
| 140 |
+
|
| 141 |
+
# Extract follow-up evidence
|
| 142 |
+
if 'intermediate_answers' in follow_up_results:
|
| 143 |
+
for answer in follow_up_results['intermediate_answers']:
|
| 144 |
+
evidence_sources.append(SourceEvidence(
|
| 145 |
+
source_type='entity_search',
|
| 146 |
+
source_id=f"followup_{len(evidence_sources)}",
|
| 147 |
+
content=f"Q: {answer.question}\nA: {answer.answer}",
|
| 148 |
+
confidence=answer.confidence,
|
| 149 |
+
phase='D'
|
| 150 |
+
))
|
| 151 |
+
|
| 152 |
+
# Extract vector evidence
|
| 153 |
+
if augmentation_results and augmentation_results.vector_results:
|
| 154 |
+
for i, vector_result in enumerate(augmentation_results.vector_results):
|
| 155 |
+
evidence_sources.append(SourceEvidence(
|
| 156 |
+
source_type='vector_doc',
|
| 157 |
+
source_id=f"vector_{i}",
|
| 158 |
+
content=vector_result.content,
|
| 159 |
+
confidence=vector_result.similarity_score,
|
| 160 |
+
phase='E'
|
| 161 |
+
))
|
| 162 |
+
|
| 163 |
+
# Rank evidence
|
| 164 |
+
ranked_evidence = sorted(evidence_sources, key=lambda x: x.confidence, reverse=True)
|
| 165 |
+
|
| 166 |
+
# Calculate coverage
|
| 167 |
+
coverage = {
|
| 168 |
+
'community_coverage': len([e for e in ranked_evidence if e.phase == 'C']) / max(1, len(community_results.get('communities', []))),
|
| 169 |
+
'entity_coverage': len([e for e in ranked_evidence if e.phase == 'D']) / max(1, len(follow_up_results.get('intermediate_answers', []))),
|
| 170 |
+
'vector_coverage': len([e for e in ranked_evidence if e.phase == 'E']) / max(1, len(augmentation_results.vector_results) if augmentation_results else 1),
|
| 171 |
+
'overall_confidence': sum(e.confidence for e in ranked_evidence) / max(1, len(ranked_evidence))
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
phase_coverage = {
|
| 175 |
+
'phase_c': len([e for e in ranked_evidence if e.phase == 'C']),
|
| 176 |
+
'phase_d': len([e for e in ranked_evidence if e.phase == 'D']),
|
| 177 |
+
'phase_e': len([e for e in ranked_evidence if e.phase == 'E'])
|
| 178 |
+
}
|
| 179 |
+
|
| 180 |
+
return {
|
| 181 |
+
'evidence': ranked_evidence[:15], # Top 15 pieces of evidence
|
| 182 |
+
'coverage': coverage,
|
| 183 |
+
'phase_coverage': phase_coverage
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
async def _generate_final_answer(self,
|
| 187 |
+
analysis: QueryAnalysis,
|
| 188 |
+
routing: DriftRoutingResult,
|
| 189 |
+
assembled_context: Dict[str, Any]) -> tuple[str, float]:
|
| 190 |
+
"""
|
| 191 |
+
Step 16: Generate comprehensive final answer using LLM synthesis.
|
| 192 |
+
|
| 193 |
+
Creates structured, comprehensive response with proper source attribution.
|
| 194 |
+
"""
|
| 195 |
+
try:
|
| 196 |
+
# Prepare prompt
|
| 197 |
+
synthesis_prompt = self._create_synthesis_prompt(
|
| 198 |
+
routing.original_query,
|
| 199 |
+
assembled_context['evidence']
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
# Generate answer
|
| 203 |
+
response = self.llm.complete(synthesis_prompt)
|
| 204 |
+
final_answer = str(response).strip()
|
| 205 |
+
|
| 206 |
+
# Calculate confidence
|
| 207 |
+
synthesis_confidence = self._calculate_synthesis_confidence(
|
| 208 |
+
assembled_context['evidence'], assembled_context['coverage']
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
# Format final answer
|
| 212 |
+
formatted_answer = self._format_final_answer(
|
| 213 |
+
final_answer, assembled_context['evidence'], synthesis_confidence
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
return formatted_answer, synthesis_confidence
|
| 217 |
+
|
| 218 |
+
except Exception as e:
|
| 219 |
+
self.logger.error(f"Final answer generation failed: {e}")
|
| 220 |
+
return self._create_fallback_answer(assembled_context['evidence']), 0.5
|
| 221 |
+
|
| 222 |
+
def _create_synthesis_prompt(self, original_query: str, evidence: List[SourceEvidence]) -> str:
|
| 223 |
+
"""Create comprehensive synthesis prompt for LLM."""
|
| 224 |
+
prompt_parts = [
|
| 225 |
+
f"# Query: {original_query}",
|
| 226 |
+
"",
|
| 227 |
+
"You are an expert synthesizing information from multiple sources.",
|
| 228 |
+
"Create a comprehensive, accurate answer using the following evidence:",
|
| 229 |
+
"",
|
| 230 |
+
"## Evidence Sources:",
|
| 231 |
+
""
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
for i, source in enumerate(evidence[:10], 1): # Top 10 sources
|
| 235 |
+
prompt_parts.extend([
|
| 236 |
+
f"### Source {i} ({source.phase} - {source.source_type}, confidence: {source.confidence:.3f})",
|
| 237 |
+
source.content[:500] + ("..." if len(source.content) > 500 else ""),
|
| 238 |
+
""
|
| 239 |
+
])
|
| 240 |
+
|
| 241 |
+
prompt_parts.extend([
|
| 242 |
+
"## Instructions:",
|
| 243 |
+
"1. Synthesize a comprehensive answer addressing the original query",
|
| 244 |
+
"2. Prioritize high-confidence sources (>0.8)",
|
| 245 |
+
"3. Include specific details and examples from the evidence",
|
| 246 |
+
"4. Structure the response clearly with sections if appropriate",
|
| 247 |
+
"5. Do not mention source IDs or technical details",
|
| 248 |
+
"6. Focus on factual accuracy and completeness",
|
| 249 |
+
"",
|
| 250 |
+
"## Comprehensive Answer:"
|
| 251 |
+
])
|
| 252 |
+
|
| 253 |
+
return "\n".join(prompt_parts)
|
| 254 |
+
|
| 255 |
+
def _calculate_synthesis_confidence(self, evidence: List[SourceEvidence], coverage: Dict[str, float]) -> float:
|
| 256 |
+
"""Calculate overall synthesis confidence based on evidence quality and coverage."""
|
| 257 |
+
if not evidence:
|
| 258 |
+
return 0.0
|
| 259 |
+
|
| 260 |
+
# Weight evidence
|
| 261 |
+
evidence_confidence = sum(e.confidence for e in evidence) / len(evidence)
|
| 262 |
+
coverage_score = sum(coverage.values()) / len(coverage)
|
| 263 |
+
|
| 264 |
+
# Coverage bonus
|
| 265 |
+
phase_diversity = len(set(e.phase for e in evidence)) / 3.0 # 3 phases max
|
| 266 |
+
|
| 267 |
+
# Combined score
|
| 268 |
+
synthesis_confidence = (evidence_confidence * 0.5) + (coverage_score * 0.3) + (phase_diversity * 0.2)
|
| 269 |
+
|
| 270 |
+
return min(synthesis_confidence, 1.0)
|
| 271 |
+
|
| 272 |
+
def _format_final_answer(self, answer: str, evidence: List[SourceEvidence], confidence: float) -> str:
|
| 273 |
+
"""Format the final answer with proper structure and attribution."""
|
| 274 |
+
formatted_parts = [
|
| 275 |
+
"# Comprehensive Answer",
|
| 276 |
+
"",
|
| 277 |
+
answer,
|
| 278 |
+
"",
|
| 279 |
+
"---",
|
| 280 |
+
"",
|
| 281 |
+
f"**Answer Confidence**: {confidence:.1%}",
|
| 282 |
+
f"**Sources Integrated**: {len(evidence)} evidence sources",
|
| 283 |
+
f"**Multi-Phase Coverage**: {len(set(e.phase for e in evidence))} phases (C: Community, D: Entity, E: Vector)",
|
| 284 |
+
""
|
| 285 |
+
]
|
| 286 |
+
|
| 287 |
+
return "\n".join(formatted_parts)
|
| 288 |
+
|
| 289 |
+
def _create_fallback_answer(self, evidence: List[SourceEvidence]) -> str:
|
| 290 |
+
"""Create fallback answer when LLM synthesis fails."""
|
| 291 |
+
if not evidence:
|
| 292 |
+
return "Unable to generate answer due to insufficient evidence."
|
| 293 |
+
|
| 294 |
+
# Simple concatenation of top evidence
|
| 295 |
+
fallback_parts = [
|
| 296 |
+
"# Answer Summary",
|
| 297 |
+
"",
|
| 298 |
+
"Based on available evidence:",
|
| 299 |
+
""
|
| 300 |
+
]
|
| 301 |
+
|
| 302 |
+
for i, source in enumerate(evidence[:3], 1):
|
| 303 |
+
fallback_parts.extend([
|
| 304 |
+
f"## Source {i} (Confidence: {source.confidence:.2f})",
|
| 305 |
+
source.content[:300] + ("..." if len(source.content) > 300 else ""),
|
| 306 |
+
""
|
| 307 |
+
])
|
| 308 |
+
|
| 309 |
+
return "\n".join(fallback_parts)
|
| 310 |
+
|
| 311 |
+
def _create_fallback_synthesis(self, community_results: Dict, follow_up_results: Dict,
|
| 312 |
+
execution_time: float, error: str) -> SynthesisResult:
|
| 313 |
+
"""Create fallback synthesis result when phase fails."""
|
| 314 |
+
return SynthesisResult(
|
| 315 |
+
final_answer=" Response failed due to technical error. Please try again.",
|
| 316 |
+
confidence_score=0.0,
|
| 317 |
+
source_evidence=[],
|
| 318 |
+
synthesis_strategy='fallback',
|
| 319 |
+
coverage_assessment={'overall_confidence': 0.0},
|
| 320 |
+
execution_time=execution_time,
|
| 321 |
+
metadata={'error': error, 'fallback': True}
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
def combine_phase_results(self,
|
| 325 |
+
phase_c_answer: str,
|
| 326 |
+
follow_up_results: Dict[str, Any],
|
| 327 |
+
augmentation_results=None) -> str:
|
| 328 |
+
"""
|
| 329 |
+
Combine Phase C, D, and E results into enhanced answer.
|
| 330 |
+
|
| 331 |
+
Creates comprehensive response by integrating results from multiple phases.
|
| 332 |
+
"""
|
| 333 |
+
try:
|
| 334 |
+
intermediate_answers = follow_up_results.get('intermediate_answers', [])
|
| 335 |
+
|
| 336 |
+
if not intermediate_answers:
|
| 337 |
+
return phase_c_answer
|
| 338 |
+
|
| 339 |
+
# Start with Phase C answer
|
| 340 |
+
enhanced_parts = [
|
| 341 |
+
"## Global Context (Phase C)",
|
| 342 |
+
phase_c_answer.strip(),
|
| 343 |
+
"",
|
| 344 |
+
"## Detailed Information (Phase D)"
|
| 345 |
+
]
|
| 346 |
+
|
| 347 |
+
# Add intermediate answers from Phase D
|
| 348 |
+
for i, answer in enumerate(intermediate_answers, 1):
|
| 349 |
+
enhanced_parts.extend([
|
| 350 |
+
f"**{i}. {answer.question}**",
|
| 351 |
+
answer.answer,
|
| 352 |
+
f"*Confidence: {answer.confidence:.2f}*",
|
| 353 |
+
""
|
| 354 |
+
])
|
| 355 |
+
|
| 356 |
+
# Add Phase E vector augmentation if available
|
| 357 |
+
if augmentation_results and hasattr(augmentation_results, 'vector_results') and augmentation_results.vector_results:
|
| 358 |
+
enhanced_parts.extend([
|
| 359 |
+
"## Vector Augmentation (Phase E)",
|
| 360 |
+
f"**Semantic Enhancement** (Confidence: {augmentation_results.augmentation_confidence:.2f})",
|
| 361 |
+
""
|
| 362 |
+
])
|
| 363 |
+
|
| 364 |
+
# Add top vector results
|
| 365 |
+
for i, vector_result in enumerate(augmentation_results.vector_results[:3], 1):
|
| 366 |
+
enhanced_parts.extend([
|
| 367 |
+
f"**Vector Result {i}** (Similarity: {vector_result.similarity_score:.3f})",
|
| 368 |
+
vector_result.content, # Show full content without truncation
|
| 369 |
+
""
|
| 370 |
+
])
|
| 371 |
+
|
| 372 |
+
# Add supporting evidence if available
|
| 373 |
+
if intermediate_answers:
|
| 374 |
+
enhanced_parts.extend([
|
| 375 |
+
"## Supporting Evidence",
|
| 376 |
+
"**Key Entities Found:** " + ", ".join(
|
| 377 |
+
set(entity for answer in intermediate_answers
|
| 378 |
+
for entity in answer.supporting_entities[:3])
|
| 379 |
+
),
|
| 380 |
+
""
|
| 381 |
+
])
|
| 382 |
+
|
| 383 |
+
return "\n".join(enhanced_parts)
|
| 384 |
+
|
| 385 |
+
except Exception as e:
|
| 386 |
+
self.logger.error(f"Failed to combine phase results: {e}")
|
| 387 |
+
return phase_c_answer
|
| 388 |
+
|
| 389 |
+
def generate_error_response(self, error_message: str) -> Dict[str, Any]:
|
| 390 |
+
"""
|
| 391 |
+
Generate standardized error response.
|
| 392 |
+
|
| 393 |
+
Creates consistent error format for failed synthesis operations.
|
| 394 |
+
"""
|
| 395 |
+
return {
|
| 396 |
+
"answer": f"Sorry, I encountered an error during answer synthesis: {error_message}",
|
| 397 |
+
"metadata": {
|
| 398 |
+
"status": "synthesis_error",
|
| 399 |
+
"error_message": error_message,
|
| 400 |
+
"synthesis_stage": "failed",
|
| 401 |
+
"confidence_score": 0.0,
|
| 402 |
+
"timestamp": datetime.now().isoformat()
|
| 403 |
+
}
|
| 404 |
+
}
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
# Exports
|
| 408 |
+
__all__ = ['AnswerSynthesisEngine', 'SynthesisResult', 'SourceEvidence']
|
query_graph_functions/follow_up_search.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Follow-up search module for local graph traversal. - Phase D (Steps 9-12)"""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, List, Any
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
import re
|
| 7 |
+
from datetime import datetime
|
| 8 |
+
|
| 9 |
+
# Project imports
|
| 10 |
+
from .setup import GraphRAGSetup
|
| 11 |
+
from .query_preprocessing import DriftRoutingResult
|
| 12 |
+
from .knowledge_retrieval import CommunityResult, EntityResult, RelationshipResult
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@dataclass
|
| 16 |
+
class FollowUpQuestion:
|
| 17 |
+
"""Represents a follow-up question from Phase C."""
|
| 18 |
+
question: str
|
| 19 |
+
question_id: int
|
| 20 |
+
extracted_entities: List[str]
|
| 21 |
+
query_type: str
|
| 22 |
+
confidence: float
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
@dataclass
|
| 26 |
+
class LocalSearchResult:
|
| 27 |
+
"""Results from local graph traversal."""
|
| 28 |
+
seed_entities: List[EntityResult]
|
| 29 |
+
traversed_entities: List[EntityResult]
|
| 30 |
+
traversed_relationships: List[RelationshipResult]
|
| 31 |
+
search_depth: int
|
| 32 |
+
total_nodes_visited: int
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class IntermediateAnswer:
|
| 37 |
+
"""Intermediate answer for a follow-up question."""
|
| 38 |
+
question_id: int
|
| 39 |
+
question: str
|
| 40 |
+
answer: str
|
| 41 |
+
confidence: float
|
| 42 |
+
reasoning: str
|
| 43 |
+
supporting_entities: List[str]
|
| 44 |
+
supporting_evidence: List[str]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class FollowUpSearch:
|
| 48 |
+
"""Follow-up search module for local graph traversal."""
|
| 49 |
+
def __init__(self, setup: GraphRAGSetup):
|
| 50 |
+
self.setup = setup
|
| 51 |
+
self.neo4j_conn = setup.neo4j_conn
|
| 52 |
+
self.logger = logging.getLogger(__name__)
|
| 53 |
+
|
| 54 |
+
# Configuration
|
| 55 |
+
self.max_traversal_depth = 2
|
| 56 |
+
self.max_entities_per_hop = 20
|
| 57 |
+
self.min_entity_confidence = 0.7
|
| 58 |
+
self.min_relationship_confidence = 0.6
|
| 59 |
+
|
| 60 |
+
async def execute_follow_up_phase(self,
|
| 61 |
+
phase_c_results: Dict[str, Any],
|
| 62 |
+
routing_result: DriftRoutingResult) -> Dict[str, Any]:
|
| 63 |
+
"""
|
| 64 |
+
Execute follow-up search pipeline based on initial results.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
phase_c_results: Results from community search with follow-up questions
|
| 68 |
+
routing_result: Routing configuration parameters
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
Dictionary with intermediate answers and entity information
|
| 72 |
+
"""
|
| 73 |
+
try:
|
| 74 |
+
self.logger.info("Starting Follow-up Search (Steps 9-12)")
|
| 75 |
+
|
| 76 |
+
# Process follow-up questions
|
| 77 |
+
self.logger.info("Starting Step 9: Follow-up Question Processing")
|
| 78 |
+
follow_up_questions = await self._process_follow_up_questions(
|
| 79 |
+
phase_c_results.get('initial_answer', {}).get('follow_up_questions', []),
|
| 80 |
+
routing_result
|
| 81 |
+
)
|
| 82 |
+
self.logger.info(f"Step 9 completed: {len(follow_up_questions)} questions processed")
|
| 83 |
+
|
| 84 |
+
# Local graph traversal
|
| 85 |
+
self.logger.info("Starting Step 10: Local Graph Traversal")
|
| 86 |
+
local_search_results = await self._execute_local_traversal(
|
| 87 |
+
follow_up_questions,
|
| 88 |
+
phase_c_results.get('communities', []),
|
| 89 |
+
routing_result
|
| 90 |
+
)
|
| 91 |
+
self.logger.info(f"Step 10 completed: {len(local_search_results)} searches performed")
|
| 92 |
+
|
| 93 |
+
# Entity extraction
|
| 94 |
+
self.logger.info("Starting Step 11: Detailed Entity Extraction")
|
| 95 |
+
detailed_entities = await self._extract_detailed_entities(
|
| 96 |
+
local_search_results,
|
| 97 |
+
routing_result
|
| 98 |
+
)
|
| 99 |
+
self.logger.info(f"Step 11 completed: {len(detailed_entities)} detailed entities extracted")
|
| 100 |
+
|
| 101 |
+
# Generate intermediate answers
|
| 102 |
+
self.logger.info("Starting Step 12: Intermediate Answer Generation")
|
| 103 |
+
intermediate_answers = await self._generate_intermediate_answers(
|
| 104 |
+
follow_up_questions,
|
| 105 |
+
local_search_results,
|
| 106 |
+
detailed_entities,
|
| 107 |
+
routing_result
|
| 108 |
+
)
|
| 109 |
+
self.logger.info(f"Step 12 completed: {len(intermediate_answers)} intermediate answers generated")
|
| 110 |
+
|
| 111 |
+
# Compile results
|
| 112 |
+
phase_d_results = {
|
| 113 |
+
'follow_up_questions': follow_up_questions,
|
| 114 |
+
'local_search_results': local_search_results,
|
| 115 |
+
'detailed_entities': detailed_entities,
|
| 116 |
+
'intermediate_answers': intermediate_answers,
|
| 117 |
+
'execution_stats': {
|
| 118 |
+
'questions_processed': len(follow_up_questions),
|
| 119 |
+
'local_searches_executed': len(local_search_results),
|
| 120 |
+
'entities_extracted': len(detailed_entities),
|
| 121 |
+
'answers_generated': len(intermediate_answers),
|
| 122 |
+
'timestamp': datetime.now().isoformat()
|
| 123 |
+
}
|
| 124 |
+
}
|
| 125 |
+
|
| 126 |
+
self.logger.info(f"Phase D completed: {len(intermediate_answers)} detailed answers generated")
|
| 127 |
+
return phase_d_results
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
self.logger.error(f"Phase D execution failed: {e}")
|
| 131 |
+
return {'error': str(e), 'intermediate_answers': []}
|
| 132 |
+
|
| 133 |
+
async def _process_follow_up_questions(self,
|
| 134 |
+
questions: List[str],
|
| 135 |
+
routing_result: DriftRoutingResult) -> List[FollowUpQuestion]:
|
| 136 |
+
"""Simple: just wrap questions in FollowUpQuestion objects."""
|
| 137 |
+
processed_questions = []
|
| 138 |
+
|
| 139 |
+
for i, question in enumerate(questions):
|
| 140 |
+
# Extract keywords
|
| 141 |
+
keywords = re.findall(r'\b[A-Z][a-z]+\b|\b[A-Z]{2,}\b', question)
|
| 142 |
+
keywords = [k for k in keywords if k not in ['What', 'Which', 'Who', 'How', 'Are', 'The']]
|
| 143 |
+
|
| 144 |
+
follow_up = FollowUpQuestion(
|
| 145 |
+
question=question,
|
| 146 |
+
question_id=i + 1,
|
| 147 |
+
extracted_entities=keywords[:3], # Top 3 keywords
|
| 148 |
+
query_type='search',
|
| 149 |
+
confidence=0.8
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
processed_questions.append(follow_up)
|
| 153 |
+
self.logger.info(f"Question {i+1}: {question} -> Keywords: {keywords[:3]}")
|
| 154 |
+
|
| 155 |
+
return processed_questions
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
async def _execute_local_traversal(self,
|
| 160 |
+
questions: List[FollowUpQuestion],
|
| 161 |
+
communities: List[CommunityResult],
|
| 162 |
+
routing_result: DriftRoutingResult) -> List[LocalSearchResult]:
|
| 163 |
+
"""
|
| 164 |
+
Step 10: Execute local graph traversal for each follow-up question.
|
| 165 |
+
|
| 166 |
+
Performs multi-hop traversal from seed entities to find detailed information.
|
| 167 |
+
"""
|
| 168 |
+
local_results = []
|
| 169 |
+
|
| 170 |
+
for question in questions:
|
| 171 |
+
try:
|
| 172 |
+
# Find seed entities
|
| 173 |
+
seed_entities = await self._find_seed_entities(
|
| 174 |
+
question.extracted_entities,
|
| 175 |
+
communities
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
if not seed_entities:
|
| 179 |
+
self.logger.warning(f"No seed entities found for question: {question.question}")
|
| 180 |
+
continue
|
| 181 |
+
|
| 182 |
+
# Multi-hop traversal
|
| 183 |
+
traversal_result = await self._multi_hop_traversal(
|
| 184 |
+
seed_entities,
|
| 185 |
+
question,
|
| 186 |
+
routing_result
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
local_results.append(traversal_result)
|
| 190 |
+
self.logger.info(f" Traversal for Q{question.question_id}: {traversal_result.total_nodes_visited} nodes visited")
|
| 191 |
+
|
| 192 |
+
except Exception as e:
|
| 193 |
+
self.logger.error(f"Local traversal failed for question {question.question_id}: {e}")
|
| 194 |
+
|
| 195 |
+
return local_results
|
| 196 |
+
|
| 197 |
+
async def _find_seed_entities(self,
|
| 198 |
+
entity_names: List[str],
|
| 199 |
+
communities: List[CommunityResult]) -> List[EntityResult]:
|
| 200 |
+
"""Just search the graph for entities matching the keywords."""
|
| 201 |
+
if not entity_names:
|
| 202 |
+
return []
|
| 203 |
+
|
| 204 |
+
# Search query
|
| 205 |
+
conditions = " OR ".join([f"n.name CONTAINS '{name}'" for name in entity_names])
|
| 206 |
+
query = f"""
|
| 207 |
+
MATCH (n)
|
| 208 |
+
WHERE n.name IS NOT NULL AND ({conditions})
|
| 209 |
+
RETURN n.id as entity_id, n.name as name, n.content as content,
|
| 210 |
+
n.confidence as confidence,
|
| 211 |
+
n.degree_centrality as degree_centrality,
|
| 212 |
+
n.betweenness_centrality as betweenness_centrality,
|
| 213 |
+
n.closeness_centrality as closeness_centrality,
|
| 214 |
+
labels(n) as node_types
|
| 215 |
+
ORDER BY n.degree_centrality DESC
|
| 216 |
+
LIMIT 20
|
| 217 |
+
"""
|
| 218 |
+
|
| 219 |
+
try:
|
| 220 |
+
results = self.neo4j_conn.execute_query(query, {})
|
| 221 |
+
entities = []
|
| 222 |
+
|
| 223 |
+
for record in results:
|
| 224 |
+
entity = EntityResult(
|
| 225 |
+
entity_id=record['entity_id'],
|
| 226 |
+
name=record['name'],
|
| 227 |
+
content=record['content'],
|
| 228 |
+
confidence=record['confidence'],
|
| 229 |
+
degree_centrality=record['degree_centrality'],
|
| 230 |
+
betweenness_centrality=record['betweenness_centrality'],
|
| 231 |
+
closeness_centrality=record['closeness_centrality'],
|
| 232 |
+
# Set community info
|
| 233 |
+
community_id='found',
|
| 234 |
+
node_type=', '.join(record['node_types']) if record['node_types'] else 'Entity'
|
| 235 |
+
)
|
| 236 |
+
entities.append(entity)
|
| 237 |
+
|
| 238 |
+
return entities
|
| 239 |
+
|
| 240 |
+
except Exception as e:
|
| 241 |
+
self.logger.error(f"Search failed: {e}")
|
| 242 |
+
return []
|
| 243 |
+
|
| 244 |
+
async def _multi_hop_traversal(self,
|
| 245 |
+
seed_entities: List[EntityResult],
|
| 246 |
+
question: FollowUpQuestion,
|
| 247 |
+
routing_result: DriftRoutingResult) -> LocalSearchResult:
|
| 248 |
+
"""Execute multi-hop graph traversal from seed entities."""
|
| 249 |
+
|
| 250 |
+
all_entities = list(seed_entities)
|
| 251 |
+
all_relationships = []
|
| 252 |
+
visited_node_ids = {entity.entity_id for entity in seed_entities}
|
| 253 |
+
|
| 254 |
+
current_entities = seed_entities
|
| 255 |
+
|
| 256 |
+
for hop in range(self.max_traversal_depth):
|
| 257 |
+
if not current_entities:
|
| 258 |
+
break
|
| 259 |
+
|
| 260 |
+
# Get entity IDs for this hop
|
| 261 |
+
current_ids = [entity.entity_id for entity in current_entities]
|
| 262 |
+
|
| 263 |
+
# Multi-hop traversal query
|
| 264 |
+
traversal_query = """
|
| 265 |
+
MATCH (seed)-[r]-(neighbor)
|
| 266 |
+
WHERE seed.id IN $current_ids
|
| 267 |
+
AND NOT (neighbor.id IN $visited_ids)
|
| 268 |
+
AND r.confidence >= $min_rel_confidence
|
| 269 |
+
AND neighbor.confidence >= $min_entity_confidence
|
| 270 |
+
AND neighbor.name IS NOT NULL
|
| 271 |
+
AND neighbor.content IS NOT NULL
|
| 272 |
+
RETURN DISTINCT
|
| 273 |
+
seed.id as seed_id,
|
| 274 |
+
neighbor.id as neighbor_id,
|
| 275 |
+
neighbor.name as neighbor_name,
|
| 276 |
+
neighbor.content as neighbor_content,
|
| 277 |
+
neighbor.confidence as neighbor_confidence,
|
| 278 |
+
neighbor.degree_centrality as degree_centrality,
|
| 279 |
+
neighbor.betweenness_centrality as betweenness_centrality,
|
| 280 |
+
neighbor.closeness_centrality as closeness_centrality,
|
| 281 |
+
labels(neighbor) as neighbor_types,
|
| 282 |
+
type(r) as relationship_type,
|
| 283 |
+
r.confidence as relationship_confidence
|
| 284 |
+
ORDER BY neighbor.degree_centrality DESC, r.confidence DESC
|
| 285 |
+
LIMIT $max_results
|
| 286 |
+
"""
|
| 287 |
+
|
| 288 |
+
try:
|
| 289 |
+
results = self.neo4j_conn.execute_query(
|
| 290 |
+
traversal_query,
|
| 291 |
+
{
|
| 292 |
+
'current_ids': current_ids,
|
| 293 |
+
'visited_ids': list(visited_node_ids),
|
| 294 |
+
'min_rel_confidence': self.min_relationship_confidence,
|
| 295 |
+
'min_entity_confidence': self.min_entity_confidence,
|
| 296 |
+
'max_results': self.max_entities_per_hop
|
| 297 |
+
}
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
next_hop_entities = []
|
| 301 |
+
|
| 302 |
+
for record in results:
|
| 303 |
+
neighbor_id = record['neighbor_id']
|
| 304 |
+
|
| 305 |
+
if neighbor_id not in visited_node_ids:
|
| 306 |
+
# Create entity result
|
| 307 |
+
entity = EntityResult(
|
| 308 |
+
entity_id=neighbor_id,
|
| 309 |
+
name=record['neighbor_name'],
|
| 310 |
+
content=record['neighbor_content'],
|
| 311 |
+
confidence=record['neighbor_confidence'],
|
| 312 |
+
degree_centrality=record['degree_centrality'] or 0.0,
|
| 313 |
+
betweenness_centrality=record['betweenness_centrality'] or 0.0,
|
| 314 |
+
closeness_centrality=record['closeness_centrality'] or 0.0,
|
| 315 |
+
# Set community info
|
| 316 |
+
community_id='unknown',
|
| 317 |
+
node_type=', '.join(record['neighbor_types']) if record['neighbor_types'] else 'Entity'
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
all_entities.append(entity)
|
| 321 |
+
next_hop_entities.append(entity)
|
| 322 |
+
visited_node_ids.add(neighbor_id)
|
| 323 |
+
|
| 324 |
+
# Create relationship result using REAL schema attributes
|
| 325 |
+
relationship = RelationshipResult(
|
| 326 |
+
start_node=record['seed_id'],
|
| 327 |
+
end_node=neighbor_id,
|
| 328 |
+
relationship_type=record['relationship_type'],
|
| 329 |
+
confidence=record['relationship_confidence']
|
| 330 |
+
# Using REAL schema: startNode, endNode
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
all_relationships.append(relationship)
|
| 334 |
+
|
| 335 |
+
current_entities = next_hop_entities
|
| 336 |
+
self.logger.info(f" Hop {hop + 1}: Found {len(next_hop_entities)} new entities")
|
| 337 |
+
|
| 338 |
+
except Exception as e:
|
| 339 |
+
self.logger.error(f"Multi-hop traversal failed at hop {hop + 1}: {e}")
|
| 340 |
+
break
|
| 341 |
+
|
| 342 |
+
return LocalSearchResult(
|
| 343 |
+
seed_entities=seed_entities,
|
| 344 |
+
traversed_entities=all_entities,
|
| 345 |
+
traversed_relationships=all_relationships,
|
| 346 |
+
search_depth=min(hop + 1, self.max_traversal_depth),
|
| 347 |
+
total_nodes_visited=len(visited_node_ids)
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
async def _extract_detailed_entities(self,
|
| 351 |
+
local_results: List[LocalSearchResult],
|
| 352 |
+
routing_result: DriftRoutingResult) -> List[EntityResult]:
|
| 353 |
+
"""
|
| 354 |
+
Step 11: Extract detailed entity information from local search results.
|
| 355 |
+
|
| 356 |
+
Combines and ranks entities from all local searches.
|
| 357 |
+
"""
|
| 358 |
+
all_entities = []
|
| 359 |
+
entity_scores = {}
|
| 360 |
+
|
| 361 |
+
# Collect all entities and calculate importance scores
|
| 362 |
+
for search_result in local_results:
|
| 363 |
+
for entity in search_result.traversed_entities:
|
| 364 |
+
if entity.entity_id not in entity_scores:
|
| 365 |
+
# Calculate entity importance score
|
| 366 |
+
importance_score = (
|
| 367 |
+
0.4 * entity.confidence +
|
| 368 |
+
0.3 * entity.degree_centrality +
|
| 369 |
+
0.2 * entity.betweenness_centrality +
|
| 370 |
+
0.1 * entity.closeness_centrality
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
entity_scores[entity.entity_id] = {
|
| 374 |
+
'entity': entity,
|
| 375 |
+
'importance_score': importance_score,
|
| 376 |
+
'appearance_count': 1
|
| 377 |
+
}
|
| 378 |
+
all_entities.append(entity)
|
| 379 |
+
else:
|
| 380 |
+
# Increment appearance count for entities found in multiple searches
|
| 381 |
+
entity_scores[entity.entity_id]['appearance_count'] += 1
|
| 382 |
+
|
| 383 |
+
# Sort entities by importance score and appearance frequency
|
| 384 |
+
sorted_entities = sorted(
|
| 385 |
+
entity_scores.values(),
|
| 386 |
+
key=lambda x: (x['appearance_count'], x['importance_score']),
|
| 387 |
+
reverse=True
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Return top entities
|
| 391 |
+
max_entities = routing_result.parameters.get('max_detailed_entities', 50)
|
| 392 |
+
detailed_entities = [item['entity'] for item in sorted_entities[:max_entities]]
|
| 393 |
+
|
| 394 |
+
self.logger.info(f"Extracted {len(detailed_entities)} detailed entities from {len(all_entities)} total")
|
| 395 |
+
return detailed_entities
|
| 396 |
+
|
| 397 |
+
async def _generate_intermediate_answers(self,
|
| 398 |
+
questions: List[FollowUpQuestion],
|
| 399 |
+
local_results: List[LocalSearchResult],
|
| 400 |
+
detailed_entities: List[EntityResult],
|
| 401 |
+
routing_result: DriftRoutingResult) -> List[IntermediateAnswer]:
|
| 402 |
+
"""Simple: just list the entity names we found."""
|
| 403 |
+
answers = []
|
| 404 |
+
|
| 405 |
+
for i, question in enumerate(questions):
|
| 406 |
+
# Get entities from search result
|
| 407 |
+
entities = local_results[i].traversed_entities if i < len(local_results) else []
|
| 408 |
+
entity_names = [e.name for e in entities[:10]]
|
| 409 |
+
|
| 410 |
+
# Simple answer with entity names
|
| 411 |
+
answer_text = f"Found entities: {', '.join(entity_names)}" if entity_names else "No specific entities found."
|
| 412 |
+
|
| 413 |
+
answer = IntermediateAnswer(
|
| 414 |
+
question_id=question.question_id,
|
| 415 |
+
question=question.question,
|
| 416 |
+
answer=answer_text,
|
| 417 |
+
confidence=0.8,
|
| 418 |
+
reasoning=f"Found {len(entity_names)} entities matching the search criteria.",
|
| 419 |
+
supporting_entities=entity_names,
|
| 420 |
+
supporting_evidence=[]
|
| 421 |
+
)
|
| 422 |
+
answers.append(answer)
|
| 423 |
+
|
| 424 |
+
return answers
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# Exports
|
| 428 |
+
__all__ = ['FollowUpSearch', 'FollowUpQuestion', 'LocalSearchResult', 'IntermediateAnswer']
|
| 429 |
+
|
query_graph_functions/knowledge_retrieval.py
ADDED
|
@@ -0,0 +1,843 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Knowledge Retrieval Module - Phase C (Steps 6-8)
|
| 3 |
+
|
| 4 |
+
Performs community search and data extraction using graph database structures.
|
| 5 |
+
Handles community retrieval, data extraction, and initial answer generation.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import logging
|
| 9 |
+
import numpy as np
|
| 10 |
+
import json
|
| 11 |
+
from typing import Dict, List, Tuple, Any
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
from .setup import GraphRAGSetup
|
| 16 |
+
from .query_preprocessing import DriftRoutingResult
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class CommunityResult:
|
| 21 |
+
"""Enhanced community result with comprehensive properties."""
|
| 22 |
+
community_id: str
|
| 23 |
+
similarity_score: float
|
| 24 |
+
summary: str
|
| 25 |
+
key_entities: List[str]
|
| 26 |
+
member_ids: List[str] # Direct member access
|
| 27 |
+
modularity_score: float # Community quality
|
| 28 |
+
level: int
|
| 29 |
+
internal_edges: int
|
| 30 |
+
member_count: int
|
| 31 |
+
centrality_stats: Dict[str, float] # Aggregated centrality measures
|
| 32 |
+
confidence_score: float
|
| 33 |
+
search_index: str # Optimized search key
|
| 34 |
+
termination_criteria: Dict[str, Any]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@dataclass
|
| 38 |
+
class EntityResult:
|
| 39 |
+
"""Entity result with attributes from graph database."""
|
| 40 |
+
entity_id: str
|
| 41 |
+
name: str
|
| 42 |
+
content: str
|
| 43 |
+
confidence: float
|
| 44 |
+
degree_centrality: float
|
| 45 |
+
betweenness_centrality: float
|
| 46 |
+
closeness_centrality: float
|
| 47 |
+
community_id: str
|
| 48 |
+
node_type: str
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@dataclass
|
| 52 |
+
class RelationshipResult:
|
| 53 |
+
"""Relationship result with graph database attributes."""
|
| 54 |
+
start_node: str
|
| 55 |
+
end_node: str
|
| 56 |
+
relationship_type: str
|
| 57 |
+
confidence: float
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
class CommunitySearchEngine:
|
| 61 |
+
"""Knowledge retrieval engine for community search and entity extraction."""
|
| 62 |
+
|
| 63 |
+
def __init__(self, setup: GraphRAGSetup):
|
| 64 |
+
self.setup = setup
|
| 65 |
+
self.neo4j_conn = setup.neo4j_conn
|
| 66 |
+
self.config = setup.config
|
| 67 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 68 |
+
|
| 69 |
+
# Initialize search optimization
|
| 70 |
+
self.community_search_index = {}
|
| 71 |
+
self.centrality_cache = {}
|
| 72 |
+
|
| 73 |
+
async def execute_primer_phase(self,
|
| 74 |
+
query_embedding: List[float],
|
| 75 |
+
routing_result: DriftRoutingResult) -> Dict[str, Any]:
|
| 76 |
+
"""Execute community search and knowledge retrieval."""
|
| 77 |
+
start_time = datetime.now()
|
| 78 |
+
|
| 79 |
+
try:
|
| 80 |
+
# Community retrieval
|
| 81 |
+
self.logger.info("Starting community retrieval")
|
| 82 |
+
communities = await self._retrieve_communities_enhanced(
|
| 83 |
+
query_embedding, routing_result
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Data extraction
|
| 87 |
+
self.logger.info("Starting data extraction")
|
| 88 |
+
extracted_data = await self._extract_community_data_enhanced(communities)
|
| 89 |
+
|
| 90 |
+
# Answer generation
|
| 91 |
+
self.logger.info("Starting answer generation")
|
| 92 |
+
initial_answer = await self._generate_initial_answer_enhanced(
|
| 93 |
+
extracted_data, routing_result
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 97 |
+
|
| 98 |
+
return {
|
| 99 |
+
'communities': communities,
|
| 100 |
+
'extracted_data': extracted_data,
|
| 101 |
+
'initial_answer': initial_answer,
|
| 102 |
+
'execution_time': execution_time,
|
| 103 |
+
'metadata': {
|
| 104 |
+
'communities_retrieved': len(communities),
|
| 105 |
+
'entities_extracted': len(extracted_data.get('entities', [])),
|
| 106 |
+
'relationships_extracted': len(extracted_data.get('relationships', [])),
|
| 107 |
+
'phase': 'primer',
|
| 108 |
+
'step_range': '6-8'
|
| 109 |
+
}
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
except Exception as e:
|
| 113 |
+
self.logger.error(f"Primer phase execution failed: {e}")
|
| 114 |
+
raise
|
| 115 |
+
|
| 116 |
+
async def _retrieve_communities_enhanced(self,
|
| 117 |
+
query_embedding: List[float],
|
| 118 |
+
routing_result: DriftRoutingResult) -> List[CommunityResult]:
|
| 119 |
+
"""
|
| 120 |
+
Step 6: Enhanced community retrieval using comprehensive properties.
|
| 121 |
+
|
| 122 |
+
Retrieves relevant communities based on query embedding similarity.
|
| 123 |
+
"""
|
| 124 |
+
try:
|
| 125 |
+
# Retrieve HyDE embeddings
|
| 126 |
+
hyde_embeddings = await self._retrieve_hyde_embeddings_enhanced()
|
| 127 |
+
|
| 128 |
+
if not hyde_embeddings:
|
| 129 |
+
self.logger.warning("No HyDE embeddings found")
|
| 130 |
+
return []
|
| 131 |
+
|
| 132 |
+
# Compute similarities
|
| 133 |
+
similarities = self._compute_hyde_similarities_enhanced(
|
| 134 |
+
query_embedding, hyde_embeddings
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Rank communities
|
| 138 |
+
ranked_communities = self._rank_communities_enhanced(
|
| 139 |
+
similarities, routing_result
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# Apply criteria
|
| 143 |
+
filtered_communities = self._apply_termination_criteria(
|
| 144 |
+
ranked_communities, routing_result
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
# Fetch community details
|
| 148 |
+
community_results = await self._fetch_community_details_enhanced(
|
| 149 |
+
filtered_communities
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.logger.info(f"Retrieved {len(community_results)} enhanced communities")
|
| 153 |
+
return community_results
|
| 154 |
+
|
| 155 |
+
except Exception as e:
|
| 156 |
+
self.logger.error(f"Enhanced community retrieval failed: {e}")
|
| 157 |
+
return []
|
| 158 |
+
|
| 159 |
+
async def _load_community_search_index(self):
|
| 160 |
+
"""Load optimized community search index from Neo4j."""
|
| 161 |
+
try:
|
| 162 |
+
query = """
|
| 163 |
+
MATCH (meta:DriftMetadata)
|
| 164 |
+
WHERE meta.community_search_index IS NOT NULL
|
| 165 |
+
RETURN meta.community_search_index as search_index,
|
| 166 |
+
meta.total_communities as total_communities
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
results = self.neo4j_conn.execute_query(query)
|
| 170 |
+
|
| 171 |
+
for record in results:
|
| 172 |
+
# The search index is a nested JSON structure with community IDs as keys
|
| 173 |
+
search_index_data = record['search_index']
|
| 174 |
+
if isinstance(search_index_data, dict):
|
| 175 |
+
# Each community in the search index
|
| 176 |
+
for community_id, community_data in search_index_data.items():
|
| 177 |
+
self.community_search_index[community_id] = community_data
|
| 178 |
+
else:
|
| 179 |
+
self.logger.warning(f"Unexpected search index format: {type(search_index_data)}")
|
| 180 |
+
|
| 181 |
+
self.logger.info(f"Loaded search index for {len(self.community_search_index)} communities")
|
| 182 |
+
|
| 183 |
+
except Exception as e:
|
| 184 |
+
self.logger.error(f"Failed to load community search index: {e}")
|
| 185 |
+
|
| 186 |
+
async def _retrieve_hyde_embeddings_enhanced(self) -> Dict[str, Dict[str, Any]]:
|
| 187 |
+
"""Retrieve HyDE embeddings and metadata."""
|
| 188 |
+
try:
|
| 189 |
+
# Retrieve community embeddings
|
| 190 |
+
query = """
|
| 191 |
+
MATCH (c:Community)
|
| 192 |
+
WHERE c.hyde_embeddings IS NOT NULL
|
| 193 |
+
OPTIONAL MATCH (meta:CommunitiesMetadata)
|
| 194 |
+
RETURN c.id as community_id,
|
| 195 |
+
c.hyde_embeddings as hyde_embeddings,
|
| 196 |
+
c.summary as summary,
|
| 197 |
+
c.key_entities as key_entities,
|
| 198 |
+
c.member_ids as member_ids,
|
| 199 |
+
size(c.hyde_embeddings) as embedding_size,
|
| 200 |
+
meta.modularity_score as global_modularity_score
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
results = self.neo4j_conn.execute_query(query)
|
| 204 |
+
hyde_embeddings = {}
|
| 205 |
+
|
| 206 |
+
for record in results:
|
| 207 |
+
community_id = record['community_id']
|
| 208 |
+
embeddings_data = record.get('hyde_embeddings')
|
| 209 |
+
|
| 210 |
+
if embeddings_data and community_id:
|
| 211 |
+
hyde_embeddings[community_id] = {
|
| 212 |
+
'embeddings': embeddings_data,
|
| 213 |
+
'summary': record.get('summary', ''),
|
| 214 |
+
'key_entities': record.get('key_entities', []),
|
| 215 |
+
'member_ids': record.get('member_ids', []),
|
| 216 |
+
'embedding_size': record.get('embedding_size', 0),
|
| 217 |
+
'global_modularity_score': record.get('global_modularity_score', 0.0),
|
| 218 |
+
'embedding_type': 'hyde'
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
self.logger.info(f"Retrieved enhanced HyDE embeddings for {len(hyde_embeddings)} communities")
|
| 222 |
+
return hyde_embeddings
|
| 223 |
+
|
| 224 |
+
except Exception as e:
|
| 225 |
+
self.logger.error(f"Failed to retrieve enhanced HyDE embeddings: {e}")
|
| 226 |
+
# Retry logic for embeddings
|
| 227 |
+
self.logger.info("Attempting retry for HyDE embeddings...")
|
| 228 |
+
try:
|
| 229 |
+
import time
|
| 230 |
+
time.sleep(2) # Brief delay before retry
|
| 231 |
+
results = self.neo4j_conn.execute_query(query)
|
| 232 |
+
hyde_embeddings = {}
|
| 233 |
+
|
| 234 |
+
for record in results:
|
| 235 |
+
community_id = record['community_id']
|
| 236 |
+
embeddings_data = record.get('hyde_embeddings')
|
| 237 |
+
|
| 238 |
+
if embeddings_data and community_id:
|
| 239 |
+
hyde_embeddings[community_id] = {
|
| 240 |
+
'embeddings': embeddings_data,
|
| 241 |
+
'summary': record.get('summary', ''),
|
| 242 |
+
'key_entities': record.get('key_entities', []),
|
| 243 |
+
'member_ids': record.get('member_ids', []),
|
| 244 |
+
'embedding_size': record.get('embedding_size', 0),
|
| 245 |
+
'global_modularity': record.get('global_modularity_score', 0.0)
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
self.logger.info(f"Retry successful: Retrieved enhanced HyDE embeddings for {len(hyde_embeddings)} communities")
|
| 249 |
+
return hyde_embeddings
|
| 250 |
+
|
| 251 |
+
except Exception as retry_error:
|
| 252 |
+
self.logger.error(f"Retry also failed: {retry_error}")
|
| 253 |
+
return {}
|
| 254 |
+
|
| 255 |
+
def _compute_hyde_similarities_enhanced(self,
|
| 256 |
+
query_embedding: List[float],
|
| 257 |
+
hyde_embeddings: Dict[str, Dict[str, Any]]) -> Dict[str, Dict[str, float]]:
|
| 258 |
+
"""
|
| 259 |
+
Enhanced similarity computation with global modularity weighting.
|
| 260 |
+
|
| 261 |
+
Calculates similarity scores between query embedding and community embeddings.
|
| 262 |
+
"""
|
| 263 |
+
similarities = {}
|
| 264 |
+
query_vec = np.array(query_embedding)
|
| 265 |
+
query_norm = np.linalg.norm(query_vec)
|
| 266 |
+
|
| 267 |
+
if query_norm == 0:
|
| 268 |
+
self.logger.warning("Query embedding has zero norm")
|
| 269 |
+
return similarities
|
| 270 |
+
|
| 271 |
+
for community_id, embedding_data in hyde_embeddings.items():
|
| 272 |
+
embeddings_list = embedding_data['embeddings']
|
| 273 |
+
global_modularity = embedding_data.get('global_modularity_score', 0.0)
|
| 274 |
+
|
| 275 |
+
max_similarity = 0.0
|
| 276 |
+
|
| 277 |
+
# Compute similarity
|
| 278 |
+
try:
|
| 279 |
+
# Parse embedding string
|
| 280 |
+
if isinstance(embeddings_list, str):
|
| 281 |
+
embeddings_list = json.loads(embeddings_list)
|
| 282 |
+
|
| 283 |
+
# Process embeddings
|
| 284 |
+
if isinstance(embeddings_list, list) and len(embeddings_list) > 0:
|
| 285 |
+
# Use first embedding
|
| 286 |
+
hyde_vec = np.array(embeddings_list[0] if isinstance(embeddings_list[0], list) else embeddings_list)
|
| 287 |
+
else:
|
| 288 |
+
hyde_vec = np.array(embeddings_list)
|
| 289 |
+
|
| 290 |
+
hyde_norm = np.linalg.norm(hyde_vec)
|
| 291 |
+
|
| 292 |
+
if hyde_norm > 0:
|
| 293 |
+
# Calculate similarity
|
| 294 |
+
base_similarity = np.dot(query_vec, hyde_vec) / (query_norm * hyde_norm)
|
| 295 |
+
|
| 296 |
+
# Apply weighting
|
| 297 |
+
weighted_similarity = base_similarity * (1 + 0.2 * global_modularity)
|
| 298 |
+
max_similarity = weighted_similarity
|
| 299 |
+
|
| 300 |
+
except Exception as e:
|
| 301 |
+
self.logger.warning(f"Error computing similarity for community {community_id}: {e}")
|
| 302 |
+
continue
|
| 303 |
+
|
| 304 |
+
similarities[community_id] = {
|
| 305 |
+
'similarity': max_similarity,
|
| 306 |
+
'global_modularity_score': global_modularity,
|
| 307 |
+
'embedding_size': embedding_data.get('embedding_size', 0)
|
| 308 |
+
}
|
| 309 |
+
|
| 310 |
+
self.logger.info(f"Computed enhanced similarities for {len(similarities)} communities")
|
| 311 |
+
return similarities
|
| 312 |
+
|
| 313 |
+
def _rank_communities_enhanced(self,
|
| 314 |
+
similarities: Dict[str, Dict[str, float]],
|
| 315 |
+
routing_result: DriftRoutingResult) -> List[Tuple[str, Dict[str, float]]]:
|
| 316 |
+
"""
|
| 317 |
+
Enhanced ranking using global modularity and similarity.
|
| 318 |
+
|
| 319 |
+
Ranks communities based on a weighted combination of similarity score and modularity.
|
| 320 |
+
"""
|
| 321 |
+
|
| 322 |
+
# Rank primarily by similarity, with modularity as secondary factor
|
| 323 |
+
|
| 324 |
+
def ranking_score(item):
|
| 325 |
+
_, scores = item
|
| 326 |
+
similarity = scores['similarity']
|
| 327 |
+
global_modularity = scores['global_modularity_score']
|
| 328 |
+
|
| 329 |
+
# Weighted combination (similarity is primary)
|
| 330 |
+
return 0.8 * similarity + 0.2 * global_modularity
|
| 331 |
+
|
| 332 |
+
# Sort by combined ranking score
|
| 333 |
+
ranked = sorted(similarities.items(), key=ranking_score, reverse=True)
|
| 334 |
+
|
| 335 |
+
# Apply similarity threshold
|
| 336 |
+
similarity_threshold = routing_result.parameters.get('similarity_threshold', 0.7)
|
| 337 |
+
filtered_ranked = [
|
| 338 |
+
(cid, scores) for cid, scores in ranked
|
| 339 |
+
if scores['similarity'] >= similarity_threshold
|
| 340 |
+
]
|
| 341 |
+
|
| 342 |
+
self.logger.info(f"Enhanced ranking: {len(filtered_ranked)} communities above threshold {similarity_threshold}")
|
| 343 |
+
return filtered_ranked
|
| 344 |
+
|
| 345 |
+
def _apply_termination_criteria(self,
|
| 346 |
+
ranked_communities: List[Tuple[str, Dict[str, float]]],
|
| 347 |
+
routing_result: DriftRoutingResult) -> List[Tuple[str, Dict[str, float]]]:
|
| 348 |
+
"""
|
| 349 |
+
Apply termination criteria for community selection.
|
| 350 |
+
|
| 351 |
+
Limits the number of communities selected based on threshold parameters.
|
| 352 |
+
"""
|
| 353 |
+
|
| 354 |
+
# Get termination criteria from routing or defaults
|
| 355 |
+
max_communities = routing_result.parameters.get('max_communities', 3)
|
| 356 |
+
min_global_modularity = routing_result.parameters.get('min_global_modularity', 0.3)
|
| 357 |
+
|
| 358 |
+
# Apply criteria
|
| 359 |
+
filtered = []
|
| 360 |
+
for community_id, scores in ranked_communities:
|
| 361 |
+
if len(filtered) >= max_communities:
|
| 362 |
+
break
|
| 363 |
+
|
| 364 |
+
# Check global modularity threshold
|
| 365 |
+
if scores['global_modularity_score'] >= min_global_modularity:
|
| 366 |
+
filtered.append((community_id, scores))
|
| 367 |
+
|
| 368 |
+
self.logger.info(f"Applied termination criteria: {len(filtered)} communities selected")
|
| 369 |
+
return filtered
|
| 370 |
+
|
| 371 |
+
async def _fetch_community_details_enhanced(self,
|
| 372 |
+
ranked_communities: List[Tuple[str, Dict[str, float]]]) -> List[CommunityResult]:
|
| 373 |
+
"""
|
| 374 |
+
Fetch comprehensive community details with all properties.
|
| 375 |
+
|
| 376 |
+
Retrieves detailed information about selected communities including summaries,
|
| 377 |
+
key entities, and member IDs.
|
| 378 |
+
"""
|
| 379 |
+
community_results = []
|
| 380 |
+
|
| 381 |
+
for community_id, scores in ranked_communities:
|
| 382 |
+
try:
|
| 383 |
+
# Query the Community node directly by ID (since embedding communities have id=community_id)
|
| 384 |
+
detail_query = """
|
| 385 |
+
MATCH (c:Community)
|
| 386 |
+
WHERE c.id = $community_id AND c.hyde_embeddings IS NOT NULL
|
| 387 |
+
OPTIONAL MATCH (meta:CommunitiesMetadata)
|
| 388 |
+
RETURN c.summary as summary,
|
| 389 |
+
c.key_entities as key_entities,
|
| 390 |
+
c.member_ids as member_ids,
|
| 391 |
+
c.internal_edges as internal_edges,
|
| 392 |
+
c.density as density,
|
| 393 |
+
c.avg_degree as avg_degree,
|
| 394 |
+
c.level as level,
|
| 395 |
+
meta.modularity_score as modularity_score,
|
| 396 |
+
CASE WHEN c.member_ids IS NOT NULL THEN size(c.member_ids) ELSE 0 END as member_count,
|
| 397 |
+
c.id as id
|
| 398 |
+
LIMIT 1
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
results = self.neo4j_conn.execute_query(
|
| 402 |
+
detail_query,
|
| 403 |
+
{'community_id': community_id}
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
if results:
|
| 407 |
+
record = results[0]
|
| 408 |
+
|
| 409 |
+
# Create enhanced community result with actual available data from Neo4j
|
| 410 |
+
community_result = CommunityResult(
|
| 411 |
+
community_id=community_id,
|
| 412 |
+
similarity_score=scores['similarity'],
|
| 413 |
+
summary=record.get('summary', ''),
|
| 414 |
+
key_entities=record.get('key_entities', []),
|
| 415 |
+
member_ids=record.get('member_ids', []),
|
| 416 |
+
modularity_score=record.get('modularity_score', 0.0),
|
| 417 |
+
level=record.get('level', 1),
|
| 418 |
+
internal_edges=record.get('internal_edges', 0),
|
| 419 |
+
member_count=record.get('member_count', 0),
|
| 420 |
+
confidence_score=scores.get('confidence_score', 0.5),
|
| 421 |
+
search_index='',
|
| 422 |
+
termination_criteria={},
|
| 423 |
+
centrality_stats={
|
| 424 |
+
'avg_degree': record.get('avg_degree', 0.0),
|
| 425 |
+
'density': record.get('density', 0.0)
|
| 426 |
+
}
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
community_results.append(community_result)
|
| 430 |
+
|
| 431 |
+
except Exception as e:
|
| 432 |
+
self.logger.error(f"Failed to fetch details for community {community_id}: {e}")
|
| 433 |
+
continue
|
| 434 |
+
|
| 435 |
+
self.logger.info(f"Fetched enhanced details for {len(community_results)} communities")
|
| 436 |
+
return community_results
|
| 437 |
+
|
| 438 |
+
async def _extract_community_data_enhanced(self,
|
| 439 |
+
communities: List[CommunityResult]) -> Dict[str, Any]:
|
| 440 |
+
"""
|
| 441 |
+
Step 7: Enhanced data extraction with centrality measures.
|
| 442 |
+
|
| 443 |
+
Extracts:
|
| 444 |
+
- Entities with degree/betweenness/closeness centrality
|
| 445 |
+
- Relationships with confidence scores
|
| 446 |
+
- Community statistics and properties
|
| 447 |
+
"""
|
| 448 |
+
try:
|
| 449 |
+
all_entities = []
|
| 450 |
+
all_relationships = []
|
| 451 |
+
community_stats = []
|
| 452 |
+
|
| 453 |
+
for community in communities:
|
| 454 |
+
# Extract entities with centrality measures
|
| 455 |
+
entities = await self._extract_entities_with_centrality(community)
|
| 456 |
+
all_entities.extend(entities)
|
| 457 |
+
|
| 458 |
+
# Extract relationships with properties
|
| 459 |
+
relationships = await self._extract_relationships_enhanced(community)
|
| 460 |
+
all_relationships.extend(relationships)
|
| 461 |
+
|
| 462 |
+
# Collect community statistics
|
| 463 |
+
community_stats.append({
|
| 464 |
+
'community_id': community.community_id,
|
| 465 |
+
'member_count': community.member_count,
|
| 466 |
+
'modularity_score': community.modularity_score,
|
| 467 |
+
'confidence_score': community.confidence_score,
|
| 468 |
+
'centrality_stats': community.centrality_stats
|
| 469 |
+
})
|
| 470 |
+
|
| 471 |
+
extracted_data = {
|
| 472 |
+
'entities': all_entities,
|
| 473 |
+
'relationships': all_relationships,
|
| 474 |
+
'community_stats': community_stats,
|
| 475 |
+
'extraction_metadata': {
|
| 476 |
+
'communities_processed': len(communities),
|
| 477 |
+
'entities_extracted': len(all_entities),
|
| 478 |
+
'relationships_extracted': len(all_relationships),
|
| 479 |
+
'timestamp': datetime.now().isoformat()
|
| 480 |
+
}
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
self.logger.info(f"Enhanced extraction completed: {len(all_entities)} entities, {len(all_relationships)} relationships")
|
| 484 |
+
return extracted_data
|
| 485 |
+
|
| 486 |
+
except Exception as e:
|
| 487 |
+
self.logger.error(f"Enhanced data extraction failed: {e}")
|
| 488 |
+
return {'entities': [], 'relationships': [], 'community_stats': []}
|
| 489 |
+
|
| 490 |
+
async def _extract_entities_with_centrality(self,
|
| 491 |
+
community: CommunityResult) -> List[EntityResult]:
|
| 492 |
+
"""
|
| 493 |
+
Extract entities with comprehensive centrality measures.
|
| 494 |
+
|
| 495 |
+
Retrieves entities from the community with their associated centrality metrics.
|
| 496 |
+
"""
|
| 497 |
+
try:
|
| 498 |
+
# Use member_ids for direct access if available
|
| 499 |
+
member_ids = community.member_ids if community.member_ids else []
|
| 500 |
+
|
| 501 |
+
if member_ids:
|
| 502 |
+
# Direct member access query based on actual schema
|
| 503 |
+
entity_query = """
|
| 504 |
+
MATCH (n)
|
| 505 |
+
WHERE n.id IN $member_ids
|
| 506 |
+
AND n.name IS NOT NULL
|
| 507 |
+
AND n.content IS NOT NULL
|
| 508 |
+
RETURN n.id as entity_id,
|
| 509 |
+
n.name as name,
|
| 510 |
+
n.content as content,
|
| 511 |
+
n.confidence as confidence,
|
| 512 |
+
n.degree_centrality as degree_centrality,
|
| 513 |
+
n.betweenness_centrality as betweenness_centrality,
|
| 514 |
+
n.closeness_centrality as closeness_centrality,
|
| 515 |
+
labels(n) as node_types
|
| 516 |
+
ORDER BY n.degree_centrality DESC
|
| 517 |
+
"""
|
| 518 |
+
|
| 519 |
+
results = self.neo4j_conn.execute_query(
|
| 520 |
+
entity_query,
|
| 521 |
+
{'member_ids': member_ids}
|
| 522 |
+
)
|
| 523 |
+
else:
|
| 524 |
+
# Fallback: find entities using community_id pattern matching
|
| 525 |
+
entity_query = """
|
| 526 |
+
MATCH (n)
|
| 527 |
+
WHERE n.community_id IS NOT NULL
|
| 528 |
+
AND n.name IS NOT NULL
|
| 529 |
+
AND n.content IS NOT NULL
|
| 530 |
+
RETURN n.id as entity_id,
|
| 531 |
+
n.name as name,
|
| 532 |
+
n.content as content,
|
| 533 |
+
n.confidence as confidence,
|
| 534 |
+
n.degree_centrality as degree_centrality,
|
| 535 |
+
n.betweenness_centrality as betweenness_centrality,
|
| 536 |
+
n.closeness_centrality as closeness_centrality,
|
| 537 |
+
labels(n) as node_types
|
| 538 |
+
ORDER BY n.degree_centrality DESC
|
| 539 |
+
LIMIT 20
|
| 540 |
+
"""
|
| 541 |
+
|
| 542 |
+
results = self.neo4j_conn.execute_query(entity_query)
|
| 543 |
+
|
| 544 |
+
entities = []
|
| 545 |
+
for record in results:
|
| 546 |
+
entity = EntityResult(
|
| 547 |
+
entity_id=record['entity_id'],
|
| 548 |
+
name=record.get('name', ''),
|
| 549 |
+
content=record.get('content', ''),
|
| 550 |
+
confidence=record.get('confidence', 0.0),
|
| 551 |
+
degree_centrality=record.get('degree_centrality', 0.0),
|
| 552 |
+
betweenness_centrality=record.get('betweenness_centrality', 0.0),
|
| 553 |
+
closeness_centrality=record.get('closeness_centrality', 0.0),
|
| 554 |
+
community_id=community.community_id,
|
| 555 |
+
node_type=record.get('node_types', ['Unknown'])[0] if record.get('node_types') else 'Unknown'
|
| 556 |
+
)
|
| 557 |
+
entities.append(entity)
|
| 558 |
+
|
| 559 |
+
return entities
|
| 560 |
+
|
| 561 |
+
except Exception as e:
|
| 562 |
+
self.logger.error(f"Failed to extract entities for community {community.community_id}: {e}")
|
| 563 |
+
return []
|
| 564 |
+
|
| 565 |
+
async def _extract_relationships_enhanced(self,
|
| 566 |
+
community: CommunityResult) -> List[RelationshipResult]:
|
| 567 |
+
"""
|
| 568 |
+
Extract relationships with enhanced properties.
|
| 569 |
+
|
| 570 |
+
Retrieves relationship data between entities within the specified community.
|
| 571 |
+
"""
|
| 572 |
+
try:
|
| 573 |
+
relationship_query = """
|
| 574 |
+
MATCH (a)-[r]->(b)
|
| 575 |
+
WHERE a.community_id = $community_id
|
| 576 |
+
AND b.community_id = $community_id
|
| 577 |
+
AND r.confidence > 0.5
|
| 578 |
+
RETURN startNode(r).id as start_node,
|
| 579 |
+
endNode(r).id as end_node,
|
| 580 |
+
type(r) as relationship_type,
|
| 581 |
+
r.confidence as confidence
|
| 582 |
+
ORDER BY r.confidence DESC
|
| 583 |
+
LIMIT 50
|
| 584 |
+
"""
|
| 585 |
+
|
| 586 |
+
results = self.neo4j_conn.execute_query(
|
| 587 |
+
relationship_query,
|
| 588 |
+
{'community_id': community.community_id}
|
| 589 |
+
)
|
| 590 |
+
|
| 591 |
+
relationships = []
|
| 592 |
+
for record in results:
|
| 593 |
+
relationship = RelationshipResult(
|
| 594 |
+
start_node=record['start_node'],
|
| 595 |
+
end_node=record['end_node'],
|
| 596 |
+
relationship_type=record['relationship_type'],
|
| 597 |
+
confidence=record.get('confidence', 0.0)
|
| 598 |
+
)
|
| 599 |
+
relationships.append(relationship)
|
| 600 |
+
|
| 601 |
+
return relationships
|
| 602 |
+
|
| 603 |
+
except Exception as e:
|
| 604 |
+
self.logger.error(f"Failed to extract relationships for community {community.community_id}: {e}")
|
| 605 |
+
return []
|
| 606 |
+
|
| 607 |
+
async def _generate_initial_answer_enhanced(self,
|
| 608 |
+
extracted_data: Dict[str, Any],
|
| 609 |
+
routing_result: DriftRoutingResult) -> Dict[str, Any]:
|
| 610 |
+
"""
|
| 611 |
+
Step 8: Context-aware initial answer generation.
|
| 612 |
+
|
| 613 |
+
Uses:
|
| 614 |
+
- Entity importance from centrality measures
|
| 615 |
+
- Relationship confidence for evidence strength
|
| 616 |
+
- Community statistics for context sizing
|
| 617 |
+
"""
|
| 618 |
+
try:
|
| 619 |
+
entities = extracted_data['entities']
|
| 620 |
+
relationships = extracted_data['relationships']
|
| 621 |
+
community_stats = extracted_data['community_stats']
|
| 622 |
+
|
| 623 |
+
# Rank entities by importance (centrality measures)
|
| 624 |
+
important_entities = sorted(
|
| 625 |
+
entities,
|
| 626 |
+
key=lambda e: (e.degree_centrality + e.betweenness_centrality) / 2,
|
| 627 |
+
reverse=True
|
| 628 |
+
)[:10]
|
| 629 |
+
|
| 630 |
+
# Select high-confidence relationships
|
| 631 |
+
strong_relationships = [
|
| 632 |
+
r for r in relationships
|
| 633 |
+
if r.confidence >= 0.7
|
| 634 |
+
]
|
| 635 |
+
|
| 636 |
+
# Prepare context for LLM
|
| 637 |
+
llm_context = self._prepare_llm_context_enhanced(
|
| 638 |
+
important_entities, strong_relationships, community_stats, routing_result
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# Generate initial answer using configured LLM
|
| 642 |
+
llm_response = await self._generate_llm_answer(llm_context, routing_result)
|
| 643 |
+
|
| 644 |
+
initial_answer = {
|
| 645 |
+
'content': llm_response['answer'],
|
| 646 |
+
'llm_context': llm_context,
|
| 647 |
+
'context_used': {
|
| 648 |
+
'important_entities': len(important_entities),
|
| 649 |
+
'strong_relationships': len(strong_relationships),
|
| 650 |
+
'communities_analyzed': len(community_stats)
|
| 651 |
+
},
|
| 652 |
+
'confidence_metrics': {
|
| 653 |
+
'avg_entity_centrality': np.mean([e.degree_centrality for e in important_entities]) if important_entities else 0,
|
| 654 |
+
'avg_relationship_confidence': np.mean([r.confidence for r in strong_relationships]) if strong_relationships else 0,
|
| 655 |
+
'avg_community_modularity': np.mean([c['modularity_score'] for c in community_stats]) if community_stats else 0,
|
| 656 |
+
'llm_confidence': llm_response['confidence']
|
| 657 |
+
},
|
| 658 |
+
'follow_up_questions': llm_response['follow_up_questions'],
|
| 659 |
+
'reasoning': llm_response['reasoning']
|
| 660 |
+
}
|
| 661 |
+
|
| 662 |
+
self.logger.info("Enhanced initial answer generated with comprehensive context")
|
| 663 |
+
return initial_answer
|
| 664 |
+
|
| 665 |
+
except Exception as e:
|
| 666 |
+
self.logger.error(f"Enhanced answer generation failed: {e}")
|
| 667 |
+
return {'content': 'Error generating initial answer', 'error': str(e)}
|
| 668 |
+
|
| 669 |
+
def _prepare_llm_context_enhanced(self,
|
| 670 |
+
entities: List[EntityResult],
|
| 671 |
+
relationships: List[RelationshipResult],
|
| 672 |
+
community_stats: List[Dict[str, Any]],
|
| 673 |
+
routing_result: DriftRoutingResult) -> str:
|
| 674 |
+
"""Prepare enhanced context for LLM with comprehensive information."""
|
| 675 |
+
|
| 676 |
+
context_parts = [
|
| 677 |
+
f"Query: {routing_result.original_query}",
|
| 678 |
+
f"Search Strategy: {routing_result.search_strategy.value}",
|
| 679 |
+
"",
|
| 680 |
+
"=== IMPORTANT ENTITIES (Use these specific names in your answer) ===",
|
| 681 |
+
]
|
| 682 |
+
|
| 683 |
+
for i, entity in enumerate(entities[:10], 1): # Show more entities
|
| 684 |
+
context_parts.append(
|
| 685 |
+
f"{i}. NAME: '{entity.name}' | Description: {entity.content[:100]}... "
|
| 686 |
+
f"| Centrality: {entity.degree_centrality:.3f} | Confidence: {entity.confidence:.3f}"
|
| 687 |
+
)
|
| 688 |
+
|
| 689 |
+
context_parts.extend([
|
| 690 |
+
"",
|
| 691 |
+
"=== KEY RELATIONSHIPS (Use these connections in your answer) ===",
|
| 692 |
+
])
|
| 693 |
+
|
| 694 |
+
for i, rel in enumerate(relationships[:8], 1): # Show more relationships
|
| 695 |
+
context_parts.append(
|
| 696 |
+
f"{i}. '{rel.start_node}' --[{rel.relationship_type}]--> '{rel.end_node}' "
|
| 697 |
+
f"| Confidence: {rel.confidence:.3f}"
|
| 698 |
+
)
|
| 699 |
+
|
| 700 |
+
# Add quick reference list of all entity names
|
| 701 |
+
entity_names = [entity.name for entity in entities[:15]]
|
| 702 |
+
context_parts.extend([
|
| 703 |
+
"",
|
| 704 |
+
"=== ENTITY NAMES FOR REFERENCE ===",
|
| 705 |
+
f"Available entities: {', '.join(entity_names)}",
|
| 706 |
+
"",
|
| 707 |
+
"=== COMMUNITY STATISTICS ===",
|
| 708 |
+
])
|
| 709 |
+
|
| 710 |
+
for stat in community_stats:
|
| 711 |
+
context_parts.append(
|
| 712 |
+
f"Community {stat['community_id']}: {stat['member_count']} members, "
|
| 713 |
+
f"modularity: {stat['modularity_score']:.3f}"
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
context_parts.extend([
|
| 717 |
+
"",
|
| 718 |
+
"REMEMBER: Use the specific entity names listed above in your answer!"
|
| 719 |
+
])
|
| 720 |
+
|
| 721 |
+
return "\n".join(context_parts)
|
| 722 |
+
|
| 723 |
+
async def _generate_llm_answer(self,
|
| 724 |
+
context: str,
|
| 725 |
+
routing_result: DriftRoutingResult) -> Dict[str, Any]:
|
| 726 |
+
"""
|
| 727 |
+
Generate actual LLM response using the configured LLM.
|
| 728 |
+
|
| 729 |
+
Uses the LLM from GraphRAGSetup to generate answers with follow-up questions.
|
| 730 |
+
"""
|
| 731 |
+
try:
|
| 732 |
+
# Construct comprehensive prompt for LLM
|
| 733 |
+
prompt = f"""
|
| 734 |
+
You are an expert knowledge analyst. Answer the user's query using SPECIFIC NAMES and information from the graph data provided below.
|
| 735 |
+
|
| 736 |
+
IMPORTANT: Use the actual entity names, organization names, and relationship details from the graph data. Do not give generic answers.
|
| 737 |
+
|
| 738 |
+
GRAPH DATA CONTEXT:
|
| 739 |
+
{context}
|
| 740 |
+
|
| 741 |
+
INSTRUCTIONS:
|
| 742 |
+
1. Answer using SPECIFIC ENTITY NAMES from the "IMPORTANT ENTITIES" section above
|
| 743 |
+
2. Reference actual relationships and organizations mentioned in the graph data
|
| 744 |
+
3. If the query asks for members/organizations, LIST THE ACTUAL NAMES from the entities
|
| 745 |
+
4. Use confidence scores and centrality measures as evidence strength indicators
|
| 746 |
+
5. Generate follow-up questions based on the specific entities found
|
| 747 |
+
|
| 748 |
+
RESPONSE FORMAT:
|
| 749 |
+
Answer: [Use specific names and details from the graph data above]
|
| 750 |
+
Confidence: [0.0-1.0]
|
| 751 |
+
Reasoning: [Why these specific entities answer the query]
|
| 752 |
+
Follow-up Questions:
|
| 753 |
+
1. [Specific question about entities found]
|
| 754 |
+
2. [Question about relationships discovered]
|
| 755 |
+
3. [Question about community connections]
|
| 756 |
+
4. [Question for deeper exploration]
|
| 757 |
+
5. [Question about related entities]
|
| 758 |
+
"""
|
| 759 |
+
|
| 760 |
+
# Call the configured LLM
|
| 761 |
+
llm_response = await self.setup.llm.acomplete(prompt)
|
| 762 |
+
response_text = llm_response.text
|
| 763 |
+
|
| 764 |
+
# Parse LLM response
|
| 765 |
+
parsed_response = self._parse_llm_response(response_text)
|
| 766 |
+
|
| 767 |
+
self.logger.info(f"LLM generated answer with confidence: {parsed_response['confidence']}")
|
| 768 |
+
return parsed_response
|
| 769 |
+
|
| 770 |
+
except Exception as e:
|
| 771 |
+
self.logger.error(f"LLM answer generation failed: {e}")
|
| 772 |
+
# Fallback response
|
| 773 |
+
return {
|
| 774 |
+
'answer': f"Based on the graph analysis, I found relevant information but encountered an issue generating the full response: {str(e)}",
|
| 775 |
+
'confidence': 0.3,
|
| 776 |
+
'reasoning': "LLM generation encountered an error, providing basic analysis from graph data.",
|
| 777 |
+
'follow_up_questions': [
|
| 778 |
+
"What specific aspects would you like me to explore further?",
|
| 779 |
+
"Are there particular entities or relationships of interest?",
|
| 780 |
+
"Should I focus on a specific community or time period?"
|
| 781 |
+
]
|
| 782 |
+
}
|
| 783 |
+
|
| 784 |
+
def _parse_llm_response(self, response_text: str) -> Dict[str, Any]:
|
| 785 |
+
"""Parse structured LLM response into components."""
|
| 786 |
+
try:
|
| 787 |
+
lines = response_text.strip().split('\n')
|
| 788 |
+
|
| 789 |
+
answer = ""
|
| 790 |
+
confidence = 0.5
|
| 791 |
+
reasoning = ""
|
| 792 |
+
follow_up_questions = []
|
| 793 |
+
|
| 794 |
+
current_section = None
|
| 795 |
+
|
| 796 |
+
for line in lines:
|
| 797 |
+
line = line.strip()
|
| 798 |
+
|
| 799 |
+
if line.startswith("Answer:"):
|
| 800 |
+
current_section = "answer"
|
| 801 |
+
answer = line.replace("Answer:", "").strip()
|
| 802 |
+
elif line.startswith("Confidence:"):
|
| 803 |
+
confidence_text = line.replace("Confidence:", "").strip()
|
| 804 |
+
try:
|
| 805 |
+
confidence = float(confidence_text)
|
| 806 |
+
except (ValueError, TypeError):
|
| 807 |
+
confidence = 0.5
|
| 808 |
+
elif line.startswith("Reasoning:"):
|
| 809 |
+
current_section = "reasoning"
|
| 810 |
+
reasoning = line.replace("Reasoning:", "").strip()
|
| 811 |
+
elif line.startswith("Follow-up Questions:"):
|
| 812 |
+
current_section = "questions"
|
| 813 |
+
elif current_section == "answer" and line:
|
| 814 |
+
answer += " " + line
|
| 815 |
+
elif current_section == "reasoning" and line:
|
| 816 |
+
reasoning += " " + line
|
| 817 |
+
elif current_section == "questions" and line.startswith(("1.", "2.", "3.", "4.", "5.")):
|
| 818 |
+
question = line[2:].strip() # Remove "1. " etc.
|
| 819 |
+
follow_up_questions.append(question)
|
| 820 |
+
|
| 821 |
+
return {
|
| 822 |
+
'answer': answer.strip() if answer else "Unable to generate answer from available context.",
|
| 823 |
+
'confidence': max(0.0, min(1.0, confidence)), # Clamp between 0-1
|
| 824 |
+
'reasoning': reasoning.strip() if reasoning else "Analysis based on graph structure and entity relationships.",
|
| 825 |
+
'follow_up_questions': follow_up_questions if follow_up_questions else [
|
| 826 |
+
"What additional information would be helpful?",
|
| 827 |
+
"Are there specific aspects to explore further?",
|
| 828 |
+
"Should I analyze different communities or relationships?"
|
| 829 |
+
]
|
| 830 |
+
}
|
| 831 |
+
|
| 832 |
+
except Exception as e:
|
| 833 |
+
self.logger.error(f"Failed to parse LLM response: {e}")
|
| 834 |
+
return {
|
| 835 |
+
'answer': response_text[:500] if response_text else "No response generated.",
|
| 836 |
+
'confidence': 0.4,
|
| 837 |
+
'reasoning': "Direct LLM output due to parsing issues.",
|
| 838 |
+
'follow_up_questions': ["What would you like to know more about?"]
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
# Exports
|
| 843 |
+
__all__ = ['CommunitySearchEngine', 'CommunityResult', 'EntityResult', 'RelationshipResult']
|
query_graph_functions/query_preprocessing.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Query preprocessing for analysis, routing, and vectorization - Phase B (Steps 3-5)."""
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Dict, List, Any, Tuple, Optional
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import re
|
| 8 |
+
|
| 9 |
+
# System imports
|
| 10 |
+
import sys
|
| 11 |
+
import os
|
| 12 |
+
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
| 13 |
+
|
| 14 |
+
from my_config import MY_CONFIG
|
| 15 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class QueryType(Enum):
|
| 19 |
+
"""Query type classifications for DRIFT routing."""
|
| 20 |
+
SPECIFIC_ENTITY = "specific_entity"
|
| 21 |
+
RELATIONSHIP_QUERY = "relationship_query"
|
| 22 |
+
BROAD_THEMATIC = "broad_thematic"
|
| 23 |
+
COMPARATIVE = "comparative"
|
| 24 |
+
COMPLEX_REASONING = "complex_reasoning"
|
| 25 |
+
FACTUAL_LOOKUP = "factual_lookup"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class SearchStrategy(Enum):
|
| 29 |
+
"""Search strategy determined by DRIFT routing."""
|
| 30 |
+
LOCAL_SEARCH = "local_search"
|
| 31 |
+
GLOBAL_SEARCH = "global_search"
|
| 32 |
+
HYBRID_SEARCH = "hybrid_search"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@dataclass
|
| 36 |
+
class QueryAnalysis:
|
| 37 |
+
"""Results of query analysis step."""
|
| 38 |
+
query_type: QueryType
|
| 39 |
+
complexity_score: float # 0.0 to 1.0
|
| 40 |
+
entities_mentioned: List[str]
|
| 41 |
+
key_concepts: List[str]
|
| 42 |
+
intent_description: str
|
| 43 |
+
context_requirements: Dict[str, Any]
|
| 44 |
+
estimated_scope: str # "narrow", "moderate", "broad"
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
@dataclass
|
| 49 |
+
class DriftRoutingResult:
|
| 50 |
+
"""Results of DRIFT routing decision."""
|
| 51 |
+
search_strategy: SearchStrategy
|
| 52 |
+
reasoning: str
|
| 53 |
+
confidence: float # 0.0 to 1.0
|
| 54 |
+
parameters: Dict[str, Any]
|
| 55 |
+
original_query: str # Added to fix answer generation
|
| 56 |
+
fallback_strategy: Optional[SearchStrategy] = None
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
@dataclass
|
| 60 |
+
class VectorizedQuery:
|
| 61 |
+
"""Results of query vectorization."""
|
| 62 |
+
embedding: List[float]
|
| 63 |
+
embedding_model: str
|
| 64 |
+
normalized_query: str
|
| 65 |
+
semantic_keywords: List[str]
|
| 66 |
+
similarity_threshold: float
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class QueryAnalyzer:
|
| 70 |
+
"""Handles Step 3: Query Analysis with intent detection and complexity assessment."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, config: Any):
|
| 73 |
+
self.config = config
|
| 74 |
+
self.logger = logging.getLogger('graphrag_query')
|
| 75 |
+
|
| 76 |
+
# Entity extraction patterns
|
| 77 |
+
self.entity_patterns = [
|
| 78 |
+
r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', # Proper nouns
|
| 79 |
+
r'\b(?:company|organization|person|place|event)\s+(?:named|called)?\s*["\']?([^"\']+)["\']?',
|
| 80 |
+
r'\bwho\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
|
| 81 |
+
r'\bwhat\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
|
| 82 |
+
]
|
| 83 |
+
|
| 84 |
+
# Complexity indicators
|
| 85 |
+
self.complexity_indicators = {
|
| 86 |
+
'high': ['compare', 'analyze', 'evaluate', 'relationship', 'impact', 'why', 'how'],
|
| 87 |
+
'medium': ['describe', 'explain', 'summarize', 'list', 'identify'],
|
| 88 |
+
'low': ['who', 'what', 'when', 'where', 'is', 'are']
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
self.logger.info("QueryAnalyzer initialized for Step 3 processing")
|
| 92 |
+
|
| 93 |
+
async def analyze_query(self, query: str) -> QueryAnalysis:
|
| 94 |
+
"""Analyze query for intent, complexity, and entities."""
|
| 95 |
+
self.logger.info(f"Starting Step 3: Query Analysis for: {query[:100]}...")
|
| 96 |
+
|
| 97 |
+
try:
|
| 98 |
+
# Extract entities and concepts
|
| 99 |
+
entities = self._extract_entities(query)
|
| 100 |
+
concepts = self._extract_key_concepts(query)
|
| 101 |
+
query_type = self._classify_query_type(query, entities, concepts)
|
| 102 |
+
complexity = self._calculate_complexity(query, query_type)
|
| 103 |
+
intent = self._determine_intent(query, query_type)
|
| 104 |
+
scope = self._estimate_scope(query, entities, concepts, complexity)
|
| 105 |
+
|
| 106 |
+
# Build context
|
| 107 |
+
context_reqs = self._analyze_context_requirements(query, query_type, entities)
|
| 108 |
+
|
| 109 |
+
analysis = QueryAnalysis(
|
| 110 |
+
query_type=query_type,
|
| 111 |
+
complexity_score=complexity,
|
| 112 |
+
entities_mentioned=entities,
|
| 113 |
+
key_concepts=concepts,
|
| 114 |
+
intent_description=intent,
|
| 115 |
+
context_requirements=context_reqs,
|
| 116 |
+
estimated_scope=scope
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
self.logger.info(f"Step 3 completed: Query type={query_type.value}, "
|
| 120 |
+
f"complexity={complexity:.2f}, entities={len(entities)}, scope={scope}")
|
| 121 |
+
|
| 122 |
+
return analysis
|
| 123 |
+
|
| 124 |
+
except Exception as e:
|
| 125 |
+
self.logger.error(f"Step 3 Query Analysis failed: {e}")
|
| 126 |
+
raise
|
| 127 |
+
|
| 128 |
+
def _extract_entities(self, query: str) -> List[str]:
|
| 129 |
+
"""Extract named entities from query text."""
|
| 130 |
+
entities = set()
|
| 131 |
+
|
| 132 |
+
for pattern in self.entity_patterns:
|
| 133 |
+
matches = re.findall(pattern, query, re.IGNORECASE)
|
| 134 |
+
entities.update(matches)
|
| 135 |
+
|
| 136 |
+
# Filter entities
|
| 137 |
+
filtered_entities = [
|
| 138 |
+
entity.strip() for entity in entities
|
| 139 |
+
if len(entity.strip()) > 2 and entity.lower() not in
|
| 140 |
+
{'the', 'and', 'are', 'is', 'was', 'were', 'this', 'that', 'what', 'who', 'how'}
|
| 141 |
+
]
|
| 142 |
+
|
| 143 |
+
return list(set(filtered_entities))
|
| 144 |
+
|
| 145 |
+
def _extract_key_concepts(self, query: str) -> List[str]:
|
| 146 |
+
"""Extract key conceptual terms from query."""
|
| 147 |
+
# Extract concepts
|
| 148 |
+
concepts = []
|
| 149 |
+
|
| 150 |
+
# Find domain terms
|
| 151 |
+
domain_terms = [
|
| 152 |
+
'revenue', 'profit', 'growth', 'market', 'strategy', 'technology',
|
| 153 |
+
'product', 'service', 'customer', 'partnership', 'acquisition',
|
| 154 |
+
'investment', 'research', 'development', 'innovation', 'competition'
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
query_lower = query.lower()
|
| 158 |
+
for term in domain_terms:
|
| 159 |
+
if term in query_lower:
|
| 160 |
+
concepts.append(term)
|
| 161 |
+
|
| 162 |
+
return concepts
|
| 163 |
+
|
| 164 |
+
def _classify_query_type(self, query: str, entities: List[str], concepts: List[str]) -> QueryType:
|
| 165 |
+
"""Classify the type of query for routing decisions."""
|
| 166 |
+
query_lower = query.lower()
|
| 167 |
+
|
| 168 |
+
# Check patterns
|
| 169 |
+
if any(word in query_lower for word in ['compare', 'versus', 'vs', 'difference']):
|
| 170 |
+
return QueryType.COMPARATIVE
|
| 171 |
+
|
| 172 |
+
if any(word in query_lower for word in ['relationship', 'connect', 'related', 'between']):
|
| 173 |
+
return QueryType.RELATIONSHIP_QUERY
|
| 174 |
+
|
| 175 |
+
if len(entities) > 0 and any(word in query_lower for word in ['who is', 'what is', 'about']):
|
| 176 |
+
return QueryType.SPECIFIC_ENTITY
|
| 177 |
+
|
| 178 |
+
if any(word in query_lower for word in ['analyze', 'evaluate', 'why', 'how', 'impact']):
|
| 179 |
+
return QueryType.COMPLEX_REASONING
|
| 180 |
+
|
| 181 |
+
if len(concepts) > 2 or any(word in query_lower for word in ['overall', 'general', 'trend']):
|
| 182 |
+
return QueryType.BROAD_THEMATIC
|
| 183 |
+
|
| 184 |
+
return QueryType.FACTUAL_LOOKUP
|
| 185 |
+
|
| 186 |
+
def _calculate_complexity(self, query: str, query_type: QueryType) -> float:
|
| 187 |
+
"""Calculate query complexity score (0.0 to 1.0)."""
|
| 188 |
+
base_score = 0.3
|
| 189 |
+
query_lower = query.lower()
|
| 190 |
+
|
| 191 |
+
# Base complexity
|
| 192 |
+
type_scores = {
|
| 193 |
+
QueryType.FACTUAL_LOOKUP: 0.2,
|
| 194 |
+
QueryType.SPECIFIC_ENTITY: 0.3,
|
| 195 |
+
QueryType.RELATIONSHIP_QUERY: 0.6,
|
| 196 |
+
QueryType.BROAD_THEMATIC: 0.7,
|
| 197 |
+
QueryType.COMPARATIVE: 0.8,
|
| 198 |
+
QueryType.COMPLEX_REASONING: 0.9
|
| 199 |
+
}
|
| 200 |
+
|
| 201 |
+
base_score = type_scores.get(query_type, 0.5)
|
| 202 |
+
|
| 203 |
+
# Adjust complexity
|
| 204 |
+
for level, indicators in self.complexity_indicators.items():
|
| 205 |
+
count = sum(1 for indicator in indicators if indicator in query_lower)
|
| 206 |
+
if level == 'high':
|
| 207 |
+
base_score += count * 0.2
|
| 208 |
+
elif level == 'medium':
|
| 209 |
+
base_score += count * 0.1
|
| 210 |
+
else:
|
| 211 |
+
base_score -= count * 0.05
|
| 212 |
+
|
| 213 |
+
# Query length and structure
|
| 214 |
+
if len(query.split()) > 15:
|
| 215 |
+
base_score += 0.1
|
| 216 |
+
if '?' in query and len(query.split('?')) > 2:
|
| 217 |
+
base_score += 0.15
|
| 218 |
+
|
| 219 |
+
return min(1.0, max(0.0, base_score))
|
| 220 |
+
|
| 221 |
+
def _determine_intent(self, query: str, query_type: QueryType) -> str:
|
| 222 |
+
"""Determine the user's intent based on query analysis."""
|
| 223 |
+
intent_map = {
|
| 224 |
+
QueryType.FACTUAL_LOOKUP: "Seeking specific factual information",
|
| 225 |
+
QueryType.SPECIFIC_ENTITY: "Requesting details about a particular entity",
|
| 226 |
+
QueryType.RELATIONSHIP_QUERY: "Exploring connections and relationships",
|
| 227 |
+
QueryType.BROAD_THEMATIC: "Understanding broad themes or patterns",
|
| 228 |
+
QueryType.COMPARATIVE: "Comparing entities or concepts",
|
| 229 |
+
QueryType.COMPLEX_REASONING: "Requiring analytical reasoning and insights"
|
| 230 |
+
}
|
| 231 |
+
|
| 232 |
+
return intent_map.get(query_type, "General information seeking")
|
| 233 |
+
|
| 234 |
+
def _estimate_scope(self, query: str, entities: List[str], concepts: List[str], complexity: float) -> str:
|
| 235 |
+
"""Estimate the scope of information needed."""
|
| 236 |
+
if len(entities) == 1 and complexity < 0.4:
|
| 237 |
+
return "narrow"
|
| 238 |
+
elif len(entities) > 3 or len(concepts) > 3 or complexity > 0.7:
|
| 239 |
+
return "broad"
|
| 240 |
+
else:
|
| 241 |
+
return "moderate"
|
| 242 |
+
|
| 243 |
+
def _analyze_context_requirements(self, query: str, query_type: QueryType, entities: List[str]) -> Dict[str, Any]:
|
| 244 |
+
"""Analyze what context information is needed."""
|
| 245 |
+
return {
|
| 246 |
+
"requires_entity_details": len(entities) > 0,
|
| 247 |
+
"requires_relationships": query_type in [QueryType.RELATIONSHIP_QUERY, QueryType.COMPARATIVE],
|
| 248 |
+
"requires_historical_context": any(word in query.lower() for word in ['history', 'past', 'previous', 'before']),
|
| 249 |
+
"requires_quantitative_data": any(word in query.lower() for word in ['number', 'amount', 'count', 'revenue', 'profit']),
|
| 250 |
+
"primary_entities": entities[:3] # Focus on top 3 entities
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class DriftRouter:
|
| 255 |
+
"""Handles Step 4: DRIFT Routing for optimal search strategy selection."""
|
| 256 |
+
|
| 257 |
+
def __init__(self, config: Any, graph_stats: Dict[str, Any]):
|
| 258 |
+
self.config = config
|
| 259 |
+
self.graph_stats = graph_stats
|
| 260 |
+
self.logger = logging.getLogger('graphrag_query')
|
| 261 |
+
|
| 262 |
+
# Routing thresholds
|
| 263 |
+
self.local_search_threshold = 0.4
|
| 264 |
+
self.global_search_threshold = 0.7
|
| 265 |
+
self.entity_count_threshold = 10 # Based on graph size
|
| 266 |
+
|
| 267 |
+
self.logger.info("DriftRouter initialized for Step 4 processing")
|
| 268 |
+
|
| 269 |
+
async def determine_search_strategy(self, query_analysis: QueryAnalysis, original_query: str) -> DriftRoutingResult:
|
| 270 |
+
"""
|
| 271 |
+
Determine optimal search strategy using DRIFT methodology (Step 4).
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
query_analysis: Results from Step 3 query analysis
|
| 275 |
+
original_query: The original user query
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
DriftRoutingResult with search strategy and parameters
|
| 279 |
+
"""
|
| 280 |
+
self.logger.info(f"Starting Step 4: DRIFT Routing for {query_analysis.query_type.value}")
|
| 281 |
+
|
| 282 |
+
try:
|
| 283 |
+
# Apply routing logic
|
| 284 |
+
strategy, reasoning, confidence, params = self._apply_drift_logic(query_analysis)
|
| 285 |
+
|
| 286 |
+
# Fallback strategy
|
| 287 |
+
fallback = self._determine_fallback_strategy(strategy)
|
| 288 |
+
|
| 289 |
+
result = DriftRoutingResult(
|
| 290 |
+
search_strategy=strategy,
|
| 291 |
+
reasoning=reasoning,
|
| 292 |
+
confidence=confidence,
|
| 293 |
+
parameters=params,
|
| 294 |
+
original_query=original_query,
|
| 295 |
+
fallback_strategy=fallback
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
self.logger.info(f"Step 4 completed: Strategy={strategy.value}, "
|
| 299 |
+
f"confidence={confidence:.2f}, reasoning={reasoning[:50]}...")
|
| 300 |
+
|
| 301 |
+
return result
|
| 302 |
+
|
| 303 |
+
except Exception as e:
|
| 304 |
+
self.logger.error(f"Step 4 DRIFT Routing failed: {e}")
|
| 305 |
+
raise
|
| 306 |
+
|
| 307 |
+
def _apply_drift_logic(self, analysis: QueryAnalysis) -> Tuple[SearchStrategy, str, float, Dict[str, Any]]:
|
| 308 |
+
"""Apply DRIFT (Distributed Retrieval and Information Filtering Technique) logic."""
|
| 309 |
+
|
| 310 |
+
# Decision factors
|
| 311 |
+
complexity = analysis.complexity_score
|
| 312 |
+
entity_count = len(analysis.entities_mentioned)
|
| 313 |
+
scope = analysis.estimated_scope
|
| 314 |
+
query_type = analysis.query_type
|
| 315 |
+
|
| 316 |
+
# Local search conditions
|
| 317 |
+
if (query_type == QueryType.SPECIFIC_ENTITY and
|
| 318 |
+
entity_count <= 2 and
|
| 319 |
+
complexity < self.local_search_threshold):
|
| 320 |
+
|
| 321 |
+
return (
|
| 322 |
+
SearchStrategy.LOCAL_SEARCH,
|
| 323 |
+
f"Specific entity query with low complexity ({complexity:.2f})",
|
| 324 |
+
0.9,
|
| 325 |
+
{
|
| 326 |
+
"max_depth": 2,
|
| 327 |
+
"entity_focus": analysis.entities_mentioned,
|
| 328 |
+
"include_neighbors": True,
|
| 329 |
+
"max_results": 20
|
| 330 |
+
}
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# Global search conditions
|
| 334 |
+
if (complexity > self.global_search_threshold or
|
| 335 |
+
scope == "broad" or
|
| 336 |
+
query_type in [QueryType.BROAD_THEMATIC, QueryType.COMPLEX_REASONING]):
|
| 337 |
+
|
| 338 |
+
return (
|
| 339 |
+
SearchStrategy.GLOBAL_SEARCH,
|
| 340 |
+
f"High complexity ({complexity:.2f}) or broad scope requiring global context",
|
| 341 |
+
0.85,
|
| 342 |
+
{
|
| 343 |
+
"community_level": "high",
|
| 344 |
+
"max_communities": 10,
|
| 345 |
+
"include_summary": True,
|
| 346 |
+
"max_results": 50
|
| 347 |
+
}
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# Hybrid search for intermediate cases
|
| 351 |
+
if (query_type == QueryType.RELATIONSHIP_QUERY or
|
| 352 |
+
query_type == QueryType.COMPARATIVE or
|
| 353 |
+
entity_count > 2):
|
| 354 |
+
|
| 355 |
+
return (
|
| 356 |
+
SearchStrategy.HYBRID_SEARCH,
|
| 357 |
+
f"Relationship/comparative query or multiple entities ({entity_count})",
|
| 358 |
+
0.75,
|
| 359 |
+
{
|
| 360 |
+
"local_depth": 2,
|
| 361 |
+
"global_communities": 5,
|
| 362 |
+
"balance_weight": 0.6, # Favor local over global
|
| 363 |
+
"max_results": 35
|
| 364 |
+
}
|
| 365 |
+
)
|
| 366 |
+
|
| 367 |
+
# Default to local search with moderate confidence
|
| 368 |
+
return (
|
| 369 |
+
SearchStrategy.LOCAL_SEARCH,
|
| 370 |
+
"Default local search for moderate complexity query",
|
| 371 |
+
0.6,
|
| 372 |
+
{
|
| 373 |
+
"max_depth": 3,
|
| 374 |
+
"entity_focus": analysis.entities_mentioned,
|
| 375 |
+
"include_neighbors": True,
|
| 376 |
+
"max_results": 25
|
| 377 |
+
}
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
def _determine_fallback_strategy(self, primary_strategy: SearchStrategy) -> Optional[SearchStrategy]:
|
| 381 |
+
"""Determine fallback strategy if primary fails."""
|
| 382 |
+
fallback_map = {
|
| 383 |
+
SearchStrategy.LOCAL_SEARCH: SearchStrategy.GLOBAL_SEARCH,
|
| 384 |
+
SearchStrategy.GLOBAL_SEARCH: SearchStrategy.LOCAL_SEARCH,
|
| 385 |
+
SearchStrategy.HYBRID_SEARCH: SearchStrategy.LOCAL_SEARCH
|
| 386 |
+
}
|
| 387 |
+
|
| 388 |
+
return fallback_map.get(primary_strategy)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
class QueryVectorizer:
|
| 392 |
+
"""Handles Step 5: Query Vectorization for semantic similarity matching."""
|
| 393 |
+
|
| 394 |
+
def __init__(self, config: Any):
|
| 395 |
+
self.config = config
|
| 396 |
+
self.logger = logging.getLogger('graphrag_query')
|
| 397 |
+
|
| 398 |
+
# Initialize embedding model using same pattern as other files
|
| 399 |
+
self.embedding_model = HuggingFaceEmbedding(
|
| 400 |
+
model_name=MY_CONFIG.EMBEDDING_MODEL
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
self.model_name = MY_CONFIG.EMBEDDING_MODEL
|
| 404 |
+
self.embedding_dimension = MY_CONFIG.EMBEDDING_LENGTH
|
| 405 |
+
|
| 406 |
+
self.logger.info(f"QueryVectorizer initialized with {self.model_name}")
|
| 407 |
+
|
| 408 |
+
async def vectorize_query(self, query: str, query_analysis: QueryAnalysis) -> VectorizedQuery:
|
| 409 |
+
"""
|
| 410 |
+
Generate query embeddings for similarity matching (Step 5).
|
| 411 |
+
|
| 412 |
+
Args:
|
| 413 |
+
query: Original query text
|
| 414 |
+
query_analysis: Results from Step 3
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
VectorizedQuery with embeddings and metadata
|
| 418 |
+
"""
|
| 419 |
+
self.logger.info(f"Starting Step 5: Query Vectorization for: {query[:100]}...")
|
| 420 |
+
|
| 421 |
+
try:
|
| 422 |
+
# Normalize query
|
| 423 |
+
normalized_query = self._normalize_query(query, query_analysis)
|
| 424 |
+
|
| 425 |
+
# Generate embedding
|
| 426 |
+
embedding = await self._generate_embedding(normalized_query)
|
| 427 |
+
|
| 428 |
+
# Extract keywords
|
| 429 |
+
semantic_keywords = self._extract_semantic_keywords(query, query_analysis)
|
| 430 |
+
|
| 431 |
+
# Set similarity threshold
|
| 432 |
+
similarity_threshold = self._calculate_similarity_threshold(query_analysis)
|
| 433 |
+
|
| 434 |
+
result = VectorizedQuery(
|
| 435 |
+
embedding=embedding,
|
| 436 |
+
embedding_model=self.model_name,
|
| 437 |
+
normalized_query=normalized_query,
|
| 438 |
+
semantic_keywords=semantic_keywords,
|
| 439 |
+
similarity_threshold=similarity_threshold
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
self.logger.info(f"Step 5 completed: Embedding dimension={len(embedding)}, "
|
| 443 |
+
f"threshold={similarity_threshold:.3f}, keywords={len(semantic_keywords)}")
|
| 444 |
+
|
| 445 |
+
return result
|
| 446 |
+
|
| 447 |
+
except Exception as e:
|
| 448 |
+
self.logger.error(f"Step 5 Query Vectorization failed: {e}")
|
| 449 |
+
raise
|
| 450 |
+
|
| 451 |
+
def _normalize_query(self, query: str, analysis: QueryAnalysis) -> str:
|
| 452 |
+
"""Normalize query text for better embedding quality."""
|
| 453 |
+
# Start with original query
|
| 454 |
+
normalized = query.strip()
|
| 455 |
+
|
| 456 |
+
# Add important entities and concepts for context
|
| 457 |
+
if analysis.entities_mentioned:
|
| 458 |
+
entity_context = " ".join(analysis.entities_mentioned[:3])
|
| 459 |
+
normalized = f"{normalized} [Entities: {entity_context}]"
|
| 460 |
+
|
| 461 |
+
if analysis.key_concepts:
|
| 462 |
+
concept_context = " ".join(analysis.key_concepts[:3])
|
| 463 |
+
normalized = f"{normalized} [Concepts: {concept_context}]"
|
| 464 |
+
|
| 465 |
+
return normalized
|
| 466 |
+
|
| 467 |
+
async def _generate_embedding(self, text: str) -> List[float]:
|
| 468 |
+
"""Generate embedding for text using configured model."""
|
| 469 |
+
try:
|
| 470 |
+
embedding = await self.embedding_model.aget_text_embedding(text)
|
| 471 |
+
return embedding
|
| 472 |
+
except Exception as e:
|
| 473 |
+
self.logger.error(f"Embedding generation failed: {e}")
|
| 474 |
+
# Fallback to synchronous call if async fails
|
| 475 |
+
return self.embedding_model.get_text_embedding(text)
|
| 476 |
+
|
| 477 |
+
def _extract_semantic_keywords(self, query: str, analysis: QueryAnalysis) -> List[str]:
|
| 478 |
+
"""Extract semantic keywords for additional matching."""
|
| 479 |
+
keywords = set()
|
| 480 |
+
|
| 481 |
+
# Add entities and concepts
|
| 482 |
+
keywords.update(analysis.entities_mentioned)
|
| 483 |
+
keywords.update(analysis.key_concepts)
|
| 484 |
+
|
| 485 |
+
# Add query-specific terms based on type
|
| 486 |
+
if analysis.query_type == QueryType.RELATIONSHIP_QUERY:
|
| 487 |
+
keywords.update(['relationship', 'connection', 'related', 'linked'])
|
| 488 |
+
elif analysis.query_type == QueryType.COMPARATIVE:
|
| 489 |
+
keywords.update(['comparison', 'versus', 'difference', 'similar'])
|
| 490 |
+
elif analysis.query_type == QueryType.BROAD_THEMATIC:
|
| 491 |
+
keywords.update(['theme', 'pattern', 'trend', 'overview'])
|
| 492 |
+
|
| 493 |
+
# Filter and return as list
|
| 494 |
+
return [kw for kw in keywords if len(kw) > 2]
|
| 495 |
+
|
| 496 |
+
def _calculate_similarity_threshold(self, analysis: QueryAnalysis) -> float:
|
| 497 |
+
"""Calculate appropriate similarity threshold based on query characteristics."""
|
| 498 |
+
base_threshold = 0.7
|
| 499 |
+
|
| 500 |
+
# Adjust based on query complexity
|
| 501 |
+
if analysis.complexity_score > 0.7:
|
| 502 |
+
base_threshold -= 0.1 # Lower threshold for complex queries
|
| 503 |
+
elif analysis.complexity_score < 0.3:
|
| 504 |
+
base_threshold += 0.1 # Higher threshold for simple queries
|
| 505 |
+
|
| 506 |
+
# Adjust based on scope
|
| 507 |
+
if analysis.estimated_scope == "narrow":
|
| 508 |
+
base_threshold += 0.05
|
| 509 |
+
elif analysis.estimated_scope == "broad":
|
| 510 |
+
base_threshold -= 0.05
|
| 511 |
+
|
| 512 |
+
# Ensure reasonable bounds
|
| 513 |
+
return max(0.5, min(0.9, base_threshold))
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
class QueryPreprocessor:
|
| 517 |
+
"""Main class coordinating all query preprocessing steps (Steps 3-5)."""
|
| 518 |
+
|
| 519 |
+
def __init__(self, config: Any, graph_stats: Dict[str, Any]):
|
| 520 |
+
self.config = config
|
| 521 |
+
self.graph_stats = graph_stats
|
| 522 |
+
self.logger = logging.getLogger('graphrag_query')
|
| 523 |
+
|
| 524 |
+
# Initialize component processors
|
| 525 |
+
self.analyzer = QueryAnalyzer(config)
|
| 526 |
+
self.router = DriftRouter(config, graph_stats)
|
| 527 |
+
self.vectorizer = QueryVectorizer(config)
|
| 528 |
+
|
| 529 |
+
self.logger.info("QueryPreprocessor initialized for Steps 3-5")
|
| 530 |
+
|
| 531 |
+
async def preprocess_query(self, query: str) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]:
|
| 532 |
+
"""
|
| 533 |
+
Execute complete query preprocessing pipeline (Steps 3-5).
|
| 534 |
+
|
| 535 |
+
Args:
|
| 536 |
+
query: User's natural language query
|
| 537 |
+
|
| 538 |
+
Returns:
|
| 539 |
+
Tuple of (analysis, routing, vectorization) results
|
| 540 |
+
"""
|
| 541 |
+
self.logger.info(f"Starting Phase B: Query Preprocessing Pipeline for: {query[:100]}...")
|
| 542 |
+
|
| 543 |
+
try:
|
| 544 |
+
# Query analysis
|
| 545 |
+
analysis = await self.analyzer.analyze_query(query)
|
| 546 |
+
|
| 547 |
+
# Query routing
|
| 548 |
+
routing = await self.router.determine_search_strategy(analysis, query)
|
| 549 |
+
|
| 550 |
+
# Query vectorization
|
| 551 |
+
vectorization = await self.vectorizer.vectorize_query(query, analysis)
|
| 552 |
+
|
| 553 |
+
self.logger.info(f"Phase B completed successfully: "
|
| 554 |
+
f"Type={analysis.query_type.value}, "
|
| 555 |
+
f"Strategy={routing.search_strategy.value}, "
|
| 556 |
+
f"Embedding_dim={len(vectorization.embedding)}")
|
| 557 |
+
|
| 558 |
+
return analysis, routing, vectorization
|
| 559 |
+
|
| 560 |
+
except Exception as e:
|
| 561 |
+
self.logger.error(f"Query preprocessing pipeline failed: {e}")
|
| 562 |
+
raise
|
| 563 |
+
|
| 564 |
+
|
| 565 |
+
# Exports
|
| 566 |
+
async def create_query_preprocessor(config: Any, graph_stats: Dict[str, Any]) -> QueryPreprocessor:
|
| 567 |
+
"""Create and initialize QueryPreprocessor."""
|
| 568 |
+
return QueryPreprocessor(config, graph_stats)
|
| 569 |
+
|
| 570 |
+
|
| 571 |
+
async def preprocess_query_pipeline(query: str, config: Any, graph_stats: Dict[str, Any]) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]:
|
| 572 |
+
"""
|
| 573 |
+
Convenience function for complete query preprocessing.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
query: User's natural language query
|
| 577 |
+
config: Application configuration
|
| 578 |
+
graph_stats: Graph database statistics
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
Complete preprocessing results
|
| 582 |
+
"""
|
| 583 |
+
preprocessor = await create_query_preprocessor(config, graph_stats)
|
| 584 |
+
return await preprocessor.preprocess_query(query)
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
__all__ = [
|
| 588 |
+
'QueryAnalyzer', 'DriftRouter', 'QueryVectorizer', 'QueryPreprocessor',
|
| 589 |
+
'create_query_preprocessor', 'preprocess_query_pipeline',
|
| 590 |
+
'QueryAnalysis', 'DriftRoutingResult', 'VectorizedQuery',
|
| 591 |
+
'QueryType', 'SearchStrategy'
|
| 592 |
+
]
|
query_graph_functions/response_management.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
"""Response management module for metadata generation and file I/O operations - Phase G (Steps 17-20)."""
|
| 3 |
+
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Dict, List, Any
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
|
| 11 |
+
from .setup import GraphRAGSetup
|
| 12 |
+
from .query_preprocessing import QueryAnalysis, DriftRoutingResult, VectorizedQuery
|
| 13 |
+
from .answer_synthesis import SynthesisResult
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
@dataclass
|
| 17 |
+
class ResponseMetadata:
|
| 18 |
+
"""Complete response metadata structure."""
|
| 19 |
+
query_type: str
|
| 20 |
+
search_strategy: str
|
| 21 |
+
complexity_score: float
|
| 22 |
+
total_time_seconds: float
|
| 23 |
+
phases_completed: List[str]
|
| 24 |
+
status: str
|
| 25 |
+
phase_details: Dict[str, Any]
|
| 26 |
+
database_stats: Dict[str, Any]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ResponseManager:
|
| 30 |
+
def __init__(self, setup: GraphRAGSetup):
|
| 31 |
+
self.setup = setup
|
| 32 |
+
self.config = setup.config
|
| 33 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 34 |
+
|
| 35 |
+
def generate_comprehensive_metadata(self,
|
| 36 |
+
analysis: QueryAnalysis,
|
| 37 |
+
routing: DriftRoutingResult,
|
| 38 |
+
vectorization: VectorizedQuery,
|
| 39 |
+
community_results: Dict[str, Any],
|
| 40 |
+
follow_up_results: Dict[str, Any],
|
| 41 |
+
augmentation_results: Any,
|
| 42 |
+
synthesis_results: SynthesisResult,
|
| 43 |
+
total_time: float) -> Dict[str, Any]:
|
| 44 |
+
"""
|
| 45 |
+
Generate comprehensive metadata for query response.
|
| 46 |
+
|
| 47 |
+
Consolidates all phase results into structured metadata format.
|
| 48 |
+
"""
|
| 49 |
+
try:
|
| 50 |
+
communities = community_results.get('communities', [])
|
| 51 |
+
|
| 52 |
+
metadata = {
|
| 53 |
+
# Execution Summary
|
| 54 |
+
"query_type": analysis.query_type.value,
|
| 55 |
+
"search_strategy": routing.search_strategy.value,
|
| 56 |
+
"complexity_score": analysis.complexity_score,
|
| 57 |
+
"total_time_seconds": round(total_time, 2),
|
| 58 |
+
"phases_completed": ["A-Init", "B-Preprocess", "C-Communities", "D-Followup", "E-Vector", "F-Synthesis"],
|
| 59 |
+
"status": "success",
|
| 60 |
+
|
| 61 |
+
# Phase A: Initialization
|
| 62 |
+
"phase_a": self._generate_phase_a_metadata(),
|
| 63 |
+
|
| 64 |
+
# Phase B: Query Preprocessing
|
| 65 |
+
"phase_b": self._generate_phase_b_metadata(analysis, vectorization, routing),
|
| 66 |
+
|
| 67 |
+
# Phase C: Community Search
|
| 68 |
+
"phase_c": self._generate_phase_c_metadata(communities, community_results),
|
| 69 |
+
|
| 70 |
+
# Phase D: Follow-up Search
|
| 71 |
+
"phase_d": self._generate_phase_d_metadata(follow_up_results),
|
| 72 |
+
|
| 73 |
+
# Phase E: Vector Augmentation
|
| 74 |
+
"phase_e": self._generate_phase_e_metadata(augmentation_results),
|
| 75 |
+
|
| 76 |
+
# Phase F: Answer Synthesis
|
| 77 |
+
"phase_f": self._generate_phase_f_metadata(synthesis_results),
|
| 78 |
+
|
| 79 |
+
# Database Statistics
|
| 80 |
+
"database_stats": self._generate_database_stats(follow_up_results, communities, augmentation_results)
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
self.logger.info("Generated comprehensive metadata with all phase details")
|
| 84 |
+
return metadata
|
| 85 |
+
|
| 86 |
+
except Exception as e:
|
| 87 |
+
self.logger.error(f"Failed to generate metadata: {e}")
|
| 88 |
+
return self._generate_fallback_metadata(str(e))
|
| 89 |
+
|
| 90 |
+
def _generate_phase_a_metadata(self) -> Dict[str, Any]:
|
| 91 |
+
"""Generate Phase A initialization metadata."""
|
| 92 |
+
from my_config import MY_CONFIG
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"neo4j_connected": bool(self.setup.neo4j_conn),
|
| 96 |
+
"vector_db_ready": bool(self.setup.query_engine),
|
| 97 |
+
"llm_model": getattr(MY_CONFIG, 'LLM_MODEL', 'unknown'),
|
| 98 |
+
"embedding_model": getattr(MY_CONFIG, 'EMBEDDING_MODEL', 'unknown'),
|
| 99 |
+
"drift_config_loaded": bool(self.setup.drift_config)
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
def _generate_phase_b_metadata(self, analysis: QueryAnalysis, vectorization: VectorizedQuery, routing: DriftRoutingResult) -> Dict[str, Any]:
|
| 103 |
+
"""Generate Phase B query preprocessing metadata."""
|
| 104 |
+
return {
|
| 105 |
+
"entities_extracted": len(analysis.entities_mentioned),
|
| 106 |
+
"semantic_keywords": len(vectorization.semantic_keywords),
|
| 107 |
+
"embedding_dimensions": len(vectorization.embedding),
|
| 108 |
+
"similarity_threshold": vectorization.similarity_threshold,
|
| 109 |
+
"routing_confidence": round(routing.confidence, 3)
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
def _generate_phase_c_metadata(self, communities: List[Any], community_results: Dict[str, Any]) -> Dict[str, Any]:
|
| 113 |
+
"""Generate Phase C community search metadata."""
|
| 114 |
+
return {
|
| 115 |
+
"communities_found": len(communities),
|
| 116 |
+
"community_ids": [c.community_id for c in communities[:5]],
|
| 117 |
+
"similarities": [round(c.similarity_score, 3) for c in communities[:5]],
|
| 118 |
+
"entities_extracted": len(community_results.get('extracted_data', {}).get('entities', [])),
|
| 119 |
+
"relationships_extracted": len(community_results.get('extracted_data', {}).get('relationships', []))
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
def _generate_phase_d_metadata(self, follow_up_results: Dict[str, Any]) -> Dict[str, Any]:
|
| 123 |
+
"""Generate Phase D follow-up search metadata."""
|
| 124 |
+
intermediate_answers = follow_up_results.get('intermediate_answers', [])
|
| 125 |
+
avg_confidence = 0.0
|
| 126 |
+
if intermediate_answers:
|
| 127 |
+
avg_confidence = sum(a.confidence for a in intermediate_answers) / len(intermediate_answers)
|
| 128 |
+
|
| 129 |
+
return {
|
| 130 |
+
"questions_generated": len(follow_up_results.get('follow_up_questions', [])),
|
| 131 |
+
"graph_traversals": len(follow_up_results.get('local_search_results', [])),
|
| 132 |
+
"entities_found": len(follow_up_results.get('detailed_entities', [])),
|
| 133 |
+
"intermediate_answers": len(intermediate_answers),
|
| 134 |
+
"avg_confidence": round(avg_confidence, 3)
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
def _generate_phase_e_metadata(self, augmentation_results: Any) -> Dict[str, Any]:
|
| 138 |
+
"""Generate Phase E vector augmentation metadata."""
|
| 139 |
+
if not augmentation_results:
|
| 140 |
+
return {"vector_results_count": 0, "augmentation_confidence": 0.0}
|
| 141 |
+
|
| 142 |
+
vector_files = []
|
| 143 |
+
if hasattr(augmentation_results, 'vector_results'):
|
| 144 |
+
for i, result in enumerate(augmentation_results.vector_results):
|
| 145 |
+
file_info = {
|
| 146 |
+
"file_id": i + 1,
|
| 147 |
+
"file_path": getattr(result, 'file_path', 'unknown'),
|
| 148 |
+
"similarity": round(result.similarity_score, 3),
|
| 149 |
+
"content_length": len(result.content),
|
| 150 |
+
"relevance": round(getattr(result, 'relevance_score', 0.0), 3)
|
| 151 |
+
}
|
| 152 |
+
vector_files.append(file_info)
|
| 153 |
+
|
| 154 |
+
return {
|
| 155 |
+
"vector_results_count": len(augmentation_results.vector_results) if hasattr(augmentation_results, 'vector_results') else 0,
|
| 156 |
+
"augmentation_confidence": round(augmentation_results.augmentation_confidence, 3) if hasattr(augmentation_results, 'augmentation_confidence') else 0.0,
|
| 157 |
+
"execution_time": round(augmentation_results.execution_time, 2) if hasattr(augmentation_results, 'execution_time') else 0.0,
|
| 158 |
+
"similarity_threshold": 0.75,
|
| 159 |
+
"vector_files": vector_files
|
| 160 |
+
}
|
| 161 |
+
|
| 162 |
+
def _generate_phase_f_metadata(self, synthesis_results: SynthesisResult) -> Dict[str, Any]:
|
| 163 |
+
"""Generate Phase F answer synthesis metadata."""
|
| 164 |
+
return {
|
| 165 |
+
"synthesis_confidence": round(synthesis_results.confidence_score, 3),
|
| 166 |
+
"sources_integrated": len(synthesis_results.source_evidence),
|
| 167 |
+
"final_answer_length": len(synthesis_results.final_answer),
|
| 168 |
+
"synthesis_method": getattr(synthesis_results, 'synthesis_method', 'comprehensive_fusion')
|
| 169 |
+
}
|
| 170 |
+
|
| 171 |
+
def _generate_database_stats(self, follow_up_results: Dict[str, Any], communities: List[Any], augmentation_results: Any) -> Dict[str, Any]:
|
| 172 |
+
"""Generate database statistics metadata."""
|
| 173 |
+
vector_docs_used = 0
|
| 174 |
+
if augmentation_results and hasattr(augmentation_results, 'vector_results'):
|
| 175 |
+
vector_docs_used = len(augmentation_results.vector_results)
|
| 176 |
+
|
| 177 |
+
return {
|
| 178 |
+
"total_nodes": self.setup.graph_stats.get('node_count', 0),
|
| 179 |
+
"total_relationships": self.setup.graph_stats.get('relationship_count', 0),
|
| 180 |
+
"total_communities": self.setup.graph_stats.get('community_count', 0),
|
| 181 |
+
"nodes_accessed": len(follow_up_results.get('detailed_entities', [])),
|
| 182 |
+
"communities_searched": len(communities),
|
| 183 |
+
"vector_docs_used": vector_docs_used
|
| 184 |
+
}
|
| 185 |
+
|
| 186 |
+
def _generate_fallback_metadata(self, error: str) -> Dict[str, Any]:
|
| 187 |
+
"""Generate minimal metadata when full generation fails."""
|
| 188 |
+
return {
|
| 189 |
+
"status": "metadata_generation_error",
|
| 190 |
+
"error": error,
|
| 191 |
+
"phases_completed": "incomplete",
|
| 192 |
+
"total_time_seconds": 0.0
|
| 193 |
+
}
|
| 194 |
+
|
| 195 |
+
def save_response_to_files(self, user_query: str, result: Dict[str, Any]) -> None:
|
| 196 |
+
"""
|
| 197 |
+
Save query response and metadata to separate files.
|
| 198 |
+
|
| 199 |
+
Handles file I/O operations for response persistence.
|
| 200 |
+
"""
|
| 201 |
+
try:
|
| 202 |
+
timestamp = time.strftime('%Y-%m-%d %H:%M:%S')
|
| 203 |
+
|
| 204 |
+
# Save response to response file
|
| 205 |
+
self._save_response_file(user_query, result, timestamp)
|
| 206 |
+
|
| 207 |
+
# Save metadata to metadata file
|
| 208 |
+
self._save_metadata_file(user_query, result, timestamp)
|
| 209 |
+
|
| 210 |
+
self.logger.info(f"Saved response and metadata for query: {user_query[:50]}...")
|
| 211 |
+
|
| 212 |
+
except Exception as e:
|
| 213 |
+
self.logger.error(f"Failed to save response files: {e}")
|
| 214 |
+
|
| 215 |
+
def _save_response_file(self, user_query: str, result: Dict[str, Any], timestamp: str) -> None:
|
| 216 |
+
"""Save response content to response file."""
|
| 217 |
+
try:
|
| 218 |
+
with open('logs/graphrag_query/graphrag_responses.txt', 'a', encoding='utf-8') as f:
|
| 219 |
+
f.write(f"\n{'='*80}\n")
|
| 220 |
+
f.write(f"QUERY [{timestamp}]: {user_query}\n")
|
| 221 |
+
f.write(f"{'='*80}\n")
|
| 222 |
+
f.write(f"RESPONSE: {result['answer']}\n")
|
| 223 |
+
f.write(f"{'='*80}\n\n")
|
| 224 |
+
except Exception as e:
|
| 225 |
+
self.logger.error(f"Failed to save response file: {e}")
|
| 226 |
+
|
| 227 |
+
def _save_metadata_file(self, user_query: str, result: Dict[str, Any], timestamp: str) -> None:
|
| 228 |
+
"""Save metadata to metadata file."""
|
| 229 |
+
try:
|
| 230 |
+
with open('logs/graphrag_query/graphrag_metadata.txt', 'a', encoding='utf-8') as f:
|
| 231 |
+
f.write(f"\n{'='*80}\n")
|
| 232 |
+
f.write(f"METADATA [{timestamp}]: {user_query}\n")
|
| 233 |
+
f.write(f"{'='*80}\n")
|
| 234 |
+
f.write(json.dumps(result['metadata'], indent=2, default=str))
|
| 235 |
+
f.write(f"\n{'='*80}\n\n")
|
| 236 |
+
except Exception as e:
|
| 237 |
+
self.logger.error(f"Failed to save metadata file: {e}")
|
| 238 |
+
|
| 239 |
+
def format_error_response(self, error_message: str) -> Dict[str, Any]:
|
| 240 |
+
"""
|
| 241 |
+
Generate standardized error response with metadata.
|
| 242 |
+
|
| 243 |
+
Creates consistent error format for failed queries.
|
| 244 |
+
"""
|
| 245 |
+
return {
|
| 246 |
+
"answer": f"Sorry, I encountered an error: {error_message}",
|
| 247 |
+
"metadata": {
|
| 248 |
+
"status": "error",
|
| 249 |
+
"error_message": error_message,
|
| 250 |
+
"phases_completed": "incomplete",
|
| 251 |
+
"neo4j_connected": bool(self.setup.neo4j_conn) if self.setup.neo4j_conn else False,
|
| 252 |
+
"vector_engine_ready": bool(self.setup.query_engine) if self.setup.query_engine else False,
|
| 253 |
+
"timestamp": datetime.now().isoformat()
|
| 254 |
+
}
|
| 255 |
+
}
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
# Exports
|
| 259 |
+
__all__ = ['ResponseManager', 'ResponseMetadata']
|
query_graph_functions/setup.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Graph setup module for database and model initialization. Phase A (Steps 1-2)"""
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import logging
|
| 5 |
+
from typing import Dict, Optional, Any
|
| 6 |
+
import sys
|
| 7 |
+
sys.path.append('..') # Add parent directory to path for imports
|
| 8 |
+
|
| 9 |
+
from my_config import MY_CONFIG
|
| 10 |
+
from neo4j import GraphDatabase
|
| 11 |
+
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
|
| 12 |
+
from llama_index.core import Settings, VectorStoreIndex, StorageContext
|
| 13 |
+
from llama_index.vector_stores.milvus import MilvusVectorStore
|
| 14 |
+
from llama_index.llms.litellm import LiteLLM
|
| 15 |
+
|
| 16 |
+
# Set up environment
|
| 17 |
+
os.environ['HF_ENDPOINT'] = MY_CONFIG.HF_ENDPOINT
|
| 18 |
+
|
| 19 |
+
# Configure logging
|
| 20 |
+
logging.basicConfig(level=logging.WARNING, format='%(asctime)s - %(levelname)s - %(message)s', force=True)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
logger.setLevel(logging.INFO)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class Neo4jConnection:
|
| 26 |
+
"""
|
| 27 |
+
Neo4j database connection manager.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(self):
|
| 31 |
+
self.uri = MY_CONFIG.NEO4J_URI
|
| 32 |
+
self.username = MY_CONFIG.NEO4J_USER
|
| 33 |
+
self.password = MY_CONFIG.NEO4J_PASSWORD
|
| 34 |
+
self.database = getattr(MY_CONFIG, "NEO4J_DATABASE", None)
|
| 35 |
+
|
| 36 |
+
# Validate required configuration
|
| 37 |
+
if not self.uri:
|
| 38 |
+
raise ValueError("NEO4J_URI config is required")
|
| 39 |
+
if not self.username:
|
| 40 |
+
raise ValueError("NEO4J_USERNAME config is required")
|
| 41 |
+
if not self.password:
|
| 42 |
+
raise ValueError("NEO4J_PASSWORD config is required")
|
| 43 |
+
if not self.database:
|
| 44 |
+
raise ValueError("NEO4J_DATABASE config is required")
|
| 45 |
+
|
| 46 |
+
self.driver: Optional[GraphDatabase.driver] = None
|
| 47 |
+
|
| 48 |
+
def connect(self):
|
| 49 |
+
"""STEP 1.2: Initialize Neo4j driver with verification"""
|
| 50 |
+
if self.driver is None:
|
| 51 |
+
try:
|
| 52 |
+
self.driver = GraphDatabase.driver(
|
| 53 |
+
self.uri,
|
| 54 |
+
auth=(self.username, self.password)
|
| 55 |
+
)
|
| 56 |
+
self.driver.verify_connectivity()
|
| 57 |
+
logger.info(f"Connected to Neo4j at {self.uri}")
|
| 58 |
+
except Exception as e:
|
| 59 |
+
logger.error(f"❌ STEP 1.2 FAILED: Neo4j connection error: {e}")
|
| 60 |
+
self.driver = None
|
| 61 |
+
|
| 62 |
+
def disconnect(self):
|
| 63 |
+
"""Clean up Neo4j connection"""
|
| 64 |
+
if self.driver:
|
| 65 |
+
self.driver.close()
|
| 66 |
+
self.driver = None
|
| 67 |
+
logger.info("Neo4j connection closed")
|
| 68 |
+
|
| 69 |
+
def execute_query(self, query: str, parameters: Optional[Dict[str, Any]] = None):
|
| 70 |
+
"""Execute Cypher query with error handling"""
|
| 71 |
+
if not self.driver:
|
| 72 |
+
raise ConnectionError("Not connected to Neo4j database")
|
| 73 |
+
|
| 74 |
+
with self.driver.session(database=self.database) as session:
|
| 75 |
+
result = session.run(query, parameters or {})
|
| 76 |
+
records = [record.data() for record in result]
|
| 77 |
+
return records
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GraphRAGSetup:
|
| 81 |
+
"""
|
| 82 |
+
Main setup class for graph-based retrieval system.
|
| 83 |
+
|
| 84 |
+
Handles core initialization and configuration:
|
| 85 |
+
- Database connections (Neo4j and vector database)
|
| 86 |
+
- Model initialization and configuration
|
| 87 |
+
- Graph statistics and validation
|
| 88 |
+
- Search configuration loading
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(self):
|
| 92 |
+
logger.info("Starting graph system initialization")
|
| 93 |
+
|
| 94 |
+
# Initialize core components
|
| 95 |
+
self.config = MY_CONFIG # Add config attribute for GraphQueryEngine
|
| 96 |
+
self.neo4j_conn = None
|
| 97 |
+
self.query_engine = None
|
| 98 |
+
self.graph_stats = {}
|
| 99 |
+
self.drift_config = {}
|
| 100 |
+
self.llm = None
|
| 101 |
+
self.embedding_model = None
|
| 102 |
+
|
| 103 |
+
# Execute Step 1 initialization sequence
|
| 104 |
+
self._execute_step1_sequence()
|
| 105 |
+
|
| 106 |
+
logger.info("Graph system initialization complete")
|
| 107 |
+
|
| 108 |
+
def _execute_step1_sequence(self):
|
| 109 |
+
"""Execute complete Step 1 initialization sequence"""
|
| 110 |
+
# STEP 1.1-1.6: Initialize all components
|
| 111 |
+
self._setup_neo4j() # STEP 1.2
|
| 112 |
+
self._setup_vector_search() # STEP 1.3-1.6
|
| 113 |
+
self._load_graph_statistics() # STEP 2.1-2.4
|
| 114 |
+
self._load_drift_configuration() # STEP 2.5
|
| 115 |
+
|
| 116 |
+
def _setup_neo4j(self):
|
| 117 |
+
"""STEP 1.2: Initialize Neo4j driver with verification"""
|
| 118 |
+
try:
|
| 119 |
+
logger.info("Initializing Neo4j connection...")
|
| 120 |
+
self.neo4j_conn = Neo4jConnection()
|
| 121 |
+
self.neo4j_conn.connect()
|
| 122 |
+
|
| 123 |
+
# Verify connection with test query
|
| 124 |
+
if self.neo4j_conn.driver:
|
| 125 |
+
test_result = self.neo4j_conn.execute_query("MATCH (n) RETURN count(n) as total_nodes LIMIT 1")
|
| 126 |
+
node_count = test_result[0]['total_nodes'] if test_result else 0
|
| 127 |
+
logger.info(f"Neo4j connected - {node_count} nodes found")
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
logger.error(f"Neo4j connection error: {e}")
|
| 131 |
+
self.neo4j_conn = None
|
| 132 |
+
|
| 133 |
+
def _setup_vector_search(self):
|
| 134 |
+
"""STEP 1.3-1.5: Initialize vector database and LLM components"""
|
| 135 |
+
try:
|
| 136 |
+
logger.info("Setting up vector search and LLM...")
|
| 137 |
+
|
| 138 |
+
# STEP 1.5: Load embedding model
|
| 139 |
+
self.embedding_model = HuggingFaceEmbedding(
|
| 140 |
+
model_name=MY_CONFIG.EMBEDDING_MODEL
|
| 141 |
+
)
|
| 142 |
+
Settings.embed_model = self.embedding_model
|
| 143 |
+
logger.info(f"Embedding model loaded: {MY_CONFIG.EMBEDDING_MODEL}")
|
| 144 |
+
|
| 145 |
+
# STEP 1.6: Connect to vector database based on configuration
|
| 146 |
+
if MY_CONFIG.VECTOR_DB_TYPE == "cloud_zilliz":
|
| 147 |
+
if not MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT or not MY_CONFIG.ZILLIZ_TOKEN:
|
| 148 |
+
raise ValueError("Cloud database configuration missing. Set ZILLIZ_CLUSTER_ENDPOINT and ZILLIZ_TOKEN in .env")
|
| 149 |
+
|
| 150 |
+
vector_store = MilvusVectorStore(
|
| 151 |
+
uri=MY_CONFIG.ZILLIZ_CLUSTER_ENDPOINT,
|
| 152 |
+
token=MY_CONFIG.ZILLIZ_TOKEN,
|
| 153 |
+
dim=MY_CONFIG.EMBEDDING_LENGTH,
|
| 154 |
+
collection_name=MY_CONFIG.COLLECTION_NAME,
|
| 155 |
+
overwrite=False
|
| 156 |
+
)
|
| 157 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
| 158 |
+
logger.info("Connected to cloud vector database")
|
| 159 |
+
else:
|
| 160 |
+
vector_store = MilvusVectorStore(
|
| 161 |
+
uri=MY_CONFIG.MILVUS_URI_HYBRID_GRAPH,
|
| 162 |
+
dim=MY_CONFIG.EMBEDDING_LENGTH,
|
| 163 |
+
collection_name=MY_CONFIG.COLLECTION_NAME,
|
| 164 |
+
overwrite=False
|
| 165 |
+
)
|
| 166 |
+
storage_context = StorageContext.from_defaults(vector_store=vector_store)
|
| 167 |
+
logger.info("Connected to local vector database")
|
| 168 |
+
|
| 169 |
+
index = VectorStoreIndex.from_vector_store(
|
| 170 |
+
vector_store=vector_store, storage_context=storage_context)
|
| 171 |
+
logger.info("Vector index loaded successfully")
|
| 172 |
+
|
| 173 |
+
# STEP 1.4: Initialize LLM provider
|
| 174 |
+
llm_model = MY_CONFIG.LLM_MODEL
|
| 175 |
+
self.llm = LiteLLM(model=llm_model)
|
| 176 |
+
Settings.llm = self.llm
|
| 177 |
+
logger.info(f"LLM initialized: {llm_model}")
|
| 178 |
+
|
| 179 |
+
self.query_engine = index.as_query_engine()
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
logger.error(f"Vector setup error: {e}")
|
| 183 |
+
self.query_engine = None
|
| 184 |
+
|
| 185 |
+
def _load_graph_statistics(self):
|
| 186 |
+
"""STEP 2.1-2.4: Load and validate graph data structure"""
|
| 187 |
+
try:
|
| 188 |
+
logger.info("Loading graph statistics and validation...")
|
| 189 |
+
|
| 190 |
+
if not self.neo4j_conn or not self.neo4j_conn.driver:
|
| 191 |
+
logger.warning("No Neo4j connection for statistics")
|
| 192 |
+
return
|
| 193 |
+
|
| 194 |
+
# STEP 2.1: Get node and relationship counts
|
| 195 |
+
stats_query = """
|
| 196 |
+
MATCH (n)
|
| 197 |
+
OPTIONAL MATCH ()-[r]-()
|
| 198 |
+
RETURN count(DISTINCT n) as node_count,
|
| 199 |
+
count(DISTINCT r) as relationship_count,
|
| 200 |
+
count(DISTINCT n.community_id) as community_count
|
| 201 |
+
"""
|
| 202 |
+
|
| 203 |
+
result = self.neo4j_conn.execute_query(stats_query)
|
| 204 |
+
if result:
|
| 205 |
+
stats = result[0]
|
| 206 |
+
self.graph_stats = {
|
| 207 |
+
'node_count': stats.get('node_count', 0),
|
| 208 |
+
'relationship_count': stats.get('relationship_count', 0),
|
| 209 |
+
'community_count': stats.get('community_count', 0)
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
logger.info(f"Graph validated - {self.graph_stats['node_count']} nodes, "
|
| 213 |
+
f"{self.graph_stats['relationship_count']} relationships, "
|
| 214 |
+
f"{self.graph_stats['community_count']} communities")
|
| 215 |
+
|
| 216 |
+
except Exception as e:
|
| 217 |
+
logger.error(f"Graph statistics error: {e}")
|
| 218 |
+
self.graph_stats = {}
|
| 219 |
+
|
| 220 |
+
def _load_drift_configuration(self):
|
| 221 |
+
"""STEP 2.5: Load DRIFT search metadata and configuration"""
|
| 222 |
+
logger.info("Loading search configuration...")
|
| 223 |
+
|
| 224 |
+
if not self.neo4j_conn or not self.neo4j_conn.driver:
|
| 225 |
+
logger.warning("No Neo4j connection for search configuration")
|
| 226 |
+
self.drift_config = {}
|
| 227 |
+
return
|
| 228 |
+
|
| 229 |
+
# Query for all DRIFT-related nodes
|
| 230 |
+
drift_metadata_query = """
|
| 231 |
+
OPTIONAL MATCH (dm:DriftMetadata)
|
| 232 |
+
OPTIONAL MATCH (dc:DriftConfiguration)
|
| 233 |
+
OPTIONAL MATCH (csi:CommunitySearchIndex)
|
| 234 |
+
OPTIONAL MATCH (gm:GraphMetadata)
|
| 235 |
+
OPTIONAL MATCH (cm:CommunitiesMetadata)
|
| 236 |
+
RETURN dm, dc, csi, gm, cm
|
| 237 |
+
"""
|
| 238 |
+
|
| 239 |
+
result = self.neo4j_conn.execute_query(drift_metadata_query)
|
| 240 |
+
if result and result[0]:
|
| 241 |
+
record = result[0]
|
| 242 |
+
drift_config = {}
|
| 243 |
+
|
| 244 |
+
# Extract DriftMetadata properties
|
| 245 |
+
if record.get('dm'):
|
| 246 |
+
dm_props = dict(record['dm'])
|
| 247 |
+
drift_config.update(dm_props)
|
| 248 |
+
logger.info("DriftMetadata node found")
|
| 249 |
+
|
| 250 |
+
# Extract DriftConfiguration properties
|
| 251 |
+
if record.get('dc'):
|
| 252 |
+
dc_props = dict(record['dc'])
|
| 253 |
+
drift_config['configuration'] = dc_props
|
| 254 |
+
logger.info("DriftConfiguration node found")
|
| 255 |
+
|
| 256 |
+
# Extract CommunitySearchIndex properties
|
| 257 |
+
if record.get('csi'):
|
| 258 |
+
csi_props = dict(record['csi'])
|
| 259 |
+
drift_config['community_search_index'] = csi_props
|
| 260 |
+
logger.info("CommunitySearchIndex node found")
|
| 261 |
+
|
| 262 |
+
# Extract GraphMetadata properties
|
| 263 |
+
if record.get('gm'):
|
| 264 |
+
gm_props = dict(record['gm'])
|
| 265 |
+
drift_config['graph_metadata'] = gm_props
|
| 266 |
+
logger.info("GraphMetadata node found")
|
| 267 |
+
|
| 268 |
+
# Extract CommunitiesMetadata properties
|
| 269 |
+
if record.get('cm'):
|
| 270 |
+
cm_props = dict(record['cm'])
|
| 271 |
+
drift_config['communities_metadata'] = cm_props
|
| 272 |
+
logger.info("CommunitiesMetadata node found")
|
| 273 |
+
|
| 274 |
+
self.drift_config = drift_config
|
| 275 |
+
logger.info("Search configuration loaded from Neo4j nodes")
|
| 276 |
+
|
| 277 |
+
else:
|
| 278 |
+
logger.warning("No metadata nodes found in Neo4j")
|
| 279 |
+
self.drift_config = {}
|
| 280 |
+
|
| 281 |
+
def validate_system_readiness(self):
|
| 282 |
+
"""Validate all required components are initialized"""
|
| 283 |
+
ready = True
|
| 284 |
+
|
| 285 |
+
if not self.neo4j_conn or not self.neo4j_conn.driver:
|
| 286 |
+
logger.error("Neo4j connection not available")
|
| 287 |
+
ready = False
|
| 288 |
+
|
| 289 |
+
if not self.query_engine:
|
| 290 |
+
logger.error("Vector query engine not available")
|
| 291 |
+
ready = False
|
| 292 |
+
|
| 293 |
+
if not self.graph_stats:
|
| 294 |
+
logger.warning("Graph statistics not loaded")
|
| 295 |
+
|
| 296 |
+
if ready:
|
| 297 |
+
logger.info("System readiness validated")
|
| 298 |
+
|
| 299 |
+
return ready
|
| 300 |
+
|
| 301 |
+
def get_system_status(self):
|
| 302 |
+
"""Get detailed system status information"""
|
| 303 |
+
return {
|
| 304 |
+
"neo4j_connected": bool(self.neo4j_conn and self.neo4j_conn.driver),
|
| 305 |
+
"vector_engine_ready": bool(self.query_engine),
|
| 306 |
+
"graph_stats_loaded": bool(self.graph_stats),
|
| 307 |
+
"drift_config_loaded": bool(self.drift_config),
|
| 308 |
+
"llm_ready": bool(self.llm),
|
| 309 |
+
"graph_stats": self.graph_stats,
|
| 310 |
+
"drift_config": self.drift_config
|
| 311 |
+
}
|
| 312 |
+
|
| 313 |
+
async def cleanup_async_tasks(self, timeout: float = 2.0) -> None:
|
| 314 |
+
"""
|
| 315 |
+
Clean up async tasks and pending operations.
|
| 316 |
+
|
| 317 |
+
Handles proper cleanup of LiteLLM and other async tasks to prevent
|
| 318 |
+
'Task was destroyed but it is pending!' warnings.
|
| 319 |
+
"""
|
| 320 |
+
try:
|
| 321 |
+
import asyncio
|
| 322 |
+
|
| 323 |
+
# Import cleanup function if available
|
| 324 |
+
try:
|
| 325 |
+
from litellm_patch import cleanup_all_async_tasks
|
| 326 |
+
await cleanup_all_async_tasks(timeout=timeout)
|
| 327 |
+
logger.info(f"Cleaned up async tasks with timeout {timeout}s")
|
| 328 |
+
except ImportError:
|
| 329 |
+
# Fallback: Cancel pending tasks manually
|
| 330 |
+
pending_tasks = [task for task in asyncio.all_tasks() if not task.done()]
|
| 331 |
+
if pending_tasks:
|
| 332 |
+
logger.info(f"Cancelling {len(pending_tasks)} pending tasks")
|
| 333 |
+
for task in pending_tasks:
|
| 334 |
+
task.cancel()
|
| 335 |
+
|
| 336 |
+
# Wait for cancellation with timeout
|
| 337 |
+
try:
|
| 338 |
+
await asyncio.wait_for(
|
| 339 |
+
asyncio.gather(*pending_tasks, return_exceptions=True),
|
| 340 |
+
timeout=timeout
|
| 341 |
+
)
|
| 342 |
+
except asyncio.TimeoutError:
|
| 343 |
+
logger.warning("Some tasks did not complete within timeout")
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
logger.error(f"Error during async cleanup: {e}")
|
| 347 |
+
|
| 348 |
+
def close(self):
|
| 349 |
+
"""Clean up all connections"""
|
| 350 |
+
if self.neo4j_conn:
|
| 351 |
+
self.neo4j_conn.disconnect()
|
| 352 |
+
logger.info("Setup cleanup complete")
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
def create_graphrag_setup():
|
| 356 |
+
"""Factory function to create GraphRAG setup instance"""
|
| 357 |
+
return GraphRAGSetup()
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
# Exports
|
| 361 |
+
__all__ = ['GraphRAGSetup', 'create_graphrag_setup']
|
query_graph_functions/vector_augmentation.py
ADDED
|
@@ -0,0 +1,262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Vector augmentation engine implementing Phase E (Steps 13-14).
|
| 3 |
+
|
| 4 |
+
Handles vector search operations and result fusion:
|
| 5 |
+
- Vector similarity search for additional context (Step 13)
|
| 6 |
+
- Result fusion strategy for enhanced answers (Step 14)
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import logging
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import Dict, List, Any, Tuple
|
| 12 |
+
from dataclasses import dataclass
|
| 13 |
+
from datetime import datetime
|
| 14 |
+
|
| 15 |
+
from .setup import GraphRAGSetup
|
| 16 |
+
from .query_preprocessing import DriftRoutingResult
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class VectorSearchResult:
|
| 21 |
+
"""Vector search result with similarity and content."""
|
| 22 |
+
document_id: str
|
| 23 |
+
content: str
|
| 24 |
+
similarity_score: float
|
| 25 |
+
metadata: Dict[str, Any]
|
| 26 |
+
source_type: str # 'vector_db', 'semantic_search'
|
| 27 |
+
relevance_score: float
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class AugmentationResult:
|
| 32 |
+
"""Phase E augmentation result with enhanced context."""
|
| 33 |
+
vector_results: List[VectorSearchResult]
|
| 34 |
+
enhanced_context: str
|
| 35 |
+
fusion_strategy: str
|
| 36 |
+
augmentation_confidence: float
|
| 37 |
+
execution_time: float
|
| 38 |
+
metadata: Dict[str, Any]
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class VectorAugmentationEngine:
|
| 42 |
+
def __init__(self, setup: GraphRAGSetup):
|
| 43 |
+
self.setup = setup
|
| 44 |
+
self.vector_engine = setup.query_engine # Milvus vector engine
|
| 45 |
+
self.embedding_model = setup.embedding_model
|
| 46 |
+
self.config = setup.config
|
| 47 |
+
self.logger = logging.getLogger(self.__class__.__name__)
|
| 48 |
+
|
| 49 |
+
# Vector search parameters
|
| 50 |
+
self.similarity_threshold = 0.75
|
| 51 |
+
self.max_vector_results = 10
|
| 52 |
+
|
| 53 |
+
async def execute_vector_augmentation_phase(self,
|
| 54 |
+
query_embedding: List[float],
|
| 55 |
+
graph_results: Dict[str, Any],
|
| 56 |
+
routing_result: DriftRoutingResult) -> AugmentationResult:
|
| 57 |
+
"""
|
| 58 |
+
Execute vector augmentation phase with similarity search.
|
| 59 |
+
|
| 60 |
+
Args:
|
| 61 |
+
query_embedding: Query vector for similarity matching
|
| 62 |
+
graph_results: Results from graph-based search
|
| 63 |
+
routing_result: Routing decision parameters
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
Augmentation results with vector context
|
| 67 |
+
"""
|
| 68 |
+
start_time = datetime.now()
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
# Step 13: Vector Similarity Search
|
| 72 |
+
self.logger.info("Starting Step 13: Vector Similarity Search")
|
| 73 |
+
vector_results = await self._perform_vector_search(
|
| 74 |
+
query_embedding, routing_result
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# Step 14: Result Fusion and Enhancement
|
| 78 |
+
self.logger.info("Starting Step 14: Result Fusion and Enhancement")
|
| 79 |
+
enhanced_context = await self._fuse_results(
|
| 80 |
+
vector_results, graph_results, routing_result
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
execution_time = (datetime.now() - start_time).total_seconds()
|
| 84 |
+
|
| 85 |
+
augmentation_result = AugmentationResult(
|
| 86 |
+
vector_results=vector_results,
|
| 87 |
+
enhanced_context=enhanced_context,
|
| 88 |
+
fusion_strategy='graph_vector_hybrid',
|
| 89 |
+
augmentation_confidence=self._calculate_augmentation_confidence(vector_results),
|
| 90 |
+
execution_time=execution_time,
|
| 91 |
+
metadata={
|
| 92 |
+
'vector_results_count': len(vector_results),
|
| 93 |
+
'avg_similarity': np.mean([r.similarity_score for r in vector_results]) if vector_results else 0,
|
| 94 |
+
'phase': 'vector_augmentation',
|
| 95 |
+
'step_range': '13-14'
|
| 96 |
+
}
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
self.logger.info(f"Phase E completed: {len(vector_results)} vector results, augmentation confidence: {augmentation_result.augmentation_confidence:.3f}")
|
| 100 |
+
return augmentation_result
|
| 101 |
+
|
| 102 |
+
except Exception as e:
|
| 103 |
+
self.logger.error(f"Vector augmentation phase failed: {e}")
|
| 104 |
+
# Return empty augmentation on failure
|
| 105 |
+
return AugmentationResult(
|
| 106 |
+
vector_results=[],
|
| 107 |
+
enhanced_context="",
|
| 108 |
+
fusion_strategy='graph_only',
|
| 109 |
+
augmentation_confidence=0.0,
|
| 110 |
+
execution_time=(datetime.now() - start_time).total_seconds(),
|
| 111 |
+
metadata={'error': str(e), 'fallback': True}
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
async def _perform_vector_search(self,
|
| 115 |
+
query_embedding: List[float],
|
| 116 |
+
routing_result: DriftRoutingResult) -> List[VectorSearchResult]:
|
| 117 |
+
"""
|
| 118 |
+
Step 13: Perform comprehensive vector similarity search.
|
| 119 |
+
|
| 120 |
+
Uses the Milvus vector database to find semantically similar content.
|
| 121 |
+
"""
|
| 122 |
+
try:
|
| 123 |
+
vector_results = []
|
| 124 |
+
|
| 125 |
+
# Use the existing vector query engine for similarity search
|
| 126 |
+
if self.vector_engine:
|
| 127 |
+
# Query the vector database with the embedding
|
| 128 |
+
search_results = self.vector_engine.query(routing_result.original_query)
|
| 129 |
+
|
| 130 |
+
# Extract vector search results from the response
|
| 131 |
+
if hasattr(search_results, 'source_nodes') and search_results.source_nodes:
|
| 132 |
+
for i, node in enumerate(search_results.source_nodes[:self.max_vector_results]):
|
| 133 |
+
# Calculate similarity score (handle different node types)
|
| 134 |
+
similarity_score = 0.8 # Default similarity
|
| 135 |
+
if hasattr(node, 'score'):
|
| 136 |
+
similarity_score = node.score
|
| 137 |
+
elif hasattr(node, 'similarity'):
|
| 138 |
+
similarity_score = node.similarity
|
| 139 |
+
elif hasattr(node, 'metadata') and 'score' in node.metadata:
|
| 140 |
+
similarity_score = node.metadata['score']
|
| 141 |
+
|
| 142 |
+
# Extract content (handle different node types)
|
| 143 |
+
content = ""
|
| 144 |
+
if hasattr(node, 'text'):
|
| 145 |
+
content = node.text
|
| 146 |
+
elif hasattr(node, 'content'):
|
| 147 |
+
content = node.content
|
| 148 |
+
elif hasattr(node, 'get_content'):
|
| 149 |
+
content = node.get_content()
|
| 150 |
+
else:
|
| 151 |
+
content = str(node)
|
| 152 |
+
|
| 153 |
+
# Extract metadata safely
|
| 154 |
+
node_metadata = {}
|
| 155 |
+
if hasattr(node, 'metadata') and node.metadata:
|
| 156 |
+
node_metadata = node.metadata
|
| 157 |
+
elif hasattr(node, 'extra_info') and node.extra_info:
|
| 158 |
+
node_metadata = node.extra_info
|
| 159 |
+
|
| 160 |
+
vector_result = VectorSearchResult(
|
| 161 |
+
document_id=node_metadata.get('doc_id', f"doc_{i}"),
|
| 162 |
+
content=content,
|
| 163 |
+
similarity_score=similarity_score,
|
| 164 |
+
metadata=node_metadata,
|
| 165 |
+
source_type='vector_db',
|
| 166 |
+
relevance_score=similarity_score * 0.9 # Slightly weighted down
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Only include results above similarity threshold
|
| 170 |
+
if similarity_score >= self.similarity_threshold:
|
| 171 |
+
vector_results.append(vector_result)
|
| 172 |
+
|
| 173 |
+
self.logger.info(f"Vector search completed: {len(vector_results)} results above threshold {self.similarity_threshold}")
|
| 174 |
+
else:
|
| 175 |
+
self.logger.warning("Vector engine not available, skipping vector search")
|
| 176 |
+
|
| 177 |
+
return vector_results
|
| 178 |
+
|
| 179 |
+
except Exception as e:
|
| 180 |
+
self.logger.error(f"Vector search failed: {e}")
|
| 181 |
+
return []
|
| 182 |
+
|
| 183 |
+
async def _fuse_results(self,
|
| 184 |
+
vector_results: List[VectorSearchResult],
|
| 185 |
+
graph_results: Dict[str, Any],
|
| 186 |
+
routing_result: DriftRoutingResult) -> str:
|
| 187 |
+
"""
|
| 188 |
+
Step 14: Fuse vector and graph results for enhanced context.
|
| 189 |
+
|
| 190 |
+
Combines graph-based entity relationships with vector similarity content.
|
| 191 |
+
"""
|
| 192 |
+
try:
|
| 193 |
+
fusion_parts = []
|
| 194 |
+
|
| 195 |
+
# Start with graph-based context (Phase C & D results)
|
| 196 |
+
if 'initial_answer' in graph_results:
|
| 197 |
+
initial_answer = graph_results['initial_answer']
|
| 198 |
+
if isinstance(initial_answer, dict) and 'content' in initial_answer:
|
| 199 |
+
fusion_parts.extend([
|
| 200 |
+
"=== GRAPH-BASED KNOWLEDGE ===",
|
| 201 |
+
initial_answer['content'],
|
| 202 |
+
""
|
| 203 |
+
])
|
| 204 |
+
|
| 205 |
+
# Add vector-based augmentation
|
| 206 |
+
if vector_results:
|
| 207 |
+
fusion_parts.extend([
|
| 208 |
+
"=== SEMANTIC AUGMENTATION ===",
|
| 209 |
+
"Additional relevant information from vector similarity search:",
|
| 210 |
+
""
|
| 211 |
+
])
|
| 212 |
+
|
| 213 |
+
for i, result in enumerate(vector_results[:5], 1): # Top 5 vector results
|
| 214 |
+
fusion_parts.extend([
|
| 215 |
+
f"**{i}. Vector Result (Similarity: {result.similarity_score:.3f})**",
|
| 216 |
+
result.content, # Show full content without truncation
|
| 217 |
+
""
|
| 218 |
+
])
|
| 219 |
+
|
| 220 |
+
# Add fusion methodology explanation
|
| 221 |
+
fusion_parts.extend([
|
| 222 |
+
"=== FUSION METHODOLOGY ===",
|
| 223 |
+
"This enhanced answer combines graph-based entity relationships with vector semantic similarity search.",
|
| 224 |
+
"Graph results provide structured knowledge connections, while vector search adds contextual depth.",
|
| 225 |
+
""
|
| 226 |
+
])
|
| 227 |
+
|
| 228 |
+
enhanced_context = "\n".join(fusion_parts)
|
| 229 |
+
|
| 230 |
+
self.logger.info(f"Result fusion completed: {len(fusion_parts)} context sections")
|
| 231 |
+
return enhanced_context
|
| 232 |
+
|
| 233 |
+
except Exception as e:
|
| 234 |
+
self.logger.error(f"Result fusion failed: {e}")
|
| 235 |
+
return "Graph-based results only (vector fusion failed)"
|
| 236 |
+
|
| 237 |
+
def _calculate_augmentation_confidence(self, vector_results: List[VectorSearchResult]) -> float:
|
| 238 |
+
"""Calculate confidence score for the augmentation results."""
|
| 239 |
+
if not vector_results:
|
| 240 |
+
return 0.0
|
| 241 |
+
|
| 242 |
+
# Base confidence on average similarity and result count
|
| 243 |
+
avg_similarity = np.mean([r.similarity_score for r in vector_results])
|
| 244 |
+
count_factor = min(len(vector_results) / 10, 1.0) # Normalize to max 10 results
|
| 245 |
+
|
| 246 |
+
# Combined confidence
|
| 247 |
+
confidence = (avg_similarity * 0.7) + (count_factor * 0.3)
|
| 248 |
+
|
| 249 |
+
return min(confidence, 1.0)
|
| 250 |
+
|
| 251 |
+
def get_augmentation_stats(self) -> Dict[str, Any]:
|
| 252 |
+
"""Get statistics about vector augmentation performance."""
|
| 253 |
+
return {
|
| 254 |
+
'similarity_threshold': self.similarity_threshold,
|
| 255 |
+
'max_vector_results': self.max_vector_results,
|
| 256 |
+
'vector_engine_ready': bool(self.vector_engine),
|
| 257 |
+
'embedding_model': str(self.embedding_model) if self.embedding_model else None
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
# Export main class
|
| 262 |
+
__all__ = ['VectorAugmentationEngine', 'VectorSearchResult', 'AugmentationResult']
|