VietCat commited on
Commit
c09c72a
·
1 Parent(s): a692f28

split into threads

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  import time
6
  import logging
7
  from datetime import datetime
 
8
 
9
  # Cấu hình logging
10
  logging.basicConfig(
@@ -12,6 +13,9 @@ logging.basicConfig(
12
  level=logging.INFO
13
  )
14
 
 
 
 
15
  app = FastAPI()
16
 
17
  # Load model
@@ -19,6 +23,8 @@ model_name = "AITeamVN/Vietnamese_Embedding_v2"
19
  logging.info(f"Loading model: {model_name}")
20
  tokenizer = AutoTokenizer.from_pretrained(model_name)
21
  model = AutoModel.from_pretrained(model_name)
 
 
22
  logging.info("Model loaded successfully.")
23
 
24
  class InputText(BaseModel):
@@ -27,26 +33,31 @@ class InputText(BaseModel):
27
  @app.get("/")
28
  def root():
29
  now = datetime.now().isoformat()
30
- logging.info(f"[GET /] Received health check at {now}")
31
- return {"message": "AITeamVN/Vietnamese_Embedding_v2 embedding API is running."}
32
 
33
- @app.post("/embed")
34
- def get_embedding(data: InputText):
35
  start_time = time.time()
36
  start_ts = datetime.now().isoformat()
37
 
38
- # Tokenize input
39
- inputs = tokenizer(data.text, return_tensors="pt", padding=True, truncation=True)
40
- input_token_count = inputs["input_ids"].shape[1]
41
- logging.info(f"[POST /embed] Start at {start_ts} | Input text: '{data.text[:50]}'... | Tokens: {input_token_count}")
42
 
43
- # Run model inference
44
  with torch.no_grad():
45
  outputs = model(**inputs)
46
  embedding = outputs.last_hidden_state[:, 0, :].squeeze().tolist()
47
 
48
  end_ts = datetime.now().isoformat()
49
  duration_ms = (time.time() - start_time) * 1000
50
- logging.info(f"[POST /embed] Done at {end_ts} | Embedding size: {len(embedding)} | Time: {duration_ms:.2f} ms")
51
 
 
 
 
 
 
 
52
  return {"embedding": embedding}
 
5
  import time
6
  import logging
7
  from datetime import datetime
8
+ from concurrent.futures import ThreadPoolExecutor
9
 
10
  # Cấu hình logging
11
  logging.basicConfig(
 
13
  level=logging.INFO
14
  )
15
 
16
+ # Giới hạn số thread = 1 để không quá tải CPU HFS free
17
+ executor = ThreadPoolExecutor(max_workers=1)
18
+
19
  app = FastAPI()
20
 
21
  # Load model
 
23
  logging.info(f"Loading model: {model_name}")
24
  tokenizer = AutoTokenizer.from_pretrained(model_name)
25
  model = AutoModel.from_pretrained(model_name)
26
+ model.eval()
27
+ torch.set_num_threads(1)
28
  logging.info("Model loaded successfully.")
29
 
30
  class InputText(BaseModel):
 
33
  @app.get("/")
34
  def root():
35
  now = datetime.now().isoformat()
36
+ logging.info(f"[GET /] Health check at {now}")
37
+ return {"message": "Vietnamese Embedding API is running."}
38
 
39
+ # Hàm xử lý embedding tách riêng
40
+ def compute_embedding(text: str):
41
  start_time = time.time()
42
  start_ts = datetime.now().isoformat()
43
 
44
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True)
45
+ token_count = inputs["input_ids"].shape[1]
46
+
47
+ logging.info(f"[EMBED] Start: {start_ts} | Input: '{text[:50]}'... | Tokens: {token_count}")
48
 
 
49
  with torch.no_grad():
50
  outputs = model(**inputs)
51
  embedding = outputs.last_hidden_state[:, 0, :].squeeze().tolist()
52
 
53
  end_ts = datetime.now().isoformat()
54
  duration_ms = (time.time() - start_time) * 1000
55
+ logging.info(f"[EMBED] Done: {end_ts} | Embedding size: {len(embedding)} | Time: {duration_ms:.2f} ms")
56
 
57
+ return embedding
58
+
59
+ @app.post("/embed")
60
+ def get_embedding(data: InputText):
61
+ # Gửi sang thread pool (sẽ đợi đến khi xong)
62
+ embedding = executor.submit(compute_embedding, data.text).result()
63
  return {"embedding": embedding}