GokulRajaR commited on
Commit
362a175
·
verified ·
1 Parent(s): 78ba2f6

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +21 -13
server.py CHANGED
@@ -4,36 +4,44 @@ from fastapi import Depends, HTTPException
4
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
5
  import os
6
 
7
- class EmbeddingAPI(ls.LitAPI):
 
8
  def setup(self, device):
 
 
9
  tokenizer = AutoTokenizer.from_pretrained(
10
- "martin-ha/toxic-comment-model",
11
- device=device,
12
  trust_remote_code=True,
13
- token=os.getenv("HF_TOKEN")
14
  )
15
- model = AutoTokenizer.from_pretrained(
16
- "martin-ha/toxic-comment-model",
17
- device=device,
18
  trust_remote_code=True,
19
- token=os.getenv("HF_TOKEN")
20
  )
21
- self.pipeline = TextClassificationPipeline(model=model, tokenizer=tokenizer)
 
 
 
 
 
22
 
23
  def decode_request(self, request):
24
  return request
25
 
26
  def predict(self, query):
27
- return self.pipeline(query)
28
 
29
  def encode_response(self, output):
30
- return output.tolist()
31
-
32
  def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
33
  if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
34
  raise HTTPException(status_code=401, detail="Bad token")
35
 
 
36
  if __name__ == "__main__":
37
- api = EmbeddingAPI()
38
  server = ls.LitServer(api, devices="cpu", accelerator="cpu")
39
  server.run(port=7860)
 
 
4
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
5
  import os
6
 
7
+
8
+ class ToxicAPI(ls.LitAPI):
9
  def setup(self, device):
10
+ model_name = "martin-ha/toxic-comment-model"
11
+
12
  tokenizer = AutoTokenizer.from_pretrained(
13
+ model_name,
14
+ use_auth_token=os.getenv("HF_TOKEN"),
15
  trust_remote_code=True,
 
16
  )
17
+ model = AutoModelForSequenceClassification.from_pretrained(
18
+ model_name,
19
+ use_auth_token=os.getenv("HF_TOKEN"),
20
  trust_remote_code=True,
 
21
  )
22
+
23
+ self.pipeline = TextClassificationPipeline(
24
+ model=model,
25
+ tokenizer=tokenizer,
26
+ device=-1 if device == "cpu" else 0, # cpu = -1, gpu = 0
27
+ )
28
 
29
  def decode_request(self, request):
30
  return request
31
 
32
  def predict(self, query):
33
+ return self.pipeline(query)
34
 
35
  def encode_response(self, output):
36
+ return output # already a JSON-serializable list of dicts
37
+
38
  def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
39
  if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
40
  raise HTTPException(status_code=401, detail="Bad token")
41
 
42
+
43
  if __name__ == "__main__":
44
+ api = ToxicAPI()
45
  server = ls.LitServer(api, devices="cpu", accelerator="cpu")
46
  server.run(port=7860)
47
+