| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | import os |
| | import faiss |
| | import numpy as np |
| | import gradio as gr |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | from datasets import load_dataset |
| | from sentence_transformers import SentenceTransformer |
| | from groq import Groq |
| | import datetime |
| | from io import BytesIO |
| | from reportlab.lib.pagesizes import letter |
| | from reportlab.pdfgen import canvas |
| | from reportlab.lib.utils import ImageReader |
| | from PIL import Image |
| |
|
| | REPORTS_DIR = "reports" |
| | os.makedirs(REPORTS_DIR, exist_ok=True) |
| |
|
| | |
| | |
| | |
| | client = Groq(api_key=os.environ.get("GROQ_API_KEY")) |
| |
|
| | |
| | |
| | |
| | medical_ds = load_dataset("lavita/medical-qa-datasets", "all-processed", split="train[:1000]") |
| | stress_ds = load_dataset("Amod/mental_health_counseling_conversations", split="train[:500]") |
| |
|
| | |
| | |
| | |
| | documents = [] |
| | for row in medical_ds: |
| | instr = row.get("instruction","") or "" |
| | inp = row.get("input","") or "" |
| | out = row.get("output","") or "" |
| | text = instr.strip() |
| | if inp.strip(): text += " " + inp.strip() |
| | text += " " + out.strip() |
| | documents.append(text) |
| | for row in stress_ds: |
| | context = row.get("Context","") or "" |
| | response = row.get("Response","") or "" |
| | documents.append(context + " " + response) |
| |
|
| | |
| | |
| | |
| | embedder = SentenceTransformer("all-MiniLM-L6-v2") |
| | embeddings = embedder.encode(documents, convert_to_numpy=True, show_progress_bar=True) |
| | dimension = embeddings.shape[1] |
| | index = faiss.IndexFlatL2(dimension) |
| | index.add(embeddings) |
| |
|
| | |
| | |
| | |
| | def retrieve_docs(query,k=5): |
| | query_embedding = embedder.encode([query]) |
| | distances, indices = index.search(query_embedding,k) |
| | return [documents[i] for i in indices[0]] |
| |
|
| | def rag_answer(query): |
| | retrieved = retrieve_docs(query) |
| | context = "\n\n".join(retrieved) |
| | prompt = f""" |
| | You are a medical assistant. |
| | Use ONLY the context below to answer. |
| | Do NOT diagnose anyone. |
| | Provide supportive and informative responses. |
| | |
| | Context: |
| | {context} |
| | |
| | Question: |
| | {query} |
| | """ |
| | response = client.chat.completions.create( |
| | model="llama-3.3-70b-versatile", |
| | messages=[{"role":"user","content":prompt}], |
| | ) |
| | return response.choices[0].message.content |
| |
|
| | |
| | |
| | |
| | CSV_FILE = "daily_entries.csv" |
| | if os.path.exists(CSV_FILE): |
| | df = pd.read_csv(CSV_FILE, parse_dates=["date"]) |
| | else: |
| | df = pd.DataFrame(columns=["date","user_id","stress","mood","sleep_hours"]) |
| |
|
| | def add_daily_entry(user_id, stress, mood, sleep_hours): |
| | global df |
| | today = datetime.date.today() |
| | new_row = pd.DataFrame([{ |
| | "date": today, |
| | "user_id": user_id, |
| | "stress": stress, |
| | "mood": mood, |
| | "sleep_hours": sleep_hours |
| | }]) |
| | df = pd.concat([df,new_row], ignore_index=True) |
| | df.to_csv(CSV_FILE,index=False) |
| | return f"Entry for {today} saved!" |
| |
|
| | |
| | |
| | |
| | def generate_weekly_report(user_id): |
| | global df |
| | df['date'] = pd.to_datetime(df['date']) |
| | user_df = df[df['user_id'] == user_id] |
| |
|
| | if user_df.empty: |
| | return "No data available yet.", None, None |
| |
|
| | user_df['week'] = user_df['date'].dt.isocalendar().week |
| |
|
| | weekly_summary = user_df.groupby('week').agg({ |
| | "stress": ["mean", "max"], |
| | "mood": ["mean", "min"], |
| | "sleep_hours": ["mean", "min"] |
| | }) |
| |
|
| | weekly_summary['stress_change'] = weekly_summary['stress']['mean'].diff() |
| | weekly_summary['mood_change'] = weekly_summary['mood']['mean'].diff() |
| | weekly_summary['sleep_change'] = weekly_summary['sleep_hours']['mean'].diff() |
| |
|
| | |
| | fig, ax = plt.subplots(3, 1, figsize=(8, 10)) |
| |
|
| | weekly_summary['stress']['mean'].plot(ax=ax[0], title="Weekly Avg Stress", marker="o") |
| | weekly_summary['mood']['mean'].plot(ax=ax[1], title="Weekly Avg Mood", marker="o") |
| | weekly_summary['sleep_hours']['mean'].plot(ax=ax[2], title="Weekly Avg Sleep Hours", marker="o") |
| |
|
| | plt.tight_layout() |
| |
|
| | chart_buf = BytesIO() |
| | plt.savefig(chart_buf, format="png") |
| | plt.close() |
| | chart_buf.seek(0) |
| |
|
| | chart_image = Image.open(chart_buf) |
| |
|
| | |
| | trend_prompt = f""" |
| | You are a wellness data analyst AI. |
| | |
| | Here is the weekly summary: |
| | {weekly_summary.tail(4)} |
| | |
| | Explain the trends in stress, mood, and sleep in simple, policymaker-friendly language. |
| | """ |
| |
|
| | response = client.chat.completions.create( |
| | model="llama-3.3-70b-versatile", |
| | messages=[{"role": "user", "content": trend_prompt}] |
| | ) |
| |
|
| | explanation = response.choices[0].message.content |
| |
|
| | |
| | import time |
| |
|
| | timestamp = int(time.time()) |
| | pdf_path = f"{REPORTS_DIR}/weekly_report_user_{user_id}_{timestamp}.pdf" |
| |
|
| | c = canvas.Canvas(pdf_path, pagesize=letter) |
| | width, height = letter |
| |
|
| | c.setFont("Helvetica-Bold", 14) |
| | c.drawString(40, height - 40, "Weekly Mental Health Trend Report") |
| |
|
| | c.setFont("Helvetica", 11) |
| | y = height - 80 |
| | for line in explanation.split("\n"): |
| | c.drawString(40, y, line) |
| | y -= 14 |
| | if y < 100: |
| | c.showPage() |
| | y = height - 40 |
| |
|
| | c.showPage() |
| | c.drawImage(ImageReader(chart_buf), 50, 200, width=500, height=400) |
| | c.save() |
| |
|
| | return explanation, chart_image, pdf_path |
| |
|
| | |
| | |
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# 🧠 Medical & Stress RAG Assistant with Persistent Reports and PDF Export") |
| |
|
| | with gr.Tab("Daily Entry"): |
| | gr.Markdown("Enter daily stress, mood, and sleep hours.") |
| | stress = gr.Slider(0,10,label="Stress Level") |
| | mood = gr.Slider(0,10,label="Mood Level") |
| | sleep = gr.Number(label="Sleep Hours") |
| | submit = gr.Button("Save Entry") |
| | output_entry = gr.Textbox(label="Status") |
| | submit.click(add_daily_entry,[gr.Number(value=1,label="User ID"),stress,mood,sleep],output_entry) |
| |
|
| | with gr.Tab("Weekly Trend Report"): |
| | gr.Markdown("View weekly summary, trends, and export PDF.") |
| | user_id_input = gr.Number(value=1,label="User ID") |
| | report_output = gr.Textbox(label="Weekly Trend Explanation") |
| | chart_output = gr.Image(label="Trend Chart") |
| | pdf_output = gr.File(label="Download PDF") |
| | generate = gr.Button("Generate Report") |
| | generate.click(generate_weekly_report,[user_id_input],[report_output,chart_output,pdf_output]) |
| |
|
| | with gr.Tab("Medical QA"): |
| | gr.Markdown("Ask questions about stress, mood, sleep, or general wellness.") |
| |
|
| | chatbot = gr.Chatbot(label="Medical QA") |
| | msg = gr.Textbox(label="Your Question") |
| | clear = gr.Button("Clear Chat") |
| |
|
| | def respond(message, history): |
| | history = history or [] |
| |
|
| | answer = rag_answer(message) |
| |
|
| | history.append({ |
| | "role": "user", |
| | "content": message |
| | }) |
| | history.append({ |
| | "role": "assistant", |
| | "content": answer |
| | }) |
| |
|
| | return "", history |
| |
|
| | msg.submit(respond, [msg, chatbot], [msg, chatbot]) |
| | clear.click(lambda: [], None, chatbot) |
| |
|
| | demo.launch() |
| |
|