Maxlegrec commited on
Commit
36c3f06
·
verified ·
1 Parent(s): faab955

Upload model.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. model.py +1 -26
model.py CHANGED
@@ -354,44 +354,19 @@ class BT4(nn.Module):
354
  is_hf_hub = "/" in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path)
355
 
356
  if is_hf_hub:
357
- # Download from HuggingFace Hub - try safetensors first
358
- print("DEBUG: Downloading safetensors from HuggingFace...")
359
  safetensors_path = hf_hub_download(
360
  repo_id=pretrained_model_name_or_path,
361
  filename="model.safetensors",
362
  cache_dir=kwargs.get("cache_dir", None),
363
  token=kwargs.get("token", None),
364
  )
365
- print(f"DEBUG: Loaded safetensors from {safetensors_path}")
366
  state_dict = load_file(safetensors_path)
367
- print(f"DEBUG: State dict has {len(state_dict)} keys")
368
-
369
- # Debug: check embedding weight before loading
370
- embedding_before = model.embedding.weight.sum().item()
371
- expected_embedding = state_dict['embedding.weight'].sum().item()
372
- print(f"DEBUG: Before loading - embedding: {embedding_before:.6f}, expected: {expected_embedding:.6f}")
373
-
374
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
375
- print(f"DEBUG: load_state_dict returned - missing: {len(missing_keys)}, unexpected: {len(unexpected_keys)}")
376
-
377
- # Debug: check embedding weight after loading
378
- embedding_after = model.embedding.weight.sum().item()
379
- print(f"DEBUG: After loading - embedding: {embedding_after:.6f}")
380
-
381
  if missing_keys:
382
  print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys")
383
  if unexpected_keys:
384
  print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys")
385
-
386
- # Verify weights loaded
387
- if abs(embedding_after - expected_embedding) > 1e-5:
388
- print(f"ERROR: Weights did not load correctly!")
389
- print(f" Before: {embedding_before:.6f}, Expected: {expected_embedding:.6f}, After: {embedding_after:.6f}")
390
- # Force reload
391
- print("DEBUG: Attempting to reload weights...")
392
- model.load_state_dict(state_dict, strict=False)
393
- embedding_after2 = model.embedding.weight.sum().item()
394
- print(f" After reload: {embedding_after2:.6f}")
395
  else:
396
  # Local path - try safetensors first
397
  safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")
 
354
  is_hf_hub = "/" in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path)
355
 
356
  if is_hf_hub:
357
+ # Download from HuggingFace Hub
 
358
  safetensors_path = hf_hub_download(
359
  repo_id=pretrained_model_name_or_path,
360
  filename="model.safetensors",
361
  cache_dir=kwargs.get("cache_dir", None),
362
  token=kwargs.get("token", None),
363
  )
 
364
  state_dict = load_file(safetensors_path)
 
 
 
 
 
 
 
365
  missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
 
 
 
 
 
 
366
  if missing_keys:
367
  print(f"Warning: Missing keys when loading weights: {len(missing_keys)} keys")
368
  if unexpected_keys:
369
  print(f"Warning: Unexpected keys when loading weights: {len(unexpected_keys)} keys")
 
 
 
 
 
 
 
 
 
 
370
  else:
371
  # Local path - try safetensors first
372
  safetensors_path = os.path.join(pretrained_model_name_or_path, "model.safetensors")