mnaz26 commited on
Commit
22e78a2
·
verified ·
1 Parent(s): 5b4c8e7

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +216 -0
  2. requirements.txt +9 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ===============================
2
+ # 1️⃣ Install dependencies (only in Colab, HF Space installs from requirements.txt)
3
+ # ===============================
4
+ # !pip install -q groq datasets sentence-transformers faiss-cpu gradio matplotlib pandas tqdm reportlab
5
+
6
+ # ===============================
7
+ # 2️⃣ Imports
8
+ # ===============================
9
+ import os
10
+ import faiss
11
+ import numpy as np
12
+ import gradio as gr
13
+ import pandas as pd
14
+ import matplotlib.pyplot as plt
15
+ from datasets import load_dataset
16
+ from sentence_transformers import SentenceTransformer
17
+ from groq import Groq
18
+ import datetime
19
+ from io import BytesIO
20
+ from reportlab.lib.pagesizes import letter
21
+ from reportlab.pdfgen import canvas
22
+ from reportlab.lib.utils import ImageReader
23
+
24
+ # ===============================
25
+ # 3️⃣ Groq Client
26
+ # ===============================
27
+ client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
28
+
29
+ # ===============================
30
+ # 4️⃣ Load datasets for RAG
31
+ # ===============================
32
+ medical_ds = load_dataset("lavita/medical-qa-datasets", "all-processed", split="train[:1000]")
33
+ stress_ds = load_dataset("Amod/mental_health_counseling_conversations", split="train[:500]")
34
+
35
+ # ===============================
36
+ # 5️⃣ Prepare documents
37
+ # ===============================
38
+ documents = []
39
+ for row in medical_ds:
40
+ instr = row.get("instruction","") or ""
41
+ inp = row.get("input","") or ""
42
+ out = row.get("output","") or ""
43
+ text = instr.strip()
44
+ if inp.strip(): text += " " + inp.strip()
45
+ text += " " + out.strip()
46
+ documents.append(text)
47
+ for row in stress_ds:
48
+ context = row.get("Context","") or ""
49
+ response = row.get("Response","") or ""
50
+ documents.append(context + " " + response)
51
+
52
+ # ===============================
53
+ # 6️⃣ Embeddings + FAISS
54
+ # ===============================
55
+ embedder = SentenceTransformer("all-MiniLM-L6-v2")
56
+ embeddings = embedder.encode(documents, convert_to_numpy=True, show_progress_bar=True)
57
+ dimension = embeddings.shape[1]
58
+ index = faiss.IndexFlatL2(dimension)
59
+ index.add(embeddings)
60
+
61
+ # ===============================
62
+ # 7️⃣ RAG functions
63
+ # ===============================
64
+ def retrieve_docs(query,k=5):
65
+ query_embedding = embedder.encode([query])
66
+ distances, indices = index.search(query_embedding,k)
67
+ return [documents[i] for i in indices[0]]
68
+
69
+ def rag_answer(query):
70
+ retrieved = retrieve_docs(query)
71
+ context = "\n\n".join(retrieved)
72
+ prompt = f"""
73
+ You are a medical assistant.
74
+ Use ONLY the context below to answer.
75
+ Do NOT diagnose anyone.
76
+ Provide supportive and informative responses.
77
+
78
+ Context:
79
+ {context}
80
+
81
+ Question:
82
+ {query}
83
+ """
84
+ response = client.chat.completions.create(
85
+ model="llama-3.3-70b-versatile",
86
+ messages=[{"role":"user","content":prompt}],
87
+ )
88
+ return response.choices[0].message.content
89
+
90
+ # ===============================
91
+ # 8️⃣ CSV persistence
92
+ # ===============================
93
+ CSV_FILE = "daily_entries.csv"
94
+ if os.path.exists(CSV_FILE):
95
+ df = pd.read_csv(CSV_FILE, parse_dates=["date"])
96
+ else:
97
+ df = pd.DataFrame(columns=["date","user_id","stress","mood","sleep_hours"])
98
+
99
+ def add_daily_entry(user_id, stress, mood, sleep_hours):
100
+ global df
101
+ today = datetime.date.today()
102
+ new_row = pd.DataFrame([{
103
+ "date": today,
104
+ "user_id": user_id,
105
+ "stress": stress,
106
+ "mood": mood,
107
+ "sleep_hours": sleep_hours
108
+ }])
109
+ df = pd.concat([df,new_row], ignore_index=True)
110
+ df.to_csv(CSV_FILE,index=False)
111
+ return f"Entry for {today} saved!"
112
+
113
+ # ===============================
114
+ # 9️⃣ Weekly report + LLaMA + chart
115
+ # ===============================
116
+ def generate_weekly_report(user_id):
117
+ global df
118
+ df['date'] = pd.to_datetime(df['date'])
119
+ user_df = df[df['user_id']==user_id]
120
+ if user_df.empty:
121
+ return "No data available yet.", None, None
122
+ user_df['week'] = user_df['date'].dt.isocalendar().week
123
+
124
+ weekly_summary = user_df.groupby('week').agg({
125
+ "stress":["mean","max"],
126
+ "mood":["mean","min"],
127
+ "sleep_hours":["mean","min"]
128
+ })
129
+ weekly_summary['stress_change'] = weekly_summary['stress']['mean'].diff()
130
+ weekly_summary['mood_change'] = weekly_summary['mood']['mean'].diff()
131
+ weekly_summary['sleep_change'] = weekly_summary['sleep_hours']['mean'].diff()
132
+
133
+ # Charts
134
+ fig, ax = plt.subplots(3,1,figsize=(8,10))
135
+ weekly_summary['stress']['mean'].plot(ax=ax[0],title="Weekly Avg Stress",color='red',marker='o')
136
+ weekly_summary['mood']['mean'].plot(ax=ax[1],title="Weekly Avg Mood",color='blue',marker='o')
137
+ weekly_summary['sleep_hours']['mean'].plot(ax=ax[2],title="Weekly Avg Sleep Hours",color='green',marker='o')
138
+ plt.tight_layout()
139
+ chart_buf = BytesIO()
140
+ plt.savefig(chart_buf, format="png")
141
+ chart_buf.seek(0)
142
+
143
+ # LLaMA explanation
144
+ trend_prompt = f"""
145
+ You are a wellness data analyst AI.
146
+ Here is the weekly summary for user {user_id}:
147
+ {weekly_summary.tail(4)}
148
+
149
+ Explain in plain language the trends in stress, mood, and sleep over the past 4 weeks.
150
+ """
151
+ response = client.chat.completions.create(
152
+ model="llama-3.3-70b-versatile",
153
+ messages=[{"role":"user","content":trend_prompt}]
154
+ )
155
+ explanation = response.choices[0].message.content
156
+
157
+ # Generate PDF
158
+ pdf_buf = BytesIO()
159
+ c = canvas.Canvas(pdf_buf, pagesize=letter)
160
+ width, height = letter
161
+ c.setFont("Helvetica",12)
162
+ y = height - 40
163
+ c.drawString(30,y,f"Weekly Trend Report for User {user_id}")
164
+ y -= 30
165
+ for line in explanation.split("\n"):
166
+ c.drawString(30,y,line)
167
+ y -= 15
168
+ if y < 100:
169
+ c.showPage()
170
+ y = height - 40
171
+ # Add chart
172
+ img = ImageReader(chart_buf)
173
+ c.showPage()
174
+ c.drawImage(img,50,150,width=500,height=400)
175
+ c.save()
176
+ pdf_buf.seek(0)
177
+
178
+ return explanation, chart_buf, pdf_buf
179
+
180
+ # ===============================
181
+ # 🔟 Gradio interface
182
+ # ===============================
183
+ with gr.Blocks() as demo:
184
+ gr.Markdown("# 🧠 Medical & Stress RAG Assistant with Persistent Reports and PDF Export")
185
+
186
+ with gr.Tab("Daily Entry"):
187
+ gr.Markdown("Enter daily stress, mood, and sleep hours.")
188
+ stress = gr.Slider(0,10,label="Stress Level")
189
+ mood = gr.Slider(0,10,label="Mood Level")
190
+ sleep = gr.Number(label="Sleep Hours")
191
+ submit = gr.Button("Save Entry")
192
+ output_entry = gr.Textbox(label="Status")
193
+ submit.click(add_daily_entry,[gr.Number(value=1,label="User ID"),stress,mood,sleep],output_entry)
194
+
195
+ with gr.Tab("Weekly Trend Report"):
196
+ gr.Markdown("View weekly summary, trends, and export PDF.")
197
+ user_id_input = gr.Number(value=1,label="User ID")
198
+ report_output = gr.Textbox(label="Weekly Trend Explanation")
199
+ chart_output = gr.Image(label="Trend Chart")
200
+ pdf_output = gr.File(label="Download PDF")
201
+ generate = gr.Button("Generate Report")
202
+ generate.click(generate_weekly_report,[user_id_input],[report_output,chart_output,pdf_output])
203
+
204
+ with gr.Tab("Medical QA"):
205
+ gr.Markdown("Ask questions about stress, mood, sleep, or general wellness.")
206
+ chatbot = gr.Chatbot()
207
+ msg = gr.Textbox(label="Your Question")
208
+ clear = gr.Button("Clear Chat")
209
+ def respond(message,history):
210
+ answer = rag_answer(message)
211
+ history.append((message,answer))
212
+ return "",history
213
+ msg.submit(respond,[msg,chatbot],[msg,chatbot])
214
+ clear.click(lambda: None,None,chatbot,queue=False)
215
+
216
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ faiss-cpu
3
+ sentence-transformers
4
+ datasets
5
+ pandas
6
+ matplotlib
7
+ tqdm
8
+ groq
9
+ reportlab