kramachan commited on
Commit
0fc1003
·
verified ·
1 Parent(s): 39626d1

Upload 2 files

Browse files
Files changed (2) hide show
  1. src/ETL_VectorDB.py +102 -0
  2. src/app.py +209 -0
src/ETL_VectorDB.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from langchain_core.documents import Document
4
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
5
+ from langchain_huggingface import HuggingFaceEmbeddings
6
+ from langchain_chroma import Chroma
7
+ from langchain_openai import ChatOpenAI
8
+ import json
9
+
10
+ # Set up logging configuration
11
+ logging.basicConfig(
12
+ level=logging.INFO,
13
+ format='%(asctime)s - %(levelname)s - %(message)s'
14
+ )
15
+
16
+ # Get a logger for this module
17
+ logger = logging.getLogger(__name__)
18
+
19
+ working_dir = os.path.dirname(os.path.abspath(__file__))
20
+ parent_dir = os.path.dirname(working_dir)
21
+ data_dir = f"{parent_dir}/"
22
+ vector_db_dir = f"{parent_dir}/vector_db"
23
+
24
+
25
+ logger.info("Reading Files Process Started...")
26
+ all_records = []
27
+
28
+ # loop through all files
29
+ for file_name in os.listdir(data_dir):
30
+ if file_name.endswith(".json"):
31
+ file_path = os.path.join(data_dir, file_name)
32
+
33
+ with open(file_path, "r", encoding="utf-8") as f:
34
+ data = json.load(f)
35
+
36
+ # if JSON contains list of records
37
+ if isinstance(data, list):
38
+ all_records.extend(data)
39
+ else:
40
+ all_records.append(data)
41
+
42
+ print("Total drug records:", len(all_records))
43
+
44
+ documents = []
45
+
46
+ for record in data:
47
+
48
+ drug = record.get("generic_name", ["UNKNOWN"])[0].upper()
49
+
50
+ # choose sections you want in RAG
51
+ sections = [
52
+ "indications_and_usage",
53
+ "warnings_and_cautions",
54
+ "adverse_reactions",
55
+ "drug_interactions"
56
+ ]
57
+
58
+ for section in sections:
59
+ if section in record:
60
+
61
+ for text in record[section]:
62
+
63
+ documents.append(
64
+ Document(
65
+ page_content=text,
66
+ metadata={
67
+ "generic_name": drug,
68
+ "section": section
69
+ }
70
+ )
71
+ )
72
+
73
+ print("Documents created:", len(documents))
74
+
75
+ logger.info("Split chunk Files Process Started...")
76
+ splitter = RecursiveCharacterTextSplitter(
77
+ chunk_size=800,
78
+ chunk_overlap=150
79
+ )
80
+
81
+ chunked_docs = splitter.split_documents(documents)
82
+
83
+ print("Chunks created:", len(chunked_docs))
84
+
85
+ logger.info("Embeddings Files Process Started...")
86
+
87
+ embeddings = HuggingFaceEmbeddings(
88
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
89
+ )
90
+ #%%
91
+ print("Chroma ready ✅")
92
+
93
+ logger.info(" VectorDB Process Started...")
94
+ vectordb = Chroma.from_documents(
95
+ documents=chunked_docs,
96
+ embedding=embeddings,
97
+ persist_directory="./chroma_db"
98
+ )
99
+
100
+ print("Vector DB created successfully ✅")
101
+ logger.info("VectorDB Process Completed...")
102
+
src/app.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from dotenv import load_dotenv
4
+ import streamlit as st
5
+ from langchain_chroma import Chroma
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_openai import ChatOpenAI
8
+
9
+ # Get a logger for this module
10
+ logger = logging.getLogger(__name__)
11
+
12
+ logger.info("Design Page...")
13
+ # -------------------------------
14
+ # PAGE CONFIG (MUST BE FIRST)
15
+ # -------------------------------
16
+ PORT = int(os.environ.get("PORT", 8501))
17
+
18
+ st.markdown("""
19
+ <style>
20
+ .main-title {
21
+ font-size: 52px;
22
+ font-weight: 800;
23
+ text-align: center;
24
+ color: #0B5ED7;
25
+ margin-bottom: 5px;
26
+ }
27
+
28
+ .sub-title {
29
+ font-size: 20px;
30
+ text-align: center;
31
+ color: #555555;
32
+ margin-bottom: 30px;
33
+ }
34
+ </style>
35
+ """, unsafe_allow_html=True)
36
+
37
+ st.markdown(
38
+ '<div class="main-title">💊 AI Medical Labelling System</div>',
39
+ unsafe_allow_html=True
40
+ )
41
+
42
+ st.markdown(
43
+ '<div class="sub-title">Simplifying FDA Drug Safety Information using Generative AI & RAG</div>',
44
+ unsafe_allow_html=True
45
+ )
46
+
47
+ # -------------------------------
48
+ # CUSTOM CSS (FANCY DESIGN)
49
+ # -------------------------------
50
+ st.markdown("""
51
+ <style>
52
+ .main {
53
+ background-color: #f7f9fc;
54
+ }
55
+
56
+ .big-title {
57
+ font-size:40px;
58
+ font-weight:700;
59
+ color:#1f4e79;
60
+ }
61
+
62
+ .subtitle {
63
+ font-size:18px;
64
+ color:#555;
65
+ }
66
+
67
+ .result-card {
68
+ background-color:white;
69
+ padding:20px;
70
+ border-radius:12px;
71
+ box-shadow:0px 2px 10px rgba(0,0,0,0.08);
72
+ margin-top:15px;
73
+ }
74
+ </style>
75
+ """, unsafe_allow_html=True)
76
+
77
+ # -------------------------------
78
+ # HEADER
79
+ # -------------------------------
80
+
81
+ st.divider()
82
+
83
+ # -------------------------------
84
+ # SIDEBAR CONTROLS
85
+ # -------------------------------
86
+ with st.sidebar:
87
+ st.header("⚙️ Search Options")
88
+
89
+ drug_name = st.text_input(
90
+ "Drug Name",
91
+ placeholder="PHENYTOIN SODIUM"
92
+ )
93
+
94
+ selected_results = st.radio(
95
+ "Information Type",
96
+ ["Side Effects", "Warnings", "Both"]
97
+ )
98
+
99
+ run_button = st.button("🔍 Generate Explanation")
100
+
101
+ # -------------------------------
102
+ # LOAD ENV + MODELS
103
+ # -------------------------------
104
+
105
+ logger.info("Loading HuggingFace embedding model...")
106
+
107
+ load_dotenv()
108
+
109
+ working_dir = os.path.dirname(os.path.abspath(__file__))
110
+
111
+ embeddings = HuggingFaceEmbeddings(
112
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
113
+ )
114
+
115
+ vectordb = Chroma(
116
+ persist_directory=os.path.join(working_dir, "Chroma_db"),
117
+ embedding_function=embeddings
118
+ )
119
+
120
+ logger.info("Calling OpenAI model gpt-4o-mini...")
121
+
122
+ llm = ChatOpenAI(
123
+ model="gpt-4o-mini",
124
+ temperature=0
125
+ )
126
+
127
+ # -------------------------------
128
+ # RAG FUNCTION
129
+ # -------------------------------
130
+ def generate_section(drug_name, section, rules):
131
+
132
+ results = vectordb.get(
133
+ where={
134
+ "$and": [
135
+ {"generic_name": drug_name},
136
+ {"section": section}
137
+ ]
138
+ }
139
+ )
140
+
141
+ documents = results.get("documents", [])
142
+
143
+ if not documents:
144
+ st.warning(f"No data found for {section}")
145
+ return
146
+
147
+ context = "\n".join(set(documents))
148
+
149
+ prompt = f"""
150
+ You are a medical assistant.
151
+
152
+ Rewrite the FDA drug information into simplified,
153
+ easy-to-understand language.
154
+
155
+ Rules:
156
+ {rules}
157
+
158
+ Drug: {drug_name}
159
+
160
+ FDA TEXT:
161
+ {context}
162
+ """
163
+
164
+ with st.spinner("🧠 AI is analysing FDA data..."):
165
+ response = llm.invoke(prompt)
166
+
167
+ st.markdown(
168
+ f'<div class="result-card">{response.content}</div>',
169
+ unsafe_allow_html=True
170
+ )
171
+
172
+ logger.info("Configuring prompt..")
173
+ # -------------------------------
174
+ # RULES
175
+ # -------------------------------
176
+ SIDE_EFFECT_RULES = """
177
+ - Use simple English
178
+ - Bullet points (max 7)
179
+ - Group similar side effects
180
+ - Separate common vs serious
181
+ """
182
+
183
+ WARNING_RULES = """
184
+ - Use simple English
185
+ - Bullet points (max 7)
186
+ - Group warnings clearly
187
+ """
188
+
189
+ SECTION_MAP = {
190
+ "Side Effects": [("adverse_reactions", SIDE_EFFECT_RULES)],
191
+ "Warnings": [("warnings_and_cautions", WARNING_RULES)],
192
+ "Both": [
193
+ ("adverse_reactions", SIDE_EFFECT_RULES),
194
+ ("warnings_and_cautions", WARNING_RULES),
195
+ ],
196
+ }
197
+
198
+ # -------------------------------
199
+ # MAIN ACTION
200
+ # -------------------------------
201
+ if run_button and drug_name:
202
+
203
+ st.subheader(f"Results for: {drug_name.upper()}")
204
+
205
+ for section, rules in SECTION_MAP[selected_results]:
206
+ generate_section(drug_name, section, rules)
207
+
208
+ elif run_button:
209
+ st.warning("Please enter a drug name.")