Spaces:
Runtime error
Runtime error
File size: 23,829 Bytes
9e5bc69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 |
"""Query preprocessing for analysis, routing, and vectorization - Phase B (Steps 3-5)."""
import logging
from typing import Dict, List, Any, Tuple, Optional
from dataclasses import dataclass
from enum import Enum
import re
# System imports
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from my_config import MY_CONFIG
from llama_index.embeddings.huggingface import HuggingFaceEmbedding
class QueryType(Enum):
"""Query type classifications for DRIFT routing."""
SPECIFIC_ENTITY = "specific_entity"
RELATIONSHIP_QUERY = "relationship_query"
BROAD_THEMATIC = "broad_thematic"
COMPARATIVE = "comparative"
COMPLEX_REASONING = "complex_reasoning"
FACTUAL_LOOKUP = "factual_lookup"
class SearchStrategy(Enum):
"""Search strategy determined by DRIFT routing."""
LOCAL_SEARCH = "local_search"
GLOBAL_SEARCH = "global_search"
HYBRID_SEARCH = "hybrid_search"
@dataclass
class QueryAnalysis:
"""Results of query analysis step."""
query_type: QueryType
complexity_score: float # 0.0 to 1.0
entities_mentioned: List[str]
key_concepts: List[str]
intent_description: str
context_requirements: Dict[str, Any]
estimated_scope: str # "narrow", "moderate", "broad"
@dataclass
@dataclass
class DriftRoutingResult:
"""Results of DRIFT routing decision."""
search_strategy: SearchStrategy
reasoning: str
confidence: float # 0.0 to 1.0
parameters: Dict[str, Any]
original_query: str # Added to fix answer generation
fallback_strategy: Optional[SearchStrategy] = None
@dataclass
class VectorizedQuery:
"""Results of query vectorization."""
embedding: List[float]
embedding_model: str
normalized_query: str
semantic_keywords: List[str]
similarity_threshold: float
class QueryAnalyzer:
"""Handles Step 3: Query Analysis with intent detection and complexity assessment."""
def __init__(self, config: Any):
self.config = config
self.logger = logging.getLogger('graphrag_query')
# Entity extraction patterns
self.entity_patterns = [
r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', # Proper nouns
r'\b(?:company|organization|person|place|event)\s+(?:named|called)?\s*["\']?([^"\']+)["\']?',
r'\bwho\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
r'\bwhat\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)',
]
# Complexity indicators
self.complexity_indicators = {
'high': ['compare', 'analyze', 'evaluate', 'relationship', 'impact', 'why', 'how'],
'medium': ['describe', 'explain', 'summarize', 'list', 'identify'],
'low': ['who', 'what', 'when', 'where', 'is', 'are']
}
self.logger.info("QueryAnalyzer initialized for Step 3 processing")
async def analyze_query(self, query: str) -> QueryAnalysis:
"""Analyze query for intent, complexity, and entities."""
self.logger.info(f"Starting Step 3: Query Analysis for: {query[:100]}...")
try:
# Extract entities and concepts
entities = self._extract_entities(query)
concepts = self._extract_key_concepts(query)
query_type = self._classify_query_type(query, entities, concepts)
complexity = self._calculate_complexity(query, query_type)
intent = self._determine_intent(query, query_type)
scope = self._estimate_scope(query, entities, concepts, complexity)
# Build context
context_reqs = self._analyze_context_requirements(query, query_type, entities)
analysis = QueryAnalysis(
query_type=query_type,
complexity_score=complexity,
entities_mentioned=entities,
key_concepts=concepts,
intent_description=intent,
context_requirements=context_reqs,
estimated_scope=scope
)
self.logger.info(f"Step 3 completed: Query type={query_type.value}, "
f"complexity={complexity:.2f}, entities={len(entities)}, scope={scope}")
return analysis
except Exception as e:
self.logger.error(f"Step 3 Query Analysis failed: {e}")
raise
def _extract_entities(self, query: str) -> List[str]:
"""Extract named entities from query text."""
entities = set()
for pattern in self.entity_patterns:
matches = re.findall(pattern, query, re.IGNORECASE)
entities.update(matches)
# Filter entities
filtered_entities = [
entity.strip() for entity in entities
if len(entity.strip()) > 2 and entity.lower() not in
{'the', 'and', 'are', 'is', 'was', 'were', 'this', 'that', 'what', 'who', 'how'}
]
return list(set(filtered_entities))
def _extract_key_concepts(self, query: str) -> List[str]:
"""Extract key conceptual terms from query."""
# Extract concepts
concepts = []
# Find domain terms
domain_terms = [
'revenue', 'profit', 'growth', 'market', 'strategy', 'technology',
'product', 'service', 'customer', 'partnership', 'acquisition',
'investment', 'research', 'development', 'innovation', 'competition'
]
query_lower = query.lower()
for term in domain_terms:
if term in query_lower:
concepts.append(term)
return concepts
def _classify_query_type(self, query: str, entities: List[str], concepts: List[str]) -> QueryType:
"""Classify the type of query for routing decisions."""
query_lower = query.lower()
# Check patterns
if any(word in query_lower for word in ['compare', 'versus', 'vs', 'difference']):
return QueryType.COMPARATIVE
if any(word in query_lower for word in ['relationship', 'connect', 'related', 'between']):
return QueryType.RELATIONSHIP_QUERY
if len(entities) > 0 and any(word in query_lower for word in ['who is', 'what is', 'about']):
return QueryType.SPECIFIC_ENTITY
if any(word in query_lower for word in ['analyze', 'evaluate', 'why', 'how', 'impact']):
return QueryType.COMPLEX_REASONING
if len(concepts) > 2 or any(word in query_lower for word in ['overall', 'general', 'trend']):
return QueryType.BROAD_THEMATIC
return QueryType.FACTUAL_LOOKUP
def _calculate_complexity(self, query: str, query_type: QueryType) -> float:
"""Calculate query complexity score (0.0 to 1.0)."""
base_score = 0.3
query_lower = query.lower()
# Base complexity
type_scores = {
QueryType.FACTUAL_LOOKUP: 0.2,
QueryType.SPECIFIC_ENTITY: 0.3,
QueryType.RELATIONSHIP_QUERY: 0.6,
QueryType.BROAD_THEMATIC: 0.7,
QueryType.COMPARATIVE: 0.8,
QueryType.COMPLEX_REASONING: 0.9
}
base_score = type_scores.get(query_type, 0.5)
# Adjust complexity
for level, indicators in self.complexity_indicators.items():
count = sum(1 for indicator in indicators if indicator in query_lower)
if level == 'high':
base_score += count * 0.2
elif level == 'medium':
base_score += count * 0.1
else:
base_score -= count * 0.05
# Query length and structure
if len(query.split()) > 15:
base_score += 0.1
if '?' in query and len(query.split('?')) > 2:
base_score += 0.15
return min(1.0, max(0.0, base_score))
def _determine_intent(self, query: str, query_type: QueryType) -> str:
"""Determine the user's intent based on query analysis."""
intent_map = {
QueryType.FACTUAL_LOOKUP: "Seeking specific factual information",
QueryType.SPECIFIC_ENTITY: "Requesting details about a particular entity",
QueryType.RELATIONSHIP_QUERY: "Exploring connections and relationships",
QueryType.BROAD_THEMATIC: "Understanding broad themes or patterns",
QueryType.COMPARATIVE: "Comparing entities or concepts",
QueryType.COMPLEX_REASONING: "Requiring analytical reasoning and insights"
}
return intent_map.get(query_type, "General information seeking")
def _estimate_scope(self, query: str, entities: List[str], concepts: List[str], complexity: float) -> str:
"""Estimate the scope of information needed."""
if len(entities) == 1 and complexity < 0.4:
return "narrow"
elif len(entities) > 3 or len(concepts) > 3 or complexity > 0.7:
return "broad"
else:
return "moderate"
def _analyze_context_requirements(self, query: str, query_type: QueryType, entities: List[str]) -> Dict[str, Any]:
"""Analyze what context information is needed."""
return {
"requires_entity_details": len(entities) > 0,
"requires_relationships": query_type in [QueryType.RELATIONSHIP_QUERY, QueryType.COMPARATIVE],
"requires_historical_context": any(word in query.lower() for word in ['history', 'past', 'previous', 'before']),
"requires_quantitative_data": any(word in query.lower() for word in ['number', 'amount', 'count', 'revenue', 'profit']),
"primary_entities": entities[:3] # Focus on top 3 entities
}
class DriftRouter:
"""Handles Step 4: DRIFT Routing for optimal search strategy selection."""
def __init__(self, config: Any, graph_stats: Dict[str, Any]):
self.config = config
self.graph_stats = graph_stats
self.logger = logging.getLogger('graphrag_query')
# Routing thresholds
self.local_search_threshold = 0.4
self.global_search_threshold = 0.7
self.entity_count_threshold = 10 # Based on graph size
self.logger.info("DriftRouter initialized for Step 4 processing")
async def determine_search_strategy(self, query_analysis: QueryAnalysis, original_query: str) -> DriftRoutingResult:
"""
Determine optimal search strategy using DRIFT methodology (Step 4).
Args:
query_analysis: Results from Step 3 query analysis
original_query: The original user query
Returns:
DriftRoutingResult with search strategy and parameters
"""
self.logger.info(f"Starting Step 4: DRIFT Routing for {query_analysis.query_type.value}")
try:
# Apply routing logic
strategy, reasoning, confidence, params = self._apply_drift_logic(query_analysis)
# Fallback strategy
fallback = self._determine_fallback_strategy(strategy)
result = DriftRoutingResult(
search_strategy=strategy,
reasoning=reasoning,
confidence=confidence,
parameters=params,
original_query=original_query,
fallback_strategy=fallback
)
self.logger.info(f"Step 4 completed: Strategy={strategy.value}, "
f"confidence={confidence:.2f}, reasoning={reasoning[:50]}...")
return result
except Exception as e:
self.logger.error(f"Step 4 DRIFT Routing failed: {e}")
raise
def _apply_drift_logic(self, analysis: QueryAnalysis) -> Tuple[SearchStrategy, str, float, Dict[str, Any]]:
"""Apply DRIFT (Distributed Retrieval and Information Filtering Technique) logic."""
# Decision factors
complexity = analysis.complexity_score
entity_count = len(analysis.entities_mentioned)
scope = analysis.estimated_scope
query_type = analysis.query_type
# Local search conditions
if (query_type == QueryType.SPECIFIC_ENTITY and
entity_count <= 2 and
complexity < self.local_search_threshold):
return (
SearchStrategy.LOCAL_SEARCH,
f"Specific entity query with low complexity ({complexity:.2f})",
0.9,
{
"max_depth": 2,
"entity_focus": analysis.entities_mentioned,
"include_neighbors": True,
"max_results": 20
}
)
# Global search conditions
if (complexity > self.global_search_threshold or
scope == "broad" or
query_type in [QueryType.BROAD_THEMATIC, QueryType.COMPLEX_REASONING]):
return (
SearchStrategy.GLOBAL_SEARCH,
f"High complexity ({complexity:.2f}) or broad scope requiring global context",
0.85,
{
"community_level": "high",
"max_communities": 10,
"include_summary": True,
"max_results": 50
}
)
# Hybrid search for intermediate cases
if (query_type == QueryType.RELATIONSHIP_QUERY or
query_type == QueryType.COMPARATIVE or
entity_count > 2):
return (
SearchStrategy.HYBRID_SEARCH,
f"Relationship/comparative query or multiple entities ({entity_count})",
0.75,
{
"local_depth": 2,
"global_communities": 5,
"balance_weight": 0.6, # Favor local over global
"max_results": 35
}
)
# Default to local search with moderate confidence
return (
SearchStrategy.LOCAL_SEARCH,
"Default local search for moderate complexity query",
0.6,
{
"max_depth": 3,
"entity_focus": analysis.entities_mentioned,
"include_neighbors": True,
"max_results": 25
}
)
def _determine_fallback_strategy(self, primary_strategy: SearchStrategy) -> Optional[SearchStrategy]:
"""Determine fallback strategy if primary fails."""
fallback_map = {
SearchStrategy.LOCAL_SEARCH: SearchStrategy.GLOBAL_SEARCH,
SearchStrategy.GLOBAL_SEARCH: SearchStrategy.LOCAL_SEARCH,
SearchStrategy.HYBRID_SEARCH: SearchStrategy.LOCAL_SEARCH
}
return fallback_map.get(primary_strategy)
class QueryVectorizer:
"""Handles Step 5: Query Vectorization for semantic similarity matching."""
def __init__(self, config: Any):
self.config = config
self.logger = logging.getLogger('graphrag_query')
# Initialize embedding model using same pattern as other files
self.embedding_model = HuggingFaceEmbedding(
model_name=MY_CONFIG.EMBEDDING_MODEL
)
self.model_name = MY_CONFIG.EMBEDDING_MODEL
self.embedding_dimension = MY_CONFIG.EMBEDDING_LENGTH
self.logger.info(f"QueryVectorizer initialized with {self.model_name}")
async def vectorize_query(self, query: str, query_analysis: QueryAnalysis) -> VectorizedQuery:
"""
Generate query embeddings for similarity matching (Step 5).
Args:
query: Original query text
query_analysis: Results from Step 3
Returns:
VectorizedQuery with embeddings and metadata
"""
self.logger.info(f"Starting Step 5: Query Vectorization for: {query[:100]}...")
try:
# Normalize query
normalized_query = self._normalize_query(query, query_analysis)
# Generate embedding
embedding = await self._generate_embedding(normalized_query)
# Extract keywords
semantic_keywords = self._extract_semantic_keywords(query, query_analysis)
# Set similarity threshold
similarity_threshold = self._calculate_similarity_threshold(query_analysis)
result = VectorizedQuery(
embedding=embedding,
embedding_model=self.model_name,
normalized_query=normalized_query,
semantic_keywords=semantic_keywords,
similarity_threshold=similarity_threshold
)
self.logger.info(f"Step 5 completed: Embedding dimension={len(embedding)}, "
f"threshold={similarity_threshold:.3f}, keywords={len(semantic_keywords)}")
return result
except Exception as e:
self.logger.error(f"Step 5 Query Vectorization failed: {e}")
raise
def _normalize_query(self, query: str, analysis: QueryAnalysis) -> str:
"""Normalize query text for better embedding quality."""
# Start with original query
normalized = query.strip()
# Add important entities and concepts for context
if analysis.entities_mentioned:
entity_context = " ".join(analysis.entities_mentioned[:3])
normalized = f"{normalized} [Entities: {entity_context}]"
if analysis.key_concepts:
concept_context = " ".join(analysis.key_concepts[:3])
normalized = f"{normalized} [Concepts: {concept_context}]"
return normalized
async def _generate_embedding(self, text: str) -> List[float]:
"""Generate embedding for text using configured model."""
try:
embedding = await self.embedding_model.aget_text_embedding(text)
return embedding
except Exception as e:
self.logger.error(f"Embedding generation failed: {e}")
# Fallback to synchronous call if async fails
return self.embedding_model.get_text_embedding(text)
def _extract_semantic_keywords(self, query: str, analysis: QueryAnalysis) -> List[str]:
"""Extract semantic keywords for additional matching."""
keywords = set()
# Add entities and concepts
keywords.update(analysis.entities_mentioned)
keywords.update(analysis.key_concepts)
# Add query-specific terms based on type
if analysis.query_type == QueryType.RELATIONSHIP_QUERY:
keywords.update(['relationship', 'connection', 'related', 'linked'])
elif analysis.query_type == QueryType.COMPARATIVE:
keywords.update(['comparison', 'versus', 'difference', 'similar'])
elif analysis.query_type == QueryType.BROAD_THEMATIC:
keywords.update(['theme', 'pattern', 'trend', 'overview'])
# Filter and return as list
return [kw for kw in keywords if len(kw) > 2]
def _calculate_similarity_threshold(self, analysis: QueryAnalysis) -> float:
"""Calculate appropriate similarity threshold based on query characteristics."""
base_threshold = 0.7
# Adjust based on query complexity
if analysis.complexity_score > 0.7:
base_threshold -= 0.1 # Lower threshold for complex queries
elif analysis.complexity_score < 0.3:
base_threshold += 0.1 # Higher threshold for simple queries
# Adjust based on scope
if analysis.estimated_scope == "narrow":
base_threshold += 0.05
elif analysis.estimated_scope == "broad":
base_threshold -= 0.05
# Ensure reasonable bounds
return max(0.5, min(0.9, base_threshold))
class QueryPreprocessor:
"""Main class coordinating all query preprocessing steps (Steps 3-5)."""
def __init__(self, config: Any, graph_stats: Dict[str, Any]):
self.config = config
self.graph_stats = graph_stats
self.logger = logging.getLogger('graphrag_query')
# Initialize component processors
self.analyzer = QueryAnalyzer(config)
self.router = DriftRouter(config, graph_stats)
self.vectorizer = QueryVectorizer(config)
self.logger.info("QueryPreprocessor initialized for Steps 3-5")
async def preprocess_query(self, query: str) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]:
"""
Execute complete query preprocessing pipeline (Steps 3-5).
Args:
query: User's natural language query
Returns:
Tuple of (analysis, routing, vectorization) results
"""
self.logger.info(f"Starting Phase B: Query Preprocessing Pipeline for: {query[:100]}...")
try:
# Query analysis
analysis = await self.analyzer.analyze_query(query)
# Query routing
routing = await self.router.determine_search_strategy(analysis, query)
# Query vectorization
vectorization = await self.vectorizer.vectorize_query(query, analysis)
self.logger.info(f"Phase B completed successfully: "
f"Type={analysis.query_type.value}, "
f"Strategy={routing.search_strategy.value}, "
f"Embedding_dim={len(vectorization.embedding)}")
return analysis, routing, vectorization
except Exception as e:
self.logger.error(f"Query preprocessing pipeline failed: {e}")
raise
# Exports
async def create_query_preprocessor(config: Any, graph_stats: Dict[str, Any]) -> QueryPreprocessor:
"""Create and initialize QueryPreprocessor."""
return QueryPreprocessor(config, graph_stats)
async def preprocess_query_pipeline(query: str, config: Any, graph_stats: Dict[str, Any]) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]:
"""
Convenience function for complete query preprocessing.
Args:
query: User's natural language query
config: Application configuration
graph_stats: Graph database statistics
Returns:
Complete preprocessing results
"""
preprocessor = await create_query_preprocessor(config, graph_stats)
return await preprocessor.preprocess_query(query)
__all__ = [
'QueryAnalyzer', 'DriftRouter', 'QueryVectorizer', 'QueryPreprocessor',
'create_query_preprocessor', 'preprocess_query_pipeline',
'QueryAnalysis', 'DriftRoutingResult', 'VectorizedQuery',
'QueryType', 'SearchStrategy'
] |