File size: 629 Bytes
025b409 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 |
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline
app = FastAPI()
classifier = pipeline(
"text-classification",
model="LokeshDevCreates/tone-baseline-v3",
top_k=None
)
class TextRequest(BaseModel):
text: str
@app.post("/predict")
def predict_tone(req: TextRequest):
results = classifier(req.text)[0]
results = sorted(results, key=lambda x: x["score"], reverse=True)
return {
"detected_tone": results[0]["label"],
"confidence": round(results[0]["score"], 4),
"all_probs": {r["label"]: round(r["score"], 4) for r in results}
}
|