|
|
|
|
|
|
|
|
|
|
|
""" |
|
|
Cognitive Distortion Detection API |
|
|
=================================== |
|
|
Provides distortion detection with both API and web interface |
|
|
""" |
|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from fastapi.middleware.cors import CORSMiddleware |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification |
|
|
import gradio as gr |
|
|
from typing import Optional, List, Dict |
|
|
import uvicorn |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MODEL_NAME = "YureiYuri/empathist" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("π€ Loading cognitive distortion detector...") |
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) |
|
|
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
id2label = { |
|
|
0: "overgeneralization", |
|
|
1: "catastrophizing", |
|
|
2: "black_and_white", |
|
|
3: "self_blame", |
|
|
4: "mind_reading" |
|
|
} |
|
|
|
|
|
DESCRIPTIONS = { |
|
|
"overgeneralization": "Making broad interpretations from single events using words like 'always', 'never', 'everyone'", |
|
|
"catastrophizing": "Expecting the worst possible outcome using words like 'terrible', 'disaster', 'awful'", |
|
|
"black_and_white": "Seeing things in absolute terms with no middle ground", |
|
|
"self_blame": "Taking excessive responsibility for things outside your control", |
|
|
"mind_reading": "Assuming you know what others are thinking without evidence" |
|
|
} |
|
|
|
|
|
print("β
Model loaded successfully!") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class DetectionRequest(BaseModel): |
|
|
text: str |
|
|
threshold: Optional[float] = 0.5 |
|
|
|
|
|
class DistortionResult(BaseModel): |
|
|
distortion: str |
|
|
confidence: float |
|
|
description: str |
|
|
|
|
|
class DetectionResponse(BaseModel): |
|
|
text: str |
|
|
distortions: List[DistortionResult] |
|
|
has_distortions: bool |
|
|
summary: str |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = FastAPI( |
|
|
title="Cognitive Distortion Detector", |
|
|
description="CBT-based cognitive distortion detection API", |
|
|
version="1.0.0" |
|
|
) |
|
|
|
|
|
app.add_middleware( |
|
|
CORSMiddleware, |
|
|
allow_origins=["*"], |
|
|
allow_credentials=True, |
|
|
allow_methods=["*"], |
|
|
allow_headers=["*"], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def detect_distortions(text: str, threshold: float = 0.5) -> Dict: |
|
|
"""Detect cognitive distortions in text""" |
|
|
if not text.strip(): |
|
|
return { |
|
|
"text": text, |
|
|
"distortions": [], |
|
|
"has_distortions": False, |
|
|
"summary": "No text provided" |
|
|
} |
|
|
|
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=128, padding=True) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = model(**inputs) |
|
|
probabilities = torch.sigmoid(outputs.logits).squeeze() |
|
|
|
|
|
|
|
|
distortions = [] |
|
|
for idx, prob in enumerate(probabilities): |
|
|
if prob > threshold: |
|
|
label = id2label[idx] |
|
|
distortions.append({ |
|
|
"distortion": label, |
|
|
"confidence": round(prob.item(), 4), |
|
|
"description": DESCRIPTIONS[label] |
|
|
}) |
|
|
|
|
|
|
|
|
distortions.sort(key=lambda x: x["confidence"], reverse=True) |
|
|
|
|
|
|
|
|
if distortions: |
|
|
summary = f"Detected {len(distortions)} distortion(s): " + ", ".join([d["distortion"] for d in distortions]) |
|
|
else: |
|
|
summary = "No significant cognitive distortions detected" |
|
|
|
|
|
return { |
|
|
"text": text, |
|
|
"distortions": distortions, |
|
|
"has_distortions": len(distortions) > 0, |
|
|
"summary": summary |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
"""Health check endpoint""" |
|
|
return { |
|
|
"status": "online", |
|
|
"service": "Cognitive Distortion Detector", |
|
|
"version": "1.0.0", |
|
|
"model": MODEL_NAME |
|
|
} |
|
|
|
|
|
@app.post("/detect", response_model=DetectionResponse) |
|
|
async def detect_endpoint(request: DetectionRequest): |
|
|
""" |
|
|
Detect cognitive distortions in text |
|
|
|
|
|
Args: |
|
|
text: Input text to analyze |
|
|
threshold: Confidence threshold (0.0-1.0), default 0.5 |
|
|
|
|
|
Returns: |
|
|
Detection results with distortions found |
|
|
""" |
|
|
try: |
|
|
result = detect_distortions(request.text, request.threshold) |
|
|
return DetectionResponse(**result) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"β Detection error: {str(e)}") |
|
|
raise HTTPException(status_code=500, detail=f"Detection failed: {str(e)}") |
|
|
|
|
|
@app.get("/distortions") |
|
|
async def list_distortions(): |
|
|
"""List all detectable distortion types with descriptions""" |
|
|
return { |
|
|
"distortions": [ |
|
|
{"name": label, "description": DESCRIPTIONS[label]} |
|
|
for label in id2label.values() |
|
|
] |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def predict_gradio(text: str, threshold: float = 0.5): |
|
|
"""Gradio prediction function""" |
|
|
if not text.strip(): |
|
|
return "Please enter some text to analyze.", "" |
|
|
|
|
|
result = detect_distortions(text, threshold) |
|
|
|
|
|
|
|
|
if not result["distortions"]: |
|
|
summary = "β
No significant cognitive distortions detected!" |
|
|
html = "<div style='background: #c8e6c9; padding: 20px; border-radius: 8px; border-left: 4px solid #4caf50;'><h3 style='color: #2e7d32; margin: 0;'>β
No significant cognitive distortions detected!</h3></div>" |
|
|
else: |
|
|
summary_lines = [] |
|
|
html = "<div style='margin-top: 20px;'>" |
|
|
|
|
|
for d in result["distortions"]: |
|
|
percentage = d["confidence"] * 100 |
|
|
summary_lines.append(f"β **{d['distortion'].replace('_', ' ').title()}** ({percentage:.1f}%)") |
|
|
|
|
|
color = "#ff6b6b" if percentage > 70 else "#ffa07a" if percentage > 50 else "#ffcc80" |
|
|
html += f""" |
|
|
<div style='background: {color}; padding: 15px; margin: 10px 0; border-radius: 8px; border-left: 4px solid #d32f2f;'> |
|
|
<h3 style='margin: 0 0 5px 0; color: #1a1a1a;'>π¨ {d['distortion'].replace('_', ' ').title()}</h3> |
|
|
<p style='margin: 5px 0; color: #333; font-size: 14px;'>{d['description']}</p> |
|
|
<p style='margin: 5px 0; font-weight: bold; color: #1a1a1a;'>Confidence: {percentage:.1f}%</p> |
|
|
</div> |
|
|
""" |
|
|
|
|
|
summary = "\n".join(summary_lines) |
|
|
html += "</div>" |
|
|
|
|
|
return summary, html |
|
|
|
|
|
|
|
|
examples = [ |
|
|
["I always mess everything up. This is a disaster!", 0.5], |
|
|
["Everyone thinks I'm incompetent. I'll never succeed.", 0.5], |
|
|
["It's all my fault that the project failed.", 0.5], |
|
|
["They must think I'm stupid for asking that question.", 0.5], |
|
|
["I'm having a challenging day, but I can work through it.", 0.5], |
|
|
] |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(), title="CBT Distortion Detector") as demo: |
|
|
gr.Markdown( |
|
|
""" |
|
|
# π§ CBT Cognitive Distortion Detector |
|
|
|
|
|
This tool analyzes text for common cognitive distortions based on Cognitive Behavioral Therapy (CBT) principles. |
|
|
It detects patterns like overgeneralization, catastrophizing, black-and-white thinking, self-blame, and mind-reading. |
|
|
|
|
|
**API Available**: Use `/detect` endpoint for programmatic access. [API Docs](/docs) |
|
|
|
|
|
**Note**: This is an educational tool and should not replace professional mental health support. |
|
|
""" |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
text_input = gr.Textbox( |
|
|
label="Enter your text", |
|
|
placeholder="Type or paste text here to analyze for cognitive distortions...", |
|
|
lines=5 |
|
|
) |
|
|
|
|
|
threshold_slider = gr.Slider( |
|
|
minimum=0.1, |
|
|
maximum=0.9, |
|
|
value=0.5, |
|
|
step=0.05, |
|
|
label="Detection Threshold (higher = stricter)", |
|
|
info="Adjust sensitivity of detection" |
|
|
) |
|
|
|
|
|
analyze_btn = gr.Button("π Analyze Text", variant="primary", size="lg") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("### π Detected Distortions") |
|
|
summary_output = gr.Markdown(label="Summary") |
|
|
|
|
|
with gr.Row(): |
|
|
detailed_output = gr.HTML(label="Detailed Results") |
|
|
|
|
|
gr.Markdown("### π‘ Try These Examples") |
|
|
gr.Examples( |
|
|
examples=examples, |
|
|
inputs=[text_input, threshold_slider], |
|
|
outputs=[summary_output, detailed_output], |
|
|
fn=predict_gradio, |
|
|
cache_examples=False |
|
|
) |
|
|
|
|
|
gr.Markdown( |
|
|
""" |
|
|
--- |
|
|
### π About Cognitive Distortions |
|
|
|
|
|
- **Overgeneralization**: Drawing broad conclusions from limited evidence |
|
|
- **Catastrophizing**: Expecting the worst-case scenario |
|
|
- **Black & White Thinking**: Viewing situations in extremes with no middle ground |
|
|
- **Self-Blame**: Taking responsibility for things beyond your control |
|
|
- **Mind Reading**: Assuming you know what others think without evidence |
|
|
|
|
|
### π API Usage |
|
|
|
|
|
```python |
|
|
import requests |
|
|
|
|
|
response = requests.post("https://your-space.hf.space/detect", |
|
|
json={"text": "I always mess everything up", "threshold": 0.5}) |
|
|
|
|
|
print(response.json()) |
|
|
``` |
|
|
|
|
|
**Model**: [YureiYuri/Empahist](https://huggingface.co/YureiYuri/Empahist) |
|
|
""" |
|
|
) |
|
|
|
|
|
analyze_btn.click( |
|
|
fn=predict_gradio, |
|
|
inputs=[text_input, threshold_slider], |
|
|
outputs=[summary_output, detailed_output] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app = gr.mount_gradio_app(app, demo, path="/") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
print("\nπ Starting Cognitive Distortion Detector...") |
|
|
uvicorn.run(app, host="0.0.0.0", port=7860) |