Spaces:
Runtime error
Runtime error
| """Query preprocessing for analysis, routing, and vectorization - Phase B (Steps 3-5).""" | |
| import logging | |
| from typing import Dict, List, Any, Tuple, Optional | |
| from dataclasses import dataclass | |
| from enum import Enum | |
| import re | |
| # System imports | |
| import sys | |
| import os | |
| sys.path.append(os.path.dirname(os.path.dirname(__file__))) | |
| from my_config import MY_CONFIG | |
| from llama_index.embeddings.huggingface import HuggingFaceEmbedding | |
| class QueryType(Enum): | |
| """Query type classifications for DRIFT routing.""" | |
| SPECIFIC_ENTITY = "specific_entity" | |
| RELATIONSHIP_QUERY = "relationship_query" | |
| BROAD_THEMATIC = "broad_thematic" | |
| COMPARATIVE = "comparative" | |
| COMPLEX_REASONING = "complex_reasoning" | |
| FACTUAL_LOOKUP = "factual_lookup" | |
| class SearchStrategy(Enum): | |
| """Search strategy determined by DRIFT routing.""" | |
| LOCAL_SEARCH = "local_search" | |
| GLOBAL_SEARCH = "global_search" | |
| HYBRID_SEARCH = "hybrid_search" | |
| class QueryAnalysis: | |
| """Results of query analysis step.""" | |
| query_type: QueryType | |
| complexity_score: float # 0.0 to 1.0 | |
| entities_mentioned: List[str] | |
| key_concepts: List[str] | |
| intent_description: str | |
| context_requirements: Dict[str, Any] | |
| estimated_scope: str # "narrow", "moderate", "broad" | |
| class DriftRoutingResult: | |
| """Results of DRIFT routing decision.""" | |
| search_strategy: SearchStrategy | |
| reasoning: str | |
| confidence: float # 0.0 to 1.0 | |
| parameters: Dict[str, Any] | |
| original_query: str # Added to fix answer generation | |
| fallback_strategy: Optional[SearchStrategy] = None | |
| class VectorizedQuery: | |
| """Results of query vectorization.""" | |
| embedding: List[float] | |
| embedding_model: str | |
| normalized_query: str | |
| semantic_keywords: List[str] | |
| similarity_threshold: float | |
| class QueryAnalyzer: | |
| """Handles Step 3: Query Analysis with intent detection and complexity assessment.""" | |
| def __init__(self, config: Any): | |
| self.config = config | |
| self.logger = logging.getLogger('graphrag_query') | |
| # Entity extraction patterns | |
| self.entity_patterns = [ | |
| r'\b[A-Z][a-z]+(?:\s+[A-Z][a-z]+)*\b', # Proper nouns | |
| r'\b(?:company|organization|person|place|event)\s+(?:named|called)?\s*["\']?([^"\']+)["\']?', | |
| r'\bwho\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', | |
| r'\bwhat\s+is\s+([A-Z][a-z]+(?:\s+[A-Z][a-z]+)*)', | |
| ] | |
| # Complexity indicators | |
| self.complexity_indicators = { | |
| 'high': ['compare', 'analyze', 'evaluate', 'relationship', 'impact', 'why', 'how'], | |
| 'medium': ['describe', 'explain', 'summarize', 'list', 'identify'], | |
| 'low': ['who', 'what', 'when', 'where', 'is', 'are'] | |
| } | |
| self.logger.info("QueryAnalyzer initialized for Step 3 processing") | |
| async def analyze_query(self, query: str) -> QueryAnalysis: | |
| """Analyze query for intent, complexity, and entities.""" | |
| self.logger.info(f"Starting Step 3: Query Analysis for: {query[:100]}...") | |
| try: | |
| # Extract entities and concepts | |
| entities = self._extract_entities(query) | |
| concepts = self._extract_key_concepts(query) | |
| query_type = self._classify_query_type(query, entities, concepts) | |
| complexity = self._calculate_complexity(query, query_type) | |
| intent = self._determine_intent(query, query_type) | |
| scope = self._estimate_scope(query, entities, concepts, complexity) | |
| # Build context | |
| context_reqs = self._analyze_context_requirements(query, query_type, entities) | |
| analysis = QueryAnalysis( | |
| query_type=query_type, | |
| complexity_score=complexity, | |
| entities_mentioned=entities, | |
| key_concepts=concepts, | |
| intent_description=intent, | |
| context_requirements=context_reqs, | |
| estimated_scope=scope | |
| ) | |
| self.logger.info(f"Step 3 completed: Query type={query_type.value}, " | |
| f"complexity={complexity:.2f}, entities={len(entities)}, scope={scope}") | |
| return analysis | |
| except Exception as e: | |
| self.logger.error(f"Step 3 Query Analysis failed: {e}") | |
| raise | |
| def _extract_entities(self, query: str) -> List[str]: | |
| """Extract named entities from query text.""" | |
| entities = set() | |
| for pattern in self.entity_patterns: | |
| matches = re.findall(pattern, query, re.IGNORECASE) | |
| entities.update(matches) | |
| # Filter entities | |
| filtered_entities = [ | |
| entity.strip() for entity in entities | |
| if len(entity.strip()) > 2 and entity.lower() not in | |
| {'the', 'and', 'are', 'is', 'was', 'were', 'this', 'that', 'what', 'who', 'how'} | |
| ] | |
| return list(set(filtered_entities)) | |
| def _extract_key_concepts(self, query: str) -> List[str]: | |
| """Extract key conceptual terms from query.""" | |
| # Extract concepts | |
| concepts = [] | |
| # Find domain terms | |
| domain_terms = [ | |
| 'revenue', 'profit', 'growth', 'market', 'strategy', 'technology', | |
| 'product', 'service', 'customer', 'partnership', 'acquisition', | |
| 'investment', 'research', 'development', 'innovation', 'competition' | |
| ] | |
| query_lower = query.lower() | |
| for term in domain_terms: | |
| if term in query_lower: | |
| concepts.append(term) | |
| return concepts | |
| def _classify_query_type(self, query: str, entities: List[str], concepts: List[str]) -> QueryType: | |
| """Classify the type of query for routing decisions.""" | |
| query_lower = query.lower() | |
| # Check patterns | |
| if any(word in query_lower for word in ['compare', 'versus', 'vs', 'difference']): | |
| return QueryType.COMPARATIVE | |
| if any(word in query_lower for word in ['relationship', 'connect', 'related', 'between']): | |
| return QueryType.RELATIONSHIP_QUERY | |
| if len(entities) > 0 and any(word in query_lower for word in ['who is', 'what is', 'about']): | |
| return QueryType.SPECIFIC_ENTITY | |
| if any(word in query_lower for word in ['analyze', 'evaluate', 'why', 'how', 'impact']): | |
| return QueryType.COMPLEX_REASONING | |
| if len(concepts) > 2 or any(word in query_lower for word in ['overall', 'general', 'trend']): | |
| return QueryType.BROAD_THEMATIC | |
| return QueryType.FACTUAL_LOOKUP | |
| def _calculate_complexity(self, query: str, query_type: QueryType) -> float: | |
| """Calculate query complexity score (0.0 to 1.0).""" | |
| base_score = 0.3 | |
| query_lower = query.lower() | |
| # Base complexity | |
| type_scores = { | |
| QueryType.FACTUAL_LOOKUP: 0.2, | |
| QueryType.SPECIFIC_ENTITY: 0.3, | |
| QueryType.RELATIONSHIP_QUERY: 0.6, | |
| QueryType.BROAD_THEMATIC: 0.7, | |
| QueryType.COMPARATIVE: 0.8, | |
| QueryType.COMPLEX_REASONING: 0.9 | |
| } | |
| base_score = type_scores.get(query_type, 0.5) | |
| # Adjust complexity | |
| for level, indicators in self.complexity_indicators.items(): | |
| count = sum(1 for indicator in indicators if indicator in query_lower) | |
| if level == 'high': | |
| base_score += count * 0.2 | |
| elif level == 'medium': | |
| base_score += count * 0.1 | |
| else: | |
| base_score -= count * 0.05 | |
| # Query length and structure | |
| if len(query.split()) > 15: | |
| base_score += 0.1 | |
| if '?' in query and len(query.split('?')) > 2: | |
| base_score += 0.15 | |
| return min(1.0, max(0.0, base_score)) | |
| def _determine_intent(self, query: str, query_type: QueryType) -> str: | |
| """Determine the user's intent based on query analysis.""" | |
| intent_map = { | |
| QueryType.FACTUAL_LOOKUP: "Seeking specific factual information", | |
| QueryType.SPECIFIC_ENTITY: "Requesting details about a particular entity", | |
| QueryType.RELATIONSHIP_QUERY: "Exploring connections and relationships", | |
| QueryType.BROAD_THEMATIC: "Understanding broad themes or patterns", | |
| QueryType.COMPARATIVE: "Comparing entities or concepts", | |
| QueryType.COMPLEX_REASONING: "Requiring analytical reasoning and insights" | |
| } | |
| return intent_map.get(query_type, "General information seeking") | |
| def _estimate_scope(self, query: str, entities: List[str], concepts: List[str], complexity: float) -> str: | |
| """Estimate the scope of information needed.""" | |
| if len(entities) == 1 and complexity < 0.4: | |
| return "narrow" | |
| elif len(entities) > 3 or len(concepts) > 3 or complexity > 0.7: | |
| return "broad" | |
| else: | |
| return "moderate" | |
| def _analyze_context_requirements(self, query: str, query_type: QueryType, entities: List[str]) -> Dict[str, Any]: | |
| """Analyze what context information is needed.""" | |
| return { | |
| "requires_entity_details": len(entities) > 0, | |
| "requires_relationships": query_type in [QueryType.RELATIONSHIP_QUERY, QueryType.COMPARATIVE], | |
| "requires_historical_context": any(word in query.lower() for word in ['history', 'past', 'previous', 'before']), | |
| "requires_quantitative_data": any(word in query.lower() for word in ['number', 'amount', 'count', 'revenue', 'profit']), | |
| "primary_entities": entities[:3] # Focus on top 3 entities | |
| } | |
| class DriftRouter: | |
| """Handles Step 4: DRIFT Routing for optimal search strategy selection.""" | |
| def __init__(self, config: Any, graph_stats: Dict[str, Any]): | |
| self.config = config | |
| self.graph_stats = graph_stats | |
| self.logger = logging.getLogger('graphrag_query') | |
| # Routing thresholds | |
| self.local_search_threshold = 0.4 | |
| self.global_search_threshold = 0.7 | |
| self.entity_count_threshold = 10 # Based on graph size | |
| self.logger.info("DriftRouter initialized for Step 4 processing") | |
| async def determine_search_strategy(self, query_analysis: QueryAnalysis, original_query: str) -> DriftRoutingResult: | |
| """ | |
| Determine optimal search strategy using DRIFT methodology (Step 4). | |
| Args: | |
| query_analysis: Results from Step 3 query analysis | |
| original_query: The original user query | |
| Returns: | |
| DriftRoutingResult with search strategy and parameters | |
| """ | |
| self.logger.info(f"Starting Step 4: DRIFT Routing for {query_analysis.query_type.value}") | |
| try: | |
| # Apply routing logic | |
| strategy, reasoning, confidence, params = self._apply_drift_logic(query_analysis) | |
| # Fallback strategy | |
| fallback = self._determine_fallback_strategy(strategy) | |
| result = DriftRoutingResult( | |
| search_strategy=strategy, | |
| reasoning=reasoning, | |
| confidence=confidence, | |
| parameters=params, | |
| original_query=original_query, | |
| fallback_strategy=fallback | |
| ) | |
| self.logger.info(f"Step 4 completed: Strategy={strategy.value}, " | |
| f"confidence={confidence:.2f}, reasoning={reasoning[:50]}...") | |
| return result | |
| except Exception as e: | |
| self.logger.error(f"Step 4 DRIFT Routing failed: {e}") | |
| raise | |
| def _apply_drift_logic(self, analysis: QueryAnalysis) -> Tuple[SearchStrategy, str, float, Dict[str, Any]]: | |
| """Apply DRIFT (Distributed Retrieval and Information Filtering Technique) logic.""" | |
| # Decision factors | |
| complexity = analysis.complexity_score | |
| entity_count = len(analysis.entities_mentioned) | |
| scope = analysis.estimated_scope | |
| query_type = analysis.query_type | |
| # Local search conditions | |
| if (query_type == QueryType.SPECIFIC_ENTITY and | |
| entity_count <= 2 and | |
| complexity < self.local_search_threshold): | |
| return ( | |
| SearchStrategy.LOCAL_SEARCH, | |
| f"Specific entity query with low complexity ({complexity:.2f})", | |
| 0.9, | |
| { | |
| "max_depth": 2, | |
| "entity_focus": analysis.entities_mentioned, | |
| "include_neighbors": True, | |
| "max_results": 20 | |
| } | |
| ) | |
| # Global search conditions | |
| if (complexity > self.global_search_threshold or | |
| scope == "broad" or | |
| query_type in [QueryType.BROAD_THEMATIC, QueryType.COMPLEX_REASONING]): | |
| return ( | |
| SearchStrategy.GLOBAL_SEARCH, | |
| f"High complexity ({complexity:.2f}) or broad scope requiring global context", | |
| 0.85, | |
| { | |
| "community_level": "high", | |
| "max_communities": 10, | |
| "include_summary": True, | |
| "max_results": 50 | |
| } | |
| ) | |
| # Hybrid search for intermediate cases | |
| if (query_type == QueryType.RELATIONSHIP_QUERY or | |
| query_type == QueryType.COMPARATIVE or | |
| entity_count > 2): | |
| return ( | |
| SearchStrategy.HYBRID_SEARCH, | |
| f"Relationship/comparative query or multiple entities ({entity_count})", | |
| 0.75, | |
| { | |
| "local_depth": 2, | |
| "global_communities": 5, | |
| "balance_weight": 0.6, # Favor local over global | |
| "max_results": 35 | |
| } | |
| ) | |
| # Default to local search with moderate confidence | |
| return ( | |
| SearchStrategy.LOCAL_SEARCH, | |
| "Default local search for moderate complexity query", | |
| 0.6, | |
| { | |
| "max_depth": 3, | |
| "entity_focus": analysis.entities_mentioned, | |
| "include_neighbors": True, | |
| "max_results": 25 | |
| } | |
| ) | |
| def _determine_fallback_strategy(self, primary_strategy: SearchStrategy) -> Optional[SearchStrategy]: | |
| """Determine fallback strategy if primary fails.""" | |
| fallback_map = { | |
| SearchStrategy.LOCAL_SEARCH: SearchStrategy.GLOBAL_SEARCH, | |
| SearchStrategy.GLOBAL_SEARCH: SearchStrategy.LOCAL_SEARCH, | |
| SearchStrategy.HYBRID_SEARCH: SearchStrategy.LOCAL_SEARCH | |
| } | |
| return fallback_map.get(primary_strategy) | |
| class QueryVectorizer: | |
| """Handles Step 5: Query Vectorization for semantic similarity matching.""" | |
| def __init__(self, config: Any): | |
| self.config = config | |
| self.logger = logging.getLogger('graphrag_query') | |
| # Initialize embedding model using same pattern as other files | |
| self.embedding_model = HuggingFaceEmbedding( | |
| model_name=MY_CONFIG.EMBEDDING_MODEL | |
| ) | |
| self.model_name = MY_CONFIG.EMBEDDING_MODEL | |
| self.embedding_dimension = MY_CONFIG.EMBEDDING_LENGTH | |
| self.logger.info(f"QueryVectorizer initialized with {self.model_name}") | |
| async def vectorize_query(self, query: str, query_analysis: QueryAnalysis) -> VectorizedQuery: | |
| """ | |
| Generate query embeddings for similarity matching (Step 5). | |
| Args: | |
| query: Original query text | |
| query_analysis: Results from Step 3 | |
| Returns: | |
| VectorizedQuery with embeddings and metadata | |
| """ | |
| self.logger.info(f"Starting Step 5: Query Vectorization for: {query[:100]}...") | |
| try: | |
| # Normalize query | |
| normalized_query = self._normalize_query(query, query_analysis) | |
| # Generate embedding | |
| embedding = await self._generate_embedding(normalized_query) | |
| # Extract keywords | |
| semantic_keywords = self._extract_semantic_keywords(query, query_analysis) | |
| # Set similarity threshold | |
| similarity_threshold = self._calculate_similarity_threshold(query_analysis) | |
| result = VectorizedQuery( | |
| embedding=embedding, | |
| embedding_model=self.model_name, | |
| normalized_query=normalized_query, | |
| semantic_keywords=semantic_keywords, | |
| similarity_threshold=similarity_threshold | |
| ) | |
| self.logger.info(f"Step 5 completed: Embedding dimension={len(embedding)}, " | |
| f"threshold={similarity_threshold:.3f}, keywords={len(semantic_keywords)}") | |
| return result | |
| except Exception as e: | |
| self.logger.error(f"Step 5 Query Vectorization failed: {e}") | |
| raise | |
| def _normalize_query(self, query: str, analysis: QueryAnalysis) -> str: | |
| """Normalize query text for better embedding quality.""" | |
| # Start with original query | |
| normalized = query.strip() | |
| # Add important entities and concepts for context | |
| if analysis.entities_mentioned: | |
| entity_context = " ".join(analysis.entities_mentioned[:3]) | |
| normalized = f"{normalized} [Entities: {entity_context}]" | |
| if analysis.key_concepts: | |
| concept_context = " ".join(analysis.key_concepts[:3]) | |
| normalized = f"{normalized} [Concepts: {concept_context}]" | |
| return normalized | |
| async def _generate_embedding(self, text: str) -> List[float]: | |
| """Generate embedding for text using configured model.""" | |
| try: | |
| embedding = await self.embedding_model.aget_text_embedding(text) | |
| return embedding | |
| except Exception as e: | |
| self.logger.error(f"Embedding generation failed: {e}") | |
| # Fallback to synchronous call if async fails | |
| return self.embedding_model.get_text_embedding(text) | |
| def _extract_semantic_keywords(self, query: str, analysis: QueryAnalysis) -> List[str]: | |
| """Extract semantic keywords for additional matching.""" | |
| keywords = set() | |
| # Add entities and concepts | |
| keywords.update(analysis.entities_mentioned) | |
| keywords.update(analysis.key_concepts) | |
| # Add query-specific terms based on type | |
| if analysis.query_type == QueryType.RELATIONSHIP_QUERY: | |
| keywords.update(['relationship', 'connection', 'related', 'linked']) | |
| elif analysis.query_type == QueryType.COMPARATIVE: | |
| keywords.update(['comparison', 'versus', 'difference', 'similar']) | |
| elif analysis.query_type == QueryType.BROAD_THEMATIC: | |
| keywords.update(['theme', 'pattern', 'trend', 'overview']) | |
| # Filter and return as list | |
| return [kw for kw in keywords if len(kw) > 2] | |
| def _calculate_similarity_threshold(self, analysis: QueryAnalysis) -> float: | |
| """Calculate appropriate similarity threshold based on query characteristics.""" | |
| base_threshold = 0.7 | |
| # Adjust based on query complexity | |
| if analysis.complexity_score > 0.7: | |
| base_threshold -= 0.1 # Lower threshold for complex queries | |
| elif analysis.complexity_score < 0.3: | |
| base_threshold += 0.1 # Higher threshold for simple queries | |
| # Adjust based on scope | |
| if analysis.estimated_scope == "narrow": | |
| base_threshold += 0.05 | |
| elif analysis.estimated_scope == "broad": | |
| base_threshold -= 0.05 | |
| # Ensure reasonable bounds | |
| return max(0.5, min(0.9, base_threshold)) | |
| class QueryPreprocessor: | |
| """Main class coordinating all query preprocessing steps (Steps 3-5).""" | |
| def __init__(self, config: Any, graph_stats: Dict[str, Any]): | |
| self.config = config | |
| self.graph_stats = graph_stats | |
| self.logger = logging.getLogger('graphrag_query') | |
| # Initialize component processors | |
| self.analyzer = QueryAnalyzer(config) | |
| self.router = DriftRouter(config, graph_stats) | |
| self.vectorizer = QueryVectorizer(config) | |
| self.logger.info("QueryPreprocessor initialized for Steps 3-5") | |
| async def preprocess_query(self, query: str) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]: | |
| """ | |
| Execute complete query preprocessing pipeline (Steps 3-5). | |
| Args: | |
| query: User's natural language query | |
| Returns: | |
| Tuple of (analysis, routing, vectorization) results | |
| """ | |
| self.logger.info(f"Starting Phase B: Query Preprocessing Pipeline for: {query[:100]}...") | |
| try: | |
| # Query analysis | |
| analysis = await self.analyzer.analyze_query(query) | |
| # Query routing | |
| routing = await self.router.determine_search_strategy(analysis, query) | |
| # Query vectorization | |
| vectorization = await self.vectorizer.vectorize_query(query, analysis) | |
| self.logger.info(f"Phase B completed successfully: " | |
| f"Type={analysis.query_type.value}, " | |
| f"Strategy={routing.search_strategy.value}, " | |
| f"Embedding_dim={len(vectorization.embedding)}") | |
| return analysis, routing, vectorization | |
| except Exception as e: | |
| self.logger.error(f"Query preprocessing pipeline failed: {e}") | |
| raise | |
| # Exports | |
| async def create_query_preprocessor(config: Any, graph_stats: Dict[str, Any]) -> QueryPreprocessor: | |
| """Create and initialize QueryPreprocessor.""" | |
| return QueryPreprocessor(config, graph_stats) | |
| async def preprocess_query_pipeline(query: str, config: Any, graph_stats: Dict[str, Any]) -> Tuple[QueryAnalysis, DriftRoutingResult, VectorizedQuery]: | |
| """ | |
| Convenience function for complete query preprocessing. | |
| Args: | |
| query: User's natural language query | |
| config: Application configuration | |
| graph_stats: Graph database statistics | |
| Returns: | |
| Complete preprocessing results | |
| """ | |
| preprocessor = await create_query_preprocessor(config, graph_stats) | |
| return await preprocessor.preprocess_query(query) | |
| __all__ = [ | |
| 'QueryAnalyzer', 'DriftRouter', 'QueryVectorizer', 'QueryPreprocessor', | |
| 'create_query_preprocessor', 'preprocess_query_pipeline', | |
| 'QueryAnalysis', 'DriftRoutingResult', 'VectorizedQuery', | |
| 'QueryType', 'SearchStrategy' | |
| ] |