Topic_Banning / server.py
GokulRajaR's picture
Update server.py
0c25e5d verified
from transformers import pipeline
import litserve as ls
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import os
class TopicBanningAPI(ls.LitAPI):
def setup(self, device):
model_name = "facebook/bart-large-mnli"
self.pipeline = pipeline("zero-shot-classification", model=model_name, token=os.getenv("HF_TOKEN"))
def decode_request(self, request):
return request
def predict(self, query):
text = query["text"]
topics = query["topics"]
return self.pipeline(text, topics)
def encode_response(self, output):
return output
def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
raise HTTPException(status_code=401, detail="Bad token")
if __name__ == "__main__":
api = TopicBanningAPI()
server = ls.LitServer(api, devices="cpu", accelerator="cpu")
server.run(port=7860)