File size: 2,512 Bytes
4fdabcd
 
415bbcd
4fdabcd
 
 
 
415bbcd
4fdabcd
415bbcd
 
 
4fdabcd
 
 
 
 
831ad2d
 
 
4fdabcd
415bbcd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fdabcd
 
 
 
415bbcd
 
 
 
 
 
4fdabcd
 
 
 
831ad2d
 
 
 
 
 
415bbcd
4fdabcd
 
 
831ad2d
 
 
 
 
 
 
415bbcd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
"""Embedding Server (sentence-transformers) for HuggingFace Spaces."""
import os
import numpy as np
from fastapi import FastAPI
from pydantic import BaseModel
from sentence_transformers import SentenceTransformer

MODEL_NAME = os.environ.get("MODEL_NAME", "nomic-ai/nomic-embed-text-v2-moe")
print(f"[Embedding] Loading model: {MODEL_NAME}...", flush=True)
model = SentenceTransformer(MODEL_NAME, trust_remote_code=True)
NATIVE_DIMS = model.get_sentence_embedding_dimension()
print(f"[Embedding] Model loaded. Native dimensions: {NATIVE_DIMS}", flush=True)

app = FastAPI()


class EmbedRequest(BaseModel):
    text: str | list[str] | None = None
    texts: list[str] | None = None
    model: str | None = None
    normalize: bool = True
    prefix: str | None = None
    dimensions: int | None = None


def _process_embeddings(embeddings: np.ndarray, dimensions: int | None) -> np.ndarray:
    """Truncate to target dimensions and re-normalize."""
    if dimensions and dimensions < embeddings.shape[1]:
        embeddings = embeddings[:, :dimensions]
        norms = np.linalg.norm(embeddings, axis=1, keepdims=True)
        embeddings = embeddings / norms
    return embeddings


def _encode(input_texts: list[str], req: EmbedRequest) -> dict:
    if req.prefix:
        input_texts = [req.prefix + t for t in input_texts]
    embeddings = model.encode(input_texts, convert_to_numpy=True,
                              normalize_embeddings=req.normalize)
    embeddings = _process_embeddings(embeddings, req.dimensions)
    return {
        "embeddings": embeddings.tolist(),
        "model": MODEL_NAME,
        "dimensions": embeddings.shape[1],
        "tokens": len(input_texts) * 32,
    }


@app.get("/health")
def health():
    return {
        "status": "ok",
        "model": MODEL_NAME,
        "model_name": MODEL_NAME,
        "native_dimensions": NATIVE_DIMS,
    }


@app.post("/embed")
def embed(req: EmbedRequest):
    if req.texts:
        input_texts = req.texts
    elif req.text:
        input_texts = [req.text] if isinstance(req.text, str) else req.text
    else:
        return {"error": "Provide 'text' or 'texts' field"}, 400
    return _encode(input_texts, req)


@app.post("/embed_batch")
def embed_batch(req: EmbedRequest):
    if req.texts:
        input_texts = req.texts
    elif req.text:
        input_texts = [req.text] if isinstance(req.text, str) else req.text
    else:
        return {"error": "Provide 'text' or 'texts' field"}, 400
    return _encode(input_texts, req)