Spaces:
Sleeping
Sleeping
File size: 1,465 Bytes
78ba2f6 d6f2dc1 362a175 d6f2dc1 362a175 78ba2f6 362a175 97f5db8 d6f2dc1 362a175 97f5db8 78ba2f6 362a175 d6f2dc1 362a175 d6f2dc1 97f5db8 362a175 d6f2dc1 362a175 d6f2dc1 362a175 d6f2dc1 362a175 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline
import litserve as ls
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import os
class ToxicAPI(ls.LitAPI):
def setup(self, device):
model_name = "martin-ha/toxic-comment-model"
tokenizer = AutoTokenizer.from_pretrained(
model_name,
token=os.getenv("HF_TOKEN"),
trust_remote_code=True,
)
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
token=os.getenv("HF_TOKEN"),
trust_remote_code=True,
)
self.pipeline = TextClassificationPipeline(
model=model,
tokenizer=tokenizer,
device=-1 if device == "cpu" else 0, # cpu = -1, gpu = 0
)
def decode_request(self, request):
return request
def predict(self, query):
return self.pipeline(query)
def encode_response(self, output):
return output
def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
raise HTTPException(status_code=401, detail="Bad token")
if __name__ == "__main__":
api = ToxicAPI()
server = ls.LitServer(api, devices="cpu", accelerator="cpu")
server.run(port=7860)
|