import streamlit as st import requests import os import json import torch from langgraph.graph import StateGraph, END from typing import TypedDict from unsloth import FastLanguageModel from peft import PeftModel from transformers import AutoTokenizer # ---------- Custom Styling ---------- st.markdown(""" """, unsafe_allow_html=True) # ---------- Header ---------- st.title("🩺 DDCBot - Disease Diagnosis Chat Bot") st.caption("🧠 AI-Powered Disease Diagnosis with Reasoning Capabilities") # ---------- Sidebar ---------- with st.sidebar: st.divider() st.markdown(""" ### CSE3085 - Predictive Analytics with Case Studies **📌 Project Review** **📜 TITLE:** LangGraph and RAG assisted LLMs for disease diagnostics with reasoning capabilities **👨‍💻 By** - **22MIA1064** Yasir Ahmad - **22MIA1049** Naveen Nidadavolu - **22MIA1034** Namansh Singh Maurya - **22MIA1044** Etash Ashwin **🎓 M.Tech CSE (Business Analytics)** **📖 Submitted to** Dr. Jaya Mala D Professor Senior, SCOPE, VIT, Chennai """, unsafe_allow_html=True) st.divider() # ---------- Load Model Once ---------- @st.cache_resource(show_spinner="🔄 Loading model...") def load_model(): base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" model, tokenizer = FastLanguageModel.from_pretrained( model_name = base_model_name, max_seq_length = 2048, dtype = None, load_in_4bit = True, device_map="auto", ) adapter_repo = "Navi004/deepseek-r1-finetuned_lora-adapter-Batch10" model.load_adapter(adapter_repo) FastLanguageModel.for_inference(model) return model, tokenizer model, tokenizer = load_model() # ---------- Define State ---------- class MentalHealthState(TypedDict): user_post: str context: str retrieved_raw: str diagnosis: str # ---------- Retriever Node ---------- def retriever_node(state: MentalHealthState) -> MentalHealthState: url = f"https://payload.vextapp.com/hook/{os.getenv('endpoint_id')}/catch/$(HF_deployment)" headers = { "accept": "application/json", "content-type": "application/json", "Apikey": f"Api-Key {os.getenv('API_KEY')}" } payload = {"payload": state["user_post"]} res = requests.post(url, json=payload, headers=headers) state["retrieved_raw"] = res.text state["context"] = json.loads(res.text)['text'] return state # ---------- Inference Node ---------- def inference_node(state: MentalHealthState) -> MentalHealthState: prompt_style = """You are a mental health analysis assistant that specializes in interpreting user-generated text on social media. Your task is to carefully analyze the post context and identify the most relevant mental health condition reflected in the writing. Classify the post into one of the following categories based on the emotional tone, context, and language used: 0: BPD 1: bipolar 2: depression 3: Anxiety 4: schizophrenia 5: mentalillness 6: normal 7: ptsd ### Instructions: Given the following post, identify the most relevant mental health condition: - Focus on the tone, word choice, context, and implied emotional state. - Identify key symptoms that match known patterns of the above disorders. - If symptoms are vague or too general, assign to "5: Unspecified Mental Illness". - If the post does not reflect any significant mental health issue, assign "6: normal". Be as accurate and empathetic as possible. Only return the **most relevant category label and disorder name** for each post, in the following format: `: ` (e.g., `2: depression`) ### Post: {} ### Response: {}""" prompt = prompt_style.format(state["context"], "") inputs = tokenizer([prompt], return_tensors="pt").to("cuda") outputs = model.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_new_tokens=10, use_cache=True, ) decoded_output = tokenizer.batch_decode(outputs)[0] state["diagnosis"] = decoded_output.split("### Response:")[1].strip() return state # ---------- Output Node ---------- def output_node(state: MentalHealthState) -> MentalHealthState: return state # ---------- Graph Pipeline ---------- @st.cache_resource def build_graph(): builder = StateGraph(MentalHealthState) builder.add_node("retriever", retriever_node) builder.add_node("inference", inference_node) builder.add_node("output", output_node) builder.set_entry_point("retriever") builder.add_edge("retriever", "inference") builder.add_edge("inference", "output") builder.add_edge("output", END) return builder.compile() graph = build_graph() # ---------- User Interaction ---------- user_post = st.text_area("📝 Enter the user's post here for diagnosis:", height=200) if st.button("🧠 Analyze Post"): if user_post.strip() == "": st.warning("Please enter a post to analyze.") else: with st.spinner("Analyzing..."): final_state = graph.invoke({"user_post": user_post}) st.subheader("🔍 Retrieved Context") st.write(final_state["context"]) st.subheader("🧠 Diagnosis") st.success(final_state["diagnosis"]) with st.expander("📄 Raw Retriever Output (JSON)"): st.json(final_state["retrieved_raw"])