Spaces:
Runtime error
Runtime error
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import requests
|
| 3 |
+
import json
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from fastapi import FastAPI, HTTPException
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
|
| 8 |
+
# Load DSM-5 Dataset
|
| 9 |
+
file_path = "dsm5_final_cleaned.csv"
|
| 10 |
+
df = pd.read_csv(file_path)
|
| 11 |
+
|
| 12 |
+
# OpenRouter API Configuration (DeepSeek Model)
|
| 13 |
+
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") # Use environment variable for security
|
| 14 |
+
if not OPENROUTER_API_KEY:
|
| 15 |
+
raise ValueError("OPENROUTER_API_KEY is missing. Set it as an environment variable.")
|
| 16 |
+
OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions"
|
| 17 |
+
|
| 18 |
+
# Initialize FastAPI
|
| 19 |
+
app = FastAPI()
|
| 20 |
+
|
| 21 |
+
# Pydantic Models
|
| 22 |
+
class ChatRequest(BaseModel):
|
| 23 |
+
message: str
|
| 24 |
+
|
| 25 |
+
class SummaryRequest(BaseModel):
|
| 26 |
+
chat_history: list
|
| 27 |
+
|
| 28 |
+
def deepseek_request(prompt, max_tokens=300):
|
| 29 |
+
"""Helper function to send a request to DeepSeek API and handle response."""
|
| 30 |
+
headers = {"Authorization": f"Bearer {OPENROUTER_API_KEY}", "Content-Type": "application/json"}
|
| 31 |
+
payload = {
|
| 32 |
+
"model": "deepseek/deepseek-r1-distill-llama-8b",
|
| 33 |
+
"messages": [{"role": "user", "content": prompt}],
|
| 34 |
+
"max_tokens": max_tokens,
|
| 35 |
+
"temperature": 0.7
|
| 36 |
+
}
|
| 37 |
+
response = requests.post(OPENROUTER_BASE_URL, headers=headers, data=json.dumps(payload))
|
| 38 |
+
if response.status_code == 200:
|
| 39 |
+
response_json = response.json()
|
| 40 |
+
if "choices" in response_json and response_json["choices"]:
|
| 41 |
+
return response_json["choices"][0].get("message", {}).get("content", "").strip()
|
| 42 |
+
return "Error: Unable to process the request."
|
| 43 |
+
|
| 44 |
+
def match_disorders(chat_history):
|
| 45 |
+
"""Match user symptoms with DSM-5 disorders based on keyword occurrence."""
|
| 46 |
+
disorder_scores = {}
|
| 47 |
+
for _, row in df.iterrows():
|
| 48 |
+
disorder = row["Disorder"]
|
| 49 |
+
keywords = row["Criteria"].split(", ")
|
| 50 |
+
match_count = sum(1 for word in keywords if word in chat_history.lower())
|
| 51 |
+
if match_count > 0:
|
| 52 |
+
disorder_scores[disorder] = match_count
|
| 53 |
+
sorted_disorders = sorted(disorder_scores, key=disorder_scores.get, reverse=True)
|
| 54 |
+
return sorted_disorders[:3] if sorted_disorders else []
|
| 55 |
+
|
| 56 |
+
@app.post("/detect_disorders")
|
| 57 |
+
def detect_disorders(request: SummaryRequest):
|
| 58 |
+
"""Detect psychiatric disorders using DSM-5 keyword matching + DeepSeek validation."""
|
| 59 |
+
full_chat = " ".join(request.chat_history)
|
| 60 |
+
matched_disorders = match_disorders(full_chat)
|
| 61 |
+
|
| 62 |
+
prompt = f"""
|
| 63 |
+
The following is a psychiatric conversation:
|
| 64 |
+
{full_chat}
|
| 65 |
+
|
| 66 |
+
Based on DSM-5 diagnostic criteria, analyze the symptoms and determine the most probable psychiatric disorders.
|
| 67 |
+
Here are possible disorder matches from DSM-5 keyword analysis: {', '.join(matched_disorders) if matched_disorders else 'None found'}.
|
| 68 |
+
If no clear matches exist, diagnose based purely on symptom patterns and clinical reasoning.
|
| 69 |
+
Return a **list** of disorders, separated by commas, without extra text.
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
response = deepseek_request(prompt, max_tokens=150)
|
| 73 |
+
disorders = [disorder.strip() for disorder in response.split(",")] if response and response.lower() != "unspecified disorder" else matched_disorders
|
| 74 |
+
return {"disorders": disorders if disorders else ["Unspecified Disorder"]}
|
| 75 |
+
|
| 76 |
+
@app.post("/get_treatment")
|
| 77 |
+
def get_treatment(request: SummaryRequest):
|
| 78 |
+
"""Retrieve structured treatment recommendations based on detected disorders."""
|
| 79 |
+
detected_disorders = detect_disorders(request)["disorders"]
|
| 80 |
+
disorders_text = ", ".join(detected_disorders)
|
| 81 |
+
prompt = f"""
|
| 82 |
+
The user has been diagnosed with: {disorders_text}.
|
| 83 |
+
Provide a structured, evidence-based treatment plan including:
|
| 84 |
+
- Therapy recommendations (e.g., CBT, DBT, psychotherapy).
|
| 85 |
+
- Possible medications if applicable (e.g., SSRIs, anxiolytics, sleep aids).
|
| 86 |
+
- Lifestyle and self-care strategies (e.g., sleep hygiene, mindfulness, exercise).
|
| 87 |
+
If the user has suicidal thoughts, emphasize **immediate crisis intervention and emergency medical support.**
|
| 88 |
+
"""
|
| 89 |
+
treatment_response = deepseek_request(prompt, max_tokens=200)
|
| 90 |
+
return {"treatments": treatment_response}
|
| 91 |
+
|
| 92 |
+
@app.post("/summarize_chat")
|
| 93 |
+
def summarize_chat(request: SummaryRequest):
|
| 94 |
+
"""Generate a structured summary of the psychiatric consultation."""
|
| 95 |
+
full_chat = " ".join(request.chat_history)
|
| 96 |
+
detected_disorders = detect_disorders(request)["disorders"]
|
| 97 |
+
treatment_response = get_treatment(request)["treatments"]
|
| 98 |
+
prompt = f"""
|
| 99 |
+
Summarize the following psychiatric conversation:
|
| 100 |
+
{full_chat}
|
| 101 |
+
|
| 102 |
+
- **Detected Disorders:** {', '.join(detected_disorders)}
|
| 103 |
+
- **Suggested Treatments:** {treatment_response}
|
| 104 |
+
|
| 105 |
+
The summary should include:
|
| 106 |
+
- Main concerns reported by the user.
|
| 107 |
+
- Key symptoms observed.
|
| 108 |
+
- Possible underlying psychological conditions.
|
| 109 |
+
- Recommended next steps, including professional consultation and self-care strategies.
|
| 110 |
+
If suicidal thoughts were mentioned, highlight the **need for immediate crisis intervention and professional support.**
|
| 111 |
+
"""
|
| 112 |
+
summary = deepseek_request(prompt, max_tokens=300)
|
| 113 |
+
return {"summary": summary}
|
| 114 |
+
|
| 115 |
+
@app.post("/chat")
|
| 116 |
+
def chat(request: ChatRequest):
|
| 117 |
+
"""Generate AI psychiatric response for user input."""
|
| 118 |
+
prompt = f"""
|
| 119 |
+
You are an AI psychiatrist conducting a mental health consultation.
|
| 120 |
+
The user is discussing their concerns and symptoms. Engage in a supportive conversation,
|
| 121 |
+
ask relevant follow-up questions, and maintain an empathetic tone.
|
| 122 |
+
|
| 123 |
+
User input:
|
| 124 |
+
{request.message}
|
| 125 |
+
|
| 126 |
+
Provide a meaningful response and a follow-up question if necessary.
|
| 127 |
+
If the user mentions suicidal thoughts, respond with an urgent and compassionate tone,
|
| 128 |
+
suggesting that they seek immediate professional help while providing emotional support.
|
| 129 |
+
"""
|
| 130 |
+
response = deepseek_request(prompt, max_tokens=200)
|
| 131 |
+
return {"response": response}
|