Spaces:
Sleeping
Sleeping
| import os | |
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| from transformers import pipeline | |
| from typing import Dict, Any | |
| import uvicorn | |
| os.environ['HF_HOME'] = '/app/.cache' | |
| os.environ['TRANSFORMERS_CACHE'] = '/app/.cache' | |
| os.environ['HF_DATASETS_CACHE'] = '/app/.cache' | |
| INTENT_MODEL = "rasouzadev/medgo-intent-classifier" | |
| HATE_MODEL = "unitary/multilingual-toxic-xlm-roberta" | |
| print("Loading pipelines (this may take a minute)...") | |
| intent_pipe = pipeline("text-classification", model=INTENT_MODEL, truncation=True, max_length=512) | |
| hate_pipe = pipeline("text-classification", model=HATE_MODEL, truncation=True, max_length=512, top_k=None) | |
| app = FastAPI(title="MedGo - Intent & Hate Detector API") | |
| class InputText(BaseModel): | |
| text: str | |
| class PredictionResponse(BaseModel): | |
| intent: str | |
| intent_score: float | |
| hate_label: str | None | |
| hate_score: float | |
| note: str | None | |
| def root(): | |
| return { | |
| "message": "MedGo API - Intent & Hate Detector", | |
| "endpoints": { | |
| "predict": "/predict", | |
| "health": "/health", | |
| "docs": "/docs" | |
| } | |
| } | |
| def health(): | |
| return {"status": "ok"} | |
| def classify(input_data: InputText) -> Dict[str, Any]: | |
| text = input_data.text | |
| hate_results = hate_pipe(text)[0] # Retorna lista de todos os scores | |
| toxicity_result = next((r for r in hate_results if r.get("label") == "toxicity"), None) | |
| if toxicity_result: | |
| hate_score = float(toxicity_result.get("score", 0.0)) | |
| is_toxic = hate_score >= 0.5 | |
| hate_label = "toxic" if is_toxic else "non-toxic" | |
| else: | |
| best_result = max(hate_results, key=lambda x: x.get("score", 0.0)) | |
| hate_score = float(best_result.get("score", 0.0)) | |
| hate_label = best_result.get("label", "unknown") | |
| is_toxic = hate_score >= 0.5 and hate_label in ["toxicity", "severe_toxicity", "obscene", "threat", "insult", "identity_attack"] | |
| if is_toxic: | |
| return { | |
| "intent": "HateSpeech", | |
| "intent_score": hate_score, | |
| "hate_label": hate_label, | |
| "hate_score": hate_score, | |
| "note": "flagged_by_hate_model" | |
| } | |
| intent_res = intent_pipe(text) | |
| intent_label = intent_res[0].get("label") if intent_res and isinstance(intent_res, list) else None | |
| intent_score = float(intent_res[0].get("score", 0.0)) if intent_res and isinstance(intent_res, list) else 0.0 | |
| return { | |
| "intent": intent_label or "Unknown", | |
| "intent_score": intent_score, | |
| "hate_label": hate_label, | |
| "hate_score": hate_score, | |
| "note": None | |
| } | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=7860) |