0xZohar commited on
Commit
daa3ea5
·
verified ·
1 Parent(s): f1260ee

Fix: Add embedder key remapping for backward compatibility

Browse files

- Modified load_model_weights to handle old checkpoint format
- Old format: embedder.weight
- New format: encoder.embedder.weight + occupancy_decoder.embedder.weight
- Checkpoint is loaded as dict first, keys are remapped, then loaded to model
- This fixes the AssertionError on HuggingFace Space deployment
- Maintains backward compatibility with both checkpoint formats

Files changed (1) hide show
  1. code/cube3d/inference/utils.py +20 -4
code/cube3d/inference/utils.py CHANGED
@@ -3,7 +3,7 @@ from typing import Any, Optional, Tuple
3
 
4
  import torch
5
  from omegaconf import DictConfig, OmegaConf
6
- from safetensors.torch import load_model
7
 
8
  BOUNDING_BOX_MAX_SIZE = 1.925
9
 
@@ -49,6 +49,10 @@ def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
49
  Load a safetensors checkpoint into a PyTorch model.
50
  The model is updated in place.
51
 
 
 
 
 
52
  Args:
53
  model: PyTorch model to load weights into
54
  ckpt_path: Path to the safetensors checkpoint file
@@ -59,9 +63,21 @@ def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
59
  assert ckpt_path.endswith(
60
  ".safetensors"
61
  ), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
62
-
63
- #load_model(model, ckpt_path)
64
- load_model(model, ckpt_path, strict=False)
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
 
67
  def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
 
3
 
4
  import torch
5
  from omegaconf import DictConfig, OmegaConf
6
+ from safetensors.torch import load_model, load_file
7
 
8
  BOUNDING_BOX_MAX_SIZE = 1.925
9
 
 
49
  Load a safetensors checkpoint into a PyTorch model.
50
  The model is updated in place.
51
 
52
+ Handles backward compatibility for embedder weight key naming:
53
+ - Old format: 'embedder.weight'
54
+ - New format: 'encoder.embedder.weight', 'occupancy_decoder.embedder.weight'
55
+
56
  Args:
57
  model: PyTorch model to load weights into
58
  ckpt_path: Path to the safetensors checkpoint file
 
63
  assert ckpt_path.endswith(
64
  ".safetensors"
65
  ), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
66
+
67
+ # Load checkpoint as dictionary for key remapping
68
+ checkpoint = load_file(ckpt_path)
69
+
70
+ # Backward compatibility: remap old embedder key format to new format
71
+ # This handles cases where checkpoint has 'embedder.weight' but model expects
72
+ # 'encoder.embedder.weight' and 'occupancy_decoder.embedder.weight'
73
+ if 'embedder.weight' in checkpoint:
74
+ if 'encoder.embedder.weight' not in checkpoint:
75
+ checkpoint['encoder.embedder.weight'] = checkpoint['embedder.weight']
76
+ if 'occupancy_decoder.embedder.weight' not in checkpoint:
77
+ checkpoint['occupancy_decoder.embedder.weight'] = checkpoint['embedder.weight']
78
+
79
+ # Load remapped checkpoint into model with strict=False for flexibility
80
+ model.load_state_dict(checkpoint, strict=False)
81
 
82
 
83
  def save_model_weights(model: torch.nn.Module, save_path: str) -> None: