Spaces:
Runtime error
Runtime error
| # app.py | |
| import os | |
| import requests | |
| import json | |
| import logging | |
| import pandas as pd | |
| import faiss | |
| from fastapi import FastAPI, HTTPException | |
| from pydantic import BaseModel | |
| from sentence_transformers import SentenceTransformer | |
| from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
| # β Set a writable cache directory inside the container | |
| os.environ["HF_HOME"] = "/app/cache" | |
| os.environ["TRANSFORMERS_CACHE"] = "/app/cache" | |
| os.environ["SENTENCE_TRANSFORMERS_HOME"] = "/app/cache" | |
| # β Initialize FastAPI | |
| app = FastAPI() | |
| # β Securely Fetch API Key | |
| OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY") | |
| if not OPENROUTER_API_KEY: | |
| raise ValueError("β OPENROUTER_API_KEY is missing. Set it as an environment variable.") | |
| OPENROUTER_BASE_URL = "https://openrouter.ai/api/v1/chat/completions" | |
| # β Load AI Models with explicit caching & remote code trust | |
| try: | |
| embedding_model = SentenceTransformer( | |
| "sentence-transformers/all-MiniLM-L6-v2", | |
| cache_folder="/app/cache", | |
| trust_remote_code=True # β Fix potential caching issues | |
| ) | |
| summarization_model = AutoModelForSeq2SeqLM.from_pretrained( | |
| "google/long-t5-tglobal-base", | |
| cache_dir="/app/cache", | |
| trust_remote_code=True # β Trust remote code | |
| ) | |
| summarization_tokenizer = AutoTokenizer.from_pretrained( | |
| "google/long-t5-tglobal-base", | |
| cache_dir="/app/cache", | |
| trust_remote_code=True | |
| ) | |
| print("β Models Loaded Successfully!") | |
| except Exception as e: | |
| print(f"β Model loading error: {e}") | |
| # β API Health Check | |
| def health_check(): | |
| return {"status": "FastAPI is running!"} | |
| # β Load Datasets | |
| try: | |
| recommendations_df = pd.read_csv("treatment_recommendations.csv") | |
| questions_df = pd.read_csv("symptom_questions.csv") | |
| print("β Datasets Loaded Successfully!") | |
| except FileNotFoundError as e: | |
| logging.error(f"β Missing dataset file: {e}") | |
| raise HTTPException(status_code=500, detail=f"Dataset file not found: {str(e)}") | |
| # β Create FAISS Indexes | |
| question_embeddings = embedding_model.encode(questions_df["Questions"].tolist(), convert_to_numpy=True) | |
| question_index = faiss.IndexFlatL2(question_embeddings.shape[1]) | |
| question_index.add(question_embeddings) | |
| treatment_embeddings = embedding_model.encode(recommendations_df["Disorder"].tolist(), convert_to_numpy=True) | |
| index = faiss.IndexFlatIP(treatment_embeddings.shape[1]) | |
| index.add(treatment_embeddings) | |
| # β Chat History Storage | |
| chat_history = [] | |
| # β Request Models | |
| class ChatRequest(BaseModel): | |
| message: str | |
| class SummaryRequest(BaseModel): | |
| chat_history: list | |
| # β Function: Call DeepSeek via OpenRouter | |
| def deepseek_request(prompt, max_tokens=300): | |
| """Send a request to OpenRouter's DeepSeek model.""" | |
| headers = { | |
| "Authorization": f"Bearer {OPENROUTER_API_KEY}", | |
| "Content-Type": "application/json" | |
| } | |
| payload = { | |
| "model": "deepseek/deepseek-r1-distill-llama-8b", | |
| "messages": [{"role": "user", "content": prompt}], | |
| "max_tokens": max_tokens, | |
| "temperature": 0.8 | |
| } | |
| try: | |
| response = requests.post(OPENROUTER_BASE_URL, headers=headers, data=json.dumps(payload)) | |
| response.raise_for_status() | |
| response_json = response.json() | |
| if "choices" in response_json and response_json["choices"]: | |
| return response_json["choices"][0].get("message", {}).get("content", "").strip() | |
| except Exception as e: | |
| logging.error(f"OpenRouter DeepSeek API error: {e}") | |
| return "I'm here to support you. Can you share more about what you're feeling?" | |
| # β Function: Retrieve Relevant Diagnostic Question | |
| def retrieve_relevant_question(user_input): | |
| """Find the most relevant diagnostic question from the dataset using FAISS.""" | |
| input_embedding = embedding_model.encode([user_input], convert_to_numpy=True) | |
| _, indices = question_index.search(input_embedding, 1) | |
| if indices[0][0] == -1: | |
| return "I'm here to listen. Can you tell me more about your symptoms?" | |
| return questions_df["Questions"].iloc[indices[0][0]] | |
| # β API Endpoint: Chat Interaction | |
| def chat(request: ChatRequest): | |
| """Patient enters data, AI responds and stores conversation.""" | |
| user_message = request.message | |
| chat_history.append(user_message) | |
| # Constructing the DeepSeek prompt | |
| prompt = f""" | |
| You are an AI psychiatrist conducting a mental health consultation. | |
| Engage in a supportive, natural conversation, maintaining an empathetic tone. | |
| - Always provide a thoughtful and compassionate response. | |
| - If a user shares distressing emotions, acknowledge their feelings and ask relevant follow-up questions. | |
| Previous conversation: | |
| {chat_history} | |
| User input: | |
| "{user_message}" | |
| Generate: | |
| - An empathetic response. | |
| - A related follow-up question. | |
| Ensure your response is meaningful and NEVER empty. | |
| """ | |
| # Call DeepSeek API | |
| ai_response = deepseek_request(prompt, max_tokens=250) | |
| # β Ensure response is NEVER empty | |
| if not ai_response or ai_response.strip() == "": | |
| ai_response = "I'm here to listen. Can you tell me more about how you're feeling? Maybe I can help." | |
| chat_history.append(ai_response) | |
| return {"response": ai_response} | |
| # β API Endpoint: Detect Disorders from Chat History | |
| def detect_disorders(): | |
| """Detect psychiatric disorders based on full chat history.""" | |
| full_chat_text = " ".join(chat_history) | |
| text_embedding = embedding_model.encode([full_chat_text], convert_to_numpy=True) | |
| distances, indices = index.search(text_embedding, 3) | |
| if indices[0][0] == -1: | |
| return {"disorders": ["No matching disorder found."]} | |
| disorders = [recommendations_df["Disorder"].iloc[i] for i in indices[0]] | |
| return {"disorders": disorders} | |
| # β API Endpoint: Get Treatment Recommendations | |
| def get_treatment(): | |
| """Retrieve treatment recommendations based on detected disorders.""" | |
| detected_disorders = detect_disorders()["disorders"] | |
| treatments = {} | |
| for disorder in detected_disorders: | |
| if disorder in recommendations_df["Disorder"].values: | |
| treatments[disorder] = recommendations_df[recommendations_df["Disorder"] == disorder]["Treatment Recommendation"].values[0] | |
| else: | |
| # Generate treatment if not in dataset | |
| treatment_prompt = f""" | |
| The user has been diagnosed with {disorder}. Provide a structured treatment plan including: | |
| - **Therapy options** (CBT, psychotherapy, etc.). | |
| - **Medications** (if applicable). | |
| - **Lifestyle strategies** (exercise, mindfulness, etc.). | |
| - **When to seek professional help**. | |
| - **Encouragement**. | |
| Ensure your response is clear and medically sound. | |
| """ | |
| treatments[disorder] = deepseek_request(treatment_prompt, max_tokens=250) | |
| return {"treatments": treatments} | |
| # β API Endpoint: Summarize Chat | |
| def summarize_chat(): | |
| """Summarize full chat session using DeepSeek.""" | |
| chat_text = " ".join(chat_history) | |
| summary_prompt = f"The following is a conversation between a patient and an AI psychiatrist. Summarize it clearly:\n{chat_text}" | |
| summary = deepseek_request(summary_prompt, max_tokens=500) | |
| return {"summary": summary} | |