from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline import litserve as ls from fastapi import Depends, HTTPException from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer import os class ToxicAPI(ls.LitAPI): def setup(self, device): model_name = "martin-ha/toxic-comment-model" tokenizer = AutoTokenizer.from_pretrained( model_name, token=os.getenv("HF_TOKEN"), trust_remote_code=True, ) model = AutoModelForSequenceClassification.from_pretrained( model_name, token=os.getenv("HF_TOKEN"), trust_remote_code=True, ) self.pipeline = TextClassificationPipeline( model=model, tokenizer=tokenizer, device=-1 if device == "cpu" else 0, # cpu = -1, gpu = 0 ) def decode_request(self, request): return request def predict(self, query): return self.pipeline(query) def encode_response(self, output): return output def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())): if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"): raise HTTPException(status_code=401, detail="Bad token") if __name__ == "__main__": api = ToxicAPI() server = ls.LitServer(api, devices="cpu", accelerator="cpu") server.run(port=7860)