from fastapi import FastAPI from pydantic import BaseModel import os import torch from transformers import BartTokenizer, BartForSequenceClassification, pipeline app = FastAPI() tokens = os.getenv("HF_TOKEN") model_name = "iconcube/BART-large_classifier" classifier_tokenizer = BartTokenizer.from_pretrained(model_name) classifier_model = BartForSequenceClassification.from_pretrained(model_name) classifier = pipeline( "text-classification", model=classifier_model, tokenizer=classifier_tokenizer, token=tokens ) class RequestText(BaseModel): text: str class ResponseLabel(BaseModel): label: str @app.post("/predict", response_model=ResponseLabel) async def predict(request: RequestText): result = classifier(request.text)[0] label = result["label"] if label == "LABEL_0": message = "safe_response" elif label == "LABEL_1": message = "unsafe_response" else: message = "error" return ResponseLabel(label=message)