0xZohar commited on
Commit
25cd1c4
·
verified ·
1 Parent(s): cfbc796

Add missing code/cube3d/inference/utils.py

Browse files
Files changed (1) hide show
  1. code/cube3d/inference/utils.py +139 -0
code/cube3d/inference/utils.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from typing import Any, Optional, Tuple
3
+
4
+ import torch
5
+ from omegaconf import DictConfig, OmegaConf
6
+ from safetensors.torch import load_model
7
+
8
+ BOUNDING_BOX_MAX_SIZE = 1.925
9
+
10
+
11
+ def normalize_bbox(bounding_box_xyz: Tuple[float]):
12
+ max_l = max(bounding_box_xyz)
13
+ return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
14
+
15
+
16
+ def load_config(cfg_path: str) -> Any:
17
+ """
18
+ Load and resolve a configuration file.
19
+ Args:
20
+ cfg_path (str): The path to the configuration file.
21
+ Returns:
22
+ Any: The loaded and resolved configuration object.
23
+ Raises:
24
+ AssertionError: If the loaded configuration is not an instance of DictConfig.
25
+ """
26
+
27
+ cfg = OmegaConf.load(cfg_path)
28
+ OmegaConf.resolve(cfg)
29
+ assert isinstance(cfg, DictConfig)
30
+ return cfg
31
+
32
+
33
+ def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
34
+ """
35
+ Parses a configuration dictionary into a structured configuration object.
36
+ Args:
37
+ cfg_type (Any): The type of the structured configuration object.
38
+ cfg (DictConfig): The configuration dictionary to be parsed.
39
+ Returns:
40
+ Any: The structured configuration object created from the dictionary.
41
+ """
42
+
43
+ scfg = OmegaConf.structured(cfg_type(**cfg))
44
+ return scfg
45
+
46
+
47
+ def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
48
+ """
49
+ Load a safetensors checkpoint into a PyTorch model.
50
+ The model is updated in place.
51
+
52
+ Args:
53
+ model: PyTorch model to load weights into
54
+ ckpt_path: Path to the safetensors checkpoint file
55
+
56
+ Returns:
57
+ None
58
+ """
59
+ assert ckpt_path.endswith(
60
+ ".safetensors"
61
+ ), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
62
+
63
+ #load_model(model, ckpt_path)
64
+ load_model(model, ckpt_path, strict=False)
65
+
66
+
67
+ def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
68
+ """
69
+ Save model weights in safetensors format.
70
+
71
+ Args:
72
+ model: PyTorch model to save
73
+ save_path: Output path (must end with .safetensors)
74
+ """
75
+ assert save_path.endswith(".safetensors"), "Path must be .safetensors"
76
+
77
+ from safetensors.torch import save_file
78
+
79
+ state_dict = model.state_dict()
80
+
81
+ save_file(state_dict, save_path)
82
+
83
+
84
+ def load_model_weights_adaption(model: torch.nn.Module, ckpt_path: str, adaption_path: str) -> torch.nn.Module:
85
+ """
86
+ Load a safetensors checkpoint into a PyTorch model.
87
+ The model is updated in place.
88
+
89
+ Args:
90
+ model: PyTorch model to load weights into
91
+ ckpt_path: Path to the safetensors checkpoint file
92
+
93
+ Returns:
94
+ None
95
+ """
96
+ assert ckpt_path.endswith(
97
+ ".safetensors"
98
+ ), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
99
+
100
+ load_model(model, ckpt_path, strict=False)
101
+ from peft import PeftModel
102
+ model = PeftModel.from_pretrained(model, adaption_path)
103
+ custom_weights = torch.load(f"{adaption_path}/unfrozen_weights.pth", map_location=torch.device('cuda:0'))
104
+
105
+
106
+ model.ldr_proj.load_state_dict(custom_weights["ldr_proj"])
107
+ model.ldr_head.load_state_dict(custom_weights["ldr_head"])
108
+ model.dte.load_state_dict(custom_weights["dte"])
109
+ model.rte.load_state_dict(custom_weights["rte"])
110
+ model.xte.load_state_dict(custom_weights["xte"])
111
+ model.yte.load_state_dict(custom_weights["yte"])
112
+ model.zte.load_state_dict(custom_weights["zte"])
113
+
114
+ return model
115
+
116
+
117
+ def select_device() -> Any:
118
+ """
119
+ Selects the appropriate PyTorch device for tensor allocation.
120
+
121
+ Returns:
122
+ Any: The `torch.device` object.
123
+ """
124
+ return torch.device(
125
+ "cuda"
126
+ if torch.cuda.is_available()
127
+ else "mps"
128
+ if torch.backends.mps.is_available()
129
+ else "cpu"
130
+ )
131
+
132
+ def decode_ldr(output_ids: torch.Tensor,):
133
+ """
134
+
135
+ Returns:
136
+ Decode ldr file
137
+ """
138
+
139
+ return ldr