rasouzadev's picture
Update app.py
3f6515d verified
import os
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
from typing import Dict, Any
import uvicorn
os.environ['HF_HOME'] = '/app/.cache'
os.environ['TRANSFORMERS_CACHE'] = '/app/.cache'
os.environ['HF_DATASETS_CACHE'] = '/app/.cache'
INTENT_MODEL = "rasouzadev/medgo-intent-classifier"
HATE_MODEL = "unitary/multilingual-toxic-xlm-roberta"
print("Loading pipelines (this may take a minute)...")
intent_pipe = pipeline("text-classification", model=INTENT_MODEL, truncation=True, max_length=512)
hate_pipe = pipeline("text-classification", model=HATE_MODEL, truncation=True, max_length=512, top_k=None)
app = FastAPI(title="MedGo - Intent & Hate Detector API")
class InputText(BaseModel):
text: str
class PredictionResponse(BaseModel):
intent: str
intent_score: float
hate_label: str | None
hate_score: float
note: str | None
@app.get("/")
def root():
return {
"message": "MedGo API - Intent & Hate Detector",
"endpoints": {
"predict": "/predict",
"health": "/health",
"docs": "/docs"
}
}
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict", response_model=PredictionResponse)
def classify(input_data: InputText) -> Dict[str, Any]:
text = input_data.text
hate_results = hate_pipe(text)[0] # Retorna lista de todos os scores
toxicity_result = next((r for r in hate_results if r.get("label") == "toxicity"), None)
if toxicity_result:
hate_score = float(toxicity_result.get("score", 0.0))
is_toxic = hate_score >= 0.5
hate_label = "toxic" if is_toxic else "non-toxic"
else:
best_result = max(hate_results, key=lambda x: x.get("score", 0.0))
hate_score = float(best_result.get("score", 0.0))
hate_label = best_result.get("label", "unknown")
is_toxic = hate_score >= 0.5 and hate_label in ["toxicity", "severe_toxicity", "obscene", "threat", "insult", "identity_attack"]
if is_toxic:
return {
"intent": "HateSpeech",
"intent_score": hate_score,
"hate_label": hate_label,
"hate_score": hate_score,
"note": "flagged_by_hate_model"
}
intent_res = intent_pipe(text)
intent_label = intent_res[0].get("label") if intent_res and isinstance(intent_res, list) else None
intent_score = float(intent_res[0].get("score", 0.0)) if intent_res and isinstance(intent_res, list) else 0.0
return {
"intent": intent_label or "Unknown",
"intent_score": intent_score,
"hate_label": hate_label,
"hate_score": hate_score,
"note": None
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=7860)