File size: 5,507 Bytes
6cf6a92
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
"""
Query Processing Pipeline for Retrieval-based QA Chatbot
========================================================

This module handles:
1. Query preprocessing
2. Intent and sub-intent classification
3. Named Entity Recognition (NER) using lightweight BioBERT

Uses: alvaroalon2/biobert_chemical_ner (~140MB, optimized for drugs/chemicals)
"""

import re
from typing import List, Tuple
from transformers import pipeline
import torch

# -------------------------------
# Initialize Lightweight NER Model
# -------------------------------

print("[NER] Loading lightweight BioBERT NER model...")

try:
    # This model is specifically trained for chemical/drug entity recognition
    ner_model = pipeline(
        "ner",
        model="alvaroalon2/biobert_chemical_ner",
        aggregation_strategy="simple",
        device=0 if torch.cuda.is_available() else -1
    )
    print("[NER] ✓ Model loaded successfully\n")
except Exception as e:
    print(f"[NER] ✗ Failed to load model: {e}")
    ner_model = None

# -------------------------------
# Named Entity Extraction
# -------------------------------

def extract_entities_BERT(question: str) -> List[str]:
    """
    Extract biomedical entities using lightweight BioBERT NER.
    
    Parameters:
        question (str): User query
        
    Returns:
        List[str]: Extracted entities (drugs, chemicals, etc.)
    """
    if ner_model is None:
        print("[NER] Model not available, returning empty list")
        return []
    
    try:
        # Run NER pipeline
        entities = ner_model(question)
        
        # Filter and clean entities
        extracted = []
        for ent in entities:
            # Only keep high-confidence entities (>70%)
            if ent['score'] > 0.7:
                # Clean up subword tokens (remove ##)
                entity_text = ent['word'].replace('##', '').strip()
                
                # Filter out very short entities and common words
                if len(entity_text) > 2 and entity_text.lower() not in ['the', 'and', 'for', 'with']:
                    extracted.append(entity_text)
        
        # Remove duplicates while preserving order
        unique_entities = []
        seen = set()
        for ent in extracted:
            ent_lower = ent.lower()
            if ent_lower not in seen:
                seen.add(ent_lower)
                unique_entities.append(ent)
        
        return unique_entities
        
    except Exception as e:
        print(f"[NER] Extraction failed: {e}")
        return []


# -------------------------------
# Rule-Based Intent Classification
# -------------------------------

def classify_intent(question: str) -> str:
    """
    Classify the user's query into a high-level intent based on keywords.
    
    Parameters:
        question (str): The user's question.
        
    Returns:
        str: One of ['description', 'before_using', 'proper_use', 'precautions', 'side_effects']
    """
    q = question.lower()

    if re.search(r"\bwhat is\b|\bused for\b|\bdefine\b", q):
        return "description"
    elif re.search(r"\bbefore using\b|\bshould I tell\b|\bdoctor know\b", q):
        return "before_using"
    elif re.search(r"\bhow to\b|\bdosage\b|\btake\b|\binstructions\b", q):
        return "proper_use"
    elif re.search(r"\bprecaution\b|\bpregnan\b|\bbreastfeed\b|\brisk\b", q):
        return "precautions"
    elif re.search(r"\bside effect\b|\badverse\b|\bnausea\b|\bdizziness\b", q):
        return "side_effects"
    else:
        return "description"  # default fallback


# -------------------------------
# Query Preprocessing Wrapper
# -------------------------------

def preprocess_query(raw_query: str) -> Tuple[Tuple[str, str], List[str]]:
    """
    Main preprocessing function that extracts:
    - Intent
    - Subsection
    - Named Entities

    Parameters:
        raw_query (str): The raw user question.

    Returns:
        Tuple[Tuple[str, str], List[str]]: ((intent, sub_intent), list of entities)
    """
    try:
        intent = classify_intent(raw_query)
        entities = extract_entities_BERT(raw_query)

        if not entities:
            print("[NER fallback] No entities found. Using raw query.")
            return (intent or ""), []

        print(f"[Query Processed] Intent = {intent}| Entities = {entities}")
        return (intent or ""), entities

    except Exception as e:
        print(f"[Preprocessing failed] {e}")
        return (""), []


# -------------------------------
# Optional: Test Function
# -------------------------------

if __name__ == "__main__":
    """Test the NER with sample queries."""
    
    test_queries = [
        "What are the side effects of Azithromycin?",
        "How much dosage of aspirin should I take for headache?",
        "Can I take Lisinopril during pregnancy?",
        "What is Metformin used for?",
        "Are there interactions between Warfarin and Ibuprofen?",
        "How should I store insulin?",
    ]
    
    print("\n" + "="*70)
    print("TESTING LIGHTWEIGHT TRANSFORMER NER")
    print("="*70 + "\n")
    
    for i, query in enumerate(test_queries, 1):
        print(f"[Test {i}] Query: {query}")
        print("-" * 70)
        
        (intent), entities = preprocess_query(query)

        print(f"  Intent: {intent}")
        print(f"  Entities: {entities if entities else 'None detected'}")
        print("-" * 70 + "\n")
    
    print("="*70)
    print("TESTING COMPLETE")
    print("="*70)