{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "1a160d17", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from PIL import Image\n", "import torchvision.transforms as T\n", "from glob import glob\n", "import tqdm\n", "from sklearn.preprocessing import StandardScaler\n", "from scipy.stats import binned_statistic_2d\n", "import umap\n", "from matplotlib.offsetbox import OffsetImage, AnnotationBbox\n", "\n", "# --- 1. MAE Model Definition (Assumes models_mae is available) ---\n", "# NOTE: In a public release, you must ensure 'models_mae.py' or equivalent \n", "# model definitions (e.g., from the official MAE repository) are accessible.\n", "try:\n", " # This line assumes you have a models_mae.py file defining the MAE architecture\n", " import models_mae \n", "except ImportError:\n", " print(\"Warning: 'models_mae.py' not found. Please ensure it's in the path or replace this section with your model definition.\")\n", " # Define a dummy module for demonstration if models_mae is missing\n", " class DummyModel:\n", " def __init__(self):\n", " pass\n", " def to(self, device):\n", " return self\n", " def load_state_dict(self, state_dict, strict):\n", " pass\n", " def eval(self):\n", " pass\n", " def forward_encoder(self, x, mask_ratio):\n", " # Simulate a feature tensor of shape (1, 197, 768) for base model\n", " B, C, H, W = x.shape\n", " dummy_features = torch.randn(B, 197, 768) \n", " return dummy_features, None, None # Features, mask, ids_restore\n", " \n", " class DummyModelsMae:\n", " def mae_vit_base_patch16(self):\n", " return DummyModel()\n", "\n", " models_mae = DummyModelsMae()\n", "\n", "\n", "def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):\n", " \"\"\"\n", " Loads the MAE model and its pre-trained weights.\n", " \"\"\"\n", " # 1. Instantiate the model architecture\n", " model = getattr(models_mae, arch)()\n", " \n", " # 2. Load the checkpoint\n", " # Note: Using map_location='cpu' ensures it loads even if a GPU is not available initially.\n", " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", " \n", " # 3. Clean up the state dictionary keys (e.g., removing 'module.' prefix from DataParallel)\n", " state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}\n", "\n", " # 4. Move model to GPU (if available) and load weights\n", " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", " print(f\"Using device: {device}\")\n", " model = model.to(device)\n", " model.load_state_dict(state_dict, strict=False)\n", " \n", " # 5. Set model to evaluation mode\n", " model.eval()\n", " return model, device\n", "\n", "# --- 2. Image Loading and Preprocessing ---\n", "\n", "def load_image(image_path, scale_size=256, target_size=224):\n", " \"\"\"\n", " Loads and preprocesses an image for MAE input.\n", " \n", " Note: The normalization values here [0.1689, 0.1536, 0.1516] and [0.1284, 0.0963, 0.1051]\n", " appear to be custom statistics, likely calculated from a specific dataset\n", " (like an astronomical image dataset, given the chkpt_dir path).\n", " Standard ImageNet stats are typically [0.485, 0.456, 0.406] and [0.229, 0.224, 0.225].\n", " \"\"\"\n", " transform = T.Compose([\n", " T.Resize(scale_size, interpolation=T.InterpolationMode.BICUBIC), # Using BICUBIC as specified by '3'\n", " T.CenterCrop(target_size),\n", " T.ToTensor(),\n", " # Custom normalization based on the dataset the MAE was trained on\n", " T.Normalize(mean=[0.1689, 0.1536, 0.1516], std=[0.1284, 0.0963, 0.1051])\n", " ])\n", " image = Image.open(image_path).convert('RGB')\n", " return transform(image)\n", "\n", "# --- 3. Feature Extraction ---\n", "\n", "def extract_features(model, image_path, device):\n", " \"\"\"\n", " Passes an image through the MAE encoder to get the feature representation.\n", " \n", " Returns the Class Token (CLS) feature vector.\n", " \"\"\"\n", " img = load_image(image_path).unsqueeze(0) # Add batch dimension\n", " img = img.to(device)\n", " \n", " with torch.no_grad():\n", " # model.forward_encoder returns (features, mask, ids_restore)\n", " # We set mask_ratio=0 to ensure we encode the full image\n", " x, _, _ = model.forward_encoder(img.float(), mask_ratio=0) \n", " \n", " # x is typically of shape [B, L, D] where B=Batch, L=Num_Tokens (e.g., 197), D=Feature_Dim (e.g., 768)\n", " # The CLS token is the first token (index 0).\n", " # We extract the CLS token feature vector x[0, 0, :]\n", " return x[0, 0, :].cpu().numpy()\n", "\n", "# --- 4. Dimensionality Reduction (UMAP) ---\n", "\n", "def perform_umap(features, n_components=2, random_state=42):\n", " \"\"\"\n", " Scales features and applies UMAP to reduce dimensions to 2D.\n", " \n", " \"\"\"\n", " print(\"Scaling features...\")\n", " scaler = StandardScaler()\n", " features_scaled = scaler.fit_transform(features)\n", " \n", " print(\"Applying UMAP...\")\n", " reducer = umap.UMAP(n_components=n_components, random_state=random_state)\n", " return reducer.fit_transform(features_scaled)\n", "\n", "# --- 5. Main Workflow ---\n", "\n", "def embedding_feature(chkpt_dir, image_dir):\n", " \"\"\"\n", " Main function to load model, extract features, and perform UMAP.\n", " \"\"\"\n", " # 1. Find all image files\n", " image_path_list = glob(os.path.join(image_dir, \"*.jpg\"))\n", " print(f\"Number of images to process: {len(image_path_list)}\")\n", "\n", " # 2. Prepare the model and device\n", " model, device = prepare_model(chkpt_dir)\n", "\n", " # 3. Extract features\n", " features, image_paths = [], []\n", " print(\"Extracting features...\")\n", " for image_path in tqdm.tqdm(image_path_list):\n", " try:\n", " features.append(extract_features(model, image_path, device))\n", " image_paths.append(image_path)\n", " except Exception as e:\n", " # Handle potential file read errors or other issues\n", " print(f\"Skipping image {image_path} due to error: {e}\")\n", " continue\n", "\n", " features = np.array(features)\n", " \n", " # 4. Perform UMAP\n", " embedding = perform_umap(features)\n", "\n", " return embedding, image_paths\n", "\n", "# --- 6. Visualization Functions (Grid Plotting) ---\n", "\n", "def plt_image(components, image_paths, save_path='umap_test.png', nx=100, ny=100):\n", " \"\"\"\n", " Creates a detailed visualization where each bin in the UMAP space \n", " is represented by a single, randomly selected image from that bin.\n", " \n", " This is often called a 'datamap' or 'tile plot'.\n", " \"\"\"\n", " print(f\"Creating a {nx}x{ny} tile plot visualization...\")\n", " z_emb = components\n", " iseed = 13579 # Fixed seed for reproducible random selection\n", "\n", " # 1. Define bounds for the UMAP space\n", " xmin, xmax = z_emb[:, 0].min(), z_emb[:, 0].max()\n", " ymin, ymax = z_emb[:, 1].min(), z_emb[:, 1].max()\n", "\n", " # 2. Define bins\n", " binx = np.linspace(xmin, xmax, nx + 1)\n", " biny = np.linspace(ymin, ymax, ny + 1)\n", "\n", " # 3. Bin the UMAP coordinates (x and y)\n", " ret = binned_statistic_2d(z_emb[:, 0], z_emb[:, 1], z_emb[:, 1], 'count', \n", " bins=[binx, biny], expand_binnumbers=True)\n", " z_emb_bins = ret.binnumber.T # Transposed bin numbers (ix, iy) for each point\n", "\n", " inds_lin = np.arange(z_emb.shape[0])\n", " inds_used = [] # Indices of images selected for plotting\n", " plotq = [] # Subplot positions for the selected images\n", "\n", " # 4. Select one image per populated bin\n", " for ix in tqdm.tqdm(range(nx), desc=\"Selecting images for grid\"):\n", " for iy in range(ny):\n", " # Find all image indices (inds) that fall into the current bin (ix, iy)\n", " # Bin numbers are 1-based, so we check for ix+1 and iy+1\n", " dm = (z_emb_bins[:, 0] == ix + 1) & (z_emb_bins[:, 1] == iy + 1)\n", " inds = inds_lin[dm]\n", "\n", " if len(inds) > 0:\n", " # Use a fixed seed based on bin location for reproducible random choice\n", " np.random.seed(ix * ny + iy + iseed) \n", " ind_plt = np.random.choice(inds)\n", " \n", " inds_used.append(ind_plt)\n", " # Calculate the 1D index for the subplot: (row * num_cols) + col + 1\n", " plotq.append(iy * nx + ix) # Adjusted from original: ix + iy * nx\n", "\n", " # 5. Create the plot\n", " print(f\"Plotting {len(inds_used)} images...\")\n", " fig = plt.figure(figsize=(20, 20)) \n", " \n", " for index, i in enumerate(inds_used):\n", " image_path = image_paths[i]\n", " if os.path.exists(image_path): \n", " try:\n", " img_jpg = plt.imread(image_path)\n", " # Subplot index is 1-based. Use plotq[index] + 1\n", " plt.subplot(nx, ny, plotq[index] + 1) \n", " plt.xticks([])\n", " plt.yticks([])\n", " plt.imshow(img_jpg)\n", " except Exception as e:\n", " print(f\"Error loading or plotting image {image_path}: {e}\")\n", " else:\n", " print(f\"File does not exist: {image_path}\")\n", "\n", " plt.suptitle(f\"MAE Features Embedded with UMAP ({len(inds_used)} images in {nx}x{ny} grid)\", fontsize=24)\n", " plt.tight_layout(rect=[0, 0, 1, 0.98]) # Adjust layout for suptitle\n", " plt.savefig(save_path)\n", " print(f\"Visualization saved to {save_path}\")\n", " plt.close(fig) # Use close(fig) instead of just close() if you are running this in a loop or non-interactive environment\n", " # Note: plt.show() is commented out for standard command-line use.\n", "\n", "# --- 7. Execution Block ---\n", "\n", "if __name__ == '__main__':\n", " # Configuration - REPLACE THESE WITH YOUR ACTUAL PATHS\n", " # The checkpoint directory of the pre-trained MAE model\n", " chkpt_dir = './ckpt/norm_ckpt/base/weights/best/epoch_777_loss_0.7236/ckpt.pth' \n", " # The directory containing the images (*.jpg) for feature extraction\n", " image_dir = './dataset' \n", "\n", " if not os.path.exists(chkpt_dir):\n", " print(f\"Error: Checkpoint file not found at {chkpt_dir}. Please update the chkpt_dir variable.\")\n", " elif not os.path.exists(image_dir):\n", " print(f\"Error: Image directory not found at {image_dir}. Please update the image_dir variable.\")\n", " else:\n", " # Perform feature extraction and UMAP reduction\n", " embedding, image_paths = embedding_feature(chkpt_dir, image_dir)\n", " \n", " # Create and save the UMAP visualization\n", " plt_image(embedding, image_paths, save_path='mae_umap_visualization.png')" ] }, { "cell_type": "code", "execution_count": null, "id": "e5928015", "metadata": {}, "outputs": [], "source": [ "import os\n", "import torch\n", "import numpy as np\n", "from PIL import Image\n", "import torchvision.transforms as T\n", "# 假设 models_mae 模块包含您的 MAE 模型定义\n", "import models_mae \n", "\n", "# 定义自定义的归一化值(与您的训练保持一致)\n", "CUSTOM_MEAN = [0.1689, 0.1536, 0.1516]\n", "CUSTOM_STD = [0.1284, 0.0963, 0.1051]\n", "TARGET_SIZE = 224 # ViT-Base 默认的输入尺寸\n", "\n", "# --- 1. 模型准备 (与原代码相同) ---\n", "def prepare_model(chkpt_dir, arch='mae_vit_base_patch16'):\n", " # ... (与原代码 prepare_model 函数完全相同) ...\n", " try:\n", " model = getattr(models_mae, arch)()\n", " except AttributeError:\n", " print(f\"错误: models_mae 模块中找不到架构 '{arch}'。请确保 models_mae.py 文件存在。\")\n", " return None, None\n", " \n", " checkpoint = torch.load(chkpt_dir, map_location='cpu')\n", " state_dict = {k.replace('module.', ''): v for k, v in checkpoint['model'].items()}\n", "\n", " device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", " print(f\"使用的设备: {device}\")\n", " model = model.to(device)\n", " \n", " model.load_state_dict(state_dict, strict=False)\n", " model.eval()\n", " return model, device\n", "\n", "# --- 2. 图像加载与预处理 (关键修改部分) ---\n", "def load_and_preprocess_image(image_path, target_size=TARGET_SIZE):\n", " \"\"\"\n", " 修改后的图像加载函数:\n", " - 不进行缩放(Resize to scale_size)\n", " - 不进行中心裁剪(CenterCrop)\n", " - 而是直接强制缩放到模型的输入尺寸 (target_size x target_size)。\n", " \"\"\"\n", " transform = T.Compose([\n", " # 强制将图像尺寸变为 (target_size, target_size),可能导致拉伸\n", " T.Resize((target_size, target_size), interpolation=T.InterpolationMode.BICUBIC),\n", " T.ToTensor(),\n", " T.Normalize(mean=CUSTOM_MEAN, std=CUSTOM_STD)\n", " ])\n", " \n", " if not os.path.exists(image_path):\n", " raise FileNotFoundError(f\"图像文件未找到: {image_path}\")\n", " \n", " image = Image.open(image_path).convert('RGB')\n", " return transform(image)\n", "\n", "# --- 3. 特征提取核心函数 (与原代码相同) ---\n", "def extract_single_feature(model, image_path, device):\n", " \"\"\"\n", " 提取单个图像的特征向量(CLS Token)。\n", " \"\"\"\n", " img_tensor = load_and_preprocess_image(image_path).unsqueeze(0) # [1, C, H, W]\n", " img_tensor = img_tensor.to(device)\n", " \n", " with torch.no_grad():\n", " x, _, _ = model.forward_encoder(img_tensor.float(), mask_ratio=0) \n", " feature_vector = x[0, 0, :].cpu().numpy()\n", " \n", " return feature_vector\n", "\n", "# --- 4. 示例使用 ---\n", "if __name__ == '__main__':\n", " # 请替换为您的模型和图像路径\n", " CHKPT_DIR = './ckpt/norm_ckpt/base/weights/best/epoch_777_loss_0.7236/ckpt.pth' \n", " SINGLE_IMAGE_PATH = './dataset/example_image_001.jpg' \n", " \n", " print(\"--- MAE 单张图像特征提取开始 (无中心裁剪/预缩放) ---\")\n", "\n", " try:\n", " # 1. 准备模型\n", " model, device = prepare_model(CHKPT_DIR)\n", " if model is None:\n", " exit()\n", "\n", " # 2. 提取特征\n", " feature = extract_single_feature(model, SINGLE_IMAGE_PATH, device)\n", "\n", " # 3. 结果展示\n", " print(\"\\n--- 特征提取结果 ---\")\n", " print(f\"图像路径: {SINGLE_IMAGE_PATH}\")\n", " print(f\"ViT 输入尺寸: {TARGET_SIZE}x{TARGET_SIZE} (通过强制缩放获得)\")\n", " print(f\"特征向量形状: {feature.shape}\")\n", " print(f\"特征向量(前5个值): {feature[:5]}\")\n", " print(\"--- 特征提取完成 ---\")\n", "\n", " except FileNotFoundError as e:\n", " print(f\"致命错误: {e}\")\n", " except Exception as e:\n", " print(f\"发生其他错误: {e}\")" ] }, { "cell_type": "markdown", "id": "f641d493", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.4" } }, "nbformat": 4, "nbformat_minor": 5 }