KJ24 commited on
Commit
5b8ca02
·
verified ·
1 Parent(s): b9aed3c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -5
app.py CHANGED
@@ -1,16 +1,24 @@
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
- from sentence_transformers import SentenceTransformer
4
- import os
 
5
 
6
  app = FastAPI()
7
 
8
- model = SentenceTransformer("thenlper/gte-small")
 
 
 
9
 
10
  class EmbedInput(BaseModel):
11
  text: str
12
 
13
  @app.post("/embed")
14
  async def embed_text(payload: EmbedInput):
15
- embedding = model.encode(payload.text, normalize_embeddings=True)
16
- return {"embedding": embedding.tolist()}
 
 
 
 
 
1
  from fastapi import FastAPI
2
  from pydantic import BaseModel
3
+ from transformers import AutoTokenizer, AutoModel
4
+ import torch
5
+ import torch.nn.functional as F
6
 
7
  app = FastAPI()
8
 
9
+ # Charger le modèle depuis HF sans passer par SentenceTransformer
10
+ MODEL_NAME = "thenlper/gte-small"
11
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
12
+ model = AutoModel.from_pretrained(MODEL_NAME)
13
 
14
  class EmbedInput(BaseModel):
15
  text: str
16
 
17
  @app.post("/embed")
18
  async def embed_text(payload: EmbedInput):
19
+ inputs = tokenizer(payload.text, return_tensors="pt", padding=True, truncation=True)
20
+ with torch.no_grad():
21
+ outputs = model(**inputs)
22
+ embeddings = outputs.last_hidden_state[:, 0] # CLS token
23
+ normalized = F.normalize(embeddings, p=2, dim=1)
24
+ return {"embedding": normalized[0].tolist()}