KJ24 commited on
Commit
a0d25c9
·
verified ·
1 Parent(s): ac99cd4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -6
app.py CHANGED
@@ -5,15 +5,15 @@ import torch
5
  import torch.nn.functional as F
6
  import os
7
 
8
- # 👉 Rediriger le cache HF vers un dossier autorisé
9
  CACHE_DIR = "/data"
10
- os.environ['HF_HOME'] = CACHE_DIR
11
- os.environ['TRANSFORMERS_CACHE'] = CACHE_DIR
12
 
13
  app = FastAPI()
14
 
15
- # Charger modèle et tokenizer avec cache_dir précisé explicitement
16
- MODEL_NAME = "thenlper/gte-small"
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
18
  model = AutoModel.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
19
 
@@ -27,4 +27,4 @@ async def embed_text(payload: EmbedInput):
27
  outputs = model(**inputs)
28
  embeddings = outputs.last_hidden_state[:, 0] # CLS token
29
  normalized = F.normalize(embeddings, p=2, dim=1)
30
- return {'embedding': normalized[0].tolist()}
 
5
  import torch.nn.functional as F
6
  import os
7
 
8
+ # 📁 Rediriger le cache HF vers un dossier autorisé
9
  CACHE_DIR = "/data"
10
+ os.environ["HF_HOME"] = CACHE_DIR
11
+ os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
12
 
13
  app = FastAPI()
14
 
15
+ # 📌 Nouveau modèle Jina Embeddings v3
16
+ MODEL_NAME = "jinaai/jina-embeddings-v3-base-en"
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
18
  model = AutoModel.from_pretrained(MODEL_NAME, cache_dir=CACHE_DIR)
19
 
 
27
  outputs = model(**inputs)
28
  embeddings = outputs.last_hidden_state[:, 0] # CLS token
29
  normalized = F.normalize(embeddings, p=2, dim=1)
30
+ return {"embedding": normalized[0].tolist()}