File size: 15,725 Bytes
ddc9c77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
"""
Query Expansion Module for Medical Linguistic Variability

This module provides intelligent query expansion to handle:
- Medical term variations and synonyms
- Abbreviation expansion
- Spelling variations (US/UK/International)
- Specialty-specific terminology
- Multi-query retrieval strategies
"""

import re
from typing import List, Dict, Set, Tuple, Optional
from langchain.schema import Document
from .medical_terminology import (
    normalize_query,
    expand_query_with_variations,
    get_synonyms,
    expand_abbreviations,
    extract_medical_entities,
    is_medical_abbreviation,
    get_abbreviation_expansion,
)
from .config import logger


class QueryExpansionStrategy:
    """
    Intelligent query expansion strategy that adapts based on query characteristics.
    """
    
    def __init__(self):
        self.expansion_cache = {}
    
    def expand(self, query: str, strategy: str = "adaptive") -> List[str]:
        """
        Expand query using specified strategy.
        
        Args:
            query: Original query string
            strategy: Expansion strategy - "adaptive", "aggressive", "conservative", "abbreviation_focused"
        
        Returns:
            List of expanded query variations
        """
        # Check cache
        cache_key = f"{query}_{strategy}"
        if cache_key in self.expansion_cache:
            return self.expansion_cache[cache_key]
        
        if strategy == "adaptive":
            expansions = self._adaptive_expansion(query)
        elif strategy == "aggressive":
            expansions = self._aggressive_expansion(query)
        elif strategy == "conservative":
            expansions = self._conservative_expansion(query)
        elif strategy == "abbreviation_focused":
            expansions = self._abbreviation_focused_expansion(query)
        else:
            expansions = [query]
        
        # Cache result
        self.expansion_cache[cache_key] = expansions
        return expansions
    
    def _adaptive_expansion(self, query: str) -> List[str]:
        """
        Adaptive expansion that adjusts based on query characteristics.
        - Short queries (< 5 words): More aggressive expansion
        - Long queries: More conservative
        - Queries with abbreviations: Focus on abbreviation expansion
        """
        words = query.split()
        word_count = len(words)
        
        # Detect if query contains abbreviations
        has_abbrev = any(is_medical_abbreviation(word) for word in words)
        
        if has_abbrev:
            # Focus on abbreviation expansion
            return self._abbreviation_focused_expansion(query)
        elif word_count <= 3:
            # Short query - aggressive expansion
            return self._aggressive_expansion(query)
        elif word_count <= 7:
            # Medium query - balanced expansion
            return expand_query_with_variations(query, max_variations=5)
        else:
            # Long query - conservative expansion
            return self._conservative_expansion(query)
    
    def _aggressive_expansion(self, query: str) -> List[str]:
        """
        Aggressive expansion with more variations.
        Useful for short queries that need more context.
        """
        expansions = []
        normalized = normalize_query(query)
        expansions.append(normalized)
        
        # 1. Abbreviation expansion
        abbrev_expansions = expand_abbreviations(normalized)
        expansions.extend(abbrev_expansions)
        
        # 2. Synonym expansion for each word
        words = normalized.split()
        for i, word in enumerate(words):
            synonyms = get_synonyms(word)
            for syn in list(synonyms)[:3]:  # Top 3 synonyms
                new_query = ' '.join(words[:i] + [syn] + words[i+1:])
                expansions.append(new_query)
        
        # 3. Multi-word phrase synonyms
        from .medical_terminology import MEDICAL_SYNONYMS
        for term, syn_list in MEDICAL_SYNONYMS.items():
            if term in normalized:
                for syn in syn_list[:3]:
                    expansions.append(normalized.replace(term, syn))
        
        # 4. Spelling variations
        from .medical_terminology import SPELLING_VARIATIONS
        for us_spelling, uk_variants in SPELLING_VARIATIONS.items():
            if us_spelling in normalized:
                for uk_spelling in uk_variants:
                    expansions.append(normalized.replace(us_spelling, uk_spelling))
        
        # Remove duplicates
        return list(dict.fromkeys(expansions))[:10]
    
    def _conservative_expansion(self, query: str) -> List[str]:
        """
        Conservative expansion with fewer variations.
        Useful for specific, well-formed queries.
        """
        expansions = []
        normalized = normalize_query(query)
        expansions.append(normalized)
        
        # Only expand obvious abbreviations
        words = normalized.split()
        for word in words:
            if is_medical_abbreviation(word):
                abbrev_expansions = expand_abbreviations(word)
                for exp in abbrev_expansions[:2]:  # Limit to 2
                    new_query = normalized.replace(word, exp)
                    expansions.append(new_query)
        
        # Remove duplicates
        return list(dict.fromkeys(expansions))[:5]
    
    def _abbreviation_focused_expansion(self, query: str) -> List[str]:
        """
        Expansion focused on abbreviation handling.
        Expands all abbreviations to their full forms.
        """
        expansions = []
        normalized = normalize_query(query)
        expansions.append(normalized)
        
        # Identify and expand all abbreviations
        words = normalized.split()
        current_query = normalized
        
        for word in words:
            if is_medical_abbreviation(word):
                full_forms = get_abbreviation_expansion(word)
                for full_form in full_forms:
                    expanded = current_query.replace(word, full_form)
                    expansions.append(expanded)
                    # Also try with the expanded form as base for further expansion
                    current_query = expanded
        
        # Remove duplicates
        return list(dict.fromkeys(expansions))[:8]


class MultiQueryRetriever:
    """
    Retrieves documents using multiple query variations and merges results.
    """
    
    def __init__(self, base_retriever_func):
        """
        Args:
            base_retriever_func: Function that takes (query, **kwargs) and returns List[Document]
        """
        self.base_retriever = base_retriever_func
        self.query_expander = QueryExpansionStrategy()
    
    def retrieve(
        self,
        query: str,
        expansion_strategy: str = "adaptive",
        merge_strategy: str = "weighted",
        **retriever_kwargs
    ) -> List[Document]:
        """
        Retrieve documents using multiple query variations.
        
        Args:
            query: Original query
            expansion_strategy: How to expand the query
            merge_strategy: How to merge results - "weighted", "union", "intersection"
            **retriever_kwargs: Additional arguments for base retriever
        
        Returns:
            Merged list of documents
        """
        # Expand query
        query_variations = self.query_expander.expand(query, strategy=expansion_strategy)
        
        logger.info(f"Expanded query into {len(query_variations)} variations")
        logger.debug(f"Query variations: {query_variations}")
        
        # Retrieve for each variation
        all_results = []
        for i, var_query in enumerate(query_variations):
            try:
                docs = self.base_retriever(var_query, **retriever_kwargs)
                # Tag documents with query variation rank
                for doc in docs:
                    if not hasattr(doc, 'metadata'):
                        doc.metadata = {}
                    doc.metadata['query_variation_rank'] = i
                    doc.metadata['query_variation'] = var_query
                all_results.append((var_query, docs))
            except Exception as e:
                logger.warning(f"Retrieval failed for variation '{var_query}': {e}")
        
        # Merge results
        if merge_strategy == "weighted":
            merged = self._weighted_merge(all_results)
        elif merge_strategy == "union":
            merged = self._union_merge(all_results)
        elif merge_strategy == "intersection":
            merged = self._intersection_merge(all_results)
        else:
            # Default to weighted
            merged = self._weighted_merge(all_results)
        
        logger.info(f"Retrieved {len(merged)} unique documents after merging")
        return merged
    
    def _weighted_merge(self, results: List[Tuple[str, List[Document]]]) -> List[Document]:
        """
        Merge results with weighted scoring.
        Earlier query variations get higher weight.
        """
        doc_scores = {}  # doc_id -> (doc, score)
        
        for query_idx, (query_var, docs) in enumerate(results):
            # Weight decreases with query variation rank
            query_weight = 1.0 / (query_idx + 1)
            
            for doc_idx, doc in enumerate(docs):
                # Create unique doc identifier
                doc_id = self._get_doc_id(doc)
                
                # Position score (earlier is better)
                position_score = 1.0 / (doc_idx + 1)
                
                # Combined score
                score = query_weight * position_score
                
                if doc_id in doc_scores:
                    # Document appeared in multiple variations - boost score
                    existing_doc, existing_score = doc_scores[doc_id]
                    doc_scores[doc_id] = (existing_doc, existing_score + score)
                else:
                    doc_scores[doc_id] = (doc, score)
        
        # Sort by score and return documents
        sorted_docs = sorted(doc_scores.values(), key=lambda x: x[1], reverse=True)
        return [doc for doc, score in sorted_docs]
    
    def _union_merge(self, results: List[Tuple[str, List[Document]]]) -> List[Document]:
        """
        Merge results using union (all unique documents).
        Preserves order from first appearance.
        """
        seen_ids = set()
        merged = []
        
        for query_var, docs in results:
            for doc in docs:
                doc_id = self._get_doc_id(doc)
                if doc_id not in seen_ids:
                    seen_ids.add(doc_id)
                    merged.append(doc)
        
        return merged
    
    def _intersection_merge(self, results: List[Tuple[str, List[Document]]]) -> List[Document]:
        """
        Merge results using intersection (only documents in all variations).
        Useful for high-precision retrieval.
        """
        if not results:
            return []
        
        # Get doc IDs from first variation
        first_docs = {self._get_doc_id(doc): doc for doc in results[0][1]}
        common_ids = set(first_docs.keys())
        
        # Intersect with other variations
        for query_var, docs in results[1:]:
            current_ids = {self._get_doc_id(doc) for doc in docs}
            common_ids &= current_ids
        
        # Return documents that appear in all variations
        return [first_docs[doc_id] for doc_id in common_ids if doc_id in first_docs]
    
    def _get_doc_id(self, doc: Document) -> str:
        """
        Generate unique identifier for a document.
        Uses source, page number, and content hash.
        """
        source = doc.metadata.get('source', 'unknown')
        page = doc.metadata.get('page_number', 'unknown')
        content_hash = hash(doc.page_content[:200])  # Hash first 200 chars
        return f"{source}_{page}_{content_hash}"


class SemanticQueryExpander:
    """
    Expands queries using semantic understanding.
    Uses context and co-occurrence patterns.
    """
    
    def __init__(self):
        from .medical_terminology import get_terminology_learner
        self.learner = get_terminology_learner()
    
    def expand_with_context(self, query: str, context: Optional[str] = None) -> List[str]:
        """
        Expand query using contextual information.
        
        Args:
            query: Original query
            context: Additional context (e.g., previous queries, conversation history)
        
        Returns:
            List of contextually expanded queries
        """
        expansions = [query]
        normalized = normalize_query(query)
        
        # Extract key terms
        entities = extract_medical_entities(normalized)
        
        # Get related terms from learned patterns
        for entity, entity_type in entities:
            related = self.learner.get_related_terms(entity)
            for related_term in list(related)[:3]:
                expanded = normalized.replace(entity, related_term)
                expansions.append(expanded)
        
        # If context provided, extract relevant terms
        if context:
            context_entities = extract_medical_entities(normalize_query(context))
            # Add context terms to query
            for entity, _ in context_entities[:2]:
                expansions.append(f"{normalized} {entity}")
        
        return list(dict.fromkeys(expansions))[:7]
    
    def expand_with_specialization(self, query: str, specialty: Optional[str] = None) -> List[str]:
        """
        Expand query with specialty-specific terminology.
        
        Args:
            query: Original query
            specialty: Medical specialty (e.g., "oncology", "radiology")
        
        Returns:
            List of specialty-aware expanded queries
        """
        expansions = [query]
        
        # Specialty-specific term mappings
        specialty_terms = {
            "oncology": ["cancer", "tumor", "malignancy", "neoplasm", "carcinoma"],
            "radiology": ["imaging", "scan", "ct", "mri", "pet"],
            "pathology": ["biopsy", "histology", "cytology", "tissue"],
            "surgery": ["resection", "operative", "surgical", "procedure"],
        }
        
        if specialty and specialty.lower() in specialty_terms:
            # Add specialty context to query
            for term in specialty_terms[specialty.lower()][:2]:
                if term not in query.lower():
                    expansions.append(f"{query} {term}")
        
        return expansions


# ============================================================================
# CONVENIENCE FUNCTIONS
# ============================================================================

def expand_medical_query(
    query: str,
    strategy: str = "adaptive",
    max_variations: int = 5
) -> List[str]:
    """
    Convenience function to expand a medical query.
    
    Args:
        query: Original query
        strategy: Expansion strategy
        max_variations: Maximum number of variations
    
    Returns:
        List of query variations
    """
    expander = QueryExpansionStrategy()
    variations = expander.expand(query, strategy=strategy)
    return variations[:max_variations]


def create_multi_query_retriever(base_retriever_func):
    """
    Create a multi-query retriever instance.
    
    Args:
        base_retriever_func: Base retrieval function
    
    Returns:
        MultiQueryRetriever instance
    """
    return MultiQueryRetriever(base_retriever_func)