Maxlegrec commited on
Commit
faab955
·
verified ·
1 Parent(s): 09150f1

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +65 -19
model.py CHANGED
@@ -340,6 +340,9 @@ class BT4(nn.Module):
340
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
341
  """Load model from pretrained checkpoint (required by transformers)."""
342
  from transformers import AutoConfig
 
 
 
343
 
344
  # Load config
345
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
@@ -347,33 +350,76 @@ class BT4(nn.Module):
347
  # Create model with config
348
  model = cls(config=config)
349
 
350
- # Load weights if available
351
- try:
352
- from safetensors.torch import load_file
353
- import os
 
 
 
 
 
 
 
 
 
 
 
354
 
355
- # Try safetensors first
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
  safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
357
  if os.path.exists(safetensors_path):
358
  state_dict = load_file(safetensors_path)
359
- model.load_state_dict(state_dict)
 
 
 
 
360
  else:
361
  # Fall back to pytorch format
362
  pt_path = os.path.join(pretrained_model_name_or_path, "model.pt")
363
- if os.path.exists(pt_path):
364
- checkpoint = torch.load(pt_path, map_location="cpu")
365
- if isinstance(checkpoint, dict):
366
- if "state_dict" in checkpoint:
367
- model.load_state_dict(checkpoint["state_dict"])
368
- elif "model" in checkpoint:
369
- model.load_state_dict(checkpoint["model"])
370
- else:
371
- model.load_state_dict(checkpoint)
372
  else:
373
- model.load_state_dict(checkpoint)
374
- except Exception as e:
375
- # If weights don't exist or fail to load, return model without weights
376
- pass
 
 
 
 
377
 
378
  return model
379
 
 
340
  def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
341
  """Load model from pretrained checkpoint (required by transformers)."""
342
  from transformers import AutoConfig
343
+ from huggingface_hub import hf_hub_download
344
+ from safetensors.torch import load_file
345
+ import os
346
 
347
  # Load config
348
  config = AutoConfig.from_pretrained(pretrained_model_name_or_path, trust_remote_code=True)
 
350
  # Create model with config
351
  model = cls(config=config)
352
 
353
+ # Check if it's a HuggingFace Hub path or local path
354
+ is_hf_hub = "/" in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path)
355
+
356
+ if is_hf_hub:
357
+ # Download from HuggingFace Hub - try safetensors first
358
+ print("DEBUG: Downloading safetensors from HuggingFace...")
359
+ safetensors_path = hf_hub_download(
360
+ repo_id=pretrained_model_name_or_path,
361
+ filename="model.safetensors",
362
+ cache_dir=kwargs.get("cache_dir", None),
363
+ token=kwargs.get("token", None),
364
+ )
365
+ print(f"DEBUG: Loaded safetensors from {safetensors_path}")
366
+ state_dict = load_file(safetensors_path)
367
+ print(f"DEBUG: State dict has {len(state_dict)} keys")
368
 
369
+ # Debug: check embedding weight before loading
370
+ embedding_before = model.embedding.weight.sum().item()
371
+ expected_embedding = state_dict['embedding.weight'].sum().item()
372
+ print(f"DEBUG: Before loading - embedding: {embedding_before:.6f}, expected: {expected_embedding:.6f}")
373
+
374
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
375
+ print(f"DEBUG: load_state_dict returned - missing: {len(missing_keys)}, unexpected: {len(unexpected_keys)}")
376
+
377
+ # Debug: check embedding weight after loading
378
+ embedding_after = model.embedding.weight.sum().item()
379
+ print(f"DEBUG: After loading - embedding: {embedding_after:.6f}")
380
+
381
+ if missing_keys:
382
+ print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys")
383
+ if unexpected_keys:
384
+ print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys")
385
+
386
+ # Verify weights loaded
387
+ if abs(embedding_after - expected_embedding) > 1e-5:
388
+ print(f"ERROR: Weights did not load correctly!")
389
+ print(f" Before: {embedding_before:.6f}, Expected: {expected_embedding:.6f}, After: {embedding_after:.6f}")
390
+ # Force reload
391
+ print("DEBUG: Attempting to reload weights...")
392
+ model.load_state_dict(state_dict, strict=False)
393
+ embedding_after2 = model.embedding.weight.sum().item()
394
+ print(f" After reload: {embedding_after2:.6f}")
395
+ else:
396
+ # Local path - try safetensors first
397
  safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
398
  if os.path.exists(safetensors_path):
399
  state_dict = load_file(safetensors_path)
400
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
401
+ if missing_keys:
402
+ print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys")
403
+ if unexpected_keys:
404
+ print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys")
405
  else:
406
  # Fall back to pytorch format
407
  pt_path = os.path.join(pretrained_model_name_or_path, "model.pt")
408
+ checkpoint = torch.load(pt_path, map_location="cpu")
409
+ if isinstance(checkpoint, dict):
410
+ if "state_dict" in checkpoint:
411
+ state_dict = checkpoint["state_dict"]
412
+ elif "model" in checkpoint:
413
+ state_dict = checkpoint["model"]
 
 
 
414
  else:
415
+ state_dict = checkpoint
416
+ else:
417
+ state_dict = checkpoint
418
+ missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
419
+ if missing_keys:
420
+ print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys")
421
+ if unexpected_keys:
422
+ print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys")
423
 
424
  return model
425