Spaces:
Running
on
Zero
Running
on
Zero
File size: 10,255 Bytes
c8b42eb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
"""
PCA Visualization of UniCeption Image Encoders
"""
import os
import random
from functools import lru_cache
from typing import Tuple
import numpy as np
import requests
import torch
import torch.nn.functional as F
from matplotlib import pyplot as plt
from PIL import Image
from sklearn.decomposition import PCA
from uniception.models.encoders import *
from uniception.models.encoders.image_normalizations import *
class TestEncoders:
def __init__(self, pca_save_folder, *args, **kwargs):
super(TestEncoders, self).__init__(*args, **kwargs)
self.pca_save_folder = pca_save_folder
self.norm_types = IMAGE_NORMALIZATION_DICT.keys()
self.encoders = [
"croco",
"dust3r_224",
"dust3r_512",
"dust3r_512_dpt",
"mast3r_512",
"dinov2_large",
"dinov2_large_reg",
"dinov2_large_dav2",
"dinov2_giant",
"dinov2_giant_reg",
"radio_v2.5-b",
"radio_v2.5-l",
"e-radio_v2",
]
self.encoder_configs = [{}] * len(self.encoders)
def inference_encoder(self, encoder, input):
return encoder(input)
def visualize_all_encoders(self):
for encoder, encoder_config in zip(self.encoders, self.encoder_configs):
encoder = _make_encoder_test(encoder, **encoder_config)
self._visualize_encoder_features_consistency(encoder, (224, 224))
def _visualize_encoder_features(self, encoder, image_size: Tuple[int, int]):
img, viz_img = self._get_example_input(image_size, encoder.data_norm_type, return_viz_img=True)
# input and output of the encoder
encoder_input: ViTEncoderInput = ViTEncoderInput(
data_norm_type=encoder.data_norm_type,
image=img,
)
encoder_output = self.inference_encoder(encoder, encoder_input)
encoder_output = encoder_output.features
self.assertTrue(isinstance(encoder_output, torch.Tensor))
# visualize the features
pca_viz = get_pca_map(encoder_output.permute(0, 2, 3, 1), image_size, return_pca_stats=False)
# plot the input image and the PCA features
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(viz_img)
axs[0].set_title("Input Image")
axs[0].axis("off")
axs[1].imshow(pca_viz)
axs[1].set_title(f"PCA Features of {encoder.name}")
axs[1].axis("off")
plt.savefig(f"{self.pca_save_folder}/pca_{encoder.name}.png", bbox_inches="tight")
plt.close()
def _visualize_encoder_features_consistency(self, encoder, image_size: Tuple[int, int]):
img0, viz_img0 = self._get_example_input(
image_size, encoder.data_norm_type, img_selection=1, return_viz_img=True
)
img1, viz_img1 = self._get_example_input(
image_size, encoder.data_norm_type, img_selection=2, return_viz_img=True
)
# input and output of the encoder
encoder_input0: ViTEncoderInput = ViTEncoderInput(
data_norm_type=encoder.data_norm_type,
image=img0,
)
encoder_input1: ViTEncoderInput = ViTEncoderInput(
data_norm_type=encoder.data_norm_type,
image=img1,
)
encoder_output0 = self.inference_encoder(encoder, encoder_input0)
encoder_output0 = encoder_output0.features
encoder_output1 = self.inference_encoder(encoder, encoder_input1)
encoder_output1 = encoder_output1.features
# get a common PCA codec
cat_feats = torch.cat([encoder_output0, encoder_output1], dim=3)
pca_viz = get_pca_map(cat_feats.permute(0, 2, 3, 1), (image_size[0], image_size[1] * 2), return_pca_stats=True)
# concatenate the input images along the width dimension
cat_imgs = torch.cat([viz_img0, viz_img1], dim=1)
# plot the input image and the PCA features
fig, axs = plt.subplots(1, 2, figsize=(12, 6))
axs[0].imshow(cat_imgs)
axs[0].set_title("Input Images")
axs[0].axis("off")
axs[1].imshow(pca_viz[0])
axs[1].set_title(f"PCA Features of {encoder.name}")
axs[1].axis("off")
plt.savefig(f"{self.pca_save_folder}/multi_pca_{encoder.name}.png", bbox_inches="tight")
plt.close()
@lru_cache(maxsize=3)
def _get_example_input(
self,
image_size: Tuple[int, int],
image_norm_type: str = "dummy",
img_selection: int = 1,
return_viz_img: bool = False,
) -> torch.Tensor:
url = f"https://raw.githubusercontent.com/naver/croco/d3d0ab2858d44bcad54e5bfc24f565983fbe18d9/assets/Chateau{img_selection}.png"
image = Image.open(requests.get(url, stream=True).raw)
image = image.resize(image_size)
image = image.convert("RGB")
img = torch.from_numpy(np.array(image))
viz_img = img.clone()
# Normalize the images
image_normalization = IMAGE_NORMALIZATION_DICT[image_norm_type]
img_mean, img_std = image_normalization.mean, image_normalization.std
img = (img.float() / 255.0 - img_mean) / img_std
# convert to BCHW format
img = img.permute(2, 0, 1).unsqueeze(0)
if return_viz_img:
return img, viz_img
else:
return img
def render_pca_as_rgb(features):
"""
Perform PCA on the given feature tensor and render the first 3 principal components as RGB.
Args:
features (torch.Tensor): Feature tensor of shape (B, C, H, W).
Returns:
np.ndarray: RGB image of shape (H, W, 3).
"""
# Ensure input is a 4D tensor
assert features.dim() == 4, "Input tensor must be 4D (B, C, H, W)"
B, C, H, W = features.shape
# Reshape the tensor to (B * H * W, C)
reshaped_features = features.permute(0, 2, 3, 1).contiguous().view(-1, C).cpu().numpy()
# Perform PCA
pca = PCA(n_components=3)
principal_components = pca.fit_transform(reshaped_features)
# Rescale the principal components to [0, 1]
principal_components = (principal_components - principal_components.min(axis=0)) / (
principal_components.max(axis=0) - principal_components.min(axis=0)
)
# Reshape the principal components to (B, H, W, 3)
principal_components = principal_components.reshape(B, H, W, 3)
# Convert the principal components to RGB image (take the first batch)
rgb_image = principal_components[0]
return rgb_image
def get_robust_pca(features: torch.Tensor, m: float = 2, remove_first_component=False):
# features: (N, C)
# m: a hyperparam controlling how many std dev outside for outliers
assert len(features.shape) == 2, "features should be (N, C)"
reduction_mat = torch.pca_lowrank(features, q=3, niter=20)[2]
colors = features @ reduction_mat
if remove_first_component:
colors_min = colors.min(dim=0).values
colors_max = colors.max(dim=0).values
tmp_colors = (colors - colors_min) / (colors_max - colors_min)
fg_mask = tmp_colors[..., 0] < 0.2
reduction_mat = torch.pca_lowrank(features[fg_mask], q=3, niter=20)[2]
colors = features @ reduction_mat
else:
fg_mask = torch.ones_like(colors[:, 0]).bool()
d = torch.abs(colors[fg_mask] - torch.median(colors[fg_mask], dim=0).values)
mdev = torch.median(d, dim=0).values
s = d / mdev
try:
rins = colors[fg_mask][s[:, 0] < m, 0]
gins = colors[fg_mask][s[:, 1] < m, 1]
bins = colors[fg_mask][s[:, 2] < m, 2]
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
except:
rins = colors
gins = colors
bins = colors
rgb_min = torch.tensor([rins.min(), gins.min(), bins.min()])
rgb_max = torch.tensor([rins.max(), gins.max(), bins.max()])
return reduction_mat, rgb_min.to(reduction_mat), rgb_max.to(reduction_mat)
def get_pca_map(
feature_map: torch.Tensor,
img_size,
interpolation="bicubic",
return_pca_stats=False,
pca_stats=None,
):
"""
feature_map: (1, h, w, C) is the feature map of a single image.
"""
if feature_map.shape[0] != 1:
# make it (1, h, w, C)
feature_map = feature_map[None]
if pca_stats is None:
reduct_mat, color_min, color_max = get_robust_pca(feature_map.reshape(-1, feature_map.shape[-1]))
else:
reduct_mat, color_min, color_max = pca_stats
pca_color = feature_map @ reduct_mat
pca_color = (pca_color - color_min) / (color_max - color_min)
pca_color = pca_color.clamp(0, 1)
pca_color = F.interpolate(
pca_color.permute(0, 3, 1, 2),
size=img_size,
mode=interpolation,
).permute(0, 2, 3, 1)
pca_color = pca_color.detach().cpu().numpy().squeeze(0)
if return_pca_stats:
return pca_color, (reduct_mat, color_min, color_max)
return pca_color
def seed_everything(seed=42):
"""
Set the `seed` value for torch and numpy seeds. Also turns on
deterministic execution for cudnn.
Parameters:
- seed: A hashable seed value
"""
random.seed(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
print(f"Seed set to: {seed} (type: {type(seed)})")
if __name__ == "__main__":
# Turn XFormers off for testing on CPU
os.environ["XFORMERS_DISABLED"] = "1"
# Seed everything for consistent testing
seed_everything()
# Create local directory for storing the PCA images
current_file_path = os.path.abspath(__file__)
relative_pca_image_folder = os.path.join(os.path.dirname(current_file_path), "../../../local/encoders/pca_images")
os.makedirs(relative_pca_image_folder, exist_ok=True)
# Initialize the test class
test = TestEncoders(pca_save_folder=relative_pca_image_folder)
# Visualize the PCA of all encoders
test.visualize_all_encoders()
print(f"The PCA visualizations of all encoders are saved successfully to {relative_pca_image_folder}!")
|