Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from pydantic import BaseModel | |
| import os | |
| import torch | |
| from transformers import BartTokenizer, BartForSequenceClassification, pipeline | |
| app = FastAPI() | |
| tokens = os.getenv("HF_TOKEN") | |
| model_name = "iconcube/BART-large_classifier" | |
| classifier_tokenizer = BartTokenizer.from_pretrained(model_name) | |
| classifier_model = BartForSequenceClassification.from_pretrained(model_name) | |
| classifier = pipeline( | |
| "text-classification", | |
| model=classifier_model, | |
| tokenizer=classifier_tokenizer, | |
| token=tokens | |
| ) | |
| class RequestText(BaseModel): | |
| text: str | |
| class ResponseLabel(BaseModel): | |
| label: str | |
| async def predict(request: RequestText): | |
| result = classifier(request.text)[0] | |
| label = result["label"] | |
| if label == "LABEL_0": | |
| message = "safe_response" | |
| elif label == "LABEL_1": | |
| message = "unsafe_response" | |
| else: | |
| message = "error" | |
| return ResponseLabel(label=message) |