Exquisique commited on
Commit
ad1fef5
·
verified ·
1 Parent(s): cefd7e8

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +50 -1
model.py CHANGED
@@ -310,9 +310,9 @@ class GemmaForCausalLM(PreTrainedModel):
310
  AutoModelForCausalLM and the transformers ecosystem.
311
  """
312
  config_class = GemmaConfig
313
- base_model_prefix = "model"
314
  supports_gradient_checkpointing = True
315
  _no_split_modules = ["GemmaDecoderLayer"]
 
316
 
317
  def __init__(self, config: GemmaConfig):
318
  super().__init__(config)
@@ -326,6 +326,55 @@ class GemmaForCausalLM(PreTrainedModel):
326
 
327
  self.post_init()
328
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
329
  def get_input_embeddings(self):
330
  return self.model.embed_tokens
331
 
 
310
  AutoModelForCausalLM and the transformers ecosystem.
311
  """
312
  config_class = GemmaConfig
 
313
  supports_gradient_checkpointing = True
314
  _no_split_modules = ["GemmaDecoderLayer"]
315
+ _supports_param_buffer_assignment = False # Fix for accelerate weight loading
316
 
317
  def __init__(self, config: GemmaConfig):
318
  super().__init__(config)
 
326
 
327
  self.post_init()
328
 
329
+ @classmethod
330
+ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
331
+ """
332
+ Custom from_pretrained that properly loads weights for this custom model.
333
+ This overrides the default behavior to ensure weights are loaded correctly.
334
+ """
335
+ import os
336
+ from huggingface_hub import hf_hub_download
337
+
338
+ # Get config
339
+ trust_remote_code = kwargs.pop("trust_remote_code", True)
340
+ torch_dtype = kwargs.pop("torch_dtype", None)
341
+ device_map = kwargs.pop("device_map", None)
342
+
343
+ # Load config
344
+ config = cls.config_class.from_pretrained(
345
+ pretrained_model_name_or_path,
346
+ trust_remote_code=trust_remote_code,
347
+ **kwargs
348
+ )
349
+
350
+ # Create model
351
+ model = cls(config)
352
+
353
+ # Find weight file
354
+ if os.path.isdir(pretrained_model_name_or_path):
355
+ weight_file = os.path.join(pretrained_model_name_or_path, "pytorch_model.bin")
356
+ else:
357
+ # Download from hub
358
+ weight_file = hf_hub_download(
359
+ repo_id=pretrained_model_name_or_path,
360
+ filename="pytorch_model.bin"
361
+ )
362
+
363
+ # Load weights
364
+ state_dict = torch.load(weight_file, map_location="cpu")
365
+ model.load_state_dict(state_dict, strict=False)
366
+
367
+ # Handle dtype and device
368
+ if torch_dtype is not None:
369
+ model = model.to(torch_dtype)
370
+ if device_map == "auto":
371
+ if torch.cuda.is_available():
372
+ model = model.to("cuda")
373
+ elif device_map is not None:
374
+ model = model.to(device_map)
375
+
376
+ return model
377
+
378
  def get_input_embeddings(self):
379
  return self.model.embed_tokens
380