Spaces:
Sleeping
Sleeping
Fix: Add get_embedding function and caching
Browse files- src/data_pipeline.py +17 -7
src/data_pipeline.py
CHANGED
|
@@ -6,18 +6,21 @@ from typing import List, Union
|
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
|
| 10 |
"""
|
| 11 |
Loads the specified model and generates embeddings for the given texts.
|
| 12 |
Handles 'nomic' and 'qwen' specific requirements (trust_remote_code).
|
| 13 |
"""
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
trust_remote_code = False
|
| 17 |
-
if "nomic" in model_name or "qwen" in model_name:
|
| 18 |
-
trust_remote_code = True
|
| 19 |
-
|
| 20 |
-
model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
|
| 21 |
|
| 22 |
# Generate embeddings
|
| 23 |
# Convert to numpy array if it returns a tensor or list
|
|
@@ -25,6 +28,13 @@ def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
|
|
| 25 |
|
| 26 |
return embeddings
|
| 27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray:
|
| 29 |
"""
|
| 30 |
Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing.
|
|
|
|
| 6 |
import torch
|
| 7 |
import torch.nn.functional as F
|
| 8 |
|
| 9 |
+
_MODEL_CACHE = {}
|
| 10 |
+
|
| 11 |
+
def get_model(model_name: str):
|
| 12 |
+
if model_name not in _MODEL_CACHE:
|
| 13 |
+
print(f"Loading embedding model: {model_name}...")
|
| 14 |
+
trust_remote_code = "nomic" in model_name or "qwen" in model_name
|
| 15 |
+
_MODEL_CACHE[model_name] = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
|
| 16 |
+
return _MODEL_CACHE[model_name]
|
| 17 |
+
|
| 18 |
def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
|
| 19 |
"""
|
| 20 |
Loads the specified model and generates embeddings for the given texts.
|
| 21 |
Handles 'nomic' and 'qwen' specific requirements (trust_remote_code).
|
| 22 |
"""
|
| 23 |
+
model = get_model(model_name)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
# Generate embeddings
|
| 26 |
# Convert to numpy array if it returns a tensor or list
|
|
|
|
| 28 |
|
| 29 |
return embeddings
|
| 30 |
|
| 31 |
+
def get_embedding(text: str, model_name: str = "nomic-ai/nomic-embed-text-v1.5") -> np.ndarray:
|
| 32 |
+
"""
|
| 33 |
+
Generates a single embedding for a query string.
|
| 34 |
+
"""
|
| 35 |
+
embeddings = get_embeddings(model_name, [text])
|
| 36 |
+
return embeddings[0]
|
| 37 |
+
|
| 38 |
def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray:
|
| 39 |
"""
|
| 40 |
Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing.
|