| import math |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import einops |
| import torch |
| from PIL import Image |
| import torchvision |
| from PIL import Image |
| import models_dinov2 |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| patch_size = 14 |
|
|
| |
| feat_extractor = getattr(models_dinov2, 'vit_large') |
| model = feat_extractor(img_size=224, |
| patch_size=14, |
| init_values=1.0, |
| ffn_layer='mlp', |
| block_chunks=0, |
| num_register_tokens=0, |
| interpolate_antialias=False, |
| interpolate_offset=0.1) |
|
|
| |
| checkpoint_path = '/data0/qiyp/Proteus-pytorch/pretrain/ckpt/proteus_vitl_backbone.pth' |
| |
| checkpoint = torch.load(checkpoint_path, map_location='cpu') |
|
|
| |
| if 'state_dict' in checkpoint: |
| pretrained_dict = checkpoint['state_dict'] |
| elif 'model' in checkpoint: |
| pretrained_dict = checkpoint['model'] |
| else: |
| pretrained_dict = checkpoint |
|
|
| |
| model.load_state_dict(pretrained_dict, strict=False) |
| model.to(device) |
| patch_h = 224 // 14 |
| patch_w = 224 // 14 |
| feat_dim = 768 |
|
|
| def visualize_features(features, output_path='./feature_visualization.png'): |
| |
| batch_size, num_features, height, width = features.shape |
| |
| |
| vis = features.mean(dim=1, keepdim=True) |
| vis = vis - vis.min() |
| vis = vis / vis.max() |
| |
| |
| vis = vis.squeeze(1).cpu().detach().numpy() |
| |
| |
| vis_colored = np.zeros((batch_size, height, width, 3)) |
| for i in range(batch_size): |
| vis_colored[i] = plt.cm.viridis(vis[i])[:, :, :3] |
| |
| |
| vis_colored = torch.tensor(vis_colored).permute(0, 3, 1, 2) |
| |
| |
| torchvision.utils.save_image(vis_colored, output_path, normalize=True) |
|
|
| from torchvision import transforms |
|
|
| transform = transforms.Compose([ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
| ]) |
|
|
| images = [ |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/0-anime_boy_sticker__holding_kitten__happy.png"), |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/62-Deadpool_minion.png"), |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/79-with_Wooden_carved_bear__salmon_and_gold_mini_ball_surround_the_blank_signboard__illustrate.png"), |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/99-Akira_toriyama_motorbike__cheatah__puma__japanese_classic_car__collectable_figure__shiny_plastic_.png"), |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/124-crowded_1920s_Chicago_street_with_lots_of_model_T_cars_and_people.png"), |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/157-steampunk_girl_with_pink_hair_riding_in_a_hot_air_balloon__hot_air_balloon_resembles_gold_and_si.png"), |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00008636.png"), |
| Image.open("/data0/qiyp/mae/imagenet-1k-samples/ILSVRC2012_val_00010240.png"), |
| ] |
| |
| tensors = [transform(img) for img in images] |
| batched_tensors = torch.stack(tensors).to(device) |
| with torch.no_grad(): |
| outputs = model(batched_tensors, is_training=True) |
| features = outputs['x_norm_patchtokens'] |
| print(features.shape) |
|
|
| features = features.view(-1, patch_h, patch_w, features.shape[2]) |
| features = features.permute(0, 3, 1, 2) |
| visualize_features(features) |
|
|
| |