0xZohar's picture
Fix: Add embedder key remapping for backward compatibility
daa3ea5 verified
raw
history blame
4.76 kB
import logging
from typing import Any, Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model, load_file
BOUNDING_BOX_MAX_SIZE = 1.925
def normalize_bbox(bounding_box_xyz: Tuple[float]):
max_l = max(bounding_box_xyz)
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Handles backward compatibility for embedder weight key naming:
- Old format: 'embedder.weight'
- New format: 'encoder.embedder.weight', 'occupancy_decoder.embedder.weight'
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
# Load checkpoint as dictionary for key remapping
checkpoint = load_file(ckpt_path)
# Backward compatibility: remap old embedder key format to new format
# This handles cases where checkpoint has 'embedder.weight' but model expects
# 'encoder.embedder.weight' and 'occupancy_decoder.embedder.weight'
if 'embedder.weight' in checkpoint:
if 'encoder.embedder.weight' not in checkpoint:
checkpoint['encoder.embedder.weight'] = checkpoint['embedder.weight']
if 'occupancy_decoder.embedder.weight' not in checkpoint:
checkpoint['occupancy_decoder.embedder.weight'] = checkpoint['embedder.weight']
# Load remapped checkpoint into model with strict=False for flexibility
model.load_state_dict(checkpoint, strict=False)
def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
"""
Save model weights in safetensors format.
Args:
model: PyTorch model to save
save_path: Output path (must end with .safetensors)
"""
assert save_path.endswith(".safetensors"), "Path must be .safetensors"
from safetensors.torch import save_file
state_dict = model.state_dict()
save_file(state_dict, save_path)
def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
load_model(model, ckpt_path, strict=False)
from peft import PeftModel
model = PeftModel.from_pretrained(model, adaption_path)
custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0'))
model.ldr_proj.load_state_dict(custom_weights["ldr_proj"])
model.ldr_head.load_state_dict(custom_weights["ldr_head"])
model.dte.load_state_dict(custom_weights["dte"])
model.rte.load_state_dict(custom_weights["rte"])
model.xte.load_state_dict(custom_weights["xte"])
model.yte.load_state_dict(custom_weights["yte"])
model.zte.load_state_dict(custom_weights["zte"])
return model
def select_device() -> Any:
"""
Selects the appropriate PyTorch device for tensor allocation.
Returns:
Any: The `torch.device` object.
"""
return torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
def decode_ldr(output_ids: torch.Tensor,):
"""
Returns:
Decode ldr file
"""
return ldr