NiranjanSathish commited on
Commit
5e9bfb5
·
verified ·
1 Parent(s): 965e103

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Datasets/Dataset.json filter=lfs diff=lfs merge=lfs -text
37
+ Datasets/flattened_drug_dataset_cleaned.csv filter=lfs diff=lfs merge=lfs -text
38
+ Vectors/faiss_index.idx filter=lfs diff=lfs merge=lfs -text
Datasets/Dataset.json ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc38f7e5bfad6d7c2865ed7c94d483c8b9b887a47853e4a3c16ce957ce1f06a0
3
+ size 35120734
Datasets/flattened_drug_dataset_cleaned.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0669d5d7366973a342a3cc35321366a02837c66ac5e7c28c3bf0569897db5b84
3
+ size 31338099
Scripts/Answer_Generation.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Answer Generation Module for Retrieval-based Medical QA Chatbot
3
+ =================================================================
4
+ This module handles:
5
+ 1. Building prompts for LLMs
6
+ 2. Querying the Groq API with selected context
7
+ 3. Generating a final answer based on retrieved chunks
8
+ """
9
+
10
+ from openai import OpenAI
11
+ import os
12
+ from Retrieval import Retrieval_averagedQP
13
+
14
+ # -------------------------------
15
+ # Groq API Client Setup
16
+ # -------------------------------
17
+
18
+ client = OpenAI(
19
+ api_key=os.environ.get("GROQ_API_KEY"),
20
+ base_url="https://api.groq.com/openai/v1"
21
+ )
22
+
23
+ # -------------------------------
24
+ # Function: Query Groq API
25
+ # -------------------------------
26
+
27
+ def query_groq(prompt, model="meta-llama/llama-4-scout-17b-16e-instruct", max_tokens=300):
28
+ """
29
+ Sends a prompt to Groq API and returns the generated response.
30
+ Parameters:
31
+ prompt (str): The text prompt for the model.
32
+ model (str): Model name deployed on Groq API.
33
+ max_tokens (int): Maximum tokens allowed in the output.
34
+ Returns:
35
+ str: Model-generated response text.
36
+ """
37
+ response = client.chat.completions.create(
38
+ model=model,
39
+ messages=[
40
+ {"role": "system", "content": "You are a biomedical assistant."},
41
+ {"role": "user", "content": prompt}
42
+ ],
43
+ temperature=0.7,
44
+ max_tokens=max_tokens
45
+ )
46
+ return response.choices[0].message.content.strip()
47
+
48
+ # -------------------------------
49
+ # Function: Build Prompt
50
+ # -------------------------------
51
+
52
+ def build_prompt(question, context):
53
+ """
54
+ Constructs a prompt for the model combining the user question and retrieved context.
55
+ Parameters:
56
+ question (str): User's question.
57
+ context (str): Retrieved relevant text chunks.
58
+ Returns:
59
+ str: Complete prompt text.
60
+ """
61
+ return f"""Strictly based on the following information, answer the question: {question}
62
+ Do not explain the context, just provide a direct answer.
63
+ Context:
64
+ {context}
65
+ """
66
+
67
+ # -------------------------------
68
+ # Function: Answer Generation
69
+ # -------------------------------
70
+
71
+ def answer_generation(question, top_chunks, top_k=3):
72
+ """
73
+ Generates an answer based on retrieved top chunks.
74
+ Parameters:
75
+ question (str): User's question.
76
+ top_chunks (DataFrame): Retrieved top chunks with context.
77
+ top_k (int): Number of top chunks to use for answer generation.
78
+ Returns:
79
+ str: Final generated answer.
80
+ """
81
+ # Select top-k chunks
82
+ top_chunks = top_chunks.head(top_k)
83
+ print("[Answer Generation] Top chunks selected for generation.")
84
+
85
+ # Join context
86
+ context = "\n".join(top_chunks["chunk_text"].tolist())
87
+
88
+ # Build prompt and query Groq
89
+ prompt = build_prompt(question, context)
90
+ answer = query_groq(prompt)
91
+
92
+ return answer
93
+
94
+ # -------------------------------
95
+ # Example Usage (Uncomment to Test)
96
+ # -------------------------------
97
+
98
+ # question = "How is Aztreonam inhalation used?"
99
+ # answer = answer_generation(question, top_chunks)
100
+ # print("Generated Answer:", answer)
Scripts/Query_processing.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Query Processing Pipeline for Retrieval-based QA Chatbot
3
+ ========================================================
4
+
5
+ This module handles:
6
+ 1. Query preprocessing
7
+ 2. Intent and sub-intent classification
8
+ 3. Named Entity Recognition (NER) using SciSpaCy
9
+
10
+ """
11
+
12
+ import spacy
13
+ import re
14
+ from typing import List, Tuple
15
+
16
+ # Load pre-trained SciSpaCy model for biomedical NER
17
+ ner_model = spacy.load("en_core_sci_md")
18
+
19
+ # -------------------------------
20
+ # Rule-Based Intent Classification
21
+ # -------------------------------
22
+
23
+ def classify_intent(question: str) -> str:
24
+ """
25
+ Classify the user's query into a high-level intent based on keywords.
26
+ Replace this rule-based system with ML-based intent detection for scalability.
27
+
28
+ Parameters:
29
+ question (str): The user's question.
30
+
31
+ Returns:
32
+ str: One of ['description', 'before_using', 'proper_use', 'precautions', 'side_effects']
33
+ """
34
+ q = question.lower()
35
+
36
+ if re.search(r"\bwhat is\b|\bused for\b|\bdefine\b", q):
37
+ return "description"
38
+ elif re.search(r"\bbefore using\b|\bshould I tell\b|\bdoctor know\b", q):
39
+ return "before_using"
40
+ elif re.search(r"\bhow to\b|\bdosage\b|\btake\b|\binstructions\b", q):
41
+ return "proper_use"
42
+ elif re.search(r"\bprecaution\b|\bpregnan\b|\bbreastfeed\b|\brisk\b", q):
43
+ return "precautions"
44
+ elif re.search(r"\bside effect\b|\badverse\b|\bnausea\b|\bdizziness\b", q):
45
+ return "side_effects"
46
+ else:
47
+ return "description" # default fallback
48
+
49
+
50
+ # -------------------------------
51
+ # Subsection Classification
52
+ # -------------------------------
53
+
54
+ def classify_subsection(question: str) -> str:
55
+ """
56
+ Identify more granular subtopics within each main intent.
57
+
58
+ Parameters:
59
+ question (str): The user's question.
60
+
61
+ Returns:
62
+ str: Sub-intent such as 'more common', 'incidence not known', etc.
63
+ """
64
+ q = question.lower()
65
+
66
+ if re.search(r"\bcommon side effects\b|\busual symptoms\b", q):
67
+ return "more common"
68
+ elif re.search(r"\bunknown\b|\brare\b|\bincidence\b", q):
69
+ return "incidence not known"
70
+ elif re.search(r"\bchildren\b|\bpediatric\b|\bkids\b", q):
71
+ return "pediatric"
72
+ elif re.search(r"\bbreastfeed\b|\bnursing\b|\blactation\b", q):
73
+ return "breastfeeding"
74
+ elif re.search(r"\belderly\b|\bgeriatric\b", q):
75
+ return "geriatric"
76
+ elif re.search(r"\binteract\b|\bcombination\b|\bcontraindications\b", q):
77
+ return "drug interactions"
78
+ else:
79
+ return ""
80
+
81
+
82
+ # -------------------------------
83
+ # Named Entity Extraction
84
+ # -------------------------------
85
+
86
+ def extract_entities_spacy(question: str) -> List[str]:
87
+ """
88
+ Use SciSpaCy NER model to extract biomedical entities.
89
+
90
+ Parameters:
91
+ question (str): User query.
92
+
93
+ Returns:
94
+ List[str]: Unique list of extracted entities.
95
+ """
96
+ doc = ner_model(question)
97
+ return list(set(ent.text for ent in doc.ents))
98
+
99
+
100
+ # -------------------------------
101
+ # Query Preprocessing Wrapper
102
+ # -------------------------------
103
+
104
+ def preprocess_query(raw_query: str) -> Tuple[Tuple[str, str], List[str]]:
105
+ """
106
+ Main preprocessing function that extracts:
107
+ - Intent
108
+ - Subsection
109
+ - Named Entities
110
+
111
+ Parameters:
112
+ raw_query (str): The raw user question.
113
+
114
+ Returns:
115
+ Tuple[Tuple[str, str], List[str]]: ((intent, sub_intent), list of entities)
116
+ """
117
+ try:
118
+ intent = classify_intent(raw_query)
119
+ sub_intent = classify_subsection(raw_query)
120
+ entities = extract_entities_spacy(raw_query)
121
+
122
+ if not entities:
123
+ print("[NER fallback] No entities found. Using raw query.")
124
+ return (intent or "", sub_intent or ""), []
125
+
126
+ print(f"[Query Processed] Intent = {intent} | Subsection = {sub_intent} | Entities = {entities}")
127
+ return (intent or "", sub_intent or ""), entities
128
+
129
+ except Exception as e:
130
+ print(f"[Preprocessing failed] {e}")
131
+ return ("", ""), []
Scripts/Retrieval.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Retrieval and FAISS Embedding Module for Medical QA Chatbot
3
+ ============================================================
4
+
5
+ This module handles:
6
+ 1. Embedding documents
7
+ 2. Building and saving FAISS index
8
+ 3. Retrieval with initial FAISS search + reranking using BioBERT similarity
9
+ """
10
+
11
+ import faiss
12
+ import pandas as pd
13
+ import numpy as np
14
+ import torch
15
+ from sentence_transformers import SentenceTransformer, util
16
+ from sklearn.preprocessing import normalize
17
+ from Query_processing import preprocess_query
18
+ import os
19
+
20
+ # -------------------------------
21
+ # File Paths
22
+ # -------------------------------
23
+
24
+ # Get the project root directory (one level up from script_dir)
25
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
26
+
27
+ # Absolute paths for dataset and index files
28
+ csv_path = os.path.join(project_root, 'Datasets', 'flattened_drug_dataset_cleaned.csv')
29
+ faiss_index_path = os.path.join(project_root, 'Vectors', 'faiss_index.idx')
30
+ doc_metadata_path = os.path.join(project_root, 'Vectors', 'doc_metadata.pkl')
31
+ doc_vectors_path = os.path.join(project_root, 'Vectors', 'doc_vectors.npy')
32
+
33
+ # Load the dataset
34
+ df = pd.read_csv(csv_path).dropna(subset=['chunk_text'])
35
+
36
+ # -------------------------------
37
+ # Model Initialization
38
+ # -------------------------------
39
+
40
+ fast_embedder = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')
41
+ biobert = SentenceTransformer('pritamdeka/BioBERT-mnli-snli-scinli-scitail-mednli-stsb')
42
+
43
+ # -------------------------------
44
+ # Function: Embed and Build FAISS Index
45
+ # -------------------------------
46
+
47
+ def Embed_and_FAISS():
48
+ """
49
+ Embeds the drug dataset and builds a FAISS index for fast retrieval.
50
+ Saves the index, metadata, and document vectors to disk.
51
+ """
52
+ print("Embedding document chunks using fast embedder...")
53
+
54
+ # Build full context strings
55
+ df['full_text'] = df.apply(lambda x: f"{x['drug_name']} | {x['section']} > {x['subsection']} | {x['chunk_text']}", axis=1)
56
+
57
+ full_texts = df['full_text'].tolist()
58
+ doc_embeddings = fast_embedder.encode(full_texts, convert_to_numpy=True, show_progress_bar=True)
59
+
60
+ # Normalize embeddings and build index
61
+ doc_embeddings = normalize(doc_embeddings, axis=1, norm='l2')
62
+ dimension = doc_embeddings.shape[1]
63
+ index = faiss.IndexFlatIP(dimension)
64
+ index.add(doc_embeddings)
65
+
66
+ # Save index and metadata
67
+ faiss.write_index(index, faiss_index_path)
68
+ df.to_pickle(doc_metadata_path)
69
+ np.save(doc_vectors_path, doc_embeddings)
70
+
71
+ print("FAISS index built and saved successfully.")
72
+
73
+ # -------------------------------
74
+ # Function: Retrieve with Context and Averaged Embeddings
75
+ # -------------------------------
76
+
77
+ def retrieve_with_context_averagedembeddings(query, top_k=10, predicted_intent=None, detected_entities=None, alpha=0.8):
78
+ """
79
+ Retrieve top chunks using FAISS followed by reranking with BioBERT similarity.
80
+
81
+ Parameters:
82
+ query (str): User query text.
83
+ top_k (int): Number of top results to retrieve.
84
+ predicted_intent (str, optional): Detected intent to adjust retrieval.
85
+ detected_entities (list, optional): List of named entities.
86
+ alpha (float): Weight for combining query and intent embeddings.
87
+
88
+ Returns:
89
+ pd.DataFrame: Retrieved chunks with metadata and reranked scores.
90
+ """
91
+ print(f"[Retrieval Pipeline Started] Query: {query}")
92
+
93
+ # Embed and normalize the query
94
+ query_vec = fast_embedder.encode([query], convert_to_numpy=True)
95
+
96
+ if predicted_intent:
97
+ intent_vec = fast_embedder.encode([predicted_intent], convert_to_numpy=True)
98
+ query_vec = normalize((alpha * query_vec + (1 - alpha) * intent_vec), axis=1)
99
+
100
+ # Load FAISS index and search
101
+ index = faiss.read_index(faiss_index_path)
102
+ D, I = index.search(query_vec, top_k)
103
+
104
+ df_meta = pd.read_pickle(doc_metadata_path)
105
+ retrieved_df = df_meta.loc[I[0]].copy()
106
+ retrieved_df['faiss_score'] = D[0]
107
+
108
+ # BioBERT reranking
109
+ query_emb = biobert.encode(query, convert_to_tensor=True)
110
+ chunk_embs = biobert.encode(retrieved_df['full_text'].tolist(), convert_to_tensor=True)
111
+ cos_scores = util.pytorch_cos_sim(query_emb, chunk_embs)[0]
112
+ reranked_idx = torch.argsort(cos_scores, descending=True)
113
+
114
+ # Boost scores based on intent, subsection match, or entity presence
115
+ results = []
116
+ for idx in reranked_idx:
117
+ idx = int(idx)
118
+ row = retrieved_df.iloc[idx]
119
+ score = cos_scores[idx].item()
120
+
121
+ section = row['section'][0] if isinstance(row['section'], tuple) else row['section']
122
+ subsection = row['subsection'][0] if isinstance(row['subsection'], tuple) else row['subsection']
123
+ if isinstance(predicted_intent, tuple):
124
+ predicted_intent = predicted_intent[0]
125
+
126
+ if predicted_intent and section.strip().lower() == predicted_intent.strip().lower():
127
+ score += 0.05
128
+ if predicted_intent and predicted_intent.lower() in subsection.strip().lower():
129
+ score += 0.03
130
+ if detected_entities:
131
+ if any(ent.lower() in row['chunk_text'].lower() for ent in detected_entities):
132
+ score += 0.1
133
+
134
+ results.append({
135
+ 'chunk_id': row['chunk_id'],
136
+ 'drug_name': row['drug_name'],
137
+ 'section': row['section'],
138
+ 'subsection': row['subsection'],
139
+ 'chunk_text': row['chunk_text'],
140
+ 'faiss_score': row['faiss_score'],
141
+ 'semantic_similarity_score': score
142
+ })
143
+
144
+ return pd.DataFrame(results)
145
+
146
+ # -------------------------------
147
+ # Function: Retrieval Wrapper
148
+ # -------------------------------
149
+
150
+ def Retrieval_averagedQP(raw_query, intent, entities, top_k=10, alpha=0.8):
151
+ """
152
+ Wrapper to retrieve top-k chunks given a raw user query.
153
+
154
+ Parameters:
155
+ raw_query (str): The user query.
156
+ intent (str): Predicted intent from query processing.
157
+ entities (list): Detected biomedical entities.
158
+ top_k (int): Number of top results to return.
159
+ alpha (float): Weighting between query and intent embeddings.
160
+
161
+ Returns:
162
+ pd.DataFrame: Top retrieved chunks with scores.
163
+ """
164
+ results_df = retrieve_with_context_averagedembeddings(
165
+ raw_query,
166
+ top_k=top_k,
167
+ predicted_intent=intent,
168
+ detected_entities=entities,
169
+ alpha=alpha
170
+ )
171
+ return results_df[['chunk_id', 'drug_name', 'section', 'subsection', 'chunk_text', 'faiss_score', 'semantic_similarity_score']]
Scripts/__pycache__/Answer_Generation.cpython-311.pyc ADDED
Binary file (3.37 kB). View file
 
Scripts/__pycache__/Query_processing.cpython-311.pyc ADDED
Binary file (5.19 kB). View file
 
Scripts/__pycache__/Retrieval.cpython-311.pyc ADDED
Binary file (8.62 kB). View file
 
Scripts/demo.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Main Execution Script for Retrieval-based Medical QA Chatbot
3
+ ============================================================
4
+
5
+ This script handles:
6
+ 1. Query preprocessing
7
+ 2. Information retrieval
8
+ 3. Answer generation
9
+ """
10
+
11
+ import warnings
12
+ warnings.filterwarnings("ignore", category=UserWarning)
13
+
14
+
15
+ from dotenv import load_dotenv
16
+ load_dotenv()
17
+
18
+
19
+ from Query_processing import preprocess_query
20
+ from Retrieval import Retrieval_averagedQP
21
+ from Answer_Generation import answer_generation
22
+ from Retrieval import Embed_and_FAISS
23
+
24
+ # -------------------------------
25
+ # Optional: Embed and Store FAISS Index
26
+ # -------------------------------
27
+ # Uncomment the below line to generate embeddings and build the FAISS index if not already done.
28
+ # Embed_and_FAISS()
29
+
30
+ # -------------------------------
31
+ # Define User Question
32
+ # -------------------------------
33
+
34
+ Question = input("Enter your question: ")
35
+
36
+ # -------------------------------
37
+ # Step 1: Query Preprocessing
38
+ # -------------------------------
39
+
40
+ (intent, sub_intent), entities = preprocess_query(Question)
41
+
42
+ # -------------------------------
43
+ # Step 2: Retrieve Relevant Chunks
44
+ # -------------------------------
45
+
46
+ top_chunks = Retrieval_averagedQP(Question, intent, entities, top_k=10, alpha=0.8)
47
+
48
+ # -------------------------------
49
+ # Step 3: Answer Generation
50
+ # -------------------------------
51
+
52
+ Generated_answer = answer_generation(Question, top_chunks, top_k=3)
53
+
54
+ # -------------------------------
55
+ # Display Generated Answer
56
+ # -------------------------------
57
+
58
+ print("Generated Answer:", Generated_answer)
Vectors/doc_metadata.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:800157a95b50080634fdce730014af49a8e0cf01d2dbb484785b15936dc9abff
3
+ size 53368209
Vectors/doc_vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f54da3cd890cf384fdc3b7abcd6ed5f840c0f53da30615fd417fc8256fd1b5ca
3
+ size 70190720
Vectors/faiss_index.idx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:58d68a5ccb27c94e357ab12eec21d5d54d903949ae37648202643eb33387156b
3
+ size 70190637