GokulRajaR commited on
Commit
d6f2dc1
·
verified ·
1 Parent(s): c7cb935

Update server.py

Browse files
Files changed (1) hide show
  1. server.py +32 -32
server.py CHANGED
@@ -1,32 +1,32 @@
1
- from sentence_transformers import SentenceTransformer
2
- import litserve as ls
3
- from fastapi import Depends, HTTPException
4
- from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
5
- import os
6
-
7
- class EmbeddingAPI(ls.LitAPI):
8
- def setup(self, device):
9
- self.model = SentenceTransformer(
10
- 'google/embeddinggemma-300m-qat-q8_0-unquantized',
11
- device=device,
12
- trust_remote_code=True,
13
- token=os.getenv("HF_TOKEN")
14
- )
15
-
16
- def decode_request(self, request):
17
- return request
18
-
19
- def predict(self, query):
20
- return self.model.encode(query)
21
-
22
- def encode_response(self, output):
23
- return output.tolist()
24
-
25
- def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
26
- if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
27
- raise HTTPException(status_code=401, detail="Bad token")
28
-
29
- if __name__ == "__main__":
30
- api = EmbeddingAPI()
31
- server = ls.LitServer(api, devices="cpu", accelerator="cpu")
32
- server.run(port=7860)
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import litserve as ls
3
+ from fastapi import Depends, HTTPException
4
+ from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
5
+ import os
6
+
7
+ class EmbeddingAPI(ls.LitAPI):
8
+ def setup(self, device):
9
+ self.model = SentenceTransformer(
10
+ 'GokulRajaR/embeddinggemma-300m-qat-q8_0-unquantized',
11
+ device=device,
12
+ trust_remote_code=True,
13
+ token=os.getenv("HF_TOKEN")
14
+ )
15
+
16
+ def decode_request(self, request):
17
+ return request
18
+
19
+ def predict(self, query):
20
+ return self.model.encode(query)
21
+
22
+ def encode_response(self, output):
23
+ return output.tolist()
24
+
25
+ def authorize(self, auth: HTTPAuthorizationCredentials = Depends(HTTPBearer())):
26
+ if auth.scheme != "Bearer" or auth.credentials != os.getenv("auth_token"):
27
+ raise HTTPException(status_code=401, detail="Bad token")
28
+
29
+ if __name__ == "__main__":
30
+ api = EmbeddingAPI()
31
+ server = ls.LitServer(api, devices="cpu", accelerator="cpu")
32
+ server.run(port=7860)