justmotes commited on
Commit
bf79e3e
·
1 Parent(s): 0b669a9

Fix: Add get_embedding function and caching

Browse files
Files changed (1) hide show
  1. src/data_pipeline.py +17 -7
src/data_pipeline.py CHANGED
@@ -6,18 +6,21 @@ from typing import List, Union
6
  import torch
7
  import torch.nn.functional as F
8
 
 
 
 
 
 
 
 
 
 
9
  def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
10
  """
11
  Loads the specified model and generates embeddings for the given texts.
12
  Handles 'nomic' and 'qwen' specific requirements (trust_remote_code).
13
  """
14
- print(f"Loading embedding model: {model_name}...")
15
-
16
- trust_remote_code = False
17
- if "nomic" in model_name or "qwen" in model_name:
18
- trust_remote_code = True
19
-
20
- model = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
21
 
22
  # Generate embeddings
23
  # Convert to numpy array if it returns a tensor or list
@@ -25,6 +28,13 @@ def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
25
 
26
  return embeddings
27
 
 
 
 
 
 
 
 
28
  def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray:
29
  """
30
  Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing.
 
6
  import torch
7
  import torch.nn.functional as F
8
 
9
+ _MODEL_CACHE = {}
10
+
11
+ def get_model(model_name: str):
12
+ if model_name not in _MODEL_CACHE:
13
+ print(f"Loading embedding model: {model_name}...")
14
+ trust_remote_code = "nomic" in model_name or "qwen" in model_name
15
+ _MODEL_CACHE[model_name] = SentenceTransformer(model_name, trust_remote_code=trust_remote_code, device='cpu')
16
+ return _MODEL_CACHE[model_name]
17
+
18
  def get_embeddings(model_name: str, texts: List[str]) -> np.ndarray:
19
  """
20
  Loads the specified model and generates embeddings for the given texts.
21
  Handles 'nomic' and 'qwen' specific requirements (trust_remote_code).
22
  """
23
+ model = get_model(model_name)
 
 
 
 
 
 
24
 
25
  # Generate embeddings
26
  # Convert to numpy array if it returns a tensor or list
 
28
 
29
  return embeddings
30
 
31
+ def get_embedding(text: str, model_name: str = "nomic-ai/nomic-embed-text-v1.5") -> np.ndarray:
32
+ """
33
+ Generates a single embedding for a query string.
34
+ """
35
+ embeddings = get_embeddings(model_name, [text])
36
+ return embeddings[0]
37
+
38
  def mrl_slice(vectors: np.ndarray, dims: int) -> np.ndarray:
39
  """
40
  Slices the vectors to the specified dimensions AND applies L2 normalization *after* slicing.