lvjiameng commited on
Commit
d593c77
·
verified ·
1 Parent(s): 0055b11

Upload 10 files

Browse files
Get_feature.ipynb ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "1a160d17",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import os\n",
11
+ "import torch\n",
12
+ "import numpy as np\n",
13
+ "import matplotlib.pyplot as plt\n",
14
+ "from PIL import Image\n",
15
+ "import torchvision.transforms as T\n",
16
+ "from glob import glob\n",
17
+ "import tqdm\n",
18
+ "from sklearn.preprocessing import StandardScaler\n",
19
+ "from scipy.stats import binned_statistic_2d\n",
20
+ "import umap\n",
21
+ "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n",
22
+ "\n",
23
+ "# --- 1. MAE Model Definition (Assumes models_mae is available) ---\n",
24
+ "# NOTE: In a public release, you must ensure 'models_mae.py' or equivalent \n",
25
+ "# model definitions (e.g., from the official MAE repository) are accessible.\n",
26
+ "try:\n",
27
+ " # This line assumes you have a models_mae.py file defining the MAE architecture\n",
28
+ " import models_mae \n",
29
+ "except ImportError:\n",
30
+ " print(\"Warning: 'models_mae.py' not found. Please ensure it's in the path or replace this section with your model definition.\")\n",
31
+ " # Define a dummy module for demonstration if models_mae is missing\n",
32
+ " class DummyModel:\n",
33
+ " def __init__(self):\n",
34
+ " pass\n",
35
+ " def to(self, device):\n",
36
+ " return self\n",
37
+ " def load_state_dict(self, state_dict, strict):\n",
38
+ " pass\n",
39
+ " def eval(self):\n",
40
+ " pass\n",
41
+ " def forward_encoder(self, x, mask_ratio):\n",
42
+ " # Simulate a feature tensor of shape (1, 197, 768) for base model\n",
43
+ " B, C, H, W = x.shape\n",
44
+ " dummy_features = torch.randn(B, 197, 768) \n",
45
+ " return dummy_features, None, None # Features, mask, ids_restore\n",
46
+ " \n",
47
+ " class DummyModelsMae:\n",
48
+ " def mae_vit_base_patch16(self):\n",
49
+ " return DummyModel()\n",
50
+ "\n",
51
+ " models_mae = DummyModelsMae()\n",
52
+ "\n",
53
+ "\n",
54
+ "def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):\n",
55
+ " \"\"\"\n",
56
+ " Loads the MAE model and its pre-trained weights.\n",
57
+ " \"\"\"\n",
58
+ " # 1. Instantiate the model architecture\n",
59
+ " model = getattr(models_mae, arch)()\n",
60
+ " \n",
61
+ " # 2. Load the checkpoint\n",
62
+ " # Note: Using map_location='cpu' ensures it loads even if a GPU is not available initially.\n",
63
+ " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
64
+ " \n",
65
+ " # 3. Clean up the state dictionary keys (e.g., removing 'module.' prefix from DataParallel)\n",
66
+ " state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}\n",
67
+ "\n",
68
+ " # 4. Move model to GPU (if available) and load weights\n",
69
+ " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
70
+ " print(f\"Using device: {device}\")\n",
71
+ " model = model.to(device)\n",
72
+ " model.load_state_dict(state_dict, strict=False)\n",
73
+ " \n",
74
+ " # 5. Set model to evaluation mode\n",
75
+ " model.eval()\n",
76
+ " return model, device\n",
77
+ "\n",
78
+ "# --- 2. Image Loading and Preprocessing ---\n",
79
+ "\n",
80
+ "def load_image(image_path, scale_size=256, target_size=224):\n",
81
+ " \"\"\"\n",
82
+ " Loads and preprocesses an image for MAE input.\n",
83
+ " \n",
84
+ " Note: The normalization values here [0.1689, 0.1536, 0.1516] and [0.1284, 0.0963, 0.1051]\n",
85
+ " appear to be custom statistics, likely calculated from a specific dataset\n",
86
+ " (like an astronomical image dataset, given the chkpt_dir path).\n",
87
+ " Standard ImageNet stats are typically [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225].\n",
88
+ " \"\"\"\n",
89
+ " transform = T.Compose([\n",
90
+ " T.Resize(scale_size, interpolation=T.InterpolationMode.BICUBIC), # Using BICUBIC as specified by '3'\n",
91
+ " T.CenterCrop(target_size),\n",
92
+ " T.ToTensor(),\n",
93
+ " # Custom normalization based on the dataset the MAE was trained on\n",
94
+ " T.Normalize(mean=[0.1689, 0.1536, 0.1516], std=[0.1284, 0.0963, 0.1051])\n",
95
+ " ])\n",
96
+ " image = Image.open(image_path).convert('RGB')\n",
97
+ " return transform(image)\n",
98
+ "\n",
99
+ "# --- 3. Feature Extraction ---\n",
100
+ "\n",
101
+ "def extract_features(model, image_path, device):\n",
102
+ " \"\"\"\n",
103
+ " Passes an image through the MAE encoder to get the feature representation.\n",
104
+ " \n",
105
+ " Returns the Class Token (CLS) feature vector.\n",
106
+ " \"\"\"\n",
107
+ " img = load_image(image_path).unsqueeze(0) # Add batch dimension\n",
108
+ " img = img.to(device)\n",
109
+ " \n",
110
+ " with torch.no_grad():\n",
111
+ " # model.forward_encoder returns (features, mask, ids_restore)\n",
112
+ " # We set mask_ratio=0 to ensure we encode the full image\n",
113
+ " x, _, _ = model.forward_encoder(img.float(), mask_ratio=0) \n",
114
+ " \n",
115
+ " # x is typically of shape [B, L, D] where B=Batch, L=Num_Tokens (e.g., 197), D=Feature_Dim (e.g., 768)\n",
116
+ " # The CLS token is the first token (index 0).\n",
117
+ " # We extract the CLS token feature vector x[0, 0, :]\n",
118
+ " return x[0, 0, :].cpu().numpy()\n",
119
+ "\n",
120
+ "# --- 4. Dimensionality Reduction (UMAP) ---\n",
121
+ "\n",
122
+ "def perform_umap(features, n_components=2, random_state=42):\n",
123
+ " \"\"\"\n",
124
+ " Scales features and applies UMAP to reduce dimensions to 2D.\n",
125
+ " \n",
126
+ " \"\"\"\n",
127
+ " print(\"Scaling features...\")\n",
128
+ " scaler = StandardScaler()\n",
129
+ " features_scaled = scaler.fit_transform(features)\n",
130
+ " \n",
131
+ " print(\"Applying UMAP...\")\n",
132
+ " reducer = umap.UMAP(n_components=n_components, random_state=random_state)\n",
133
+ " return reducer.fit_transform(features_scaled)\n",
134
+ "\n",
135
+ "# --- 5. Main Workflow ---\n",
136
+ "\n",
137
+ "def embedding_feature(chkpt_dir, image_dir):\n",
138
+ " \"\"\"\n",
139
+ " Main function to load model, extract features, and perform UMAP.\n",
140
+ " \"\"\"\n",
141
+ " # 1. Find all image files\n",
142
+ " image_path_list = glob(os.path.join(image_dir, \"*.jpg\"))\n",
143
+ " print(f\"Number of images to process: {len(image_path_list)}\")\n",
144
+ "\n",
145
+ " # 2. Prepare the model and device\n",
146
+ " model, device = prepare_model(chkpt_dir)\n",
147
+ "\n",
148
+ " # 3. Extract features\n",
149
+ " features, image_paths = [], []\n",
150
+ " print(\"Extracting features...\")\n",
151
+ " for image_path in tqdm.tqdm(image_path_list):\n",
152
+ " try:\n",
153
+ " features.append(extract_features(model, image_path, device))\n",
154
+ " image_paths.append(image_path)\n",
155
+ " except Exception as e:\n",
156
+ " # Handle potential file read errors or other issues\n",
157
+ " print(f\"Skipping image {image_path} due to error: {e}\")\n",
158
+ " continue\n",
159
+ "\n",
160
+ " features = np.array(features)\n",
161
+ " \n",
162
+ " # 4. Perform UMAP\n",
163
+ " embedding = perform_umap(features)\n",
164
+ "\n",
165
+ " return embedding, image_paths\n",
166
+ "\n",
167
+ "# --- 6. Visualization Functions (Grid Plotting) ---\n",
168
+ "\n",
169
+ "def plt_image(components, image_paths, save_path='umap_test.png', nx=100, ny=100):\n",
170
+ " \"\"\"\n",
171
+ " Creates a detailed visualization where each bin in the UMAP space \n",
172
+ " is represented by a single, randomly selected image from that bin.\n",
173
+ " \n",
174
+ " This is often called a 'datamap' or 'tile plot'.\n",
175
+ " \"\"\"\n",
176
+ " print(f\"Creating a {nx}x{ny} tile plot visualization...\")\n",
177
+ " z_emb = components\n",
178
+ " iseed = 13579 # Fixed seed for reproducible random selection\n",
179
+ "\n",
180
+ " # 1. Define bounds for the UMAP space\n",
181
+ " xmin, xmax = z_emb[:, 0].min(), z_emb[:, 0].max()\n",
182
+ " ymin, ymax = z_emb[:, 1].min(), z_emb[:, 1].max()\n",
183
+ "\n",
184
+ " # 2. Define bins\n",
185
+ " binx = np.linspace(xmin, xmax, nx + 1)\n",
186
+ " biny = np.linspace(ymin, ymax, ny + 1)\n",
187
+ "\n",
188
+ " # 3. Bin the UMAP coordinates (x and y)\n",
189
+ " ret = binned_statistic_2d(z_emb[:, 0], z_emb[:, 1], z_emb[:, 1], 'count', \n",
190
+ " bins=[binx, biny], expand_binnumbers=True)\n",
191
+ " z_emb_bins = ret.binnumber.T # Transposed bin numbers (ix, iy) for each point\n",
192
+ "\n",
193
+ " inds_lin = np.arange(z_emb.shape[0])\n",
194
+ " inds_used = [] # Indices of images selected for plotting\n",
195
+ " plotq = [] # Subplot positions for the selected images\n",
196
+ "\n",
197
+ " # 4. Select one image per populated bin\n",
198
+ " for ix in tqdm.tqdm(range(nx), desc=\"Selecting images for grid\"):\n",
199
+ " for iy in range(ny):\n",
200
+ " # Find all image indices (inds) that fall into the current bin (ix, iy)\n",
201
+ " # Bin numbers are 1-based, so we check for ix+1 and iy+1\n",
202
+ " dm = (z_emb_bins[:, 0] == ix + 1) & (z_emb_bins[:, 1] == iy + 1)\n",
203
+ " inds = inds_lin[dm]\n",
204
+ "\n",
205
+ " if len(inds) > 0:\n",
206
+ " # Use a fixed seed based on bin location for reproducible random choice\n",
207
+ " np.random.seed(ix * ny + iy + iseed) \n",
208
+ " ind_plt = np.random.choice(inds)\n",
209
+ " \n",
210
+ " inds_used.append(ind_plt)\n",
211
+ " # Calculate the 1D index for the subplot: (row * num_cols) + col + 1\n",
212
+ " plotq.append(iy * nx + ix) # Adjusted from original: ix + iy * nx\n",
213
+ "\n",
214
+ " # 5. Create the plot\n",
215
+ " print(f\"Plotting {len(inds_used)} images...\")\n",
216
+ " fig = plt.figure(figsize=(20, 20)) \n",
217
+ " \n",
218
+ " for index, i in enumerate(inds_used):\n",
219
+ " image_path = image_paths[i]\n",
220
+ " if os.path.exists(image_path): \n",
221
+ " try:\n",
222
+ " img_jpg = plt.imread(image_path)\n",
223
+ " # Subplot index is 1-based. Use plotq[index] + 1\n",
224
+ " plt.subplot(nx, ny, plotq[index] + 1) \n",
225
+ " plt.xticks([])\n",
226
+ " plt.yticks([])\n",
227
+ " plt.imshow(img_jpg)\n",
228
+ " except Exception as e:\n",
229
+ " print(f\"Error loading or plotting image {image_path}: {e}\")\n",
230
+ " else:\n",
231
+ " print(f\"File does not exist: {image_path}\")\n",
232
+ "\n",
233
+ " plt.suptitle(f\"MAE Features Embedded with UMAP ({len(inds_used)} images in {nx}x{ny} grid)\", fontsize=24)\n",
234
+ " plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout for suptitle\n",
235
+ " plt.savefig(save_path)\n",
236
+ " print(f\"Visualization saved to {save_path}\")\n",
237
+ " plt.close(fig) # Use close(fig) instead of just close() if you are running this in a loop or non-interactive environment\n",
238
+ " # Note: plt.show() is commented out for standard command-line use.\n",
239
+ "\n",
240
+ "# --- 7. Execution Block ---\n",
241
+ "\n",
242
+ "if __name__ == '__main__':\n",
243
+ " # Configuration - REPLACE THESE WITH YOUR ACTUAL PATHS\n",
244
+ " # The checkpoint directory of the pre-trained MAE model\n",
245
+ " chkpt_dir = './ckpt/norm_ckpt/base/weights/best/epoch_777_loss_0.7236/ckpt.pth' \n",
246
+ " # The directory containing the images (*.jpg) for feature extraction\n",
247
+ " image_dir = './dataset' \n",
248
+ "\n",
249
+ " if not os.path.exists(chkpt_dir):\n",
250
+ " print(f\"Error: Checkpoint file not found at {chkpt_dir}. Please update the chkpt_dir variable.\")\n",
251
+ " elif not os.path.exists(image_dir):\n",
252
+ " print(f\"Error: Image directory not found at {image_dir}. Please update the image_dir variable.\")\n",
253
+ " else:\n",
254
+ " # Perform feature extraction and UMAP reduction\n",
255
+ " embedding, image_paths = embedding_feature(chkpt_dir, image_dir)\n",
256
+ " \n",
257
+ " # Create and save the UMAP visualization\n",
258
+ " plt_image(embedding, image_paths, save_path='mae_umap_visualization.png')"
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": null,
264
+ "id": "e5928015",
265
+ "metadata": {},
266
+ "outputs": [],
267
+ "source": [
268
+ "import os\n",
269
+ "import torch\n",
270
+ "import numpy as np\n",
271
+ "from PIL import Image\n",
272
+ "import torchvision.transforms as T\n",
273
+ "# 假设 models_mae 模块包含您的 MAE 模型定义\n",
274
+ "import models_mae \n",
275
+ "\n",
276
+ "# 定义自定义的归一化值(与您的训练保持一致)\n",
277
+ "CUSTOM_MEAN = [0.1689, 0.1536, 0.1516]\n",
278
+ "CUSTOM_STD = [0.1284, 0.0963, 0.1051]\n",
279
+ "TARGET_SIZE = 224 # ViT-Base 默认的输入尺寸\n",
280
+ "\n",
281
+ "# --- 1. 模型准备 (与原代码相同) ---\n",
282
+ "def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):\n",
283
+ " # ... (与原代码 prepare_model 函数完全相同) ...\n",
284
+ " try:\n",
285
+ " model = getattr(models_mae, arch)()\n",
286
+ " except AttributeError:\n",
287
+ " print(f\"错误: models_mae 模块中找不到架构 '{arch}'。请确保 models_mae.py 文件存在。\")\n",
288
+ " return None, None\n",
289
+ " \n",
290
+ " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n",
291
+ " state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}\n",
292
+ "\n",
293
+ " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
294
+ " print(f\"使用的设备: {device}\")\n",
295
+ " model = model.to(device)\n",
296
+ " \n",
297
+ " model.load_state_dict(state_dict, strict=False)\n",
298
+ " model.eval()\n",
299
+ " return model, device\n",
300
+ "\n",
301
+ "# --- 2. 图像加载与预处理 (关键修改部分) ---\n",
302
+ "def load_and_preprocess_image(image_path, target_size=TARGET_SIZE):\n",
303
+ " \"\"\"\n",
304
+ " 修改后的图像加载函数:\n",
305
+ " - 不进行缩放(Resize to scale_size)\n",
306
+ " - 不进行中心裁剪(CenterCrop)\n",
307
+ " - 而是直接强制缩放到模型的输入尺寸 (target_size x target_size)。\n",
308
+ " \"\"\"\n",
309
+ " transform = T.Compose([\n",
310
+ " # 强制将图像尺寸变为 (target_size, target_size),可能导致拉伸\n",
311
+ " T.Resize((target_size, target_size), interpolation=T.InterpolationMode.BICUBIC),\n",
312
+ " T.ToTensor(),\n",
313
+ " T.Normalize(mean=CUSTOM_MEAN, std=CUSTOM_STD)\n",
314
+ " ])\n",
315
+ " \n",
316
+ " if not os.path.exists(image_path):\n",
317
+ " raise FileNotFoundError(f\"图像文件未找到: {image_path}\")\n",
318
+ " \n",
319
+ " image = Image.open(image_path).convert('RGB')\n",
320
+ " return transform(image)\n",
321
+ "\n",
322
+ "# --- 3. 特征提取核心函数 (与原代码相同) ---\n",
323
+ "def extract_single_feature(model, image_path, device):\n",
324
+ " \"\"\"\n",
325
+ " 提取单个图像的特征向量(CLS Token)。\n",
326
+ " \"\"\"\n",
327
+ " img_tensor = load_and_preprocess_image(image_path).unsqueeze(0) # [1, C, H, W]\n",
328
+ " img_tensor = img_tensor.to(device)\n",
329
+ " \n",
330
+ " with torch.no_grad():\n",
331
+ " x, _, _ = model.forward_encoder(img_tensor.float(), mask_ratio=0) \n",
332
+ " feature_vector = x[0, 0, :].cpu().numpy()\n",
333
+ " \n",
334
+ " return feature_vector\n",
335
+ "\n",
336
+ "# --- 4. 示例使用 ---\n",
337
+ "if __name__ == '__main__':\n",
338
+ " # 请替换为您的模型和图像路径\n",
339
+ " CHKPT_DIR = './ckpt/norm_ckpt/base/weights/best/epoch_777_loss_0.7236/ckpt.pth' \n",
340
+ " SINGLE_IMAGE_PATH = './dataset/example_image_001.jpg' \n",
341
+ " \n",
342
+ " print(\"--- MAE 单张图像特征提取开始 (无中心裁剪/预缩放) ---\")\n",
343
+ "\n",
344
+ " try:\n",
345
+ " # 1. 准备模型\n",
346
+ " model, device = prepare_model(CHKPT_DIR)\n",
347
+ " if model is None:\n",
348
+ " exit()\n",
349
+ "\n",
350
+ " # 2. 提取特征\n",
351
+ " feature = extract_single_feature(model, SINGLE_IMAGE_PATH, device)\n",
352
+ "\n",
353
+ " # 3. 结果展示\n",
354
+ " print(\"\\n--- 特征提取结果 ---\")\n",
355
+ " print(f\"图像路径: {SINGLE_IMAGE_PATH}\")\n",
356
+ " print(f\"ViT 输入尺寸: {TARGET_SIZE}x{TARGET_SIZE} (通过强制缩放获得)\")\n",
357
+ " print(f\"特征向量形状: {feature.shape}\")\n",
358
+ " print(f\"特征向量(前5个值): {feature[:5]}\")\n",
359
+ " print(\"--- 特征提取完成 ---\")\n",
360
+ "\n",
361
+ " except FileNotFoundError as e:\n",
362
+ " print(f\"致命错误: {e}\")\n",
363
+ " except Exception as e:\n",
364
+ " print(f\"发生其他错误: {e}\")"
365
+ ]
366
+ },
367
+ {
368
+ "cell_type": "markdown",
369
+ "id": "f641d493",
370
+ "metadata": {},
371
+ "source": []
372
+ }
373
+ ],
374
+ "metadata": {
375
+ "kernelspec": {
376
+ "display_name": "base",
377
+ "language": "python",
378
+ "name": "python3"
379
+ },
380
+ "language_info": {
381
+ "codemirror_mode": {
382
+ "name": "ipython",
383
+ "version": 3
384
+ },
385
+ "file_extension": ".py",
386
+ "mimetype": "text/x-python",
387
+ "name": "python",
388
+ "nbconvert_exporter": "python",
389
+ "pygments_lexer": "ipython3",
390
+ "version": "3.12.4"
391
+ }
392
+ },
393
+ "nbformat": 4,
394
+ "nbformat_minor": 5
395
+ }
models_mae.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+ from timm.models.vision_transformer import PatchEmbed, Block
5
+ from util.pos_embed import get_2d_sincos_pos_embed
6
+
7
+
8
+ class MaskedAutoEncoderViT(nn.Module):
9
+ """ Masked Autoencoder with VisionTransformer backbone
10
+ """
11
+ def __init__(self, img_size=224, patch_size=16, in_chans=3,
12
+ embed_dim=1024, depth=24, num_heads=16,
13
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
14
+ mlp_ratio=4.0, norm_layer=nn.LayerNorm, norm_pix_loss=False):
15
+ super().__init__()
16
+
17
+
18
+ self.patch_embed = PatchEmbed(img_size, patch_size, in_chans, embed_dim)
19
+ num_patches = self.patch_embed.num_patches
20
+
21
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
22
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim), requires_grad=False) # fixed sin-cos embedding
23
+
24
+ self.blocks = nn.ModuleList([
25
+ Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
26
+ for i in range(depth)])
27
+ self.norm = norm_layer(embed_dim)
28
+
29
+ self.decoder_embed = nn.Linear(embed_dim, decoder_embed_dim, bias=True)
30
+
31
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
32
+
33
+ self.decoder_pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, decoder_embed_dim), requires_grad=False) # fixed sin-cos embedding
34
+
35
+ self.decoder_blocks = nn.ModuleList([
36
+ Block(decoder_embed_dim, decoder_num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer)
37
+ for i in range(decoder_depth)
38
+ ])
39
+
40
+ self.decoder_norm = norm_layer(decoder_embed_dim)
41
+ self.decoder_pred = nn.Linear(decoder_embed_dim, patch_size**2 * in_chans, bias=True) # decoder to patch
42
+
43
+ self.norm_pix_loss = norm_pix_loss
44
+
45
+ self.initialize_weights()
46
+
47
+ def initialize_weights(self):
48
+ pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
49
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
50
+
51
+ decoder_pos_embed = get_2d_sincos_pos_embed(self.decoder_pos_embed.shape[-1], int(self.patch_embed.num_patches**.5), cls_token=True)
52
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
53
+
54
+ w = self.patch_embed.proj.weight.data
55
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
56
+
57
+ torch.nn.init.normal_(self.cls_token, std=.02)
58
+ torch.nn.init.normal_(self.mask_token, std=.02)
59
+
60
+ self.apply(self._init_weights)
61
+
62
+ def _init_weights(self, m):
63
+ if isinstance(m, nn.Linear):
64
+ torch.nn.init.xavier_uniform_(m.weight)
65
+ if isinstance(m, nn.Linear) and m.bias is not None:
66
+ nn.init.constant_(m.bias, 0)
67
+ elif isinstance(m, nn.LayerNorm):
68
+ nn.init.constant_(m.bias, 0)
69
+ nn.init.constant_(m.weight, 1.0)
70
+
71
+ def random_masking(self, x, mask_ratio):
72
+ """
73
+ Perform per-sample random masking by per-sample shuffling.
74
+ Per-sample shuffling is done by argsort random noise.
75
+ x: [N, L, D], sequence
76
+ """
77
+ N, L, D = x.shape # batch, length, dim
78
+ len_keep = int(L * (1 - mask_ratio))
79
+
80
+ noise = torch.rand(N, L, device=x.device) # noise in [0, 1]
81
+
82
+ ids_shuffle = torch.argsort(noise, dim=1) # ascend: small is keep, large is remove
83
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
84
+
85
+ ids_keep = ids_shuffle[:, :len_keep]
86
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
87
+
88
+ mask = torch.ones([N, L], device=x.device)
89
+ mask[:, :len_keep] = 0
90
+ mask = torch.gather(mask, dim=1, index=ids_restore)
91
+
92
+ return x_masked, mask, ids_restore
93
+
94
+ def patchify(self, imgs):
95
+ """
96
+ imgs: (N, 3, H, W)
97
+ x: (N, L, patch_size**2 *3)
98
+ """
99
+ p = self.patch_embed.patch_size[0]
100
+ assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0
101
+
102
+ h = w = imgs.shape[2] // p
103
+ x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p))
104
+ x = torch.einsum('nchpwq->nhwpqc', x)
105
+ x = x.reshape(shape=(imgs.shape[0], h*w, p**2*3))
106
+ return x
107
+
108
+ def unpatchify(self, x):
109
+ """
110
+ x: (N, L, patch_size**2 *3)
111
+ imgs: (N, 3, H, W)
112
+ """
113
+ p = self.patch_embed.patch_size[0]
114
+ h = w = int(x.shape[1]**0.5)
115
+ assert h *w == x.shape[1]
116
+
117
+ x = x.reshape(shape=(x.shape[0], h, w, p, p, 3))
118
+ x = torch.einsum('nhwpqc->nchpwq', x)
119
+ imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p))
120
+ return imgs
121
+
122
+ def forward_encoder(self, x, mask_ratio):
123
+ x = self.patch_embed(x)
124
+
125
+ x = x + self.pos_embed[:, 1:, :]
126
+
127
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
128
+
129
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
130
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
131
+ x = torch.cat((cls_tokens, x), dim=1)
132
+
133
+ for blk in self.blocks:
134
+ x = blk(x)
135
+ x = self.norm(x)
136
+
137
+ return x, mask, ids_restore
138
+
139
+ def forward_decoder(self, x, ids_restore):
140
+ x = self.decoder_embed(x)
141
+
142
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
143
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
144
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2])) # unshuffle
145
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
146
+
147
+ x = x + self.decoder_pos_embed
148
+
149
+ for blk in self.decoder_blocks:
150
+ x = blk(x)
151
+ x = self.decoder_norm(x)
152
+
153
+ x = self.decoder_pred(x)
154
+
155
+ x = x[:, 1:, :]
156
+
157
+ return x
158
+
159
+ def forward_loss(self, imgs, pred, mask):
160
+ """
161
+ imgs: [N, 3, H, W]
162
+ pred: [N, L, p*p*3]
163
+ mask: [N, L], 0 is keep, 1 is move.
164
+ """
165
+ target = self.patchify(imgs)
166
+ if self.norm_pix_loss:
167
+ mean = target.mean(dim=-1, keepdim=True)
168
+ var = target.var(dim=-1, keepdim=True)
169
+ target = (target - mean) / (var + 1.e-6)**0.5
170
+
171
+ loss = (pred - target) ** 2
172
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
173
+
174
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
175
+ return loss
176
+
177
+ def forward(self, imgs, mask_ratio=0.75):
178
+ latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio)
179
+ pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3]
180
+ loss = self.forward_loss(imgs, pred, mask)
181
+ return loss, pred, mask
182
+
183
+
184
+ def forward_encoder_with_given_mask(self, x, given_patch_mask):
185
+
186
+ x = self.patch_embed(x) # (N, L, D)
187
+
188
+ x = x + self.pos_embed[:, 1:, :] # (N, L, D)
189
+
190
+ N, L, D = x.shape
191
+ noise = torch.rand(N, L, device=x.device)
192
+
193
+ mask_float = given_patch_mask.float()
194
+ ids_shuffle = torch.argsort(mask_float * (noise.max() + 1) + noise, dim=1) # (N, L)
195
+ ids_restore = torch.argsort(ids_shuffle, dim=1)
196
+
197
+ len_keep = L - given_patch_mask.sum(dim=1).max().int().item()
198
+ ids_keep = ids_shuffle[:, :len_keep]
199
+ x_masked = torch.gather(x, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, D))
200
+
201
+ cls_token = self.cls_token + self.pos_embed[:, :1, :]
202
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
203
+ x = torch.cat((cls_tokens, x_masked), dim=1)
204
+
205
+ for blk in self.blocks:
206
+ x = blk(x)
207
+ x = self.norm(x)
208
+
209
+ return x, given_patch_mask, ids_restore
210
+
211
+ def forward_with_given_mask(self, imgs, given_patch_mask):
212
+
213
+ latent, mask, ids_restore = self.forward_encoder_with_given_mask(imgs, given_patch_mask)
214
+ pred = self.forward_decoder(latent, ids_restore)
215
+ loss = self.forward_loss(imgs, pred, mask)
216
+ return loss, pred, mask
217
+
218
+
219
+
220
+
221
+ def mae_vit_base_patch16(**kwargs):
222
+ model = MaskedAutoEncoderViT(
223
+ patch_size=16, embed_dim=768, depth=12, num_heads=12,
224
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
225
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
226
+ return model
227
+
228
+
229
+ def mae_vit_large_patch16(**kwargs):
230
+ model = MaskedAutoEncoderViT(
231
+ patch_size=16, embed_dim=1024, depth=24, num_heads=16,
232
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
233
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
234
+ return model
235
+
236
+
237
+ def mae_vit_huge_patch14(**kwargs):
238
+ model = MaskedAutoEncoderViT(
239
+ patch_size=14, embed_dim=1280, depth=32, num_heads=16,
240
+ decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16,
241
+ mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
242
+ return model
util/__pycache__/lr_sched.cpython-38.pyc ADDED
Binary file (2.5 kB). View file
 
util/__pycache__/misc.cpython-38.pyc ADDED
Binary file (1.35 kB). View file
 
util/__pycache__/pos_embed.cpython-312.pyc ADDED
Binary file (4.04 kB). View file
 
util/__pycache__/pos_embed.cpython-38.pyc ADDED
Binary file (2.4 kB). View file
 
util/__pycache__/pos_embed.cpython-39.pyc ADDED
Binary file (2.37 kB). View file
 
util/lr_sched.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch, math
2
+ import torch.nn as nn
3
+
4
+ def adjust_learning_rate(optimizer, epoch, args):
5
+ """Decay the learning rate with half-cycle cosine after warmup"""
6
+ if epoch < args.warmup_epochs:
7
+ lr = args.lr * epoch / args.warmup_epochs
8
+ else:
9
+ lr = args.min_lr + (args.lr - args.min_lr) * 0.5 * \
10
+ (1. + math.cos(math.pi * (epoch - args.warmup_epochs) / (args.epochs - args.warmup_epochs)))
11
+ for param_group in optimizer.param_groups:
12
+ if "lr_scale" in param_group:
13
+ param_group["lr"] = lr * param_group["lr_scale"]
14
+ else:
15
+ param_group["lr"] = lr
16
+ return lr
17
+
18
+
19
+ def param_groups_weight_decay(model: nn.Module, weight_decay=1e-5, no_weight_decay_list=()):
20
+ no_weight_decay_list = set(no_weight_decay_list)
21
+ decay = []
22
+ no_decay = []
23
+ for name, param in model.named_parameters():
24
+ if not param.requires_grad:
25
+ continue
26
+
27
+ if param.ndim <= 1 or name.endswith(".bias") or name in no_weight_decay_list:
28
+ no_decay.append(param)
29
+ else:
30
+ decay.append(param)
31
+
32
+ return [
33
+ {'params': no_decay, 'weight_decay': 0.},
34
+ {'params': decay, 'weight_decay': weight_decay}]
35
+
36
+
37
+ def param_groups_lrd(model, weight_decay=0.05, no_weight_decay_list=[], layer_decay=.75):
38
+ """
39
+ Parameter groups for layer-wise lr decay
40
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L58
41
+ """
42
+ param_group_names = {}
43
+ param_groups = {}
44
+
45
+ num_layers = len(model.blocks) + 1
46
+
47
+ layer_scales = list(layer_decay ** (num_layers - i) for i in range(num_layers + 1))
48
+
49
+ for n, p in model.named_parameters():
50
+ if not p.requires_grad:
51
+ continue
52
+
53
+ # no decay: all 1D parameters and model specific ones
54
+ if p.ndim == 1 or n in no_weight_decay_list:
55
+ g_decay = "no_decay"
56
+ this_decay = 0.
57
+ else:
58
+ g_decay = "decay"
59
+ this_decay = weight_decay
60
+
61
+ layer_id = get_layer_id_for_vit(n, num_layers)
62
+ group_name = "layer_%d_%s" % (layer_id, g_decay)
63
+
64
+ if group_name not in param_group_names:
65
+ this_scale = layer_scales[layer_id]
66
+
67
+ param_group_names[group_name] = {
68
+ "lr_scale": this_scale,
69
+ "weight_decay": this_decay,
70
+ "params": [],
71
+ }
72
+ param_groups[group_name] = {
73
+ "lr_scale": this_scale,
74
+ "weight_decay": this_decay,
75
+ "params": [],
76
+ }
77
+
78
+ param_group_names[group_name]["params"].append(n)
79
+ param_groups[group_name]["params"].append(p)
80
+
81
+ # print("parameter groups: \n%s" % json.dumps(param_group_names, indent=2))
82
+
83
+ return list(param_groups.values())
84
+
85
+
86
+ def get_layer_id_for_vit(name, num_layers):
87
+ """
88
+ Assign a parameter with its layer id
89
+ Following BEiT: https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py#L33
90
+ """
91
+ if name in ['cls_token', 'pos_embed']:
92
+ return 0
93
+ elif name.startswith('patch_embed'):
94
+ return 0
95
+ elif name.startswith('blocks'):
96
+ return int(name.split('.')[1]) + 1
97
+ else:
98
+ return num_layers
util/misc.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil
2
+ import torch, math
3
+
4
+ def colorstr(*input):
5
+ # Colors a string https://en.wikipedia.org/wiki/ANSI_escape_code, i.e. colorstr('blue', 'hello world')
6
+ *args, string = input if len(input) > 1 else ('blue', 'bold', input[0]) # color arguments, string
7
+ colors = {'black': '\033[30m', # basic colors
8
+ 'red': '\033[31m',
9
+ 'green': '\033[32m',
10
+ 'yellow': '\033[33m',
11
+ 'blue': '\033[34m',
12
+ 'magenta': '\033[35m',
13
+ 'cyan': '\033[36m',
14
+ 'white': '\033[37m',
15
+ 'bright_black': '\033[90m', # bright colors
16
+ 'bright_red': '\033[91m',
17
+ 'bright_green': '\033[92m',
18
+ 'bright_yellow': '\033[93m',
19
+ 'bright_blue': '\033[94m',
20
+ 'bright_magenta': '\033[95m',
21
+ 'bright_cyan': '\033[96m',
22
+ 'bright_white': '\033[97m',
23
+ 'end': '\033[0m', # misc
24
+ 'bold': '\033[1m',
25
+ 'underline': '\033[4m'}
26
+ return ''.join(colors[x] for x in args) + f'{string}' + colors['end']
27
+
28
+
29
+ def SaveCheckpoint(state, last, last_path, best, best_path, is_best):
30
+ if os.path.exists(last):
31
+ shutil.rmtree(last)
32
+ last_path.mkdir(parents=True, exist_ok=True)
33
+ torch.save(state, os.path.join(last_path, 'ckpt.pth'))
34
+
35
+ if is_best:
36
+ if os.path.exists(best):
37
+ shutil.rmtree(best)
38
+ best_path.mkdir(parents=True, exist_ok=True)
39
+ torch.save(state, os.path.join(best_path, 'ckpt.pth'))
40
+
util/pos_embed.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ # --------------------------------------------------------
4
+ # 2D sine-cosine position embedding
5
+ # References:
6
+ # Transformer: https://github.com/tensorflow/models/blob/master/official/nlp/transformer/model_utils.py
7
+ # MoCo v3: https://github.com/facebookresearch/moco-v3
8
+ # --------------------------------------------------------
9
+ def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
10
+ """
11
+ grid_size: int of the grid height and width
12
+ return:
13
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
14
+ """
15
+ grid_h = np.arange(grid_size, dtype=np.float32)
16
+ grid_w = np.arange(grid_size, dtype=np.float32)
17
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
18
+ grid = np.stack(grid, axis=0)
19
+
20
+ grid = grid.reshape([2, 1, grid_size, grid_size])
21
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
22
+ if cls_token:
23
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
24
+ return pos_embed
25
+
26
+
27
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
28
+ assert embed_dim % 2 == 0
29
+
30
+ # use half of dimensions to encode grid_h
31
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
32
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
33
+
34
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
35
+ return emb
36
+
37
+
38
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
39
+ """
40
+ embed_dim: output dimension for each position
41
+ pos: a list of positions to be encoded: size (M,)
42
+ out: (M, D)
43
+ """
44
+ assert embed_dim % 2 == 0
45
+ omega = np.arange(embed_dim // 2, dtype=float)
46
+ omega /= embed_dim / 2.
47
+ omega = 1. / 10000**omega # (D/2,)
48
+
49
+ pos = pos.reshape(-1) # (M,)
50
+ out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
51
+
52
+ emb_sin = np.sin(out) # (M, D/2)
53
+ emb_cos = np.cos(out) # (M, D/2)
54
+
55
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
56
+ return emb
57
+
58
+
59
+ # --------------------------------------------------------
60
+ # Interpolate position embeddings for high-resolution
61
+ # References:
62
+ # DeiT: https://github.com/facebookresearch/deit
63
+ # --------------------------------------------------------
64
+ def interpolate_pos_embed(model, checkpoint_model):
65
+ if 'pos_embed' in checkpoint_model:
66
+ pos_embed_checkpoint = checkpoint_model['pos_embed']
67
+ embedding_size = pos_embed_checkpoint.shape[-1]
68
+ num_patches = model.patch_embed.num_patches
69
+ num_extra_tokens = model.pos_embed.shape[-2] - num_patches
70
+ # height (== width) for the checkpoint position embedding
71
+ orig_size = int((pos_embed_checkpoint.shape[-2] - num_extra_tokens) ** 0.5)
72
+ # height (== width) for the new position embedding
73
+ new_size = int(num_patches ** 0.5)
74
+ # class_token and dist_token are kept unchanged
75
+ if orig_size != new_size:
76
+ print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
77
+ extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
78
+ # only the position tokens are interpolated
79
+ pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
80
+ pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
81
+ pos_tokens = torch.nn.functional.interpolate(
82
+ pos_tokens, size=(new_size, new_size), mode='bicubic', align_corners=False)
83
+ pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
84
+ new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)
85
+ checkpoint_model['pos_embed'] = new_pos_embed