Spaces:
Paused
Paused
File size: 3,866 Bytes
25cd1c4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | import logging
from typing import Any, Optional, Tuple
import torch
from omegaconf import DictConfig, OmegaConf
from safetensors.torch import load_model
BOUNDING_BOX_MAX_SIZE = 1.925
def normalize_bbox(bounding_box_xyz: Tuple[float]):
max_l = max(bounding_box_xyz)
return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
def load_config(cfg_path: str) -> Any:
"""
Load and resolve a configuration file.
Args:
cfg_path (str): The path to the configuration file.
Returns:
Any: The loaded and resolved configuration object.
Raises:
AssertionError: If the loaded configuration is not an instance of DictConfig.
"""
cfg = OmegaConf.load(cfg_path)
OmegaConf.resolve(cfg)
assert isinstance(cfg, DictConfig)
return cfg
def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
"""
Parses a configuration dictionary into a structured configuration object.
Args:
cfg_type (Any): The type of the structured configuration object.
cfg (DictConfig): The configuration dictionary to be parsed.
Returns:
Any: The structured configuration object created from the dictionary.
"""
scfg = OmegaConf.structured(cfg_type(**cfg))
return scfg
def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
#load_model(model, ckpt_path)
load_model(model, ckpt_path, strict=False)
def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
"""
Save model weights in safetensors format.
Args:
model: PyTorch model to save
save_path: Output path (must end with .safetensors)
"""
assert save_path.endswith(".safetensors"), "Path must be .safetensors"
from safetensors.torch import save_file
state_dict = model.state_dict()
save_file(state_dict, save_path)
def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module:
"""
Load a safetensors checkpoint into a PyTorch model.
The model is updated in place.
Args:
model: PyTorch model to load weights into
ckpt_path: Path to the safetensors checkpoint file
Returns:
None
"""
assert ckpt_path.endswith(
".safetensors"
), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
load_model(model, ckpt_path, strict=False)
from peft import PeftModel
model = PeftModel.from_pretrained(model, adaption_path)
custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0'))
model.ldr_proj.load_state_dict(custom_weights["ldr_proj"])
model.ldr_head.load_state_dict(custom_weights["ldr_head"])
model.dte.load_state_dict(custom_weights["dte"])
model.rte.load_state_dict(custom_weights["rte"])
model.xte.load_state_dict(custom_weights["xte"])
model.yte.load_state_dict(custom_weights["yte"])
model.zte.load_state_dict(custom_weights["zte"])
return model
def select_device() -> Any:
"""
Selects the appropriate PyTorch device for tensor allocation.
Returns:
Any: The `torch.device` object.
"""
return torch.device(
"cuda"
if torch.cuda.is_available()
else "mps"
if torch.backends.mps.is_available()
else "cpu"
)
def decode_ldr(output_ids: torch.Tensor,):
"""
Returns:
Decode ldr file
"""
return ldr
|