Spaces:
Runtime error
Runtime error
| import os | |
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from typing import List | |
| from sklearn.decomposition import PCA | |
| from typing import Optional, Tuple | |
| from PIL import Image | |
| from model.modules.new_object_detection import * | |
| class DIFTLatentStore: | |
| def __init__(self, steps: List[int], up_ft_indices: List[int]): | |
| self.steps = steps | |
| self.up_ft_indices = up_ft_indices | |
| self.dift_features = {} | |
| self.smoothed_dift_features = {} | |
| def __call__(self, features: torch.Tensor, t: int, layer_index: int): | |
| if t in self.steps and layer_index in self.up_ft_indices: | |
| self.dift_features[f'{int(t)}_{layer_index}'] = features | |
| def smooth(self, kernel_size=3, sigma=1): | |
| for key, value in self.dift_features.items(): | |
| if key not in self.smoothed_dift_features: | |
| self.smoothed_dift_features[key] = torch.stack([gaussian_smooth(x, kernel_size=kernel_size, sigma=sigma) for x in value], dim=0) | |
| def copy(self): | |
| copy_dift = DIFTLatentStore(self.steps, self.up_ft_indices) | |
| for key, value in self.dift_features.items(): | |
| copy_dift.dift_features[key] = value.clone() | |
| return copy_dift | |
| def reset(self): | |
| self.dift_features = {} | |
| self.smoothed_dift_features = {} | |
| def gaussian_smooth(input_tensor, kernel_size=3, sigma=1): | |
| kernel = np.fromfunction( | |
| lambda x, y: (1/ (2 * np.pi * sigma ** 2)) * | |
| np.exp(-((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2) / (2 * sigma ** 2)), | |
| (kernel_size, kernel_size) | |
| ) | |
| kernel = torch.Tensor(kernel / kernel.sum()).to(input_tensor.dtype).to(input_tensor.device) | |
| kernel = kernel.unsqueeze(0).unsqueeze(0) | |
| smoothed_slices = [] | |
| for i in range(input_tensor.size(0)): | |
| slice_tensor = input_tensor[i, :, :] | |
| slice_tensor = F.conv2d(slice_tensor.unsqueeze(0).unsqueeze(0), kernel, padding=kernel_size // 2)[0, 0] | |
| smoothed_slices.append(slice_tensor) | |
| smoothed_tensor = torch.stack(smoothed_slices, dim=0) | |
| return smoothed_tensor | |
| def cos_dist(a, b): | |
| a_norm = F.normalize(a, dim=-1) | |
| b_norm = F.normalize(b, dim=-1) | |
| res = a_norm @ b_norm.T | |
| return 1 - res | |
| def extract_patches(feature_map: torch.Tensor, patch_size: int, stride: int) -> torch.Tensor: | |
| # feature_map is (C, H, W). Unfold requires (B, C, H, W). | |
| feature_map = feature_map.unsqueeze(0) # (1, C, H, W) | |
| # Unfold: output shape will be (B, C * patch_size^2, num_patches) | |
| patches = F.unfold( | |
| feature_map, | |
| kernel_size=patch_size, | |
| stride=stride | |
| ) | |
| # Now patches is (1, C*patch_size^2, num_patches) | |
| # Transpose to get shape (num_patches, C*patch_size^2) | |
| patches = patches.squeeze(0).transpose(0, 1) # (num_patches, C*patch_size^2) | |
| return patches | |
| def reassemble_patches( | |
| patches: torch.Tensor, | |
| out_shape: Tuple[int, int, int], | |
| patch_size: int, | |
| stride: int | |
| ) -> torch.Tensor: | |
| C, H, W = out_shape | |
| # 1) Convert from (num_patches, C*patch_size^2) to (B=1, C*patch_size^2, num_patches) | |
| patches_4d = patches.transpose(0, 1).unsqueeze(0) # (1, C*patch_size^2, num_patches) | |
| # 2) fold: reassemble patches to (1, C, H, W) | |
| reassembled = F.fold( | |
| patches_4d, | |
| output_size=(H, W), | |
| kernel_size=patch_size, | |
| stride=stride | |
| ) | |
| # 3) Create a divisor mask to account for overlapping regions. | |
| # We do this by folding a "ones" tensor of the same shape as patches_4d. | |
| ones_input = torch.ones_like(patches_4d) | |
| overlap_count = F.fold( | |
| ones_input, | |
| output_size=(H, W), | |
| kernel_size=patch_size, | |
| stride=stride | |
| ) | |
| # 4) Divide to normalize overlapping areas | |
| reassembled = reassembled / overlap_count.clamp_min(1e-8) | |
| # 5) Remove the batch dimension -> (C, H, W) | |
| reassembled = reassembled.squeeze(0) | |
| return reassembled | |
| def calculate_patch_distance(index1: int, index2: int, grid_size: int, stride: int, patch_size: int) -> float: | |
| row1, col1 = index1 // grid_size, index1 % grid_size | |
| row2, col2 = index2 // grid_size, index2 % grid_size | |
| # print('row1, col1:', row1, col1) | |
| x_center1, y_center1 = (row1 * stride) + (patch_size / 2), (col1 * stride) + (patch_size / 2) | |
| x_center2, y_center2 = (row2 * stride) + (patch_size / 2), (col2 * stride) + (patch_size / 2) | |
| return math.sqrt((x_center2 - x_center1)**2 + (y_center2 - y_center1)**2) | |
| def gen_nn_map( | |
| latent, | |
| src_features, | |
| tgt_features, | |
| device, | |
| kernel_size=3, | |
| stride=1, | |
| return_newness=False, | |
| **kwargs | |
| ): | |
| batch_size = kwargs.get("batch_size", None) | |
| timestep = kwargs.get("timestep", None) | |
| if kwargs.get("visualize", False): | |
| dift_visualization(src_features, tgt_features, filename_out=f"output/feat_colors_{timestep}.png") | |
| src_patches = extract_patches(src_features, kernel_size, stride) | |
| tgt_patches = extract_patches(tgt_features, kernel_size, stride) | |
| if isinstance(latent, list): | |
| latent_patches = [extract_patches(l, kernel_size, stride) for l in latent] | |
| else: | |
| latent_patches = extract_patches(latent, kernel_size, stride) | |
| num_tgt = src_patches.size(0) | |
| batch = batch_size or num_tgt | |
| nearest_neighbor_indices = torch.empty(num_tgt, dtype=torch.long, device=device) | |
| nearest_neighbor_distances = torch.empty(num_tgt, dtype=torch.long, device=device) | |
| dist_chunks = [] | |
| for start in range(0, num_tgt, batch): | |
| sims = cos_dist(src_patches, tgt_patches[start : start + batch]) | |
| dist_chunks.append(sims) | |
| min_distances, best_idx = sims.min(0) | |
| nearest_neighbor_indices[start : start + batch] = best_idx | |
| nearest_neighbor_distances[start : start + batch] = min_distances | |
| if not isinstance(latent, list): | |
| aligned_latent = latent_patches[nearest_neighbor_indices] | |
| aligned_latent = reassemble_patches(aligned_latent, latent.shape, kernel_size, stride) | |
| else: | |
| aligned_latent = [latent_patches[i][nearest_neighbor_indices] for i in range(len(latent_patches))] | |
| aligned_latent = [reassemble_patches(l, latent[0].shape, kernel_size, stride) for l in aligned_latent] | |
| if return_newness: | |
| dist_matrix = torch.cat(dist_chunks, dim=0) | |
| newness_method = 'two_sided' | |
| # newness_method = 'distance' | |
| if newness_method.lower() == "distance": | |
| newness = detect_newness_distance(nearest_neighbor_distances, quantile=0.97) | |
| elif newness_method.lower() == "two_sided": | |
| newness = detect_newness_two_sided(dist_matrix, k=4) | |
| out_shape = latent[0].shape if isinstance(latent, list) else latent.shape | |
| out_shape = (1, out_shape[1], out_shape[2]) | |
| newness = reassemble_patches(newness.unsqueeze(-1), out_shape, kernel_size, stride) | |
| del src_patches, tgt_patches, latent_patches, nearest_neighbor_indices, nearest_neighbor_distances | |
| ################## visualization of changing source features to match target ################## | |
| if False: | |
| updated_src_patches = src_patches[nearest_neighbor_indices] | |
| updated_src_patches = reassemble_patches(updated_src_patches, src_features.shape, kernel_size, stride) | |
| dift_visualization( | |
| updated_src_patches, tgt_features, | |
| filename_out=f"output/updated_feat_colors_{timestep}.png", | |
| ) | |
| if return_newness: | |
| if isinstance(aligned_latent, list): | |
| aligned_latent.append(newness) | |
| else: | |
| return aligned_latent, newness | |
| return aligned_latent | |
| def dift_visualization( | |
| src_feature: torch.Tensor, | |
| tgt_feature: torch.Tensor, | |
| filename_out: str, | |
| resize_to: Optional[Tuple[int, int]] = (512, 512) | |
| ): | |
| """ | |
| Flatten features, apply PCA for 3D embedding, normalize for RGB, then reshape and save as image | |
| """ | |
| C, H_s, W_s = src_feature.shape | |
| _, H_t, W_t = tgt_feature.shape | |
| src_flat = src_feature.permute(1, 2, 0).reshape(-1, C) # (H_s*W_s, C) | |
| tgt_flat = tgt_feature.permute(1, 2, 0).reshape(-1, C) # (H_t*W_t, C) | |
| all_features = torch.cat([src_flat, tgt_flat], dim=0) # shape: (N_total, C) | |
| all_features_np = all_features.detach().cpu().numpy() | |
| num_components = 3 | |
| pca = PCA(n_components=num_components) | |
| all_features_3d = pca.fit_transform(all_features_np) # shape: (N_total, 3) | |
| # 6) Normalize each dimension to [0,1] | |
| def normalize_to_01(array_2d): | |
| min_vals = array_2d.min(axis=0) | |
| max_vals = array_2d.max(axis=0) | |
| denom = (max_vals - min_vals) + 1e-8 | |
| return (array_2d - min_vals) / denom | |
| all_features_rgb = normalize_to_01(all_features_3d) | |
| N_src = H_s * W_s | |
| src_rgb_flat = all_features_rgb[:N_src] # (N_src, 3) | |
| tgt_rgb_flat = all_features_rgb[N_src:] # (N_tgt, 3) | |
| src_color_map = src_rgb_flat.reshape(H_s, W_s, 3) | |
| tgt_color_map = tgt_rgb_flat.reshape(H_t, W_t, 3) | |
| src_img = Image.fromarray((src_color_map * 255).astype(np.uint8)) | |
| tgt_img = Image.fromarray((tgt_color_map * 255).astype(np.uint8)) | |
| src_img_resized = src_img.resize(resize_to, Image.Resampling.LANCZOS) | |
| tgt_img_resized = tgt_img.resize(resize_to, Image.Resampling.LANCZOS) | |
| combined_width = resize_to[0] * 2 | |
| combined_height = resize_to[1] | |
| combined_img = Image.new("RGB", (combined_width, combined_height)) | |
| combined_img.paste(src_img_resized, (0, 0)) | |
| combined_img.paste(tgt_img_resized, (resize_to[0], 0)) | |
| os.makedirs(os.path.dirname(filename_out), exist_ok=True) | |
| combined_img.save(filename_out) | |
| print(f"Saved visualization to {filename_out}") | |