vanifala commited on
Commit
415bbcd
·
verified ·
1 Parent(s): 1c863bc

upgrade to nomic-v2-moe + dimensions support

Browse files
Files changed (1) hide show
  1. app.py +38 -25
app.py CHANGED
@@ -1,13 +1,15 @@
1
  """Embedding Server (sentence-transformers) for HuggingFace Spaces."""
2
  import os
 
3
  from fastapi import FastAPI
4
  from pydantic import BaseModel
5
  from sentence_transformers import SentenceTransformer
6
 
7
- MODEL_NAME = os.environ.get("MODEL_NAME", "intfloat/multilingual-e5-small")
8
  print(f"[Embedding] Loading model: {MODEL_NAME}...", flush=True)
9
- model = SentenceTransformer(MODEL_NAME)
10
- print("[Embedding] Model loaded.", flush=True)
 
11
 
12
  app = FastAPI()
13
 
@@ -17,33 +19,52 @@ class EmbedRequest(BaseModel):
17
  texts: list[str] | None = None
18
  model: str | None = None
19
  normalize: bool = True
20
- prefix: str | None = None # e5 models require "query: " or "passage: " prefix
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
 
23
  @app.get("/health")
24
  def health():
25
- return {"status": "ok", "model": MODEL_NAME}
 
 
 
 
 
26
 
27
 
28
  @app.post("/embed")
29
  def embed(req: EmbedRequest):
30
- # Accept both "text" (single/list) and "texts" (list) fields
31
  if req.texts:
32
  input_texts = req.texts
33
  elif req.text:
34
  input_texts = [req.text] if isinstance(req.text, str) else req.text
35
  else:
36
  return {"error": "Provide 'text' or 'texts' field"}, 400
37
- # Apply prefix if provided (e5 models require "query: " or "passage: ")
38
- if req.prefix:
39
- input_texts = [req.prefix + t for t in input_texts]
40
- embeddings = model.encode(input_texts, normalize_embeddings=req.normalize)
41
- return {
42
- "embeddings": embeddings.tolist(),
43
- "model": MODEL_NAME,
44
- "dimensions": embeddings.shape[1],
45
- "tokens": len(input_texts) * 32,
46
- }
47
 
48
 
49
  @app.post("/embed_batch")
@@ -54,12 +75,4 @@ def embed_batch(req: EmbedRequest):
54
  input_texts = [req.text] if isinstance(req.text, str) else req.text
55
  else:
56
  return {"error": "Provide 'text' or 'texts' field"}, 400
57
- if req.prefix:
58
- input_texts = [req.prefix + t for t in input_texts]
59
- embeddings = model.encode(input_texts, normalize_embeddings=req.normalize)
60
- return {
61
- "embeddings": embeddings.tolist(),
62
- "model": MODEL_NAME,
63
- "dimensions": embeddings.shape[1],
64
- "tokens": len(input_texts) * 32,
65
- }
 
1
  """Embedding Server (sentence-transformers) for HuggingFace Spaces."""
2
  import os
3
+ import numpy as np
4
  from fastapi import FastAPI
5
  from pydantic import BaseModel
6
  from sentence_transformers import SentenceTransformer
7
 
8
+ MODEL_NAME = os.environ.get("MODEL_NAME", "nomic-ai/nomic-embed-text-v2-moe")
9
  print(f"[Embedding] Loading model: {MODEL_NAME}...", flush=True)
10
+ model = SentenceTransformer(MODEL_NAME, trust_remote_code=True)
11
+ NATIVE_DIMS = model.get_sentence_embedding_dimension()
12
+ print(f"[Embedding] Model loaded. Native dimensions: {NATIVE_DIMS}", flush=True)
13
 
14
  app = FastAPI()
15
 
 
19
  texts: list[str] | None = None
20
  model: str | None = None
21
  normalize: bool = True
22
+ prefix: str | None = None
23
+ dimensions: int | None = None
24
+
25
+
26
+ def _process_embeddings(embeddings: np.ndarray, dimensions: int | None) -> np.ndarray:
27
+ """Truncate to target dimensions and re-normalize."""
28
+ if dimensions and dimensions < embeddings.shape[1]:
29
+ embeddings = embeddings[:, :dimensions]
30
+ norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
31
+ embeddings = embeddings / norms
32
+ return embeddings
33
+
34
+
35
+ def _encode(input_texts: list[str], req: EmbedRequest) -> dict:
36
+ if req.prefix:
37
+ input_texts = [req.prefix + t for t in input_texts]
38
+ embeddings = model.encode(input_texts, convert_to_numpy=True,
39
+ normalize_embeddings=req.normalize)
40
+ embeddings = _process_embeddings(embeddings, req.dimensions)
41
+ return {
42
+ "embeddings": embeddings.tolist(),
43
+ "model": MODEL_NAME,
44
+ "dimensions": embeddings.shape[1],
45
+ "tokens": len(input_texts) * 32,
46
+ }
47
 
48
 
49
  @app.get("/health")
50
  def health():
51
+ return {
52
+ "status": "ok",
53
+ "model": MODEL_NAME,
54
+ "model_name": MODEL_NAME,
55
+ "native_dimensions": NATIVE_DIMS,
56
+ }
57
 
58
 
59
  @app.post("/embed")
60
  def embed(req: EmbedRequest):
 
61
  if req.texts:
62
  input_texts = req.texts
63
  elif req.text:
64
  input_texts = [req.text] if isinstance(req.text, str) else req.text
65
  else:
66
  return {"error": "Provide 'text' or 'texts' field"}, 400
67
+ return _encode(input_texts, req)
 
 
 
 
 
 
 
 
 
68
 
69
 
70
  @app.post("/embed_batch")
 
75
  input_texts = [req.text] if isinstance(req.text, str) else req.text
76
  else:
77
  return {"error": "Provide 'text' or 'texts' field"}, 400
78
+ return _encode(input_texts, req)