VietCat commited on
Commit
79b3c25
·
1 Parent(s): e6e96a7

fix permission issue for cache, and add pooling

Browse files
Files changed (3) hide show
  1. Dockerfile +7 -0
  2. app.py +7 -13
  3. requirements.txt +1 -2
Dockerfile CHANGED
@@ -2,11 +2,18 @@ FROM python:3.10-slim
2
 
3
  WORKDIR /app
4
 
 
 
 
 
5
  COPY requirements.txt .
6
  RUN pip install --no-cache-dir -r requirements.txt
7
 
8
  COPY app.py .
9
 
 
 
 
10
  EXPOSE 7860
11
 
12
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
2
 
3
  WORKDIR /app
4
 
5
+ # Tạo thư mục cache riêng
6
+ ENV TRANSFORMERS_CACHE=/app/cache
7
+ ENV HF_HOME=/app/hf_home
8
+
9
  COPY requirements.txt .
10
  RUN pip install --no-cache-dir -r requirements.txt
11
 
12
  COPY app.py .
13
 
14
+ # Tạo sẵn thư mục cache (tránh lỗi lần đầu chạy)
15
+ RUN mkdir -p /app/cache && mkdir -p /app/hf_home
16
+
17
  EXPOSE 7860
18
 
19
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,27 +1,21 @@
1
- from fastapi import FastAPI, Request
2
  from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModel
4
- import torch
5
 
6
  app = FastAPI()
7
 
8
  # Load model
9
- model_name = "BAAI/bge-m3"
10
- tokenizer = AutoTokenizer.from_pretrained(model_name)
11
- model = AutoModel.from_pretrained(model_name)
12
 
13
  class InputText(BaseModel):
14
  text: str
15
 
16
  @app.get("/")
17
- def root():
18
- return {"message": "BAAI/bge-m3 embedding API is running."}
19
 
20
  @app.post("/embed")
21
  def get_embedding(data: InputText):
22
- inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
23
- with torch.no_grad():
24
- outputs = model(**inputs)
25
- # Get CLS token or use pooling method
26
- embedding = outputs.last_hidden_state[:, 0, :].squeeze().tolist()
27
  return {"embedding": embedding}
 
1
+ from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from sentence_transformers import SentenceTransformer
4
+ import uvicorn
5
 
6
  app = FastAPI()
7
 
8
  # Load model
9
+ model = SentenceTransformer("BAAI/bge-m3")
 
 
10
 
11
  class InputText(BaseModel):
12
  text: str
13
 
14
  @app.get("/")
15
+ def read_root():
16
+ return {"message": "BAAI/bge-m3 Sentence Embedding API is running."}
17
 
18
  @app.post("/embed")
19
  def get_embedding(data: InputText):
20
+ embedding = model.encode(data.text, normalize_embeddings=True).tolist()
 
 
 
 
21
  return {"embedding": embedding}
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- transformers==4.41.0
2
- torch
3
  fastapi
4
  uvicorn
 
1
+ sentence-transformers
 
2
  fastapi
3
  uvicorn