""" Query Processing Pipeline for Retrieval-based QA Chatbot ======================================================== This module handles: 1. Query preprocessing 2. Intent and sub-intent classification 3. Named Entity Recognition (NER) using lightweight BioBERT Uses: alvaroalon2/biobert_chemical_ner (~140MB, optimized for drugs/chemicals) """ import re from typing import List, Tuple from transformers import pipeline import torch # ------------------------------- # Initialize Lightweight NER Model # ------------------------------- print("[NER] Loading lightweight BioBERT NER model...") try: # This model is specifically trained for chemical/drug entity recognition ner_model = pipeline( "ner", model="alvaroalon2/biobert_chemical_ner", aggregation_strategy="simple", device=0 if torch.cuda.is_available() else -1 ) print("[NER] ✓ Model loaded successfully\n") except Exception as e: print(f"[NER] ✗ Failed to load model: {e}") ner_model = None # ------------------------------- # Named Entity Extraction # ------------------------------- def extract_entities_BERT(question: str) -> List[str]: """ Extract biomedical entities using lightweight BioBERT NER. Parameters: question (str): User query Returns: List[str]: Extracted entities (drugs, chemicals, etc.) """ if ner_model is None: print("[NER] Model not available, returning empty list") return [] try: # Run NER pipeline entities = ner_model(question) # Filter and clean entities extracted = [] for ent in entities: # Only keep high-confidence entities (>70%) if ent['score'] > 0.7: # Clean up subword tokens (remove ##) entity_text = ent['word'].replace('##', '').strip() # Filter out very short entities and common words if len(entity_text) > 2 and entity_text.lower() not in ['the', 'and', 'for', 'with']: extracted.append(entity_text) # Remove duplicates while preserving order unique_entities = [] seen = set() for ent in extracted: ent_lower = ent.lower() if ent_lower not in seen: seen.add(ent_lower) unique_entities.append(ent) return unique_entities except Exception as e: print(f"[NER] Extraction failed: {e}") return [] # ------------------------------- # Rule-Based Intent Classification # ------------------------------- def classify_intent(question: str) -> str: """ Classify the user's query into a high-level intent based on keywords. Parameters: question (str): The user's question. Returns: str: One of ['description', 'before_using', 'proper_use', 'precautions', 'side_effects'] """ q = question.lower() if re.search(r"\bwhat is\b|\bused for\b|\bdefine\b", q): return "description" elif re.search(r"\bbefore using\b|\bshould I tell\b|\bdoctor know\b", q): return "before_using" elif re.search(r"\bhow to\b|\bdosage\b|\btake\b|\binstructions\b", q): return "proper_use" elif re.search(r"\bprecaution\b|\bpregnan\b|\bbreastfeed\b|\brisk\b", q): return "precautions" elif re.search(r"\bside effect\b|\badverse\b|\bnausea\b|\bdizziness\b", q): return "side_effects" else: return "description" # default fallback # ------------------------------- # Query Preprocessing Wrapper # ------------------------------- def preprocess_query(raw_query: str) -> Tuple[Tuple[str, str], List[str]]: """ Main preprocessing function that extracts: - Intent - Subsection - Named Entities Parameters: raw_query (str): The raw user question. Returns: Tuple[Tuple[str, str], List[str]]: ((intent, sub_intent), list of entities) """ try: intent = classify_intent(raw_query) entities = extract_entities_BERT(raw_query) if not entities: print("[NER fallback] No entities found. Using raw query.") return (intent or ""), [] print(f"[Query Processed] Intent = {intent}| Entities = {entities}") return (intent or ""), entities except Exception as e: print(f"[Preprocessing failed] {e}") return (""), [] # ------------------------------- # Optional: Test Function # ------------------------------- if __name__ == "__main__": """Test the NER with sample queries.""" test_queries = [ "What are the side effects of Azithromycin?", "How much dosage of aspirin should I take for headache?", "Can I take Lisinopril during pregnancy?", "What is Metformin used for?", "Are there interactions between Warfarin and Ibuprofen?", "How should I store insulin?", ] print("\n" + "="*70) print("TESTING LIGHTWEIGHT TRANSFORMER NER") print("="*70 + "\n") for i, query in enumerate(test_queries, 1): print(f"[Test {i}] Query: {query}") print("-" * 70) (intent), entities = preprocess_query(query) print(f" Intent: {intent}") print(f" Entities: {entities if entities else 'None detected'}") print("-" * 70 + "\n") print("="*70) print("TESTING COMPLETE") print("="*70)