embedding-api / app.py
KJ24's picture
Update app.py
5b8ca02 verified
raw
history blame
802 Bytes
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModel
import torch
import torch.nn.functional as F
app = FastAPI()
# Charger le modèle depuis HF sans passer par SentenceTransformer
MODEL_NAME = "thenlper/gte-small"
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModel.from_pretrained(MODEL_NAME)
class EmbedInput(BaseModel):
text: str
@app.post("/embed")
async def embed_text(payload: EmbedInput):
inputs = tokenizer(payload.text, return_tensors="pt", padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state[:, 0] # CLS token
normalized = F.normalize(embeddings, p=2, dim=1)
return {"embedding": normalized[0].tolist()}