File size: 629 Bytes
025b409
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import pipeline

app = FastAPI()

classifier = pipeline(
    "text-classification",
    model="LokeshDevCreates/tone-baseline-v3",
    top_k=None
)

class TextRequest(BaseModel):
    text: str

@app.post("/predict")
def predict_tone(req: TextRequest):
    results = classifier(req.text)[0]
    results = sorted(results, key=lambda x: x["score"], reverse=True)

    return {
        "detected_tone": results[0]["label"],
        "confidence": round(results[0]["score"], 4),
        "all_probs": {r["label"]: round(r["score"], 4) for r in results}
    }