| |
|
|
| """ |
| Query Expansion System for CogniChat RAG Application |
| |
| This module implements advanced query expansion techniques to improve retrieval quality: |
| - QueryAnalyzer: Extracts intent, entities, and keywords |
| - QueryRephraser: Generates natural language variations |
| - MultiQueryExpander: Creates diverse query formulations |
| - MultiHopReasoner: Connects concepts across documents |
| - FallbackStrategies: Handles edge cases gracefully |
| |
| Author: CogniChat Team |
| Date: October 19, 2025 |
| """ |
|
|
| import re |
| from typing import List, Dict, Any, Optional |
| from dataclasses import dataclass |
| from enum import Enum |
|
|
|
|
| class QueryStrategy(Enum): |
| """Query expansion strategies with different complexity levels.""" |
| QUICK = "quick" |
| BALANCED = "balanced" |
| COMPREHENSIVE = "comprehensive" |
|
|
|
|
| @dataclass |
| class QueryAnalysis: |
| """Results from query analysis.""" |
| intent: str |
| entities: List[str] |
| keywords: List[str] |
| complexity: str |
| domain: Optional[str] = None |
|
|
|
|
| @dataclass |
| class ExpandedQuery: |
| """Container for expanded query variations.""" |
| original: str |
| variations: List[str] |
| strategy_used: QueryStrategy |
| analysis: QueryAnalysis |
|
|
|
|
| class QueryAnalyzer: |
| """ |
| Analyzes queries to extract intent, entities, and key information. |
| Uses LLM-based analysis for intelligent query understanding. |
| """ |
| |
| def __init__(self, llm=None): |
| """ |
| Initialize QueryAnalyzer. |
| |
| Args: |
| llm: Optional LangChain LLM for advanced analysis |
| """ |
| self.llm = llm |
| self.intent_patterns = { |
| 'definition': r'\b(what is|define|meaning of|definition)\b', |
| 'how_to': r'\b(how to|how do|how can|steps to)\b', |
| 'comparison': r'\b(compare|difference|versus|vs|better than)\b', |
| 'explanation': r'\b(why|explain|reason|cause)\b', |
| 'listing': r'\b(list|enumerate|what are|types of)\b', |
| 'example': r'\b(example|instance|sample|case)\b', |
| } |
| |
| def analyze(self, query: str) -> QueryAnalysis: |
| """ |
| Analyze query to extract intent, entities, and keywords. |
| |
| Args: |
| query: User's original query |
| |
| Returns: |
| QueryAnalysis object with extracted information |
| """ |
| query_lower = query.lower() |
| |
| |
| intent = self._detect_intent(query_lower) |
| |
| |
| entities = self._extract_entities(query) |
| |
| |
| keywords = self._extract_keywords(query) |
| |
| |
| complexity = self._assess_complexity(query, entities, keywords) |
| |
| |
| domain = self._detect_domain(query_lower) |
| |
| return QueryAnalysis( |
| intent=intent, |
| entities=entities, |
| keywords=keywords, |
| complexity=complexity, |
| domain=domain |
| ) |
| |
| def _detect_intent(self, query_lower: str) -> str: |
| """Detect query intent using pattern matching.""" |
| for intent, pattern in self.intent_patterns.items(): |
| if re.search(pattern, query_lower): |
| return intent |
| return 'general' |
| |
| def _extract_entities(self, query: str) -> List[str]: |
| """Extract named entities (simplified version).""" |
| |
| words = query.split() |
| entities = [] |
| |
| for word in words: |
| |
| if word[0].isupper() and word.lower() not in ['what', 'how', 'why', 'when', 'where', 'which']: |
| entities.append(word) |
| |
| |
| quoted = re.findall(r'"([^"]+)"', query) |
| entities.extend(quoted) |
| |
| return list(set(entities)) |
| |
| def _extract_keywords(self, query: str) -> List[str]: |
| """Extract important keywords from query.""" |
| |
| stop_words = { |
| 'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', |
| 'what', 'how', 'why', 'when', 'where', 'which', 'who', |
| 'do', 'does', 'did', 'can', 'could', 'should', 'would', |
| 'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by' |
| } |
| |
| |
| words = re.findall(r'\b\w+\b', query.lower()) |
| keywords = [w for w in words if w not in stop_words and len(w) > 2] |
| |
| return keywords[:10] |
| |
| def _assess_complexity(self, query: str, entities: List[str], keywords: List[str]) -> str: |
| """Assess query complexity.""" |
| word_count = len(query.split()) |
| entity_count = len(entities) |
| keyword_count = len(keywords) |
| |
| |
| score = word_count + (entity_count * 2) + (keyword_count * 1.5) |
| |
| if score < 15: |
| return 'simple' |
| elif score < 30: |
| return 'medium' |
| else: |
| return 'complex' |
| |
| def _detect_domain(self, query_lower: str) -> Optional[str]: |
| """Detect technical domain if present.""" |
| domains = { |
| 'programming': ['code', 'function', 'class', 'variable', 'algorithm', 'debug'], |
| 'data_science': ['model', 'dataset', 'training', 'prediction', 'accuracy'], |
| 'machine_learning': ['neural', 'network', 'learning', 'ai', 'deep learning'], |
| 'web': ['html', 'css', 'javascript', 'api', 'frontend', 'backend'], |
| 'database': ['sql', 'query', 'database', 'table', 'index'], |
| 'security': ['encryption', 'authentication', 'vulnerability', 'attack'], |
| } |
| |
| for domain, keywords in domains.items(): |
| if any(kw in query_lower for kw in keywords): |
| return domain |
| |
| return None |
|
|
|
|
| class QueryRephraser: |
| """ |
| Generates natural language variations of queries using multiple strategies. |
| """ |
| |
| def __init__(self, llm=None): |
| """ |
| Initialize QueryRephraser. |
| |
| Args: |
| llm: LangChain LLM for generating variations |
| """ |
| self.llm = llm |
| |
| def generate_variations( |
| self, |
| query: str, |
| analysis: QueryAnalysis, |
| strategy: QueryStrategy = QueryStrategy.BALANCED |
| ) -> List[str]: |
| """ |
| Generate query variations based on strategy. |
| |
| Args: |
| query: Original query |
| analysis: Query analysis results |
| strategy: Expansion strategy to use |
| |
| Returns: |
| List of query variations |
| """ |
| variations = [query] |
| |
| if strategy == QueryStrategy.QUICK: |
| |
| variations.append(self._synonym_variation(query, analysis)) |
| |
| elif strategy == QueryStrategy.BALANCED: |
| |
| variations.append(self._synonym_variation(query, analysis)) |
| variations.append(self._expanded_variation(query, analysis)) |
| variations.append(self._simplified_variation(query, analysis)) |
| |
| elif strategy == QueryStrategy.COMPREHENSIVE: |
| |
| variations.append(self._synonym_variation(query, analysis)) |
| variations.append(self._expanded_variation(query, analysis)) |
| variations.append(self._simplified_variation(query, analysis)) |
| variations.append(self._keyword_focused(query, analysis)) |
| variations.append(self._context_variation(query, analysis)) |
| |
| if analysis.intent in ['how_to', 'explanation']: |
| variations.append(f"Guide to {' '.join(analysis.keywords[:3])}") |
| |
| |
| variations = [v for v in variations if v] |
| return list(dict.fromkeys(variations)) |
| |
| def _synonym_variation(self, query: str, analysis: QueryAnalysis) -> str: |
| """Generate variation using synonyms.""" |
| |
| synonyms = { |
| 'error': 'issue', |
| 'problem': 'issue', |
| 'fix': 'resolve', |
| 'use': 'utilize', |
| 'create': 'generate', |
| 'make': 'create', |
| 'get': 'retrieve', |
| 'show': 'display', |
| 'find': 'locate', |
| 'explain': 'describe', |
| } |
| |
| words = query.lower().split() |
| for i, word in enumerate(words): |
| if word in synonyms: |
| words[i] = synonyms[word] |
| break |
| |
| return ' '.join(words).capitalize() |
| |
| def _expanded_variation(self, query: str, analysis: QueryAnalysis) -> str: |
| """Generate expanded version with more detail.""" |
| if analysis.intent == 'definition': |
| return f"Detailed explanation and definition of {' '.join(analysis.keywords)}" |
| elif analysis.intent == 'how_to': |
| return f"Step-by-step guide on {query.lower()}" |
| elif analysis.intent == 'comparison': |
| return f"Comprehensive comparison: {query}" |
| else: |
| |
| return f"Detailed information about {query.lower()}" |
| |
| def _simplified_variation(self, query: str, analysis: QueryAnalysis) -> str: |
| """Generate simplified version focusing on core concepts.""" |
| |
| if len(analysis.keywords) >= 2: |
| return ' '.join(analysis.keywords[:3]) |
| return query |
| |
| def _keyword_focused(self, query: str, analysis: QueryAnalysis) -> str: |
| """Create keyword-focused variation for BM25.""" |
| keywords = analysis.keywords + analysis.entities |
| return ' '.join(keywords[:5]) |
| |
| def _context_variation(self, query: str, analysis: QueryAnalysis) -> str: |
| """Add contextual information if domain detected.""" |
| if analysis.domain: |
| return f"{query} in {analysis.domain} context" |
| return query |
|
|
|
|
| class MultiQueryExpander: |
| """ |
| Main query expansion orchestrator that combines analysis and rephrasing. |
| """ |
| |
| def __init__(self, llm=None): |
| """ |
| Initialize MultiQueryExpander. |
| |
| Args: |
| llm: LangChain LLM for advanced expansions |
| """ |
| self.analyzer = QueryAnalyzer(llm) |
| self.rephraser = QueryRephraser(llm) |
| |
| def expand( |
| self, |
| query: str, |
| strategy: QueryStrategy = QueryStrategy.BALANCED, |
| max_queries: int = 6 |
| ) -> ExpandedQuery: |
| """ |
| Expand query into multiple variations. |
| |
| Args: |
| query: Original user query |
| strategy: Expansion strategy |
| max_queries: Maximum number of queries to generate |
| |
| Returns: |
| ExpandedQuery object with all variations |
| """ |
| |
| analysis = self.analyzer.analyze(query) |
| |
| |
| variations = self.rephraser.generate_variations(query, analysis, strategy) |
| |
| |
| variations = variations[:max_queries] |
| |
| return ExpandedQuery( |
| original=query, |
| variations=variations, |
| strategy_used=strategy, |
| analysis=analysis |
| ) |
|
|
|
|
| class MultiHopReasoner: |
| """ |
| Implements multi-hop reasoning to connect concepts across documents. |
| Useful for complex queries that require information from multiple sources. |
| """ |
| |
| def __init__(self, llm=None): |
| """ |
| Initialize MultiHopReasoner. |
| |
| Args: |
| llm: LangChain LLM for reasoning |
| """ |
| self.llm = llm |
| |
| def generate_sub_queries(self, query: str, analysis: QueryAnalysis) -> List[str]: |
| """ |
| Break complex query into sub-queries for multi-hop reasoning. |
| |
| Args: |
| query: Original complex query |
| analysis: Query analysis |
| |
| Returns: |
| List of sub-queries |
| """ |
| sub_queries = [query] |
| |
| |
| if analysis.intent == 'comparison' and len(analysis.entities) >= 2: |
| for entity in analysis.entities[:2]: |
| sub_queries.append(f"Information about {entity}") |
| elif analysis.intent == 'comparison' and len(analysis.keywords) >= 2: |
| |
| for keyword in analysis.keywords[:2]: |
| sub_queries.append(f"Information about {keyword}") |
| |
| |
| if analysis.intent == 'how_to' and len(analysis.keywords) >= 2: |
| main_topic = ' '.join(analysis.keywords[:2]) |
| sub_queries.append(f"Prerequisites for {main_topic}") |
| sub_queries.append(f"Steps to {main_topic}") |
| |
| |
| if analysis.complexity == 'complex' and len(analysis.keywords) > 3: |
| |
| mid = len(analysis.keywords) // 2 |
| sub_queries.append(' '.join(analysis.keywords[:mid])) |
| sub_queries.append(' '.join(analysis.keywords[mid:])) |
| |
| return sub_queries[:5] |
|
|
|
|
| class FallbackStrategies: |
| """ |
| Implements fallback strategies for queries that don't retrieve good results. |
| """ |
| |
| @staticmethod |
| def simplify_query(query: str) -> str: |
| """Simplify query by removing modifiers and focusing on core terms.""" |
| |
| query = re.sub(r'\b(what|how|why|when|where|which|who|can|could|should|would)\b', '', query, flags=re.IGNORECASE) |
| |
| |
| query = re.sub(r'\b(is|are|was|were|be|been|the|a|an)\b', '', query, flags=re.IGNORECASE) |
| |
| |
| query = re.sub(r'\s+', ' ', query).strip() |
| |
| return query |
| |
| @staticmethod |
| def broaden_query(query: str, analysis: QueryAnalysis) -> str: |
| """Broaden query to increase recall.""" |
| |
| query = re.sub(r'\b(specific|exactly|precisely|only|just)\b', '', query, flags=re.IGNORECASE) |
| |
| |
| if analysis.keywords: |
| return f"{analysis.keywords[0]} overview" |
| |
| return query |
| |
| @staticmethod |
| def focus_entities(analysis: QueryAnalysis) -> str: |
| """Create entity-focused query as fallback.""" |
| if analysis.entities: |
| return ' '.join(analysis.entities) |
| elif analysis.keywords: |
| return ' '.join(analysis.keywords[:3]) |
| return "" |
|
|
|
|
| |
| def expand_query_simple( |
| query: str, |
| strategy: str = "balanced", |
| llm=None |
| ) -> List[str]: |
| """ |
| Simple function to expand a query without dealing with classes. |
| |
| Args: |
| query: User's query to expand |
| strategy: "quick", "balanced", or "comprehensive" |
| llm: Optional LangChain LLM |
| |
| Returns: |
| List of expanded query variations |
| |
| Example: |
| >>> queries = expand_query_simple("How do I debug Python code?", strategy="balanced") |
| >>> print(queries) |
| ['How do I debug Python code?', 'How do I resolve Python code?', ...] |
| """ |
| expander = MultiQueryExpander(llm=llm) |
| strategy_enum = QueryStrategy(strategy) |
| expanded = expander.expand(query, strategy=strategy_enum) |
| return expanded.variations |
|
|
|
|
| |
| if __name__ == "__main__": |
| |
| print("=" * 60) |
| print("Example 1: Simple Query Expansion") |
| print("=" * 60) |
| |
| query = "What is machine learning?" |
| queries = expand_query_simple(query, strategy="balanced") |
| |
| print(f"\nOriginal: {query}") |
| print(f"\nExpanded queries ({len(queries)}):") |
| for i, q in enumerate(queries, 1): |
| print(f" {i}. {q}") |
| |
| |
| print("\n" + "=" * 60) |
| print("Example 2: Complex Query with Analysis") |
| print("=" * 60) |
| |
| expander = MultiQueryExpander() |
| query = "How do I compare the performance of different neural network architectures?" |
| result = expander.expand(query, strategy=QueryStrategy.COMPREHENSIVE) |
| |
| print(f"\nOriginal: {result.original}") |
| print(f"\nAnalysis:") |
| print(f" Intent: {result.analysis.intent}") |
| print(f" Entities: {result.analysis.entities}") |
| print(f" Keywords: {result.analysis.keywords}") |
| print(f" Complexity: {result.analysis.complexity}") |
| print(f" Domain: {result.analysis.domain}") |
| print(f"\nExpanded queries ({len(result.variations)}):") |
| for i, q in enumerate(result.variations, 1): |
| print(f" {i}. {q}") |
| |
| |
| print("\n" + "=" * 60) |
| print("Example 3: Multi-Hop Reasoning") |
| print("=" * 60) |
| |
| reasoner = MultiHopReasoner() |
| analyzer = QueryAnalyzer() |
| |
| query = "Compare Python and Java for web development" |
| analysis = analyzer.analyze(query) |
| sub_queries = reasoner.generate_sub_queries(query, analysis) |
| |
| print(f"\nOriginal: {query}") |
| print(f"\nSub-queries for multi-hop reasoning:") |
| for i, sq in enumerate(sub_queries, 1): |
| print(f" {i}. {sq}") |
| |
| |
| print("\n" + "=" * 60) |
| print("Example 4: Fallback Strategies") |
| print("=" * 60) |
| |
| query = "What is the specific difference between supervised and unsupervised learning?" |
| analysis = analyzer.analyze(query) |
| |
| print(f"\nOriginal: {query}") |
| print(f"Simplified: {FallbackStrategies.simplify_query(query)}") |
| print(f"Broadened: {FallbackStrategies.broaden_query(query, analysis)}") |
| print(f"Entity-focused: {FallbackStrategies.focus_entities(analysis)}") |
|
|