LancetRobotics commited on
Commit
fcec7f0
·
verified ·
1 Parent(s): b2ac48c

Delete visualize_attention.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. visualize_attention.py +0 -178
visualize_attention.py DELETED
@@ -1,178 +0,0 @@
1
- import os
2
- import torch
3
- import torch.nn.functional as F
4
- import decord
5
- import numpy as np
6
- import matplotlib.pyplot as plt
7
- import cv2
8
- from transformers import AutoModel, AutoConfig
9
- import torchvision.transforms.v2 as T
10
- import warnings
11
- warnings.filterwarnings("ignore")
12
-
13
- os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
14
- os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
15
-
16
- # ================= 配置 =================
17
- import glob
18
- video_files = glob.glob("/root/hri30/train/*/*.avi")
19
- if len(video_files) > 0:
20
- idx = min(50, len(video_files)-1)
21
- VIDEO_PATH = video_files[idx]
22
- else:
23
- VIDEO_PATH = ""
24
-
25
- CKPT_PATH = "/root/autodl-tmp/checkpoints_final/final_sota_best.pth"
26
- MODEL_ID = "OpenGVLab/VideoMAEv2-giant"
27
- CACHE_DIR = "/root/autodl-tmp/hf_cache"
28
-
29
- NUM_FRAMES = 16
30
- IMG_SIZE = 224
31
-
32
- # ================= 模型定义 (智能 Hook) =================
33
- class DualHeadMAE(torch.nn.Module):
34
- def __init__(self):
35
- super().__init__()
36
- v_config = AutoConfig.from_pretrained(MODEL_ID, trust_remote_code=True, cache_dir=CACHE_DIR)
37
- v_config.use_cache = False
38
- self.visual = AutoModel.from_pretrained(MODEL_ID, trust_remote_code=True, config=v_config, cache_dir=CACHE_DIR, torch_dtype=torch.float32)
39
- self.attention_map = None
40
- self._register_hooks()
41
-
42
- def _register_hooks(self):
43
- def hook_fn(module, input, output):
44
- self.attention_map = output.detach()
45
-
46
- target_module = None
47
- # 优先找 attn_drop
48
- for name, module in self.visual.named_modules():
49
- if "attn_drop" in name:
50
- target_module = module
51
-
52
- if target_module is not None:
53
- target_module.register_forward_hook(hook_fn)
54
- print("✅ Hooked Attention Layer")
55
-
56
- def forward(self, x):
57
- _ = self.visual(x)
58
- return self.attention_map
59
-
60
- # ================= 图像处理 =================
61
- def get_attention_map(model, video_tensor):
62
- model.eval()
63
- with torch.no_grad():
64
- _ = model(video_tensor)
65
-
66
- att_mat = model.attention_map
67
- if att_mat is None: return None
68
-
69
- # [B, Heads, N, N] -> Mean Heads -> [B, N, N]
70
- if att_mat.dim() == 4:
71
- att_mat = torch.mean(att_mat, dim=1)
72
-
73
- # 获取 [CLS] 的 attention
74
- # 假设第0个是CLS
75
- # 如果 N=2048 (无CLS?) 或者 N=2049 (有CLS)
76
- seq_len = att_mat.shape[-1]
77
-
78
- # 尝试取第0行
79
- cls_attn = att_mat[:, 0, :] # [B, N]
80
-
81
- # 如果包含自己,去掉自己
82
- # 这里我们做一个简单的处理:直接用全部
83
- # 归一化
84
- cls_attn = (cls_attn - cls_attn.min()) / (cls_attn.max() - cls_attn.min())
85
- return cls_attn
86
-
87
- def visualize(video_path, save_path="attention_vis.png"):
88
- if not os.path.exists(video_path): return
89
- print(f"🎥 Video: {video_path}")
90
-
91
- # 读取
92
- vr = decord.VideoReader(video_path)
93
- idx = torch.linspace(0, len(vr)-1, NUM_FRAMES).long()
94
- batch = vr.get_batch(idx).asnumpy()
95
-
96
- # 预处理
97
- buffer = torch.from_numpy(batch).permute(0, 3, 1, 2).float()
98
- transform = T.Compose([T.Resize((IMG_SIZE, IMG_SIZE), antialias=True)])
99
- buffer = transform(buffer)
100
- mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
101
- std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
102
- norm_buffer = (buffer / 255.0 - mean) / std
103
- input_tensor = norm_buffer.permute(1, 0, 2, 3).unsqueeze(0).cuda()
104
-
105
- # 推理
106
- model = DualHeadMAE().cuda()
107
- try:
108
- sd = torch.load(CKPT_PATH)
109
- # 只加载 visual
110
- new_sd = {}
111
- for k, v in sd.items():
112
- if "visual" in k: new_sd[k.replace("visual.", "visual.")] = v
113
- elif "backbone" in k: new_sd[k.replace("backbone.", "visual.")] = v
114
- model.load_state_dict(new_sd, strict=False)
115
- print("✅ Weights Loaded")
116
- except:
117
- print("⚠️ Random Weights")
118
-
119
- model.eval()
120
- attn_score = get_attention_map(model, input_tensor) # [1, N]
121
-
122
- # 🔥🔥🔥 暴力 Reshape 修复 🔥🔥🔥
123
- num_tokens = attn_score.shape[1]
124
- print(f"Tokens: {num_tokens}")
125
-
126
- # 目标:变成 [T, H, W]
127
- # 我们知道 T=8 (16/2)
128
- # 剩下的 spatial_tokens = num_tokens / 8
129
-
130
- # 假设有 CLS,先去掉一个看看能不能整除
131
- if num_tokens % 8 != 0:
132
- attn_score = attn_score[:, 1:] # 丢掉第一个
133
- num_tokens -= 1
134
-
135
- spatial = num_tokens // 8
136
- h = int(np.sqrt(spatial))
137
- w = h
138
-
139
- print(f"Reshaping to [8, {h}, {w}]")
140
-
141
- try:
142
- attn_score = attn_score.reshape(8, h, w)
143
- except:
144
- # 实在不行,硬插值
145
- print("⚠️ Shape mismatch, forcing interpolation...")
146
- attn_score = F.interpolate(attn_score.unsqueeze(0), size=8*14*14, mode='linear').reshape(8, 14, 14)
147
-
148
- # 插值回视频尺寸
149
- attn_score = F.interpolate(attn_score.unsqueeze(0).unsqueeze(0), size=(16, 224, 224), mode='trilinear').squeeze()
150
- attn_score = attn_score.cpu().numpy()
151
-
152
- # 绘图
153
- frame_indices = [2, 6, 10, 14]
154
- fig, axes = plt.subplots(2, 4, figsize=(16, 8))
155
- orig_imgs = F.interpolate(torch.from_numpy(batch).permute(0,3,1,2).float(), size=(224,224)).permute(0,2,3,1).numpy().astype(np.uint8)
156
-
157
- for i, frame_idx in enumerate(frame_indices):
158
- img = orig_imgs[frame_idx]
159
- heatmap = attn_score[frame_idx]
160
- heatmap = (heatmap - heatmap.min()) / (heatmap.max() - heatmap.min() + 1e-8)
161
- heatmap = np.uint8(255 * heatmap)
162
- heatmap = cv2.applyColorMap(heatmap, cv2.COLORMAP_JET)
163
- overlay = cv2.addWeighted(img, 0.6, heatmap, 0.4, 0)
164
-
165
- axes[0, i].imshow(img)
166
- axes[0, i].axis('off')
167
- axes[0, i].set_title(f"Frame {frame_idx}")
168
-
169
- axes[1, i].imshow(cv2.cvtColor(overlay, cv2.COLOR_BGR2RGB))
170
- axes[1, i].axis('off')
171
- axes[1, i].set_title(f"Attention")
172
-
173
- plt.tight_layout()
174
- plt.savefig(save_path)
175
- print(f"✅ Saved: {save_path}")
176
-
177
- if __name__ == "__main__":
178
- if VIDEO_PATH: visualize(VIDEO_PATH)