GokulRajaR's picture
Update server.py
97f5db8 verified
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline
import litserve as ls
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import os
class ToxicAPI(ls.LitAPI):
def setup(self, device):
model_name = "martin-ha/toxic-comment-model"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=os.getenv("HF_TOKEN"),
trust_remote_code=True,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
token=os.getenv("HF_TOKEN"),
trust_remote_code=True,
)
self.pipeline = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
device=-1 if device == "cpu" else 0, # cpu = -1, gpu = 0
)
def decode_request(self, request):
return request
def predict(self, query):
return self.pipeline(query)
def encode_response(self, output):
return output
def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
raise HTTPException(status_code=401, detail="Bad token")
if __name__ == "__main__":
api = ToxicAPI()
server = ls.LitServer(api, devices="cpu", accelerator="cpu")
server.run(port=7860)