# ============================================ # app.py - FastAPI + Gradio Interface # ============================================ """ Cognitive Distortion Detection API =================================== Provides distortion detection with both API and web interface """ from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch from transformers import AutoTokenizer, AutoModelForSequenceClassification import gradio as gr from typing import Optional, List, Dict import uvicorn # ============================================ # CONFIGURATION # ============================================ MODEL_NAME = "YureiYuri/empathist" # ============================================ # LOAD MODEL # ============================================ print("š¤ Loading cognitive distortion detector...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) model.eval() # Label mappings id2label = { 0: "overgeneralization", 1: "catastrophizing", 2: "black_and_white", 3: "self_blame", 4: "mind_reading" } DESCRIPTIONS = { "overgeneralization": "Making broad interpretations from single events using words like 'always', 'never', 'everyone'", "catastrophizing": "Expecting the worst possible outcome using words like 'terrible', 'disaster', 'awful'", "black_and_white": "Seeing things in absolute terms with no middle ground", "self_blame": "Taking excessive responsibility for things outside your control", "mind_reading": "Assuming you know what others are thinking without evidence" } print("ā Model loaded successfully!") # ============================================ # PYDANTIC MODELS # ============================================ class DetectionRequest(BaseModel): text: str threshold: Optional[float] = 0.5 class DistortionResult(BaseModel): distortion: str confidence: float description: str class DetectionResponse(BaseModel): text: str distortions: List[DistortionResult] has_distortions: bool summary: str # ============================================ # FASTAPI APP # ============================================ app = FastAPI( title="Cognitive Distortion Detector", description="CBT-based cognitive distortion detection API", version="1.0.0" ) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ============================================ # HELPER FUNCTIONS # ============================================ def detect_distortions(text: str, threshold: float = 0.5) -> Dict: """Detect cognitive distortions in text""" if not text.strip(): return { "text": text, "distortions": [], "has_distortions": False, "summary": "No text provided" } # Tokenize inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding=True) # Predict with torch.no_grad(): outputs = model(**inputs) probabilities = torch.sigmoid(outputs.logits).squeeze() # Extract distortions above threshold distortions = [] for idx, prob in enumerate(probabilities): if prob > threshold: label = id2label[idx] distortions.append({ "distortion": label, "confidence": round(prob.item(), 4), "description": DESCRIPTIONS[label] }) # Sort by confidence distortions.sort(key=lambda x: x["confidence"], reverse=True) # Create summary if distortions: summary = f"Detected {len(distortions)} distortion(s): " + ", ".join([d["distortion"] for d in distortions]) else: summary = "No significant cognitive distortions detected" return { "text": text, "distortions": distortions, "has_distortions": len(distortions) > 0, "summary": summary } # ============================================ # API ENDPOINTS # ============================================ @app.get("/") async def root(): """Health check endpoint""" return { "status": "online", "service": "Cognitive Distortion Detector", "version": "1.0.0", "model": MODEL_NAME } @app.post("/detect", response_model=DetectionResponse) async def detect_endpoint(request: DetectionRequest): """ Detect cognitive distortions in text Args: text: Input text to analyze threshold: Confidence threshold (0.0-1.0), default 0.5 Returns: Detection results with distortions found """ try: result = detect_distortions(request.text, request.threshold) return DetectionResponse(**result) except Exception as e: print(f"ā Detection error: {str(e)}") raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}") @app.get("/distortions") async def list_distortions(): """List all detectable distortion types with descriptions""" return { "distortions": [ {"name": label, "description": DESCRIPTIONS[label]} for label in id2label.values() ] } # ============================================ # GRADIO INTERFACE # ============================================ def predict_gradio(text: str, threshold: float = 0.5): """Gradio prediction function""" if not text.strip(): return "Please enter some text to analyze.", "" result = detect_distortions(text, threshold) # Format summary if not result["distortions"]: summary = "ā No significant cognitive distortions detected!" html = "
{d['description']}
Confidence: {percentage:.1f}%