MindAI / app.py
mnaz26's picture
Upload app.py
ba7ceaf verified
# ===============================
# 1️⃣ Install dependencies (only in Colab, HF Space installs from requirements.txt)
# ===============================
# !pip install -q groq datasets sentence-transformers faiss-cpu gradio matplotlib pandas tqdm reportlab
# ===============================
# 2️⃣ Imports
# ===============================
import os
import faiss
import numpy as np
import gradio as gr
import pandas as pd
import matplotlib.pyplot as plt
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from groq import Groq
import datetime
from io import BytesIO
from reportlab.lib.pagesizes import letter
from reportlab.pdfgen import canvas
from reportlab.lib.utils import ImageReader
from PIL import Image
REPORTS_DIR = "reports"
os.makedirs(REPORTS_DIR, exist_ok=True)
# ===============================
# 3️⃣ Groq Client
# ===============================
client = Groq(api_key=os.environ.get("GROQ_API_KEY"))
# ===============================
# 4️⃣ Load datasets for RAG
# ===============================
medical_ds = load_dataset("lavita/medical-qa-datasets", "all-processed", split="train[:1000]")
stress_ds = load_dataset("Amod/mental_health_counseling_conversations", split="train[:500]")
# ===============================
# 5️⃣ Prepare documents
# ===============================
documents = []
for row in medical_ds:
instr = row.get("instruction","") or ""
inp = row.get("input","") or ""
out = row.get("output","") or ""
text = instr.strip()
if inp.strip(): text += " " + inp.strip()
text += " " + out.strip()
documents.append(text)
for row in stress_ds:
context = row.get("Context","") or ""
response = row.get("Response","") or ""
documents.append(context + " " + response)
# ===============================
# 6️⃣ Embeddings + FAISS
# ===============================
embedder = SentenceTransformer("all-MiniLM-L6-v2")
embeddings = embedder.encode(documents, convert_to_numpy=True, show_progress_bar=True)
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(embeddings)
# ===============================
# 7️⃣ RAG functions
# ===============================
def retrieve_docs(query,k=5):
query_embedding = embedder.encode([query])
distances, indices = index.search(query_embedding,k)
return [documents[i] for i in indices[0]]
def rag_answer(query):
retrieved = retrieve_docs(query)
context = "\n\n".join(retrieved)
prompt = f"""
You are a medical assistant.
Use ONLY the context below to answer.
Do NOT diagnose anyone.
Provide supportive and informative responses.
Context:
{context}
Question:
{query}
"""
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role":"user","content":prompt}],
)
return response.choices[0].message.content
# ===============================
# 8️⃣ CSV persistence
# ===============================
CSV_FILE = "daily_entries.csv"
if os.path.exists(CSV_FILE):
df = pd.read_csv(CSV_FILE, parse_dates=["date"])
else:
df = pd.DataFrame(columns=["date","user_id","stress","mood","sleep_hours"])
def add_daily_entry(user_id, stress, mood, sleep_hours):
global df
today = datetime.date.today()
new_row = pd.DataFrame([{
"date": today,
"user_id": user_id,
"stress": stress,
"mood": mood,
"sleep_hours": sleep_hours
}])
df = pd.concat([df,new_row], ignore_index=True)
df.to_csv(CSV_FILE,index=False)
return f"Entry for {today} saved!"
# ===============================
# 9️⃣ Weekly report + LLaMA + chart
# ===============================
def generate_weekly_report(user_id):
global df
df['date'] = pd.to_datetime(df['date'])
user_df = df[df['user_id'] == user_id]
if user_df.empty:
return "No data available yet.", None, None
user_df['week'] = user_df['date'].dt.isocalendar().week
weekly_summary = user_df.groupby('week').agg({
"stress": ["mean", "max"],
"mood": ["mean", "min"],
"sleep_hours": ["mean", "min"]
})
weekly_summary['stress_change'] = weekly_summary['stress']['mean'].diff()
weekly_summary['mood_change'] = weekly_summary['mood']['mean'].diff()
weekly_summary['sleep_change'] = weekly_summary['sleep_hours']['mean'].diff()
# ---- Create chart ----
fig, ax = plt.subplots(3, 1, figsize=(8, 10))
weekly_summary['stress']['mean'].plot(ax=ax[0], title="Weekly Avg Stress", marker="o")
weekly_summary['mood']['mean'].plot(ax=ax[1], title="Weekly Avg Mood", marker="o")
weekly_summary['sleep_hours']['mean'].plot(ax=ax[2], title="Weekly Avg Sleep Hours", marker="o")
plt.tight_layout()
chart_buf = BytesIO()
plt.savefig(chart_buf, format="png")
plt.close()
chart_buf.seek(0)
chart_image = Image.open(chart_buf)
# ---- LLaMA explanation ----
trend_prompt = f"""
You are a wellness data analyst AI.
Here is the weekly summary:
{weekly_summary.tail(4)}
Explain the trends in stress, mood, and sleep in simple, policymaker-friendly language.
"""
response = client.chat.completions.create(
model="llama-3.3-70b-versatile",
messages=[{"role": "user", "content": trend_prompt}]
)
explanation = response.choices[0].message.content
# ---- PDF generation ----
import time
timestamp = int(time.time())
pdf_path = f"{REPORTS_DIR}/weekly_report_user_{user_id}_{timestamp}.pdf"
c = canvas.Canvas(pdf_path, pagesize=letter)
width, height = letter
c.setFont("Helvetica-Bold", 14)
c.drawString(40, height - 40, "Weekly Mental Health Trend Report")
c.setFont("Helvetica", 11)
y = height - 80
for line in explanation.split("\n"):
c.drawString(40, y, line)
y -= 14
if y < 100:
c.showPage()
y = height - 40
c.showPage()
c.drawImage(ImageReader(chart_buf), 50, 200, width=500, height=400)
c.save()
return explanation, chart_image, pdf_path
# ===============================
# 🔟 Gradio interface
# ===============================
with gr.Blocks() as demo:
gr.Markdown("# 🧠 Medical & Stress RAG Assistant with Persistent Reports and PDF Export")
with gr.Tab("Daily Entry"):
gr.Markdown("Enter daily stress, mood, and sleep hours.")
stress = gr.Slider(0,10,label="Stress Level")
mood = gr.Slider(0,10,label="Mood Level")
sleep = gr.Number(label="Sleep Hours")
submit = gr.Button("Save Entry")
output_entry = gr.Textbox(label="Status")
submit.click(add_daily_entry,[gr.Number(value=1,label="User ID"),stress,mood,sleep],output_entry)
with gr.Tab("Weekly Trend Report"):
gr.Markdown("View weekly summary, trends, and export PDF.")
user_id_input = gr.Number(value=1,label="User ID")
report_output = gr.Textbox(label="Weekly Trend Explanation")
chart_output = gr.Image(label="Trend Chart")
pdf_output = gr.File(label="Download PDF")
generate = gr.Button("Generate Report")
generate.click(generate_weekly_report,[user_id_input],[report_output,chart_output,pdf_output])
with gr.Tab("Medical QA"):
gr.Markdown("Ask questions about stress, mood, sleep, or general wellness.")
chatbot = gr.Chatbot(label="Medical QA")
msg = gr.Textbox(label="Your Question")
clear = gr.Button("Clear Chat")
def respond(message, history):
history = history or []
answer = rag_answer(message)
history.append({
"role": "user",
"content": message
})
history.append({
"role": "assistant",
"content": answer
})
return "", history
msg.submit(respond, [msg, chatbot], [msg, chatbot])
clear.click(lambda: [], None, chatbot)
demo.launch()