VietCat commited on
Commit
a03dcc1
·
1 Parent(s): 3eaf7a6

fix ValueError: Out of range float values are not JSON compliant

Browse files
Files changed (2) hide show
  1. app/main.py +18 -12
  2. app/model_loader.py +4 -3
app/main.py CHANGED
@@ -42,10 +42,23 @@ async def startup_event():
42
  llm = await asyncio.to_thread(load_model, model_path)
43
  logging.info("✅ Đã tải mô hình thành công.")
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  @app.post("/embed")
46
  async def embed(request: Request):
47
  """Trả về nhiều vector (mảng 2D) - phù hợp RAG"""
48
- global llm
49
  data = await request.json()
50
  text = data.get("text")
51
  if not text:
@@ -54,11 +67,7 @@ async def embed(request: Request):
54
  start_time = time.time()
55
  logging.info(f"📥 Nhận request /embed lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
56
 
57
- token_ids = llm.tokenize(text.encode("utf-8"))
58
- logging.info(f"🧩 Số token đầu vào: {len(token_ids)}")
59
-
60
- embedding = await asyncio.to_thread(llm.embed, text)
61
- logging.info(f"📊 Số vector trả về: {len(embedding)}")
62
 
63
  end_time = time.time()
64
  duration_ms = round((end_time - start_time) * 1000, 2)
@@ -66,10 +75,10 @@ async def embed(request: Request):
66
 
67
  return {"embedding": embedding}
68
 
 
69
  @app.post("/embed/mean")
70
  async def embed_mean(request: Request):
71
  """Trả về 1 vector duy nhất (mean pooling) - phù hợp semantic search"""
72
- global llm
73
  data = await request.json()
74
  text = data.get("text")
75
  if not text:
@@ -78,11 +87,7 @@ async def embed_mean(request: Request):
78
  start_time = time.time()
79
  logging.info(f"📥 Nhận request /embed/mean lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
80
 
81
- token_ids = llm.tokenize(text.encode("utf-8"))
82
- logging.info(f"🧩 Số token đầu vào: {len(token_ids)}")
83
-
84
- raw_embedding = await asyncio.to_thread(llm.embed, text)
85
- logging.info(f"📊 Số vector (trước pooling) trả về: {len(raw_embedding)}")
86
 
87
  if isinstance(raw_embedding, list) and isinstance(raw_embedding[0], list):
88
  embedding = np.mean(raw_embedding, axis=0).tolist()
@@ -97,6 +102,7 @@ async def embed_mean(request: Request):
97
 
98
  return {"embedding": embedding}
99
 
 
100
  @app.get("/")
101
  def root():
102
  return {"message": "Qwen3Embedding4BQ4KM embedding API is running."}
 
42
  llm = await asyncio.to_thread(load_model, model_path)
43
  logging.info("✅ Đã tải mô hình thành công.")
44
 
45
+
46
+ def generate_embedding(text: str) -> list:
47
+ """Gọi embedding và đảm bảo kết quả JSON-safe (không NaN/Inf)"""
48
+ global llm
49
+ token_ids = llm.tokenize(text.encode("utf-8"))
50
+ logging.info(f"🧩 Số token đầu vào: {len(token_ids)}")
51
+
52
+ raw_embedding = llm.embed(text)
53
+ logging.info(f"📊 Số vector trả về: {len(raw_embedding)}")
54
+
55
+ cleaned = np.nan_to_num(raw_embedding).tolist()
56
+ return cleaned
57
+
58
+
59
  @app.post("/embed")
60
  async def embed(request: Request):
61
  """Trả về nhiều vector (mảng 2D) - phù hợp RAG"""
 
62
  data = await request.json()
63
  text = data.get("text")
64
  if not text:
 
67
  start_time = time.time()
68
  logging.info(f"📥 Nhận request /embed lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
69
 
70
+ embedding = await asyncio.to_thread(generate_embedding, text)
 
 
 
 
71
 
72
  end_time = time.time()
73
  duration_ms = round((end_time - start_time) * 1000, 2)
 
75
 
76
  return {"embedding": embedding}
77
 
78
+
79
  @app.post("/embed/mean")
80
  async def embed_mean(request: Request):
81
  """Trả về 1 vector duy nhất (mean pooling) - phù hợp semantic search"""
 
82
  data = await request.json()
83
  text = data.get("text")
84
  if not text:
 
87
  start_time = time.time()
88
  logging.info(f"📥 Nhận request /embed/mean lúc {time.strftime('%Y-%m-%d %H:%M:%S')}")
89
 
90
+ raw_embedding = await asyncio.to_thread(generate_embedding, text)
 
 
 
 
91
 
92
  if isinstance(raw_embedding, list) and isinstance(raw_embedding[0], list):
93
  embedding = np.mean(raw_embedding, axis=0).tolist()
 
102
 
103
  return {"embedding": embedding}
104
 
105
+
106
  @app.get("/")
107
  def root():
108
  return {"message": "Qwen3Embedding4BQ4KM embedding API is running."}
app/model_loader.py CHANGED
@@ -16,11 +16,12 @@ def load_model(model_path: str):
16
 
17
  model = Llama(
18
  model_path=model_path,
19
- embedding=True, # ✅ QUAN TRỌNG: bật chế độ embedding
20
- n_ctx=1024, # đủ để xử lý hầu hết đoạn văn
21
- n_batch=64,
22
  n_threads=4,
23
  n_threads_batch=2,
 
24
  logits_all=False,
25
  use_mlock=False,
26
  verbose=False
 
16
 
17
  model = Llama(
18
  model_path=model_path,
19
+ embedding=True,
20
+ n_ctx=1024,
21
+ n_batch=16, # ✅ Giảm batch size để tránh lỗi bộ nhớ
22
  n_threads=4,
23
  n_threads_batch=2,
24
+ n_gpu_layers=0, # ✅ Chạy thuần CPU để tránh crash nếu không có GPU
25
  logits_all=False,
26
  use_mlock=False,
27
  verbose=False