GokulRajaR commited on
Commit
fdf2329
·
verified ·
1 Parent(s): 0343921

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +8 -24
server.py CHANGED
@@ -1,37 +1,22 @@
1
- from transformers import AutoTokenizer, TFAutoModelForSequenceClassification
2
  import litserve as ls
3
  from fastapi import Depends, HTTPException
4
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
5
  import os
6
 
7
 
8
- class BiasAPI(ls.LitAPI):
9
  def setup(self, device):
10
- model_name = "d4data/bias-detection-model"
11
-
12
- tokenizer = AutoTokenizer.from_pretrained(
13
- model_name,
14
- token=os.getenv("HF_TOKEN"),
15
- trust_remote_code=True,
16
- )
17
- model = TFAutoModelForSequenceClassification.from_pretrained(
18
- model_name,
19
- token=os.getenv("HF_TOKEN"),
20
- trust_remote_code=True,
21
- )
22
-
23
- self.pipeline = TextClassificationPipeline(
24
- 'text-classification',
25
- model=model,
26
- tokenizer=tokenizer,
27
- device=-1 if device == "cpu" else 0, # cpu = -1, gpu = 0
28
- )
29
 
30
  def decode_request(self, request):
31
  return request
32
 
33
  def predict(self, query):
34
- return self.pipeline(query)
 
 
35
 
36
  def encode_response(self, output):
37
  return output
@@ -42,7 +27,6 @@ class BiasAPI(ls.LitAPI):
42
 
43
 
44
  if __name__ == "__main__":
45
- api = BiasAPI()
46
  server = ls.LitServer(api, devices="cpu", accelerator="cpu")
47
  server.run(port=7860)
48
-
 
1
+ from transformers import pipeline
2
  import litserve as ls
3
  from fastapi import Depends, HTTPException
4
  from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
5
  import os
6
 
7
 
8
+ class TopicBanningAPI(ls.LitAPI):
9
  def setup(self, device):
10
+ model_name = "facebook/bart-large-mnli"
11
+ self.pipeline = pipeline("zero-shot-classification", model=model, token=os.getenv("HF_TOKEN"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def decode_request(self, request):
14
  return request
15
 
16
  def predict(self, query):
17
+ text = query["text"]
18
+ topics = query["topics"]
19
+ return self.pipeline(text, topics)
20
 
21
  def encode_response(self, output):
22
  return output
 
27
 
28
 
29
  if __name__ == "__main__":
30
+ api = TopicBanningAPI()
31
  server = ls.LitServer(api, devices="cpu", accelerator="cpu")
32
  server.run(port=7860)