DiffICM / 1_feature_extractor /fast_vis_proteus_feats.py
Qiyp's picture
code of stage1 & 3, remove large files
1633fcc
import math
import matplotlib.pyplot as plt
import numpy as np
import einops
import torch
from PIL import Image
import torchvision
from PIL import Image
import models_dinov2
device = "cuda" if torch.cuda.is_available() else "cpu"
patch_size = 14
# feat_extractor = getattr(models_dinov2, 'vit_base')
feat_extractor = getattr(models_dinov2, 'vit_large')
model = feat_extractor(img_size=224,
patch_size=14,
init_values=1.0,
ffn_layer='mlp',
block_chunks=0,
num_register_tokens=0,
interpolate_antialias=False,
interpolate_offset=0.1)
# checkpoint_path = '/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitb_backbone.pth' # 替换为实际的检查点路径
checkpoint_path = '/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitl_backbone.pth' # 替换为实际的检查点路径
# 加载检查点
checkpoint = torch.load(checkpoint_path, map_location='cpu')
# 加载模型权重
if 'state_dict' in checkpoint:
pretrained_dict = checkpoint['state_dict']
elif 'model' in checkpoint:
pretrained_dict = checkpoint['model']
else:
pretrained_dict = checkpoint
# 只加载与学生模型相关的部分
model.load_state_dict(pretrained_dict, strict=False)
model.to(device)
patch_h = 224 // 14
patch_w = 224 // 14
feat_dim = 768
def visualize_features(features, output_path='./feature_visualization.png'):
# Assuming features are of shape (batch_size, num_features, height, width)
batch_size, num_features, height, width = features.shape
# Normalize the feature maps to the range [0, 1]
vis = features.mean(dim=1, keepdim=True)
vis = vis - vis.min()
vis = vis / vis.max()
# Squeeze the channel dimension
vis = vis.squeeze(1).cpu().detach().numpy()
# Apply a colormap (e.g., viridis) to convert it to RGB
vis_colored = np.zeros((batch_size, height, width, 3))
for i in range(batch_size):
vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] # Drop the alpha channel
# Convert vis_colored to a tensor and save using torchvision
vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) # Convert to (batch, channels, height, width)
# Save the image
torchvision.utils.save_image(vis_colored, output_path, normalize=True)
from torchvision import transforms
transform = transforms.Compose([
transforms.Resize((224, 224)), # 调整图像大小
transforms.ToTensor(), # 转换为tensor
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 标准化
])
images = [
Image.open("/data0/qiyp/mae/imagenet-1k-samples/0-anime_boy_sticker__holding_kitten__happy.png"),
Image.open("/data0/qiyp/mae/imagenet-1k-samples/62-Deadpool_minion.png"),
Image.open("/data0/qiyp/mae/imagenet-1k-samples/79-with_Wooden_carved_bear__salmon_and_gold_mini_ball_surround_the_blank_signboard__illustrate.png"),
Image.open("/data0/qiyp/mae/imagenet-1k-samples/99-Akira_toriyama_motorbike__cheatah__puma__japanese_classic_car__collectable_figure__shiny_plastic_.png"),
Image.open("/data0/qiyp/mae/imagenet-1k-samples/124-crowded_1920s_Chicago_street_with_lots_of_model_T_cars_and_people.png"),
Image.open("/data0/qiyp/mae/imagenet-1k-samples/157-steampunk_girl_with_pink_hair_riding_in_a_hot_air_balloon__hot_air_balloon_resembles_gold_and_si.png"),
Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00008636.png"),
Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00010240.png"),
]
# inputs = processor(images=images, return_tensors="pt", padding=True).to(device)
tensors = [transform(img) for img in images]
batched_tensors = torch.stack(tensors).to(device)
with torch.no_grad():
outputs = model(batched_tensors, is_training=True)
features = outputs['x_norm_patchtokens'] # (batch_size, num_patches, feat_dim)
print(features.shape)
features = features.view(-1, patch_h, patch_w, features.shape[2]) # [B, h, w, c]
features = features.permute(0, 3, 1, 2)
visualize_features(features)
# pooled_output = outputs.pooler_output # pooled CLS states.