kiko-embedding / app.py
vanifala's picture
upgrade to nomic-v2-moe + dimensions support
415bbcd verified
"""Embedding Server (sentence-transformers) for HuggingFace Spaces."""
import os
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer
MODEL_NAME = os.environ.get("MODEL_NAME", "nomic-ai/nomic-embed-text-v2-moe")
print(f"[Embedding] Loading model: {MODEL_NAME}...", flush=True)
model = SentenceTransformer(MODEL_NAME, trust_remote_code=True)
NATIVE_DIMS = model.get_sentence_embedding_dimension()
print(f"[Embedding] Model loaded. Native dimensions: {NATIVE_DIMS}", flush=True)
app = FastAPI()
class EmbedRequest(BaseModel):
text: str | list[str] | None = None
texts: list[str] | None = None
model: str | None = None
normalize: bool = True
prefix: str | None = None
dimensions: int | None = None
def _process_embeddings(embeddings: np.ndarray, dimensions: int | None) -> np.ndarray:
"""Truncate to target dimensions and re-normalize."""
if dimensions and dimensions < embeddings.shape[1]:
embeddings = embeddings[:, :dimensions]
norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
embeddings = embeddings / norms
return embeddings
def _encode(input_texts: list[str], req: EmbedRequest) -> dict:
if req.prefix:
input_texts = [req.prefix + t for t in input_texts]
embeddings = model.encode(input_texts, convert_to_numpy=True,
normalize_embeddings=req.normalize)
embeddings = _process_embeddings(embeddings, req.dimensions)
return {
"embeddings": embeddings.tolist(),
"model": MODEL_NAME,
"dimensions": embeddings.shape[1],
"tokens": len(input_texts) * 32,
}
@app.get("/health")
def health():
return {
"status": "ok",
"model": MODEL_NAME,
"model_name": MODEL_NAME,
"native_dimensions": NATIVE_DIMS,
}
@app.post("/embed")
def embed(req: EmbedRequest):
if req.texts:
input_texts = req.texts
elif req.text:
input_texts = [req.text] if isinstance(req.text, str) else req.text
else:
return {"error": "Provide 'text' or 'texts' field"}, 400
return _encode(input_texts, req)
@app.post("/embed_batch")
def embed_batch(req: EmbedRequest):
if req.texts:
input_texts = req.texts
elif req.text:
input_texts = [req.text] if isinstance(req.text, str) else req.text
else:
return {"error": "Provide 'text' or 'texts' field"}, 400
return _encode(input_texts, req)