vanifala commited on
Commit
831ad2d
·
verified ·
1 Parent(s): e808998

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -13,12 +13,9 @@ app = FastAPI()
13
 
14
 
15
  class EmbedRequest(BaseModel):
16
- text: str | list[str]
17
- normalize: bool = True
18
-
19
-
20
- class BatchEmbedRequest(BaseModel):
21
- texts: list[str]
22
  normalize: bool = True
23
 
24
 
@@ -29,20 +26,34 @@ def health():
29
 
30
  @app.post("/embed")
31
  def embed(req: EmbedRequest):
32
- texts = [req.text] if isinstance(req.text, str) else req.text
33
- embeddings = model.encode(texts, normalize_embeddings=req.normalize)
 
 
 
 
 
 
34
  return {
35
  "embeddings": embeddings.tolist(),
36
  "model": MODEL_NAME,
37
  "dimensions": embeddings.shape[1],
 
38
  }
39
 
40
 
41
  @app.post("/embed_batch")
42
- def embed_batch(req: BatchEmbedRequest):
43
- embeddings = model.encode(req.texts, normalize_embeddings=req.normalize)
 
 
 
 
 
 
44
  return {
45
  "embeddings": embeddings.tolist(),
46
  "model": MODEL_NAME,
47
  "dimensions": embeddings.shape[1],
 
48
  }
 
13
 
14
 
15
  class EmbedRequest(BaseModel):
16
+ text: str | list[str] | None = None
17
+ texts: list[str] | None = None
18
+ model: str | None = None
 
 
 
19
  normalize: bool = True
20
 
21
 
 
26
 
27
  @app.post("/embed")
28
  def embed(req: EmbedRequest):
29
+ # Accept both "text" (single/list) and "texts" (list) fields
30
+ if req.texts:
31
+ input_texts = req.texts
32
+ elif req.text:
33
+ input_texts = [req.text] if isinstance(req.text, str) else req.text
34
+ else:
35
+ return {"error": "Provide 'text' or 'texts' field"}, 400
36
+ embeddings = model.encode(input_texts, normalize_embeddings=req.normalize)
37
  return {
38
  "embeddings": embeddings.tolist(),
39
  "model": MODEL_NAME,
40
  "dimensions": embeddings.shape[1],
41
+ "tokens": len(input_texts) * 32,
42
  }
43
 
44
 
45
  @app.post("/embed_batch")
46
+ def embed_batch(req: EmbedRequest):
47
+ if req.texts:
48
+ input_texts = req.texts
49
+ elif req.text:
50
+ input_texts = [req.text] if isinstance(req.text, str) else req.text
51
+ else:
52
+ return {"error": "Provide 'text' or 'texts' field"}, 400
53
+ embeddings = model.encode(input_texts, normalize_embeddings=req.normalize)
54
  return {
55
  "embeddings": embeddings.tolist(),
56
  "model": MODEL_NAME,
57
  "dimensions": embeddings.shape[1],
58
+ "tokens": len(input_texts) * 32,
59
  }