Spaces:
Runtime error
Runtime error
| """ | |
| Query Processing Pipeline for Retrieval-based QA Chatbot | |
| ======================================================== | |
| This module handles: | |
| 1. Query preprocessing | |
| 2. Intent and sub-intent classification | |
| 3. Named Entity Recognition (NER) using lightweight BioBERT | |
| Uses: alvaroalon2/biobert_chemical_ner (~140MB, optimized for drugs/chemicals) | |
| """ | |
| import re | |
| from typing import List, Tuple | |
| from transformers import pipeline | |
| import torch | |
| # ------------------------------- | |
| # Initialize Lightweight NER Model | |
| # ------------------------------- | |
| print("[NER] Loading lightweight BioBERT NER model...") | |
| try: | |
| # This model is specifically trained for chemical/drug entity recognition | |
| ner_model = pipeline( | |
| "ner", | |
| model="alvaroalon2/biobert_chemical_ner", | |
| aggregation_strategy="simple", | |
| device=0 if torch.cuda.is_available() else -1 | |
| ) | |
| print("[NER] ✓ Model loaded successfully\n") | |
| except Exception as e: | |
| print(f"[NER] ✗ Failed to load model: {e}") | |
| ner_model = None | |
| # ------------------------------- | |
| # Named Entity Extraction | |
| # ------------------------------- | |
| def extract_entities_BERT(question: str) -> List[str]: | |
| """ | |
| Extract biomedical entities using lightweight BioBERT NER. | |
| Parameters: | |
| question (str): User query | |
| Returns: | |
| List[str]: Extracted entities (drugs, chemicals, etc.) | |
| """ | |
| if ner_model is None: | |
| print("[NER] Model not available, returning empty list") | |
| return [] | |
| try: | |
| # Run NER pipeline | |
| entities = ner_model(question) | |
| # Filter and clean entities | |
| extracted = [] | |
| for ent in entities: | |
| # Only keep high-confidence entities (>70%) | |
| if ent['score'] > 0.7: | |
| # Clean up subword tokens (remove ##) | |
| entity_text = ent['word'].replace('##', '').strip() | |
| # Filter out very short entities and common words | |
| if len(entity_text) > 2 and entity_text.lower() not in ['the', 'and', 'for', 'with']: | |
| extracted.append(entity_text) | |
| # Remove duplicates while preserving order | |
| unique_entities = [] | |
| seen = set() | |
| for ent in extracted: | |
| ent_lower = ent.lower() | |
| if ent_lower not in seen: | |
| seen.add(ent_lower) | |
| unique_entities.append(ent) | |
| return unique_entities | |
| except Exception as e: | |
| print(f"[NER] Extraction failed: {e}") | |
| return [] | |
| # ------------------------------- | |
| # Rule-Based Intent Classification | |
| # ------------------------------- | |
| def classify_intent(question: str) -> str: | |
| """ | |
| Classify the user's query into a high-level intent based on keywords. | |
| Parameters: | |
| question (str): The user's question. | |
| Returns: | |
| str: One of ['description', 'before_using', 'proper_use', 'precautions', 'side_effects'] | |
| """ | |
| q = question.lower() | |
| if re.search(r"\bwhat is\b|\bused for\b|\bdefine\b", q): | |
| return "description" | |
| elif re.search(r"\bbefore using\b|\bshould I tell\b|\bdoctor know\b", q): | |
| return "before_using" | |
| elif re.search(r"\bhow to\b|\bdosage\b|\btake\b|\binstructions\b", q): | |
| return "proper_use" | |
| elif re.search(r"\bprecaution\b|\bpregnan\b|\bbreastfeed\b|\brisk\b", q): | |
| return "precautions" | |
| elif re.search(r"\bside effect\b|\badverse\b|\bnausea\b|\bdizziness\b", q): | |
| return "side_effects" | |
| else: | |
| return "description" # default fallback | |
| # ------------------------------- | |
| # Query Preprocessing Wrapper | |
| # ------------------------------- | |
| def preprocess_query(raw_query: str) -> Tuple[Tuple[str, str], List[str]]: | |
| """ | |
| Main preprocessing function that extracts: | |
| - Intent | |
| - Subsection | |
| - Named Entities | |
| Parameters: | |
| raw_query (str): The raw user question. | |
| Returns: | |
| Tuple[Tuple[str, str], List[str]]: ((intent, sub_intent), list of entities) | |
| """ | |
| try: | |
| intent = classify_intent(raw_query) | |
| entities = extract_entities_BERT(raw_query) | |
| if not entities: | |
| print("[NER fallback] No entities found. Using raw query.") | |
| return (intent or ""), [] | |
| print(f"[Query Processed] Intent = {intent}| Entities = {entities}") | |
| return (intent or ""), entities | |
| except Exception as e: | |
| print(f"[Preprocessing failed] {e}") | |
| return (""), [] | |
| # ------------------------------- | |
| # Optional: Test Function | |
| # ------------------------------- | |
| if __name__ == "__main__": | |
| """Test the NER with sample queries.""" | |
| test_queries = [ | |
| "What are the side effects of Azithromycin?", | |
| "How much dosage of aspirin should I take for headache?", | |
| "Can I take Lisinopril during pregnancy?", | |
| "What is Metformin used for?", | |
| "Are there interactions between Warfarin and Ibuprofen?", | |
| "How should I store insulin?", | |
| ] | |
| print("\n" + "="*70) | |
| print("TESTING LIGHTWEIGHT TRANSFORMER NER") | |
| print("="*70 + "\n") | |
| for i, query in enumerate(test_queries, 1): | |
| print(f"[Test {i}] Query: {query}") | |
| print("-" * 70) | |
| (intent), entities = preprocess_query(query) | |
| print(f" Intent: {intent}") | |
| print(f" Entities: {entities if entities else 'None detected'}") | |
| print("-" * 70 + "\n") | |
| print("="*70) | |
| print("TESTING COMPLETE") | |
| print("="*70) |