DrugBot_Retrieval_Based / Scripts /Query_processing.py
Niranjan Sathish
Initial commit
6cf6a92
"""
Query Processing Pipeline for Retrieval-based QA Chatbot
========================================================
This module handles:
1. Query preprocessing
2. Intent and sub-intent classification
3. Named Entity Recognition (NER) using lightweight BioBERT
Uses: alvaroalon2/biobert_chemical_ner (~140MB, optimized for drugs/chemicals)
"""
import re
from typing import List, Tuple
from transformers import pipeline
import torch
# -------------------------------
# Initialize Lightweight NER Model
# -------------------------------
print("[NER] Loading lightweight BioBERT NER model...")
try:
# This model is specifically trained for chemical/drug entity recognition
ner_model = pipeline(
"ner",
model="alvaroalon2/biobert_chemical_ner",
aggregation_strategy="simple",
device=0 if torch.cuda.is_available() else -1
)
print("[NER] ✓ Model loaded successfully\n")
except Exception as e:
print(f"[NER] ✗ Failed to load model: {e}")
ner_model = None
# -------------------------------
# Named Entity Extraction
# -------------------------------
def extract_entities_BERT(question: str) -> List[str]:
"""
Extract biomedical entities using lightweight BioBERT NER.
Parameters:
question (str): User query
Returns:
List[str]: Extracted entities (drugs, chemicals, etc.)
"""
if ner_model is None:
print("[NER] Model not available, returning empty list")
return []
try:
# Run NER pipeline
entities = ner_model(question)
# Filter and clean entities
extracted = []
for ent in entities:
# Only keep high-confidence entities (>70%)
if ent['score'] > 0.7:
# Clean up subword tokens (remove ##)
entity_text = ent['word'].replace('##', '').strip()
# Filter out very short entities and common words
if len(entity_text) > 2 and entity_text.lower() not in ['the', 'and', 'for', 'with']:
extracted.append(entity_text)
# Remove duplicates while preserving order
unique_entities = []
seen = set()
for ent in extracted:
ent_lower = ent.lower()
if ent_lower not in seen:
seen.add(ent_lower)
unique_entities.append(ent)
return unique_entities
except Exception as e:
print(f"[NER] Extraction failed: {e}")
return []
# -------------------------------
# Rule-Based Intent Classification
# -------------------------------
def classify_intent(question: str) -> str:
"""
Classify the user's query into a high-level intent based on keywords.
Parameters:
question (str): The user's question.
Returns:
str: One of ['description', 'before_using', 'proper_use', 'precautions', 'side_effects']
"""
q = question.lower()
if re.search(r"\bwhat is\b|\bused for\b|\bdefine\b", q):
return "description"
elif re.search(r"\bbefore using\b|\bshould I tell\b|\bdoctor know\b", q):
return "before_using"
elif re.search(r"\bhow to\b|\bdosage\b|\btake\b|\binstructions\b", q):
return "proper_use"
elif re.search(r"\bprecaution\b|\bpregnan\b|\bbreastfeed\b|\brisk\b", q):
return "precautions"
elif re.search(r"\bside effect\b|\badverse\b|\bnausea\b|\bdizziness\b", q):
return "side_effects"
else:
return "description" # default fallback
# -------------------------------
# Query Preprocessing Wrapper
# -------------------------------
def preprocess_query(raw_query: str) -> Tuple[Tuple[str, str], List[str]]:
"""
Main preprocessing function that extracts:
- Intent
- Subsection
- Named Entities
Parameters:
raw_query (str): The raw user question.
Returns:
Tuple[Tuple[str, str], List[str]]: ((intent, sub_intent), list of entities)
"""
try:
intent = classify_intent(raw_query)
entities = extract_entities_BERT(raw_query)
if not entities:
print("[NER fallback] No entities found. Using raw query.")
return (intent or ""), []
print(f"[Query Processed] Intent = {intent}| Entities = {entities}")
return (intent or ""), entities
except Exception as e:
print(f"[Preprocessing failed] {e}")
return (""), []
# -------------------------------
# Optional: Test Function
# -------------------------------
if __name__ == "__main__":
"""Test the NER with sample queries."""
test_queries = [
"What are the side effects of Azithromycin?",
"How much dosage of aspirin should I take for headache?",
"Can I take Lisinopril during pregnancy?",
"What is Metformin used for?",
"Are there interactions between Warfarin and Ibuprofen?",
"How should I store insulin?",
]
print("\n" + "="*70)
print("TESTING LIGHTWEIGHT TRANSFORMER NER")
print("="*70 + "\n")
for i, query in enumerate(test_queries, 1):
print(f"[Test {i}] Query: {query}")
print("-" * 70)
(intent), entities = preprocess_query(query)
print(f" Intent: {intent}")
print(f" Entities: {entities if entities else 'None detected'}")
print("-" * 70 + "\n")
print("="*70)
print("TESTING COMPLETE")
print("="*70)