VietCat commited on
Commit
c58569a
·
1 Parent(s): b1de237

fix permission issue for cache, and remove pooling

Browse files
Files changed (2) hide show
  1. app.py +18 -11
  2. requirements.txt +2 -1
app.py CHANGED
@@ -1,20 +1,27 @@
1
- from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
 
4
 
5
  app = FastAPI()
6
 
7
- # Dùng bản hỗ trợ sentence-transformers
8
- model = SentenceTransformer("BAAI/bge-m3-v2")
 
 
9
 
10
  class InputText(BaseModel):
11
  text: str
12
 
13
- @app.post("/embed")
14
- def embed_text(data: InputText):
15
- vector = model.encode(data.text, normalize_embeddings=True)
16
- return {"embedding": vector.tolist()}
17
-
18
  @app.get("/")
19
- def read_root():
20
- return {"message": "BAAI/bge-m3 Sentence Embedding API is running."}
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
 
6
  app = FastAPI()
7
 
8
+ # Load model
9
+ model_name = "BAAI/bge-m3"
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+ model = AutoModel.from_pretrained(model_name)
12
 
13
  class InputText(BaseModel):
14
  text: str
15
 
 
 
 
 
 
16
  @app.get("/")
17
+ def root():
18
+ return {"message": "BAAI/bge-m3 embedding API is running."}
19
+
20
+ @app.post("/embed")
21
+ def get_embedding(data: InputText):
22
+ inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
23
+ with torch.no_grad():
24
+ outputs = model(**inputs)
25
+ # Get CLS token or use pooling method
26
+ embedding = outputs.last_hidden_state[:, 0, :].squeeze().tolist()
27
+ return {"embedding": embedding}
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
- sentence-transformers
 
2
  fastapi
3
  uvicorn
 
1
+ transformers==4.41.0
2
+ torch
3
  fastapi
4
  uvicorn