Lev Israel commited on
Commit
d1c390a
·
1 Parent(s): 1787d7f

Embedding Gemma

Browse files
Files changed (1) hide show
  1. models.py +28 -5
models.py CHANGED
@@ -56,6 +56,14 @@ CURATED_MODELS = {
56
  "query_prefix": "",
57
  "passage_prefix": "",
58
  },
 
 
 
 
 
 
 
 
59
  }
60
 
61
  # API-based models
@@ -233,6 +241,7 @@ class EmbeddingModel(BaseEmbeddingModel):
233
  model_id: str,
234
  device: Optional[str] = None,
235
  max_length: int = 512,
 
236
  ):
237
  """
238
  Initialize the embedding model.
@@ -241,12 +250,12 @@ class EmbeddingModel(BaseEmbeddingModel):
241
  model_id: Hugging Face model ID
242
  device: Device to use ('cuda', 'cpu', or None for auto)
243
  max_length: Maximum sequence length for tokenization
 
244
  """
245
  from sentence_transformers import SentenceTransformer
246
  import torch
247
 
248
  self.model_id = model_id
249
- self.max_length = max_length
250
 
251
  # Auto-detect device
252
  if device is None:
@@ -262,12 +271,18 @@ class EmbeddingModel(BaseEmbeddingModel):
262
  "passage_prefix": "",
263
  })
264
 
 
 
 
 
 
 
265
  # Load the model with float16 on CUDA to save VRAM
266
  # (12B model: float32 = 48GB, float16 = 24GB)
267
  print(f"Loading model: {model_id} on {device}")
268
 
269
  # Only trust remote code from known publishers (security measure)
270
- trusted_publishers = ["nvidia/"]
271
  trust_remote_code = any(model_id.startswith(pub) for pub in trusted_publishers)
272
 
273
  if device == "cuda":
@@ -276,13 +291,19 @@ class EmbeddingModel(BaseEmbeddingModel):
276
  device=device,
277
  model_kwargs={"torch_dtype": torch.float16},
278
  trust_remote_code=trust_remote_code,
 
279
  )
280
  else:
281
- self.model = SentenceTransformer(model_id, device=device, trust_remote_code=trust_remote_code)
 
 
 
 
 
282
 
283
  # Set max sequence length if supported
284
  if hasattr(self.model, "max_seq_length"):
285
- self.model.max_seq_length = min(max_length, self.model.max_seq_length)
286
 
287
  self.embedding_dim = self.model.get_sentence_embedding_dimension()
288
  print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
@@ -980,6 +1001,7 @@ def load_model(
980
  model_id: str,
981
  device: Optional[str] = None,
982
  api_key: Optional[str] = None,
 
983
  ) -> BaseEmbeddingModel:
984
  """
985
  Load an embedding model by ID.
@@ -988,6 +1010,7 @@ def load_model(
988
  model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large')
989
  device: Device to use (for local models only)
990
  api_key: API key (for API-based models, or uses environment variable)
 
991
 
992
  Returns:
993
  Loaded embedding model instance
@@ -1012,7 +1035,7 @@ def load_model(
1012
  raise ValueError(f"Unknown API model type: {model_id}")
1013
 
1014
  # Otherwise, load as a local sentence-transformer model
1015
- return EmbeddingModel(model_id, device=device)
1016
 
1017
 
1018
  def validate_model_id(model_id: str) -> tuple[bool, str]:
 
56
  "query_prefix": "",
57
  "passage_prefix": "",
58
  },
59
+ "google/embeddinggemma-300m": {
60
+ "name": "EmbeddingGemma",
61
+ "description": "Google's 300M param embedding model, 100+ languages, 768d (requires HF token + license)",
62
+ "type": "local",
63
+ "query_prefix": "task: search result | query: ",
64
+ "passage_prefix": "title: none | text: ",
65
+ "max_length": 2048,
66
+ },
67
  }
68
 
69
  # API-based models
 
241
  model_id: str,
242
  device: Optional[str] = None,
243
  max_length: int = 512,
244
+ hf_token: Optional[str] = None,
245
  ):
246
  """
247
  Initialize the embedding model.
 
250
  model_id: Hugging Face model ID
251
  device: Device to use ('cuda', 'cpu', or None for auto)
252
  max_length: Maximum sequence length for tokenization
253
+ hf_token: HuggingFace token for gated models (or uses HF_TOKEN env var)
254
  """
255
  from sentence_transformers import SentenceTransformer
256
  import torch
257
 
258
  self.model_id = model_id
 
259
 
260
  # Auto-detect device
261
  if device is None:
 
271
  "passage_prefix": "",
272
  })
273
 
274
+ # Use config max_length if available, otherwise use parameter
275
+ self.max_length = self.config.get("max_length", max_length)
276
+
277
+ # Get HF token from parameter or environment (for gated models like EmbeddingGemma)
278
+ hf_token = hf_token or os.environ.get("HF_TOKEN")
279
+
280
  # Load the model with float16 on CUDA to save VRAM
281
  # (12B model: float32 = 48GB, float16 = 24GB)
282
  print(f"Loading model: {model_id} on {device}")
283
 
284
  # Only trust remote code from known publishers (security measure)
285
+ trusted_publishers = ["nvidia/", "google/"]
286
  trust_remote_code = any(model_id.startswith(pub) for pub in trusted_publishers)
287
 
288
  if device == "cuda":
 
291
  device=device,
292
  model_kwargs={"torch_dtype": torch.float16},
293
  trust_remote_code=trust_remote_code,
294
+ token=hf_token,
295
  )
296
  else:
297
+ self.model = SentenceTransformer(
298
+ model_id,
299
+ device=device,
300
+ trust_remote_code=trust_remote_code,
301
+ token=hf_token,
302
+ )
303
 
304
  # Set max sequence length if supported
305
  if hasattr(self.model, "max_seq_length"):
306
+ self.model.max_seq_length = min(self.max_length, self.model.max_seq_length)
307
 
308
  self.embedding_dim = self.model.get_sentence_embedding_dimension()
309
  print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
 
1001
  model_id: str,
1002
  device: Optional[str] = None,
1003
  api_key: Optional[str] = None,
1004
+ hf_token: Optional[str] = None,
1005
  ) -> BaseEmbeddingModel:
1006
  """
1007
  Load an embedding model by ID.
 
1010
  model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large')
1011
  device: Device to use (for local models only)
1012
  api_key: API key (for API-based models, or uses environment variable)
1013
+ hf_token: HuggingFace token for gated local models (or uses HF_TOKEN env var)
1014
 
1015
  Returns:
1016
  Loaded embedding model instance
 
1035
  raise ValueError(f"Unknown API model type: {model_id}")
1036
 
1037
  # Otherwise, load as a local sentence-transformer model
1038
+ return EmbeddingModel(model_id, device=device, hf_token=hf_token)
1039
 
1040
 
1041
  def validate_model_id(model_id: str) -> tuple[bool, str]: