Lev Israel
commited on
Commit
·
d1c390a
1
Parent(s):
1787d7f
Embedding Gemma
Browse files
models.py
CHANGED
|
@@ -56,6 +56,14 @@ CURATED_MODELS = {
|
|
| 56 |
"query_prefix": "",
|
| 57 |
"passage_prefix": "",
|
| 58 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
}
|
| 60 |
|
| 61 |
# API-based models
|
|
@@ -233,6 +241,7 @@ class EmbeddingModel(BaseEmbeddingModel):
|
|
| 233 |
model_id: str,
|
| 234 |
device: Optional[str] = None,
|
| 235 |
max_length: int = 512,
|
|
|
|
| 236 |
):
|
| 237 |
"""
|
| 238 |
Initialize the embedding model.
|
|
@@ -241,12 +250,12 @@ class EmbeddingModel(BaseEmbeddingModel):
|
|
| 241 |
model_id: Hugging Face model ID
|
| 242 |
device: Device to use ('cuda', 'cpu', or None for auto)
|
| 243 |
max_length: Maximum sequence length for tokenization
|
|
|
|
| 244 |
"""
|
| 245 |
from sentence_transformers import SentenceTransformer
|
| 246 |
import torch
|
| 247 |
|
| 248 |
self.model_id = model_id
|
| 249 |
-
self.max_length = max_length
|
| 250 |
|
| 251 |
# Auto-detect device
|
| 252 |
if device is None:
|
|
@@ -262,12 +271,18 @@ class EmbeddingModel(BaseEmbeddingModel):
|
|
| 262 |
"passage_prefix": "",
|
| 263 |
})
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
# Load the model with float16 on CUDA to save VRAM
|
| 266 |
# (12B model: float32 = 48GB, float16 = 24GB)
|
| 267 |
print(f"Loading model: {model_id} on {device}")
|
| 268 |
|
| 269 |
# Only trust remote code from known publishers (security measure)
|
| 270 |
-
trusted_publishers = ["nvidia/"]
|
| 271 |
trust_remote_code = any(model_id.startswith(pub) for pub in trusted_publishers)
|
| 272 |
|
| 273 |
if device == "cuda":
|
|
@@ -276,13 +291,19 @@ class EmbeddingModel(BaseEmbeddingModel):
|
|
| 276 |
device=device,
|
| 277 |
model_kwargs={"torch_dtype": torch.float16},
|
| 278 |
trust_remote_code=trust_remote_code,
|
|
|
|
| 279 |
)
|
| 280 |
else:
|
| 281 |
-
self.model = SentenceTransformer(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
| 283 |
# Set max sequence length if supported
|
| 284 |
if hasattr(self.model, "max_seq_length"):
|
| 285 |
-
self.model.max_seq_length = min(max_length, self.model.max_seq_length)
|
| 286 |
|
| 287 |
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
| 288 |
print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
|
|
@@ -980,6 +1001,7 @@ def load_model(
|
|
| 980 |
model_id: str,
|
| 981 |
device: Optional[str] = None,
|
| 982 |
api_key: Optional[str] = None,
|
|
|
|
| 983 |
) -> BaseEmbeddingModel:
|
| 984 |
"""
|
| 985 |
Load an embedding model by ID.
|
|
@@ -988,6 +1010,7 @@ def load_model(
|
|
| 988 |
model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large')
|
| 989 |
device: Device to use (for local models only)
|
| 990 |
api_key: API key (for API-based models, or uses environment variable)
|
|
|
|
| 991 |
|
| 992 |
Returns:
|
| 993 |
Loaded embedding model instance
|
|
@@ -1012,7 +1035,7 @@ def load_model(
|
|
| 1012 |
raise ValueError(f"Unknown API model type: {model_id}")
|
| 1013 |
|
| 1014 |
# Otherwise, load as a local sentence-transformer model
|
| 1015 |
-
return EmbeddingModel(model_id, device=device)
|
| 1016 |
|
| 1017 |
|
| 1018 |
def validate_model_id(model_id: str) -> tuple[bool, str]:
|
|
|
|
| 56 |
"query_prefix": "",
|
| 57 |
"passage_prefix": "",
|
| 58 |
},
|
| 59 |
+
"google/embeddinggemma-300m": {
|
| 60 |
+
"name": "EmbeddingGemma",
|
| 61 |
+
"description": "Google's 300M param embedding model, 100+ languages, 768d (requires HF token + license)",
|
| 62 |
+
"type": "local",
|
| 63 |
+
"query_prefix": "task: search result | query: ",
|
| 64 |
+
"passage_prefix": "title: none | text: ",
|
| 65 |
+
"max_length": 2048,
|
| 66 |
+
},
|
| 67 |
}
|
| 68 |
|
| 69 |
# API-based models
|
|
|
|
| 241 |
model_id: str,
|
| 242 |
device: Optional[str] = None,
|
| 243 |
max_length: int = 512,
|
| 244 |
+
hf_token: Optional[str] = None,
|
| 245 |
):
|
| 246 |
"""
|
| 247 |
Initialize the embedding model.
|
|
|
|
| 250 |
model_id: Hugging Face model ID
|
| 251 |
device: Device to use ('cuda', 'cpu', or None for auto)
|
| 252 |
max_length: Maximum sequence length for tokenization
|
| 253 |
+
hf_token: HuggingFace token for gated models (or uses HF_TOKEN env var)
|
| 254 |
"""
|
| 255 |
from sentence_transformers import SentenceTransformer
|
| 256 |
import torch
|
| 257 |
|
| 258 |
self.model_id = model_id
|
|
|
|
| 259 |
|
| 260 |
# Auto-detect device
|
| 261 |
if device is None:
|
|
|
|
| 271 |
"passage_prefix": "",
|
| 272 |
})
|
| 273 |
|
| 274 |
+
# Use config max_length if available, otherwise use parameter
|
| 275 |
+
self.max_length = self.config.get("max_length", max_length)
|
| 276 |
+
|
| 277 |
+
# Get HF token from parameter or environment (for gated models like EmbeddingGemma)
|
| 278 |
+
hf_token = hf_token or os.environ.get("HF_TOKEN")
|
| 279 |
+
|
| 280 |
# Load the model with float16 on CUDA to save VRAM
|
| 281 |
# (12B model: float32 = 48GB, float16 = 24GB)
|
| 282 |
print(f"Loading model: {model_id} on {device}")
|
| 283 |
|
| 284 |
# Only trust remote code from known publishers (security measure)
|
| 285 |
+
trusted_publishers = ["nvidia/", "google/"]
|
| 286 |
trust_remote_code = any(model_id.startswith(pub) for pub in trusted_publishers)
|
| 287 |
|
| 288 |
if device == "cuda":
|
|
|
|
| 291 |
device=device,
|
| 292 |
model_kwargs={"torch_dtype": torch.float16},
|
| 293 |
trust_remote_code=trust_remote_code,
|
| 294 |
+
token=hf_token,
|
| 295 |
)
|
| 296 |
else:
|
| 297 |
+
self.model = SentenceTransformer(
|
| 298 |
+
model_id,
|
| 299 |
+
device=device,
|
| 300 |
+
trust_remote_code=trust_remote_code,
|
| 301 |
+
token=hf_token,
|
| 302 |
+
)
|
| 303 |
|
| 304 |
# Set max sequence length if supported
|
| 305 |
if hasattr(self.model, "max_seq_length"):
|
| 306 |
+
self.model.max_seq_length = min(self.max_length, self.model.max_seq_length)
|
| 307 |
|
| 308 |
self.embedding_dim = self.model.get_sentence_embedding_dimension()
|
| 309 |
print(f"Model loaded. Embedding dimension: {self.embedding_dim}")
|
|
|
|
| 1001 |
model_id: str,
|
| 1002 |
device: Optional[str] = None,
|
| 1003 |
api_key: Optional[str] = None,
|
| 1004 |
+
hf_token: Optional[str] = None,
|
| 1005 |
) -> BaseEmbeddingModel:
|
| 1006 |
"""
|
| 1007 |
Load an embedding model by ID.
|
|
|
|
| 1010 |
model_id: Model ID (HuggingFace model ID or API model like 'openai/text-embedding-3-large')
|
| 1011 |
device: Device to use (for local models only)
|
| 1012 |
api_key: API key (for API-based models, or uses environment variable)
|
| 1013 |
+
hf_token: HuggingFace token for gated local models (or uses HF_TOKEN env var)
|
| 1014 |
|
| 1015 |
Returns:
|
| 1016 |
Loaded embedding model instance
|
|
|
|
| 1035 |
raise ValueError(f"Unknown API model type: {model_id}")
|
| 1036 |
|
| 1037 |
# Otherwise, load as a local sentence-transformer model
|
| 1038 |
+
return EmbeddingModel(model_id, device=device, hf_token=hf_token)
|
| 1039 |
|
| 1040 |
|
| 1041 |
def validate_model_id(model_id: str) -> tuple[bool, str]:
|