kramachan commited on
Commit
c138f03
·
verified ·
1 Parent(s): 3cb42e2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +202 -33
src/streamlit_app.py CHANGED
@@ -1,40 +1,209 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  """
7
- # Welcome to Streamlit!
8
 
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- In the meantime, below is an example of what you can do with just a few lines of code:
 
 
 
14
  """
15
 
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
2
+ import logging
3
+ from dotenv import load_dotenv
4
  import streamlit as st
5
+ from langchain_chroma import Chroma
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_openai import ChatOpenAI
8
 
9
+ # Get a logger for this module
10
+ logger = logging.getLogger(__name__)
11
+
12
+ logger.info("Design Page...")
13
+ # -------------------------------
14
+ # PAGE CONFIG (MUST BE FIRST)
15
+ # -------------------------------
16
+ PORT = int(os.environ.get("PORT", 8501))
17
+
18
+ st.markdown("""
19
+ <style>
20
+ .main-title {
21
+ font-size: 52px;
22
+ font-weight: 800;
23
+ text-align: center;
24
+ color: #0B5ED7;
25
+ margin-bottom: 5px;
26
+ }
27
+
28
+ .sub-title {
29
+ font-size: 20px;
30
+ text-align: center;
31
+ color: #555555;
32
+ margin-bottom: 30px;
33
+ }
34
+ </style>
35
+ """, unsafe_allow_html=True)
36
+
37
+ st.markdown(
38
+ '<div class="main-title">💊 AI Medical Labelling System</div>',
39
+ unsafe_allow_html=True
40
+ )
41
+
42
+ st.markdown(
43
+ '<div class="sub-title">Simplifying FDA Drug Safety Information using Generative AI & RAG</div>',
44
+ unsafe_allow_html=True
45
+ )
46
+
47
+ # -------------------------------
48
+ # CUSTOM CSS (FANCY DESIGN)
49
+ # -------------------------------
50
+ st.markdown("""
51
+ <style>
52
+ .main {
53
+ background-color: #f7f9fc;
54
+ }
55
+
56
+ .big-title {
57
+ font-size:40px;
58
+ font-weight:700;
59
+ color:#1f4e79;
60
+ }
61
+
62
+ .subtitle {
63
+ font-size:18px;
64
+ color:#555;
65
+ }
66
+
67
+ .result-card {
68
+ background-color:white;
69
+ padding:20px;
70
+ border-radius:12px;
71
+ box-shadow:0px 2px 10px rgba(0,0,0,0.08);
72
+ margin-top:15px;
73
+ }
74
+ </style>
75
+ """, unsafe_allow_html=True)
76
+
77
+ # -------------------------------
78
+ # HEADER
79
+ # -------------------------------
80
+
81
+ st.divider()
82
+
83
+ # -------------------------------
84
+ # SIDEBAR CONTROLS
85
+ # -------------------------------
86
+ with st.sidebar:
87
+ st.header("⚙️ Search Options")
88
+
89
+ drug_name = st.text_input(
90
+ "Drug Name",
91
+ placeholder="PHENYTOIN SODIUM"
92
+ )
93
+
94
+ selected_results = st.radio(
95
+ "Information Type",
96
+ ["Side Effects", "Warnings", "Both"]
97
+ )
98
+
99
+ run_button = st.button("🔍 Generate Explanation")
100
+
101
+ # -------------------------------
102
+ # LOAD ENV + MODELS
103
+ # -------------------------------
104
+
105
+ logger.info("Loading HuggingFace embedding model...")
106
+
107
+ load_dotenv()
108
+
109
+ working_dir = os.path.dirname(os.path.abspath(__file__))
110
+
111
+ embeddings = HuggingFaceEmbeddings(
112
+ model_name="sentence-transformers/all-MiniLM-L6-v2"
113
+ )
114
+
115
+ vectordb = Chroma(
116
+ persist_directory=os.path.join(working_dir, "Chroma_db"),
117
+ embedding_function=embeddings
118
+ )
119
+
120
+ logger.info("Calling OpenAI model gpt-4o-mini...")
121
+
122
+ llm = ChatOpenAI(
123
+ model="gpt-4o-mini",
124
+ temperature=0
125
+ )
126
+
127
+ # -------------------------------
128
+ # RAG FUNCTION
129
+ # -------------------------------
130
+ def generate_section(drug_name, section, rules):
131
+
132
+ results = vectordb.get(
133
+ where={
134
+ "$and": [
135
+ {"generic_name": drug_name},
136
+ {"section": section}
137
+ ]
138
+ }
139
+ )
140
+
141
+ documents = results.get("documents", [])
142
+
143
+ if not documents:
144
+ st.warning(f"No data found for {section}")
145
+ return
146
+
147
+ context = "\n".join(set(documents))
148
+
149
+ prompt = f"""
150
+ You are a medical assistant.
151
+
152
+ Rewrite the FDA drug information into simplified,
153
+ easy-to-understand language.
154
+
155
+ Rules:
156
+ {rules}
157
+
158
+ Drug: {drug_name}
159
+
160
+ FDA TEXT:
161
+ {context}
162
  """
 
163
 
164
+ with st.spinner("🧠 AI is analysing FDA data..."):
165
+ response = llm.invoke(prompt)
166
+
167
+ st.markdown(
168
+ f'<div class="result-card">{response.content}</div>',
169
+ unsafe_allow_html=True
170
+ )
171
+
172
+ logger.info("Configuring prompt..")
173
+ # -------------------------------
174
+ # RULES
175
+ # -------------------------------
176
+ SIDE_EFFECT_RULES = """
177
+ - Use simple English
178
+ - Bullet points (max 7)
179
+ - Group similar side effects
180
+ - Separate common vs serious
181
+ """
182
 
183
+ WARNING_RULES = """
184
+ - Use simple English
185
+ - Bullet points (max 7)
186
+ - Group warnings clearly
187
  """
188
 
189
+ SECTION_MAP = {
190
+ "Side Effects": [("adverse_reactions", SIDE_EFFECT_RULES)],
191
+ "Warnings": [("warnings_and_cautions", WARNING_RULES)],
192
+ "Both": [
193
+ ("adverse_reactions", SIDE_EFFECT_RULES),
194
+ ("warnings_and_cautions", WARNING_RULES),
195
+ ],
196
+ }
197
+
198
+ # -------------------------------
199
+ # MAIN ACTION
200
+ # -------------------------------
201
+ if run_button and drug_name:
202
+
203
+ st.subheader(f"Results for: {drug_name.upper()}")
204
+
205
+ for section, rules in SECTION_MAP[selected_results]:
206
+ generate_section(drug_name, section, rules)
207
+
208
+ elif run_button:
209
+ st.warning("Please enter a drug name.")