Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import requests | |
| import os | |
| import json | |
| import torch | |
| from langgraph.graph import StateGraph, END | |
| from typing import TypedDict | |
| from unsloth import FastLanguageModel | |
| from peft import PeftModel | |
| from transformers import AutoTokenizer | |
| # ---------- Custom Styling ---------- | |
| st.markdown(""" | |
| <style> | |
| .main { | |
| background-color: #1a1a1a; | |
| color: #ffffff; | |
| } | |
| .sidebar .sidebar-content { | |
| background-color: #2d2d2d; | |
| } | |
| .stTextInput textarea { | |
| color: #ffffff !important; | |
| } | |
| .stSelectbox div[data-baseweb="select"] { | |
| color: white !important; | |
| background-color: #3d3d3d !important; | |
| } | |
| .stSelectbox svg { | |
| fill: white !important; | |
| } | |
| .stSelectbox option { | |
| background-color: #2d2d2d !important; | |
| color: white !important; | |
| } | |
| div[role="listbox"] div { | |
| background-color: #2d2d2d !important; | |
| color: white !important; | |
| } | |
| </style> | |
| """, unsafe_allow_html=True) | |
| # ---------- Header ---------- | |
| st.title("π©Ί DDCBot - Disease Diagnosis Chat Bot") | |
| st.caption("π§ AI-Powered Disease Diagnosis with Reasoning Capabilities") | |
| # ---------- Sidebar ---------- | |
| with st.sidebar: | |
| st.divider() | |
| st.markdown(""" | |
| ### CSE3085 - Predictive Analytics with Case Studies | |
| **π Project Review** | |
| **π TITLE:** LangGraph and RAG assisted LLMs for disease diagnostics with reasoning capabilities | |
| **π¨βπ» By** | |
| - **22MIA1064** Yasir Ahmad | |
| - **22MIA1049** Naveen Nidadavolu | |
| - **22MIA1034** Namansh Singh Maurya | |
| - **22MIA1044** Etash Ashwin | |
| **π M.Tech CSE (Business Analytics)** | |
| **π Submitted to** | |
| Dr. Jaya Mala D | |
| Professor Senior, SCOPE, VIT, Chennai | |
| """, unsafe_allow_html=True) | |
| st.divider() | |
| # ---------- Load Model Once ---------- | |
| def load_model(): | |
| base_model_name = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B" | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = base_model_name, | |
| max_seq_length = 2048, | |
| dtype = None, | |
| load_in_4bit = True, | |
| device_map="auto", | |
| ) | |
| adapter_repo = "Navi004/deepseek-r1-finetuned_lora-adapter-Batch10" | |
| model.load_adapter(adapter_repo) | |
| FastLanguageModel.for_inference(model) | |
| return model, tokenizer | |
| model, tokenizer = load_model() | |
| # ---------- Define State ---------- | |
| class MentalHealthState(TypedDict): | |
| user_post: str | |
| context: str | |
| retrieved_raw: str | |
| diagnosis: str | |
| # ---------- Retriever Node ---------- | |
| def retriever_node(state: MentalHealthState) -> MentalHealthState: | |
| url = f"https://payload.vextapp.com/hook/{os.getenv('endpoint_id')}/catch/$(HF_deployment)" | |
| headers = { | |
| "accept": "application/json", | |
| "content-type": "application/json", | |
| "Apikey": f"Api-Key {os.getenv('API_KEY')}" | |
| } | |
| payload = {"payload": state["user_post"]} | |
| res = requests.post(url, json=payload, headers=headers) | |
| state["retrieved_raw"] = res.text | |
| state["context"] = json.loads(res.text)['text'] | |
| return state | |
| # ---------- Inference Node ---------- | |
| def inference_node(state: MentalHealthState) -> MentalHealthState: | |
| prompt_style = """You are a mental health analysis assistant that specializes in interpreting user-generated text on social media. Your task is to carefully analyze the post context and identify the most relevant mental health condition reflected in the writing. | |
| Classify the post into one of the following categories based on the emotional tone, context, and language used: | |
| 0: BPD | |
| 1: bipolar | |
| 2: depression | |
| 3: Anxiety | |
| 4: schizophrenia | |
| 5: mentalillness | |
| 6: normal | |
| 7: ptsd | |
| ### Instructions: | |
| Given the following post, identify the most relevant mental health condition: | |
| - Focus on the tone, word choice, context, and implied emotional state. | |
| - Identify key symptoms that match known patterns of the above disorders. | |
| - If symptoms are vague or too general, assign to "5: Unspecified Mental Illness". | |
| - If the post does not reflect any significant mental health issue, assign "6: normal". | |
| Be as accurate and empathetic as possible. Only return the **most relevant category label and disorder name** for each post, in the following format: | |
| `<label_number>: <disorder_name>` | |
| (e.g., `2: depression`) | |
| ### Post: | |
| {} | |
| ### Response: | |
| {}""" | |
| prompt = prompt_style.format(state["context"], "") | |
| inputs = tokenizer([prompt], return_tensors="pt").to("cuda") | |
| outputs = model.generate( | |
| input_ids=inputs.input_ids, | |
| attention_mask=inputs.attention_mask, | |
| max_new_tokens=10, | |
| use_cache=True, | |
| ) | |
| decoded_output = tokenizer.batch_decode(outputs)[0] | |
| state["diagnosis"] = decoded_output.split("### Response:")[1].strip() | |
| return state | |
| # ---------- Output Node ---------- | |
| def output_node(state: MentalHealthState) -> MentalHealthState: | |
| return state | |
| # ---------- Graph Pipeline ---------- | |
| def build_graph(): | |
| builder = StateGraph(MentalHealthState) | |
| builder.add_node("retriever", retriever_node) | |
| builder.add_node("inference", inference_node) | |
| builder.add_node("output", output_node) | |
| builder.set_entry_point("retriever") | |
| builder.add_edge("retriever", "inference") | |
| builder.add_edge("inference", "output") | |
| builder.add_edge("output", END) | |
| return builder.compile() | |
| graph = build_graph() | |
| # ---------- User Interaction ---------- | |
| user_post = st.text_area("π Enter the user's post here for diagnosis:", height=200) | |
| if st.button("π§ Analyze Post"): | |
| if user_post.strip() == "": | |
| st.warning("Please enter a post to analyze.") | |
| else: | |
| with st.spinner("Analyzing..."): | |
| final_state = graph.invoke({"user_post": user_post}) | |
| st.subheader("π Retrieved Context") | |
| st.write(final_state["context"]) | |
| st.subheader("π§ Diagnosis") | |
| st.success(final_state["diagnosis"]) | |
| with st.expander("π Raw Retriever Output (JSON)"): | |
| st.json(final_state["retrieved_raw"]) |