File size: 1,465 Bytes
78ba2f6
d6f2dc1
 
 
 
 
362a175
 
d6f2dc1
362a175
 
78ba2f6
362a175
97f5db8
d6f2dc1
 
362a175
 
97f5db8
78ba2f6
 
362a175
 
 
 
 
 
d6f2dc1
 
 
 
 
362a175
d6f2dc1
 
97f5db8
362a175
d6f2dc1
 
 
 
362a175
d6f2dc1
362a175
d6f2dc1
 
362a175
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from transformers import AutoModelForSequenceClassification, AutoTokenizer, TextClassificationPipeline
import litserve as ls
from fastapi import Depends, HTTPException
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
import os


class ToxicAPI(ls.LitAPI):
    def setup(self, device):
        model_name = "martin-ha/toxic-comment-model"

        tokenizer = AutoTokenizer.from_pretrained(
            model_name,
            token=os.getenv("HF_TOKEN"),
            trust_remote_code=True,
        )
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            token=os.getenv("HF_TOKEN"),
            trust_remote_code=True,
        )

        self.pipeline = TextClassificationPipeline(
            model=model,
            tokenizer=tokenizer,
            device=-1 if device == "cpu" else 0,  # cpu = -1, gpu = 0
        )

    def decode_request(self, request):
        return request

    def predict(self, query):
        return self.pipeline(query)

    def encode_response(self, output):
        return output 

    def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
        if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
            raise HTTPException(status_code=401, detail="Bad token")


if __name__ == "__main__":
    api = ToxicAPI()
    server = ls.LitServer(api, devices="cpu", accelerator="cpu")
    server.run(port=7860)