0xZohar commited on
Commit
d0d37cd
·
verified ·
1 Parent(s): 92f98bd

Add code/cube3d/training/utils.py

Browse files
Files changed (1) hide show
  1. code/cube3d/training/utils.py +341 -0
code/cube3d/training/utils.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+ from typing import Any, Optional, Tuple
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from omegaconf import DictConfig, OmegaConf
8
+ from safetensors.torch import load_model, save_model
9
+ import matplotlib.pyplot as plt
10
+ import numpy as np
11
+ import seaborn as sns
12
+ from matplotlib.ticker import MaxNLocator
13
+
14
+ BOUNDING_BOX_MAX_SIZE = 1.925
15
+
16
+
17
+ def normalize_bbox(bounding_box_xyz: Tuple[float]):
18
+ #import ipdb; ipdb.set_trace()
19
+ max_l = max(bounding_box_xyz)
20
+ return [BOUNDING_BOX_MAX_SIZE * elem / max_l for elem in bounding_box_xyz]
21
+
22
+ def normalize_bboxs(bounding_box_xyz, max_xyz):
23
+ #max_l = max(bounding_box_xyz)
24
+ normalized = BOUNDING_BOX_MAX_SIZE * bounding_box_xyz / torch.tensor(max_xyz, device=bounding_box_xyz.device)
25
+ return normalized
26
+
27
+ def load_config(cfg_path: str) -> Any:
28
+ """
29
+ Load and resolve a configuration file.
30
+ Args:
31
+ cfg_path (str): The path to the configuration file.
32
+ Returns:
33
+ Any: The loaded and resolved configuration object.
34
+ Raises:
35
+ AssertionError: If the loaded configuration is not an instance of DictConfig.
36
+ """
37
+
38
+ cfg = OmegaConf.load(cfg_path)
39
+ OmegaConf.resolve(cfg)
40
+ assert isinstance(cfg, DictConfig)
41
+ return cfg
42
+
43
+
44
+ def parse_structured(cfg_type: Any, cfg: DictConfig) -> Any:
45
+ """
46
+ Parses a configuration dictionary into a structured configuration object.
47
+ Args:
48
+ cfg_type (Any): The type of the structured configuration object.
49
+ cfg (DictConfig): The configuration dictionary to be parsed.
50
+ Returns:
51
+ Any: The structured configuration object created from the dictionary.
52
+ """
53
+
54
+ scfg = OmegaConf.structured(cfg_type(**cfg))
55
+ return scfg
56
+
57
+
58
+ def load_model_weights(model: torch.nn.Module, ckpt_path: str) -> None:
59
+ """
60
+ Load a safetensors checkpoint into a PyTorch model.
61
+ The model is updated in place.
62
+
63
+ Args:
64
+ model: PyTorch model to load weights into
65
+ ckpt_path: Path to the safetensors checkpoint file
66
+
67
+ Returns:
68
+ None
69
+ """
70
+ assert ckpt_path.endswith(
71
+ ".safetensors"
72
+ ), f"Checkpoint path '{ckpt_path}' is not a safetensors file"
73
+
74
+ load_model(model, ckpt_path)
75
+
76
+
77
+ def save_model_weights(model: torch.nn.Module, save_path: str) -> None:
78
+ """
79
+ Save a PyTorch model as safetensors format.
80
+
81
+ Args:
82
+ model: PyTorch model to save
83
+ save_path: Output path (must end with .safetensors)
84
+ """
85
+ assert save_path.endswith(".safetensors"), "Path must end with .safetensors"
86
+
87
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
88
+
89
+ save_model(model, save_path)
90
+
91
+ assert os.path.exists(save_path), f"Failed to save to {save_path}"
92
+
93
+
94
+ def select_device() -> Any:
95
+ """
96
+ Selects the appropriate PyTorch device for tensor allocation.
97
+
98
+ Returns:
99
+ Any: The `torch.device` object.
100
+ """
101
+ return torch.device(
102
+ "cuda"
103
+ if torch.cuda.is_available()
104
+ else "mps"
105
+ if torch.backends.mps.is_available()
106
+ else "cpu"
107
+ )
108
+
109
+ def mask_cross_entropy(p_st, p_ed, p_max, logits, target, shift):
110
+ p_range = torch.arange(p_st, p_ed, device=logits.device)
111
+ p_range_expanded = p_range.unsqueeze(0).repeat(p_max.shape[0], 1)
112
+ valid_p_mask = p_range_expanded <= p_max.unsqueeze(1)+p_st
113
+
114
+ valid_p_mask = valid_p_mask.unsqueeze(1).expand(-1, logits.shape[1], -1)
115
+ logits_masked = logits.clone()
116
+ logits_masked[:,:,p_st:p_ed][~valid_p_mask] = float('-inf')
117
+
118
+ p_loss = F.cross_entropy(
119
+ logits_masked[:, :-1, p_st:p_ed].permute(0, 2, 1),
120
+ target[:, shift:, p_st:p_ed].argmax(-1),
121
+ )
122
+
123
+ return p_loss
124
+
125
+ def positional_encoding(x, num_freqs):
126
+
127
+ freqs = 2.0 ** torch.arange(num_freqs, device=x.device) # [num_freqs]
128
+ angles = x.unsqueeze(-1) * freqs # [..., num_freqs]
129
+ sin_cos = torch.cat([angles.sin(), angles.cos()], dim=-1) # [..., 2*num_freqs]
130
+ return sin_cos.flatten(-2)
131
+
132
+ def visualize_token_probabilities(
133
+ probs,
134
+ cut_idx,
135
+ sample_idx=0,
136
+ tokens_per_page=10, # 每页显示的token数量
137
+ figsize=(12, 20), # 单页图表大小
138
+ save_dir=None # 保存图片的目录(None则直接显示)
139
+ ):
140
+ """
141
+ 分页展示所有有效token的概率分布(每页10个,一行一个token)
142
+
143
+ 参数:
144
+ - probs: 概率张量,形状为 (batch_size, seq_len, num_classes)
145
+ - cut_idx: 有效区域的截止索引
146
+ - sample_idx: 要可视化的batch样本索引
147
+ - tokens_per_page: 每页显示的token数量
148
+ - figsize: 单页图表大小
149
+ - save_dir: 保存图片的目录(若为None则直接显示)
150
+ """
151
+ # 转换为numpy数组
152
+ if isinstance(probs, torch.Tensor):
153
+ probs = probs.cpu().detach().numpy()
154
+
155
+ # 获取单个样本的概率分布
156
+ sample_probs = probs[sample_idx] # (seq_len, num_classes)
157
+ seq_len, num_classes = sample_probs.shape
158
+
159
+ # 处理cut_idx,确定有效区域并提取有效token
160
+ if isinstance(cut_idx, torch.Tensor):
161
+ cut_idx = cut_idx.cpu().detach().numpy()
162
+ valid_length = min(int(cut_idx[sample_idx] if not np.isscalar(cut_idx) else cut_idx), seq_len)
163
+ valid_probs = sample_probs[:valid_length, :] # 只取有效区域内的token
164
+ num_valid_tokens = valid_probs.shape[0]
165
+
166
+ if num_valid_tokens == 0:
167
+ print(f"警告:没有有效token可显示(有效区域长度:{valid_length})")
168
+ return None
169
+
170
+ # 创建保存目录(如果需要)
171
+ if save_dir is not None and not os.path.exists(save_dir):
172
+ os.makedirs(save_dir)
173
+
174
+ # 计算总页数
175
+ total_pages = (num_valid_tokens + tokens_per_page - 1) // tokens_per_page
176
+ print(f"共{num_valid_tokens}个有效token,分为{total_pages}页展示")
177
+
178
+ # 分页生成图表
179
+ figures = []
180
+ for page in range(total_pages):
181
+ # 计算当前页的token范围
182
+ start = page * tokens_per_page
183
+ end = min(start + tokens_per_page, num_valid_tokens)
184
+ page_tokens = end - start
185
+
186
+ # 创建当前页的画布
187
+ fig, axes = plt.subplots(page_tokens, 1, figsize=(figsize[0], 2*page_tokens))
188
+ fig.suptitle(
189
+ f'Token Probability Distributions (Sample {sample_idx}) - Page {page+1}/{total_pages}',
190
+ fontsize=16,
191
+ y=1.02
192
+ )
193
+
194
+ # 为当前页的每个token绘制分布
195
+ for i in range(page_tokens):
196
+ token_idx = start + i
197
+ token_probs = valid_probs[i] # 当前页内的相对索引
198
+ ax = axes[i] if page_tokens > 1 else axes # 处理单token情况
199
+
200
+ # 绘制条形图
201
+ class_indices = np.arange(num_classes)
202
+ bars = ax.bar(class_indices, token_probs, width=0.8, color='skyblue', edgecolor='black')
203
+
204
+ # 突出显示最高概率的类别
205
+ max_prob_idx = np.argmax(token_probs)
206
+ max_prob_value = token_probs[max_prob_idx]
207
+ bars[max_prob_idx].set_color('orange')
208
+
209
+ # 标注概率>5%的类别
210
+ for j, (bar, prob) in enumerate(zip(bars, token_probs)):
211
+ height = bar.get_height()
212
+ if prob > 0.05:
213
+ ax.text(bar.get_x() + bar.get_width()/2., height + 0.01,
214
+ f'{prob:.2f}', ha='center', va='bottom', fontsize=9)
215
+
216
+ # 设置子图标题和坐标轴
217
+ ax.set_title(
218
+ f'Token {token_idx} (Max: Class {max_prob_idx} = {max_prob_value:.2f})',
219
+ fontsize=11
220
+ )
221
+ ax.set_xlabel('Class Index')
222
+ ax.set_ylabel('Probability')
223
+ ax.set_ylim(0, 1.1)
224
+ ax.xaxis.set_major_locator(MaxNLocator(integer=True))
225
+ ax.grid(True, alpha=0.3, axis='y')
226
+
227
+ # 除最后一个子图外隐藏x轴标签
228
+ if i != page_tokens - 1:
229
+ ax.set_xlabel('')
230
+
231
+ plt.tight_layout()
232
+ figures.append(fig)
233
+
234
+ # 保存或显示图表
235
+ if save_dir is not None:
236
+ save_path = os.path.join(save_dir, f'token_probs_page_{page+1}.png')
237
+ fig.savefig(save_path, dpi=300, bbox_inches='tight')
238
+ print(f"已保存第{page+1}页至: {save_path}")
239
+ else:
240
+ plt.show()
241
+ plt.close(fig) # 关闭当前页图表,释放内存
242
+
243
+ return figures
244
+
245
+ def visualize_max_prob_distribution(
246
+ probs,
247
+ cut_idx=None, # 不再需要,因为已提前过滤
248
+ sample_idx=0,
249
+ bins=20,
250
+ figsize=(12, 6)
251
+ ):
252
+ # 转换为numpy数组
253
+ if isinstance(probs, torch.Tensor):
254
+ probs = probs.cpu().detach().numpy()
255
+
256
+ # 获取单个样本的概率分布并计算最大概率
257
+ sample_probs = probs[sample_idx]
258
+ max_probs_per_token = np.max(sample_probs, axis=1) # 所有token都是已过滤的有效token
259
+
260
+ # 创建画布
261
+ fig, ax = plt.subplots(figsize=figsize)
262
+
263
+ # 绘制直方图
264
+ n, bins, patches = ax.hist(
265
+ max_probs_per_token,
266
+ bins=bins,
267
+ range=(0, 1),
268
+ edgecolor='black',
269
+ alpha=0.7,
270
+ color='skyblue'
271
+ )
272
+
273
+ # 标注数量
274
+ for count, patch in zip(n, patches):
275
+ height = patch.get_height()
276
+ if height > 0:
277
+ ax.text(
278
+ patch.get_x() + patch.get_width()/2.,
279
+ height + 0.5,
280
+ f'{int(count)}',
281
+ ha='center',
282
+ va='bottom',
283
+ fontsize=9
284
+ )
285
+
286
+ # 统计指标
287
+ mean_prob = np.mean(max_probs_per_token)
288
+ median_prob = np.median(max_probs_per_token)
289
+ max_count = int(np.max(n)) if len(n) > 0 else 0
290
+
291
+ # 设置标题和坐标轴
292
+ ax.set_title(
293
+ f'Distribution of Maximum Probabilities (All Valid Tokens from 5 Iterations)\n'
294
+ f'Total tokens: {len(max_probs_per_token)} | Mean: {mean_prob:.2f} | Median: {median_prob:.2f}',
295
+ fontsize=14
296
+ )
297
+ ax.set_xlabel('Maximum Probability Value (0-1)')
298
+ ax.set_ylabel('Number of Tokens (Frequency)')
299
+ ax.set_xlim(0, 1)
300
+ ax.set_ylim(0, max_count + 2)
301
+ ax.xaxis.set_major_locator(MaxNLocator(nbins=11))
302
+ ax.yaxis.set_major_locator(MaxNLocator(integer=True))
303
+ ax.grid(True, alpha=0.3, axis='y')
304
+
305
+ plt.tight_layout()
306
+ return fig
307
+
308
+
309
+ def top_k_prob_mask(probs, cut_idx, top_percent=0.15, visualize=False):
310
+ max_probs = probs.permute(0, 2, 1).max(dim=1).values # (batch_size, seq_len)
311
+ batch_size, seq_len = max_probs.shape
312
+
313
+ # 1. 生成基础mask:cut_idx前面为True,后面为False
314
+ if isinstance(cut_idx, (int, float)):
315
+ cut_idx = torch.tensor([cut_idx] * batch_size, device=max_probs.device)
316
+ base_mask = (torch.arange(seq_len, device=max_probs.device)[None, :] < cut_idx[:, None])
317
+ valid_count = base_mask.sum().item()
318
+
319
+ # 处理无有效位置的情况
320
+ if valid_count == 0:
321
+ empty_mask = torch.zeros_like(max_probs, dtype=torch.bool)
322
+ return empty_mask, empty_mask
323
+
324
+ # 2. 计算原始目标mask(cut内前N%高概率True)
325
+ valid_probs = max_probs[base_mask]
326
+ total_valid = valid_probs.numel()
327
+ k = max(min(int(total_valid * top_percent), total_valid), 1)
328
+ _, top_valid_indices = torch.topk(valid_probs, k)
329
+
330
+ # 原始mask:cut内top k为True,其余全False
331
+ valid_area_original = torch.zeros(total_valid, dtype=torch.bool, device=max_probs.device)
332
+ valid_area_original[top_valid_indices] = True
333
+ original_mask = torch.zeros_like(max_probs, dtype=torch.bool)
334
+ original_mask[base_mask] = valid_area_original
335
+
336
+ # 3. 计算反向mask(cut内非top k为True,cut外全False)
337
+ valid_area_reverse = ~valid_area_original # 与原始有效区域完全相反
338
+ reverse_mask = torch.zeros_like(max_probs, dtype=torch.bool)
339
+ reverse_mask[base_mask] = valid_area_reverse # cut外保持False
340
+
341
+ return original_mask, reverse_mask # 返回两个mask