""" Query Expansion Module for Medical Linguistic Variability This module provides intelligent query expansion to handle: - Medical term variations and synonyms - Abbreviation expansion - Spelling variations (US/UK/International) - Specialty-specific terminology - Multi-query retrieval strategies """ import re from typing import List, Dict, Set, Tuple, Optional from langchain.schema import Document from .medical_terminology import ( normalize_query, expand_query_with_variations, get_synonyms, expand_abbreviations, extract_medical_entities, is_medical_abbreviation, get_abbreviation_expansion, ) from .config import logger class QueryExpansionStrategy: """ Intelligent query expansion strategy that adapts based on query characteristics. """ def __init__(self): self.expansion_cache = {} def expand(self, query: str, strategy: str = "adaptive") -> List[str]: """ Expand query using specified strategy. Args: query: Original query string strategy: Expansion strategy - "adaptive", "aggressive", "conservative", "abbreviation_focused" Returns: List of expanded query variations """ # Check cache cache_key = f"{query}_{strategy}" if cache_key in self.expansion_cache: return self.expansion_cache[cache_key] if strategy == "adaptive": expansions = self._adaptive_expansion(query) elif strategy == "aggressive": expansions = self._aggressive_expansion(query) elif strategy == "conservative": expansions = self._conservative_expansion(query) elif strategy == "abbreviation_focused": expansions = self._abbreviation_focused_expansion(query) else: expansions = [query] # Cache result self.expansion_cache[cache_key] = expansions return expansions def _adaptive_expansion(self, query: str) -> List[str]: """ Adaptive expansion that adjusts based on query characteristics. - Short queries (< 5 words): More aggressive expansion - Long queries: More conservative - Queries with abbreviations: Focus on abbreviation expansion """ words = query.split() word_count = len(words) # Detect if query contains abbreviations has_abbrev = any(is_medical_abbreviation(word) for word in words) if has_abbrev: # Focus on abbreviation expansion return self._abbreviation_focused_expansion(query) elif word_count <= 3: # Short query - aggressive expansion return self._aggressive_expansion(query) elif word_count <= 7: # Medium query - balanced expansion return expand_query_with_variations(query, max_variations=5) else: # Long query - conservative expansion return self._conservative_expansion(query) def _aggressive_expansion(self, query: str) -> List[str]: """ Aggressive expansion with more variations. Useful for short queries that need more context. """ expansions = [] normalized = normalize_query(query) expansions.append(normalized) # 1. Abbreviation expansion abbrev_expansions = expand_abbreviations(normalized) expansions.extend(abbrev_expansions) # 2. Synonym expansion for each word words = normalized.split() for i, word in enumerate(words): synonyms = get_synonyms(word) for syn in list(synonyms)[:3]: # Top 3 synonyms new_query = ' '.join(words[:i] + [syn] + words[i+1:]) expansions.append(new_query) # 3. Multi-word phrase synonyms from .medical_terminology import MEDICAL_SYNONYMS for term, syn_list in MEDICAL_SYNONYMS.items(): if term in normalized: for syn in syn_list[:3]: expansions.append(normalized.replace(term, syn)) # 4. Spelling variations from .medical_terminology import SPELLING_VARIATIONS for us_spelling, uk_variants in SPELLING_VARIATIONS.items(): if us_spelling in normalized: for uk_spelling in uk_variants: expansions.append(normalized.replace(us_spelling, uk_spelling)) # Remove duplicates return list(dict.fromkeys(expansions))[:10] def _conservative_expansion(self, query: str) -> List[str]: """ Conservative expansion with fewer variations. Useful for specific, well-formed queries. """ expansions = [] normalized = normalize_query(query) expansions.append(normalized) # Only expand obvious abbreviations words = normalized.split() for word in words: if is_medical_abbreviation(word): abbrev_expansions = expand_abbreviations(word) for exp in abbrev_expansions[:2]: # Limit to 2 new_query = normalized.replace(word, exp) expansions.append(new_query) # Remove duplicates return list(dict.fromkeys(expansions))[:5] def _abbreviation_focused_expansion(self, query: str) -> List[str]: """ Expansion focused on abbreviation handling. Expands all abbreviations to their full forms. """ expansions = [] normalized = normalize_query(query) expansions.append(normalized) # Identify and expand all abbreviations words = normalized.split() current_query = normalized for word in words: if is_medical_abbreviation(word): full_forms = get_abbreviation_expansion(word) for full_form in full_forms: expanded = current_query.replace(word, full_form) expansions.append(expanded) # Also try with the expanded form as base for further expansion current_query = expanded # Remove duplicates return list(dict.fromkeys(expansions))[:8] class MultiQueryRetriever: """ Retrieves documents using multiple query variations and merges results. """ def __init__(self, base_retriever_func): """ Args: base_retriever_func: Function that takes (query, **kwargs) and returns List[Document] """ self.base_retriever = base_retriever_func self.query_expander = QueryExpansionStrategy() def retrieve( self, query: str, expansion_strategy: str = "adaptive", merge_strategy: str = "weighted", **retriever_kwargs ) -> List[Document]: """ Retrieve documents using multiple query variations. Args: query: Original query expansion_strategy: How to expand the query merge_strategy: How to merge results - "weighted", "union", "intersection" **retriever_kwargs: Additional arguments for base retriever Returns: Merged list of documents """ # Expand query query_variations = self.query_expander.expand(query, strategy=expansion_strategy) logger.info(f"Expanded query into {len(query_variations)} variations") logger.debug(f"Query variations: {query_variations}") # Retrieve for each variation all_results = [] for i, var_query in enumerate(query_variations): try: docs = self.base_retriever(var_query, **retriever_kwargs) # Tag documents with query variation rank for doc in docs: if not hasattr(doc, 'metadata'): doc.metadata = {} doc.metadata['query_variation_rank'] = i doc.metadata['query_variation'] = var_query all_results.append((var_query, docs)) except Exception as e: logger.warning(f"Retrieval failed for variation '{var_query}': {e}") # Merge results if merge_strategy == "weighted": merged = self._weighted_merge(all_results) elif merge_strategy == "union": merged = self._union_merge(all_results) elif merge_strategy == "intersection": merged = self._intersection_merge(all_results) else: # Default to weighted merged = self._weighted_merge(all_results) logger.info(f"Retrieved {len(merged)} unique documents after merging") return merged def _weighted_merge(self, results: List[Tuple[str, List[Document]]]) -> List[Document]: """ Merge results with weighted scoring. Earlier query variations get higher weight. """ doc_scores = {} # doc_id -> (doc, score) for query_idx, (query_var, docs) in enumerate(results): # Weight decreases with query variation rank query_weight = 1.0 / (query_idx + 1) for doc_idx, doc in enumerate(docs): # Create unique doc identifier doc_id = self._get_doc_id(doc) # Position score (earlier is better) position_score = 1.0 / (doc_idx + 1) # Combined score score = query_weight * position_score if doc_id in doc_scores: # Document appeared in multiple variations - boost score existing_doc, existing_score = doc_scores[doc_id] doc_scores[doc_id] = (existing_doc, existing_score + score) else: doc_scores[doc_id] = (doc, score) # Sort by score and return documents sorted_docs = sorted(doc_scores.values(), key=lambda x: x[1], reverse=True) return [doc for doc, score in sorted_docs] def _union_merge(self, results: List[Tuple[str, List[Document]]]) -> List[Document]: """ Merge results using union (all unique documents). Preserves order from first appearance. """ seen_ids = set() merged = [] for query_var, docs in results: for doc in docs: doc_id = self._get_doc_id(doc) if doc_id not in seen_ids: seen_ids.add(doc_id) merged.append(doc) return merged def _intersection_merge(self, results: List[Tuple[str, List[Document]]]) -> List[Document]: """ Merge results using intersection (only documents in all variations). Useful for high-precision retrieval. """ if not results: return [] # Get doc IDs from first variation first_docs = {self._get_doc_id(doc): doc for doc in results[0][1]} common_ids = set(first_docs.keys()) # Intersect with other variations for query_var, docs in results[1:]: current_ids = {self._get_doc_id(doc) for doc in docs} common_ids &= current_ids # Return documents that appear in all variations return [first_docs[doc_id] for doc_id in common_ids if doc_id in first_docs] def _get_doc_id(self, doc: Document) -> str: """ Generate unique identifier for a document. Uses source, page number, and content hash. """ source = doc.metadata.get('source', 'unknown') page = doc.metadata.get('page_number', 'unknown') content_hash = hash(doc.page_content[:200]) # Hash first 200 chars return f"{source}_{page}_{content_hash}" class SemanticQueryExpander: """ Expands queries using semantic understanding. Uses context and co-occurrence patterns. """ def __init__(self): from .medical_terminology import get_terminology_learner self.learner = get_terminology_learner() def expand_with_context(self, query: str, context: Optional[str] = None) -> List[str]: """ Expand query using contextual information. Args: query: Original query context: Additional context (e.g., previous queries, conversation history) Returns: List of contextually expanded queries """ expansions = [query] normalized = normalize_query(query) # Extract key terms entities = extract_medical_entities(normalized) # Get related terms from learned patterns for entity, entity_type in entities: related = self.learner.get_related_terms(entity) for related_term in list(related)[:3]: expanded = normalized.replace(entity, related_term) expansions.append(expanded) # If context provided, extract relevant terms if context: context_entities = extract_medical_entities(normalize_query(context)) # Add context terms to query for entity, _ in context_entities[:2]: expansions.append(f"{normalized} {entity}") return list(dict.fromkeys(expansions))[:7] def expand_with_specialization(self, query: str, specialty: Optional[str] = None) -> List[str]: """ Expand query with specialty-specific terminology. Args: query: Original query specialty: Medical specialty (e.g., "oncology", "radiology") Returns: List of specialty-aware expanded queries """ expansions = [query] # Specialty-specific term mappings specialty_terms = { "oncology": ["cancer", "tumor", "malignancy", "neoplasm", "carcinoma"], "radiology": ["imaging", "scan", "ct", "mri", "pet"], "pathology": ["biopsy", "histology", "cytology", "tissue"], "surgery": ["resection", "operative", "surgical", "procedure"], } if specialty and specialty.lower() in specialty_terms: # Add specialty context to query for term in specialty_terms[specialty.lower()][:2]: if term not in query.lower(): expansions.append(f"{query} {term}") return expansions # ============================================================================ # CONVENIENCE FUNCTIONS # ============================================================================ def expand_medical_query( query: str, strategy: str = "adaptive", max_variations: int = 5 ) -> List[str]: """ Convenience function to expand a medical query. Args: query: Original query strategy: Expansion strategy max_variations: Maximum number of variations Returns: List of query variations """ expander = QueryExpansionStrategy() variations = expander.expand(query, strategy=strategy) return variations[:max_variations] def create_multi_query_retriever(base_retriever_func): """ Create a multi-query retriever instance. Args: base_retriever_func: Base retrieval function Returns: MultiQueryRetriever instance """ return MultiQueryRetriever(base_retriever_func)