Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import os | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms.functional as TF | |
| from einops import rearrange | |
| import textwrap | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import cv2 | |
| from itertools import groupby | |
| # For visualizing CLIP feature maps | |
| from sklearn.decomposition import PCA | |
| # Detectron2 for semantic segmentation visualizations | |
| try: | |
| from detectron2.utils.visualizer import ColorMode, Visualizer | |
| from detectron2.data import MetadataCatalog | |
| coco_metadata = MetadataCatalog.get("coco_2017_val_panoptic") | |
| USE_DETECTRON = True | |
| except Exception as e: | |
| print(e) | |
| print("Detectron2 can be used for semseg visualizations. Please install detectron2 to use this feature, or plotting will fall back to matplotlib.") | |
| USE_DETECTRON = False | |
| from fourm.data.modality_transforms import get_transform_key, get_transform_resolution, MetadataTransform | |
| from fourm.utils.data_constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, COCO_SEMSEG_NUM_CLASSES | |
| from fourm.utils import denormalize, get_sentinel_to_id_mapping, merge_span_masking | |
| from fourm.utils.generation import unbatch | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def tensor_to_images(tensor): | |
| """ | |
| Converts a (B C H W) tensor to numpy arrays. | |
| If B = 1, the tensor is unbatched and converted to a single image. | |
| If C = 1, the channel dimension is removed. | |
| Args: | |
| tensor (torch.Tensor): Tensor to convert to images. | |
| """ | |
| B, C, H, W = tensor.shape | |
| if B == 1: | |
| img = rearrange(unbatch(tensor), "c h w -> h w c") | |
| else: | |
| img = rearrange(tensor, "b c h w -> b h w c") | |
| if C == 1: | |
| img = img[..., 0] | |
| return img.detach().cpu().numpy() | |
| def pca_visualize(features, n_components=3): | |
| """ | |
| Visualizes a feature map using PCA. | |
| Args: | |
| features (torch.Tensor): CxHxW feature map to visualize. | |
| n_components (int): Number of PCA components to use. | |
| """ | |
| C, H, W = features.shape | |
| features_flat = rearrange(features.float(), 'c h w -> (h w) c').detach().cpu().numpy() | |
| pca = PCA(n_components=n_components) | |
| img_pca = rearrange(pca.fit_transform(features_flat), '(h w) c -> h w c', h=H, w=W) | |
| img_pca = (img_pca - img_pca.min()) / (img_pca.max() - img_pca.min()) | |
| return img_pca | |
| def np_squeeze(array, axis=0): | |
| """ | |
| Squeeses a numpy array along a given axis if that axis is one-dimensional. | |
| Otherwise, it returns the same array. | |
| Args: | |
| array (numpy.ndarray): Array to squeeze. | |
| axis (int): Axis to squeeze. | |
| """ | |
| if array.shape[axis] == 1: | |
| return np.squeeze(array, axis=axis) | |
| else: | |
| return array | |
| def decode_input_rgb(mod_dict, key='rgb'): | |
| """ | |
| Decodes (denormalizes) an RGB image from a model dictionary. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the RGB modality to decode. | |
| """ | |
| img = denormalize(mod_dict[key]['tensor']) | |
| return tensor_to_images(img) | |
| def decode_tok_rgb(mod_dict, tokenizers, key='tok_rgb', image_size=224, patch_size=16, t=25, verbose=False): | |
| """ | |
| Decodes a sequence of RGB tokens from a model dictionary into an RGB image. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| key (str): Key of the tokenized RGB modality to decode. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| t (int): Number of timesteps to decode using the tokenizer diffusion model (if applicable). | |
| verbose (bool): Whether to print the decoding progress. | |
| """ | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=image_size//patch_size, nw=image_size//patch_size) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok, timesteps=t, image_size=image_size, verbose=verbose) | |
| rec = denormalize(rec, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)).clamp(0, 1) | |
| return tensor_to_images(rec) | |
| def decode_tok_rgb_controlnet(mod_dict, tokenizers, key='tok_rgb', image_size=224, patch_size=16, | |
| t=25, guidance_scale=2.5, cond_scale=0.8, verbose=False): | |
| """ | |
| Decodes a sequence of RGB tokens from a model dictionary into an RGB image using a ControlNet. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. Needs to contain the key 'controlnet'. | |
| key (str): Key of the tokenized RGB modality to decode. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| t (int): Number of timesteps to decode using the ControlNet. | |
| guidance_scale (float): Classifier-free guidance scale. | |
| cond_scale (float): ControlNet conditioning scale. | |
| verbose (bool): Whether to print the decoding progress. | |
| """ | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=image_size//patch_size, nw=image_size//patch_size) | |
| rec = tokenizers['controlnet'].decode_tokens( | |
| img_tok, timesteps=t, guidance_scale=guidance_scale, cond_scale=cond_scale, verbose=verbose | |
| ) | |
| rec = tokenizers['controlnet'].vae_decode(rec) | |
| rec = denormalize(rec, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)).clamp(0, 1) | |
| return tensor_to_images(rec) | |
| def decode_tok_normal(mod_dict, tokenizers, key='tok_normal', image_size=224, patch_size=16, t=25, verbose=False): | |
| """ | |
| Decodes a sequence of surface normal tokens from a model dictionary into an RGB image. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| key (str): Key of the tokenized normal modality to decode. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| t (int): Number of timesteps to decode using the tokenizer diffusion model (if applicable). | |
| verbose (bool): Whether to print the decoding progress. | |
| """ | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=image_size//patch_size, nw=image_size//patch_size) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok, timesteps=t, image_size=image_size, verbose=verbose) | |
| rec = denormalize(rec, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5)).clamp(0, 1) | |
| return tensor_to_images(rec) | |
| def decode_tok_canny_edge(mod_dict, tokenizers, key='tok_canny_edge', image_size=224, patch_size=16, t=10, verbose=False): | |
| """ | |
| Decodes a sequence of Canny edges tokens from a model dictionary into an RGB image. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| key (str): Key of the tokenized Canny edges modality to decode. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| t (int): Number of timesteps to decode using the tokenizer diffusion model (if applicable). | |
| verbose (bool): Whether to print the decoding progress. | |
| """ | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=image_size//patch_size, nw=image_size//patch_size) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok, timesteps=t, image_size=image_size, verbose=verbose) | |
| rec = (0.5*(rec+1)).clamp(0, 1) | |
| return tensor_to_images(rec) | |
| def decode_tok_sam_edge(mod_dict, tokenizers, key='tok_sam_edge', image_size=224, patch_size=16, t=10, verbose=False): | |
| """ | |
| Decodes a sequence of SAM edges from a model dictionary into an RGB image. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| key (str): Key of the tokenized SAM edges modality to decode. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| t (int): Number of timesteps to decode using the tokenizer diffusion model (if applicable). | |
| verbose (bool): Whether to print the decoding progress. | |
| """ | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=image_size//patch_size, nw=image_size//patch_size) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok, timesteps=t, image_size=image_size, verbose=verbose) | |
| rec = (0.5*(rec+1)).clamp(0, 1) | |
| return tensor_to_images(rec) | |
| def decode_tok_depth(mod_dict, tokenizers, key='tok_depth', image_size=224, patch_size=16, t=25, verbose=False, cmap='turbo'): | |
| """ | |
| Decodes a sequence of depth tokens from a model dictionary into an RGB image. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| key (str): Key of the tokenized depth modality to decode. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| t (int): Number of timesteps to decode using the tokenizer diffusion model (if applicable). | |
| verbose (bool): Whether to print the decoding progress. | |
| cmap (str): Colormap to use for the depth image. | |
| """ | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=image_size//patch_size, nw=image_size//patch_size) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok, timesteps=t, image_size=image_size, verbose=verbose) | |
| rec = rec.detach().cpu().numpy()[:,0] | |
| if cmap is None: | |
| return rec | |
| colormap = plt.get_cmap('turbo') | |
| imgs = [] | |
| for img in rec: | |
| img_norm = (img - np.min(img)) / (np.max(img) - np.min(img)) | |
| rgb_image = colormap(img_norm)[..., :3] | |
| imgs.append(rgb_image) | |
| rgb_image = np_squeeze(np.stack(imgs), axis=0) | |
| return rgb_image | |
| def decode_tok_semseg(rgb_img, mod_dict, tokenizers, key='tok_semseg', image_size=224, patch_size=16, use_detectron=True, return_logits=False): | |
| """ | |
| Decodes a sequence of semantic segmentation tokens from a model dictionary into an RGB image. | |
| Args: | |
| rgb_img (torch.Tensor): RGB image to overlay the semantic segmentation on. | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| key (str): Key of the tokenized semantic segmentation modality to decode. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| use_detectron (bool): Uses detectron2's visualization for the semseg output. | |
| """ | |
| tokens = mod_dict[key]['tensor'] | |
| tokens = tokens.unsqueeze(0) if tokens.ndim == 1 else tokens | |
| img_tok = rearrange(tokens, "b (nh nw) -> b nh nw", nh=image_size//patch_size, nw=image_size//patch_size) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok).detach().cpu() | |
| if return_logits: | |
| return rec | |
| semsegs = rec.argmax(1) | |
| B, H, W = semsegs.shape | |
| if not use_detectron: | |
| return semsegs if B > 1 else semsegs[0] | |
| else: | |
| rgb_imgs = [rgb_img] * B | |
| imgs = [] | |
| for rgb, semseg in zip(rgb_imgs, semsegs): | |
| if USE_DETECTRON: | |
| v = Visualizer(255*rgb, coco_metadata, scale=1.2, instance_mode=ColorMode.IMAGE_BW) | |
| img = v.draw_sem_seg((semseg-1).cpu()).get_image() / 255.0 | |
| else: | |
| colormap = plt.get_cmap('viridis') | |
| img = colormap(semseg.cpu())[..., :3] | |
| imgs.append(img) | |
| imgs = np_squeeze(np.stack(imgs), axis=0) | |
| return imgs | |
| def decode_tok_clip(mod_dict, tokenizers, key='tok_clip', image_size=224, patch_size=16): | |
| """ | |
| Decodes a sequence of CLIP tokens from a model dictionary into an PCA representation. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the tokenized CLIP modality to decode. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| """ | |
| n_patches = image_size // patch_size | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=n_patches, nw=n_patches) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok) | |
| pca_viz = [pca_visualize(feat) for feat in rec] | |
| pca_viz = np_squeeze(np.stack(pca_viz), axis=0) | |
| return pca_viz | |
| def decode_tok_dinov2(mod_dict, tokenizers, key='tok_dinov2', image_size=224, patch_size=14): | |
| """ | |
| Decodes a sequence of DINOv2 spatial tokens from a model dictionary into an PCA representation. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the tokenized CLIP modality to decode. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| """ | |
| patch_size = 14 | |
| n_patches = image_size // patch_size | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=n_patches, nw=n_patches) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok) | |
| pca_viz = [pca_visualize(feat) for feat in rec] | |
| pca_viz = np_squeeze(np.stack(pca_viz), axis=0) | |
| return pca_viz | |
| def decode_tok_imagebind(mod_dict, tokenizers, key='tok_imagebind', image_size=224, patch_size=14): | |
| """ | |
| Decodes a sequence of ImageBind spatial tokens from a model dictionary into an PCA representation. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the tokenized CLIP modality to decode. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| """ | |
| patch_size = 14 | |
| n_patches = image_size // patch_size | |
| img_tok = rearrange(mod_dict[key]['tensor'], "b (nh nw) -> b nh nw", nh=n_patches, nw=n_patches) | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(img_tok) | |
| pca_viz = [pca_visualize(feat) for feat in rec] | |
| pca_viz = np_squeeze(np.stack(pca_viz), axis=0) | |
| return pca_viz | |
| def decode_tok_dinov2_global(mod_dict, tokenizers, key='tok_dinov2_global'): | |
| """ | |
| Decodes a sequence of DINOv2 global tokens from a model dictionary. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the tokenized DINOv2 global token modality to decode. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| """ | |
| toks = rearrange(mod_dict[key]['tensor'].long(), 'b n -> b n 1 1') | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(toks) | |
| return rec.squeeze() | |
| def decode_tok_imagebind_global(mod_dict, tokenizers, key='tok_imagebind_global'): | |
| """ | |
| Decodes a sequence of ImageBind global tokens from a model dictionary. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the tokenized ImageBind global token modality to decode. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| """ | |
| toks = rearrange(mod_dict[key]['tensor'].long(), 'b n -> b n 1 1') | |
| rec = tokenizers[get_transform_key(key)].decode_tokens(toks) | |
| return rec.squeeze() | |
| def decode_color_palette(mod_dict, text_tokenizer, key='color_palette'): | |
| """ | |
| Decodes a sequence of color palettes from a model dictionary. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the tokenized ImageBind modality to decode. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| image_size (int): Size of the image. | |
| patch_size (int): Size of the patches. | |
| """ | |
| decoded = decode_text(mod_dict, key, text_tokenizer)[2] | |
| all_decoded = decoded if isinstance(decoded, list) else [decoded] | |
| all_decoded = [d.replace(' [EOS]', '') for d in all_decoded] | |
| all_decoded = [visualize_palettes_multi(d) for d in all_decoded] | |
| all_decoded = all_decoded[0] if len(all_decoded) == 1 else all_decoded | |
| return all_decoded | |
| def decode_human_poses(mod_dict, tokenizers, text_tokenizer, key='human_poses'): | |
| """ | |
| Decodes human poses tokenized with text + BMLP | |
| """ | |
| decoded = decode_text(mod_dict, key, text_tokenizer)[2] | |
| all_decoded = decoded if isinstance(decoded, list) else [decoded] | |
| all_decoded = [d.replace(' [EOS]', '') for d in all_decoded] | |
| imgs = [] | |
| for decoded in all_decoded: | |
| img = np.ones((224,224,4)) | |
| if decoded != 'none': | |
| try: | |
| img = visualize_human_poses(decoded, tokenizers[key], mod_dict) | |
| except Exception as e: | |
| print('Error in decoding human poses. Packages required for plotting may not be installed. Trace:') | |
| print(e) | |
| imgs.append(img) | |
| imgs = np_squeeze(np.stack(imgs), axis=0) | |
| return imgs | |
| metadata_transform = MetadataTransform(shuffle=False, random_trunc=False, return_chunks=False) | |
| def _split_metadata_string(input_string): | |
| result = [] | |
| current_subseq = [] | |
| for part in input_string.split(): | |
| # If we encounter a "v1" and there's already a subsequence being built, | |
| # we add it to the result and start a new one | |
| if 'v1' in part and current_subseq: | |
| result.append(current_subseq) | |
| current_subseq = [] | |
| current_subseq.append(part) | |
| # Append any remaining subsequence to the result | |
| if current_subseq: | |
| result.append(current_subseq) | |
| return result | |
| def decode_metadata(mod_dict, text_tokenizer, key='metadata'): | |
| """ | |
| Decodes a sequence of metadata tokens into a dictionary of metadata. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the metadata modality to decode. | |
| text_tokenizer (tokenizers.Tokenizer): Text tokenizer. | |
| """ | |
| decoded = decode_text(mod_dict, key, text_tokenizer)[2] | |
| all_decoded = decoded if isinstance(decoded, list) else [decoded] | |
| all_decoded = [d.replace(' [EOS]', '').replace(' [PAD]', '') for d in all_decoded] | |
| all_metadata = [] | |
| for decoded in all_decoded: | |
| parts = _split_metadata_string(decoded) | |
| invalid_parts = [] | |
| metadata_dict = {} | |
| for part in parts: | |
| # Check if part has been parsed correctly | |
| if len(part) != 2: | |
| invalid_parts.append(str(part)) | |
| continue | |
| metadata_id, metadata_value = part | |
| if (not metadata_id.startswith('v1=') or | |
| not metadata_value.startswith('v0=') or | |
| metadata_id not in metadata_transform.id_metadata_map): | |
| invalid_parts.append(str(part)) | |
| # Parse metadata type and value | |
| metadata_type = metadata_transform.id_metadata_map[metadata_id] | |
| metadata_value = int(metadata_value.split('=')[1]) | |
| if metadata_type in metadata_transform.image_dim_modalities: | |
| metadata_value *= metadata_transform.image_dim_bin_size | |
| elif metadata_type in metadata_transform.metadata_min_max_bins: | |
| vmin, vmax, bins = metadata_transform.metadata_min_max_bins[metadata_type] | |
| metadata_value = (vmax - vmin) * (metadata_value / bins) + vmin | |
| metadata_dict[metadata_type] = metadata_value | |
| metadata_dict = {k: metadata_dict[k] for k in metadata_transform.metadata_id_map if k in metadata_dict} | |
| all_metadata.append(metadata_dict) | |
| all_metadata = all_metadata[0] if len(all_metadata) == 1 else all_metadata | |
| return all_metadata | |
| def decode_text(mod_dict, key, text_tokenizer): | |
| """ | |
| Decodes a text sequence from a model dictionary. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the text modality to decode. | |
| text_tokenizer (tokenizers.Tokenizer): Text tokenizer. | |
| """ | |
| input_texts, target_texts, merged_texts = [], [], [] | |
| sentinel_ids = set(get_sentinel_to_id_mapping(text_tokenizer).values()) | |
| B = mod_dict[key]['tensor'].shape[0] | |
| for i in range(B): | |
| input_seq = mod_dict[key]['tensor'][i] | |
| input_seq = input_seq[mod_dict[key]['input_mask'][i] == 0] | |
| input_seq = input_seq.tolist() | |
| target_seq = mod_dict[key]['tensor'][i] | |
| target_seq = target_seq[mod_dict[key]['target_mask'][i] == 0] | |
| target_seq = target_seq.tolist() | |
| merged_seq = merge_span_masking(input_seq, target_seq, sentinel_ids=sentinel_ids) | |
| input_text = text_tokenizer.decode(input_seq, skip_special_tokens=False) | |
| target_text = text_tokenizer.decode(target_seq, skip_special_tokens=False) | |
| merged_text = text_tokenizer.decode(merged_seq, skip_special_tokens=False) | |
| input_texts.append(input_text) | |
| target_texts.append(target_text) | |
| merged_texts.append(merged_text) | |
| if B == 1: | |
| input_texts, target_texts, merged_texts = input_texts[0], target_texts[0], merged_texts[0] | |
| return input_texts, target_texts, merged_texts | |
| def decode_sam_instances(mod_dict, tokenizers, text_tokenizer, key='sam_instance', image_size=224, token_len=16): | |
| ''' | |
| Decodes a sequence of SAM instance tokens into the instance representation. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| key (str): Key of the tokenized ImageBind modality to decode. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| text_tokenizer (tokenizers.Tokenizer): Text tokenizer. | |
| image_size (int): Size of the image. | |
| token_len (int): Tokenized SAM instance token length. | |
| ''' | |
| assert image_size == 224, 'SAM instance decoding only supports 224x224 images' | |
| decoded = decode_text(mod_dict, key, text_tokenizer)[2] | |
| all_decoded = decoded if isinstance(decoded, list) else [decoded] | |
| all_decoded = [d.replace(' [EOS]', '') for d in all_decoded] | |
| # Generate deterministic SAM color palette | |
| rng = np.random.default_rng(seed=0) | |
| sam_palette = [rng.integers(0, 255, size=3) for i in range(1000)] | |
| def group_by_identifier(input_list, identifier): | |
| ''' | |
| Groups the input_list [a,b,c,a,d,d,c,..] using the identifier a, in the following format: | |
| [[b,c], [d,d,c], ...] | |
| ''' | |
| return [list(group) for key, group in groupby(input_list, lambda x: x == identifier) if not key] | |
| def map_locations(inp, tokens=False): | |
| ''' | |
| Converts v0, v1, v2, v3 textual representation into int. | |
| When tokens=True, inp is mapped to its corresponding token id. | |
| ''' | |
| if '=' not in inp: | |
| return None | |
| axis, position = inp.split("=") | |
| try: | |
| position = int(position) | |
| except: | |
| return None | |
| if tokens: | |
| if axis == 'v0': | |
| return position | |
| else: | |
| return position + 512 | |
| return position | |
| def iou(box1, box2): | |
| ''' | |
| Calculates iou of the input bounding boxes | |
| ''' | |
| # Calculate the coordinates of the intersection rectangle | |
| x1 = max(box1[0], box2[0]) | |
| y1 = max(box1[1], box2[1]) | |
| x2 = min(box1[2], box2[2]) | |
| y2 = min(box1[3], box2[3]) | |
| # Calculate the area of the intersection | |
| intersection_area = max(0, x2 - x1) * max(0, y2 - y1) | |
| # Calculate the areas of the individual bounding boxes | |
| area_box1 = (box1[2] - box1[0]) * (box1[3] - box1[1]) | |
| area_box2 = (box2[2] - box2[0]) * (box2[3] - box2[1]) | |
| # Calculate the union area | |
| union_area = area_box1 + area_box2 - intersection_area | |
| # Calculate and return the IoU | |
| return intersection_area / union_area | |
| all_sam_instances = [] | |
| for decoded in all_decoded: | |
| tokens_per_sample = [] | |
| bboxes_per_sample = [] | |
| areas_per_sample = [] | |
| parts = decoded.split() | |
| for part in group_by_identifier(parts, identifier='point'): | |
| instances = part[2:] | |
| # Ignore 'none' cases | |
| if len(instances) <= 1: | |
| continue | |
| for positions in group_by_identifier(part, identifier='polygon'): | |
| # Ignore incomplete polygons | |
| if len(positions) != token_len + 4: | |
| continue | |
| bbox, tokens = positions[:4], positions[4:] | |
| min_w, min_h, max_w, max_h = map(map_locations, bbox) | |
| # Ignore the cases where the bounding box is prediction is in incorrect format | |
| if None in [min_w, max_w, min_h, max_h] or (min_w >= max_w or min_h >= max_h): | |
| continue | |
| bbox = np.array([min_h, min_w, max_h, max_w]) | |
| tokens = list(map(lambda x: map_locations(x, tokens=True), tokens)) | |
| if None in tokens: | |
| continue | |
| tokens = np.array(tokens) | |
| tokens_per_sample.append(tokens) | |
| bboxes_per_sample.append(bbox) | |
| areas_per_sample.append((max_w - min_w) * (max_h - min_h)) | |
| final_instances = np.zeros((image_size, image_size, 3), dtype=np.uint8) | |
| if len(areas_per_sample) == 0: | |
| return final_instances | |
| # Sort the instance masks by area | |
| areas_per_sample = np.array(areas_per_sample) | |
| sorted_idx = np.argsort(-areas_per_sample) | |
| tokens_per_sample = np.stack(tokens_per_sample)[sorted_idx] | |
| bboxes_per_sample = np.stack(bboxes_per_sample)[sorted_idx] | |
| # Decoded tokens | |
| tokens_per_sample = torch.LongTensor(tokens_per_sample).reshape(-1, 4, 4).to(device) | |
| decoded_tokens = tokenizers[key].decode_tokens(tokens_per_sample) | |
| instances = torch.sigmoid(decoded_tokens).squeeze(1).cpu().detach().numpy() | |
| # Filter and group instances | |
| representive_masks = [] | |
| representive_bboxes = [] | |
| for (mask, bbox) in zip(instances, bboxes_per_sample): | |
| # Filter out unusual masks | |
| if (mask.max() - mask.min()) < 0.9: | |
| continue | |
| # Groups the duplicated instance masks | |
| duplicated_flag = False | |
| for rms, rbs in zip(representive_masks, representive_bboxes): | |
| rm, rb = rms[0], rbs[0] | |
| sim_score = 2 * ((rm * mask).sum() + 0.01) / (rm.sum() + mask.sum() + 0.01) | |
| box_iou = iou(rb, bbox) | |
| # If the similarity and IoU are high, consider them as the same instance and group them | |
| if sim_score > 0.8 and box_iou > 0.9: | |
| # Add the mask to its corresponding group | |
| rms.append(mask) | |
| rbs.append(bbox) | |
| duplicated_flag = True | |
| break | |
| if not duplicated_flag: | |
| representive_masks.append([mask]) | |
| representive_bboxes.append([bbox]) | |
| # Plot the instances | |
| for i, (rms, rbs) in enumerate(zip(representive_masks, representive_bboxes)): | |
| mask = np.mean(rms, axis=0) | |
| bbox = np.mean(rbs, axis=0).astype(np.int32) | |
| min_h, min_w, max_h, max_w = bbox.tolist() | |
| mask = cv2.resize(mask, (max_w - min_w, max_h - min_h), interpolation=cv2.INTER_CUBIC) | |
| max_w, max_h = min(max_w, final_instances.shape[1]), min(max_h, final_instances.shape[0]) | |
| mask = mask[:max_h - min_h,:max_w - min_w] > 0.5 | |
| final_instances[min_h:max_h, min_w:max_w, :][mask] = sam_palette[i] | |
| all_sam_instances.append(final_instances) | |
| all_sam_instances = all_sam_instances[0] if len(all_sam_instances) == 1 else np.stack(all_sam_instances) | |
| return all_sam_instances | |
| def decode_dict(mod_dict, tokenizers, text_tokenizer, image_size=224, patch_size=16, | |
| decoding_steps=25, activate_controlnet=False, controlnet_guidance_scale=2.5, controlnet_cond_scale=0.8, | |
| to_rgb=True, seed=None): | |
| """ | |
| Decodes the model output dictionary into a dictionary of images and text. | |
| Args: | |
| mod_dict (dict): Model output dictionary. | |
| tokenizers (dict): Dictionary of tokenizers. | |
| text_tokenizer (tokenizers.Tokenizer): Text tokenizer. | |
| image_size (int): Image size. | |
| patch_size (int): Patch size. | |
| decoding_steps (int): Number of diffusion decoding steps (if applicable). | |
| activate_controlnet (bool): Whether to activate the RGB ControlNet and override the RGB detokenizer. | |
| controlnet_guidance_scale (float): Classifier-free guidance scale for the ControlNet. | |
| controlnet_cond_scale (float): ControlNet conditioning scale. | |
| """ | |
| dec_dict = {} | |
| for key in mod_dict: | |
| k, res = get_transform_key(key), get_transform_resolution(key, image_size, to_tuple=False) | |
| if k == 'rgb': | |
| decoded = decode_input_rgb(mod_dict, key=key) | |
| elif k == 'tok_rgb': | |
| if not activate_controlnet or 'controlnet' not in tokenizers: | |
| decoded = decode_tok_rgb( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size, | |
| t=decoding_steps, verbose=False | |
| ) | |
| else: | |
| decoded = decode_tok_rgb_controlnet( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size, | |
| t=decoding_steps, guidance_scale=controlnet_guidance_scale, | |
| cond_scale=controlnet_cond_scale, verbose=False | |
| ) | |
| elif k == 'tok_canny_edge': | |
| decoded = decode_tok_canny_edge( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size, | |
| t=decoding_steps, verbose=False | |
| ) | |
| elif k == 'tok_sam_edge': | |
| decoded = decode_tok_sam_edge( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size, | |
| t=decoding_steps, verbose=False | |
| ) | |
| elif k == 'tok_normal': | |
| decoded = decode_tok_normal( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size, | |
| t=decoding_steps, verbose=False | |
| ) | |
| elif k == 'tok_depth': | |
| decoded = decode_tok_depth( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size, | |
| t=decoding_steps, verbose=False, cmap='turbo' if to_rgb else None | |
| ) | |
| elif k == 'tok_semseg': | |
| decoded = decode_tok_semseg( | |
| np.ones((res, res, 3)), mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size, return_logits=not to_rgb | |
| ) | |
| elif k == 'tok_clip': | |
| decoded = decode_tok_clip( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size | |
| ) | |
| elif k == 'tok_dinov2': | |
| decoded = decode_tok_dinov2( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size | |
| ) | |
| elif k == 'tok_dinov2_global': | |
| decoded = decode_tok_dinov2_global( | |
| mod_dict, tokenizers, key=key | |
| ) | |
| elif k == 'tok_imagebind': | |
| decoded = decode_tok_imagebind( | |
| mod_dict, tokenizers, key=key, | |
| image_size=res, patch_size=patch_size | |
| ) | |
| elif k == 'tok_imagebind_global': | |
| decoded = decode_tok_imagebind_global( | |
| mod_dict, tokenizers, key=key | |
| ) | |
| elif k == 'color_palette': | |
| decoded = decode_color_palette( | |
| mod_dict, text_tokenizer, key=key | |
| ) | |
| elif k == 'human_poses': | |
| decoded = decode_human_poses( | |
| mod_dict, tokenizers, text_tokenizer, key=key | |
| ) | |
| elif k in ['caption', 'det']: | |
| decoded = decode_text(mod_dict, key, text_tokenizer)[2] | |
| decoded = decoded if isinstance(decoded, list) else [decoded] | |
| decoded = [d.replace(' [EOS]', '') for d in decoded] | |
| elif k in ['metadata']: | |
| decoded = decode_metadata( | |
| mod_dict, text_tokenizer, key=key | |
| ) | |
| elif k == 'sam_instance': | |
| decoded = decode_sam_instances( | |
| mod_dict, tokenizers, text_tokenizer, | |
| key=key, image_size=224, | |
| ) | |
| elif k in ['t5_caption']: | |
| if 'ascii_tensor' in mod_dict[key]: | |
| decoded = [] | |
| for ascii_tensor in mod_dict[key]['ascii_tensor']: | |
| ascii_values = ascii_tensor.flatten().tolist() | |
| decoded_text = ''.join(chr(val) for val in ascii_values if val != 0) | |
| decoded.append(f"T5-XXL embedding of: {decoded_text}") | |
| decoded = decoded[0] if len(decoded) == 1 else decoded | |
| else: | |
| decoded = "T5-XXL embedding" | |
| dec_dict[key] = decoded | |
| return dec_dict | |
| # Plotting utils | |
| MOD_PRINT_NAMES = { | |
| 'rgb': 'RGB', | |
| 'tok_rgb': 'RGB (tok)', | |
| 'tok_normal': 'Normal (tok)', | |
| 'tok_depth': 'Depth (tok)', | |
| 'tok_semseg': 'Semseg (tok)', | |
| 'tok_clip': 'CLIP (tok)', | |
| 'tok_canny': 'Canny (tok)', | |
| 'tok_sam': 'SAM (tok)', | |
| 'sam_instance': 'SAM Instances (tok)', | |
| 'rgb@224': 'RGB@224', | |
| 'tok_rgb@224': 'RGB@224 (tok)', | |
| 'tok_normal@224': 'Normal@224 (tok)', | |
| 'tok_depth@224': 'Depth@224 (tok)', | |
| 'tok_semseg@224': 'Semseg@224 (tok)', | |
| 'tok_clip@224': 'CLIP@224 (tok)', | |
| 'rgb@448': 'RGB@448', | |
| 'tok_rgb@448': 'RGB@448 (tok)', | |
| 'tok_normal@448': 'Normal@448 (tok)', | |
| 'tok_depth@448': 'Depth@448 (tok)', | |
| 'tok_semseg@448': 'Semseg@448 (tok)', | |
| 'tok_clip@448': 'CLIP@448 (tok)', | |
| 'caption': 'Caption', | |
| 'det': 'Detection', | |
| 't5_caption': 'T5 XXL', | |
| 'metadata': 'Metadata', | |
| 'human_poses': 'Human poses', | |
| 'color_palette': 'Color palette', | |
| 'tok_dinov2': 'DINOv2 (tok)', | |
| 'tok_dinov2_global': 'DINOv2 global (tok)', | |
| 'tok_imagebind': 'ImageBind (tok)', | |
| 'tok_imagebind_global': 'ImageBind global (tok)', | |
| } | |
| def remove_ticks_and_labels(ax): | |
| """ | |
| Remove the axis ticks and labels | |
| Args: | |
| ax (matplotlib.axes.Axes): Axis to remove ticks and labels from | |
| """ | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| ax.set_xticklabels([]) | |
| ax.set_yticklabels([]) | |
| def remove_spines(ax): | |
| """ | |
| Removes the spines from the given axis. | |
| Args: | |
| ax (matplotlib.axes.Axes): Axis to remove spines from | |
| """ | |
| ax.spines['top'].set_visible(False) | |
| ax.spines['right'].set_visible(False) | |
| ax.spines['bottom'].set_visible(False) | |
| ax.spines['left'].set_visible(False) | |
| def convert_string_to_bboxes(bboxes_str, bins=1000): | |
| """ | |
| Converts a string of bboxes to a list of bboxes. | |
| Args: | |
| bboxes_str (str): String of bboxes | |
| bins (int): Number of bins (default: 1000) | |
| """ | |
| bboxes_str = bboxes_str.split(" ") | |
| bboxes = [] | |
| for token in bboxes_str: | |
| if "=" in token: | |
| coord = token.split("=")[1] | |
| coord = float(coord) / (bins - 1) | |
| if token.startswith("v0="): | |
| bboxes.append([coord,]) | |
| else: | |
| bboxes[-1].append(coord) | |
| elif len(bboxes[-1]) == 4: | |
| bboxes[-1].append(token) | |
| else: | |
| bboxes[-1][4] = " ".join([bboxes[-1][4], token]) | |
| bboxes = [bbox for bbox in bboxes if len(bbox) == 5] | |
| return bboxes | |
| def visualize_palettes_multi(palettes): | |
| palettes = palettes.split() | |
| palettes = palettes[1:] | |
| all_colors = [] | |
| for ii in range(len(palettes)): | |
| all_colors.append(int(palettes[ii][3:])) | |
| w = h = 25 | |
| # construct palette image | |
| o = Image.new("RGB", size=(w * len(palettes)//3, h * len(palettes)//3)) | |
| arr = np.asarray(o).copy() | |
| for ii in range(len(palettes)//3): | |
| arr[:, ii * h : (ii + 1) * h, :] = all_colors[ii*3:(ii+1)*3] | |
| final_palette = arr / 255 | |
| return final_palette | |
| BOX_COLOR = (255, 0, 0) # Red | |
| TEXT_COLOR = (255, 255, 255) # White | |
| try: | |
| from fourm.utils.hmr2_utils.hmr2.models.smpl_wrapper import SMPL | |
| from fourm.utils.hmr2_utils.hmr2.utils.renderer import Renderer, cam_crop_to_full | |
| import pickle as pkl | |
| LIGHT_BLUE=(0.65098039, 0.74117647, 0.85882353) | |
| with open('./fourm/utils/hmr2_utils/model_cfg.pkl','rb') as f: | |
| pose_model_cfg = pkl.load(f) | |
| # Instantiate SMPL model | |
| smpl_cfg = {k.lower(): v for k,v in dict(pose_model_cfg.SMPL).items()} | |
| smpl_cfg['model_path'] = './fourm/utils/hmr2_utils/data/smpl' | |
| smpl_cfg['joint_regressor_extra'] = './fourm/utils/hmr2_utils/data/SMPL_to_J19.pkl' | |
| smpl_cfg['mean_params'] = './fourm/utils/hmr2_utils/data/smpl_mean_params.npz' | |
| smpl = SMPL(**smpl_cfg) | |
| # Setup the renderer | |
| renderer = Renderer(pose_model_cfg, faces=smpl.faces) | |
| except Exception as e: | |
| print(e) | |
| print('Human pose dependencies are not installed, hence poses will not be visualized. To visualize them (optional), you can do the following: \n' \ | |
| '1) Install via `pip install timm yacs smplx pyrender pyopengl==3.1.4` \n' \ | |
| ' You may need to follow the pyrender install instructions: https://pyrender.readthedocs.io/en/latest/install/index.html \n' \ | |
| '2) Download SMPL data from https://smpl.is.tue.mpg.de/. See https://github.com/shubham-goel/4D-Humans/ for an example. \n' \ | |
| '3) Copy the required SMPL files (smpl_mean_params.npz, SMPL_to_J19.pkl, smpl/SMPL_NEUTRAL.pkl) to fourm/utils/hmr2_utils/data .') | |
| def visualize_human_poses(pose, poses_tokenizer, mod_dict): | |
| full_gts = pose | |
| full_gts = full_gts.split() | |
| num_instances = len(full_gts) // 39 # total length of a pose instance seq is 39 | |
| all_verts = [] | |
| all_cam_t = [] | |
| for inst in range(num_instances): | |
| try: | |
| full_gt = full_gts[inst*39:(inst+1)*39] | |
| ##create the pose params dict | |
| all_params = {} | |
| all_params['bbox_xyxy'] = torch.Tensor((int(full_gt[1][3:])/999*224, int(full_gt[2][3:])/999*224, int(full_gt[3][3:])/999*224, int(full_gt[4][3:])/999*224)) | |
| all_params["box_center"] = torch.cat(( ((all_params["bbox_xyxy"][0] + all_params["bbox_xyxy"][2]) / 2.).unsqueeze(0).unsqueeze(1) , ( (all_params["bbox_xyxy"][1] + all_params["bbox_xyxy"][3]) / 2.).unsqueeze(0).unsqueeze(1) ), dim = 1) | |
| all_params["box_size"] = torch.max((all_params["box_center"][0,0] - all_params["bbox_xyxy"][0]) * 2 , (all_params["box_center"][0,1] - all_params["bbox_xyxy"][1]) * 2 ) | |
| all_params["img_size"] = torch.Tensor([224., 224.]) | |
| all_params["img_size"] = all_params["img_size"].unsqueeze(0) | |
| all_params["focal_length"] = torch.Tensor([5000., 5000.]) | |
| for ii in range(len(full_gt)): | |
| if full_gt[ii] == 'camera': | |
| all_params['pred_cam'] = torch.Tensor([ (int(full_gt[ii+1][3:])-49.95)/49.95, (int(full_gt[ii+2][3:])-49.95)/49.95, (int(full_gt[ii+3][3:])-49.95)/49.95 ]) | |
| break | |
| all_params['pred_cam'] = all_params['pred_cam'].unsqueeze(0) | |
| all_params['pred_smpl_params'] = {} | |
| for ii in range(len(full_gt)): | |
| if full_gt[ii] == 'shape': | |
| all_params['pred_smpl_params']['betas'] = torch.Tensor([ (int(full_gt[ii+1][3:])-499.5)/166.5, (int(full_gt[ii+2][3:])-499.5)/166.5, (int(full_gt[ii+3][3:])-499.5)/166.5, (int(full_gt[ii+4][3:])-499.5)/166.5, (int(full_gt[ii+5][3:])-499.5)/166.5, (int(full_gt[ii+6][3:])-499.5)/166.5, (int(full_gt[ii+7][3:])-499.5)/166.5, (int(full_gt[ii+8][3:])-499.5)/166.5, (int(full_gt[ii+9][3:])-499.5)/166.5, (int(full_gt[ii+10][3:])-499.5)/166.5 ]) | |
| break | |
| all_params['pred_smpl_params']['betas'] = all_params['pred_smpl_params']['betas'].unsqueeze(0) | |
| for ii in range(len(full_gt)): | |
| if full_gt[ii] == 'global': | |
| all_params['pred_smpl_params']['global_orient'] = torch.Tensor( [ [(int(full_gt[ii+1][3:])-499.5)/499.5, (int(full_gt[ii+2][3:])-499.5)/499.5, (int(full_gt[ii+3][3:])-499.5)/499.5 ] , [ (int(full_gt[ii+4][3:])-499.5)/499.5, (int(full_gt[ii+5][3:])-499.5)/499.5, (int(full_gt[ii+6][3:])-499.5)/499.5], [(int(full_gt[ii+7][3:])-499.5)/499.5, (int(full_gt[ii+8][3:])-499.5)/499.5, (int(full_gt[ii+9][3:])-499.5)/499.5 ] ] ) | |
| break | |
| all_params['pred_smpl_params']['global_orient'] = all_params['pred_smpl_params']['global_orient'].unsqueeze(0).unsqueeze(0) | |
| body_poses = torch.FloatTensor() | |
| for ii in range(len(full_gt)): | |
| if full_gt[ii] == 'pose': | |
| pose_start = ii | |
| break | |
| for ii in range(8): | |
| pose_curr = ii + pose_start + 1 | |
| if 'v1' in full_gt[pose_curr]: | |
| poses_curr = torch.Tensor([int(full_gt[pose_curr][3:])+512]) | |
| else: | |
| poses_curr = torch.Tensor([int(full_gt[pose_curr][3:])]) | |
| poses_curr = poses_curr | |
| body_poses = torch.cat((body_poses,poses_curr), dim=0) | |
| body_poses = body_poses.long() | |
| body_poses = body_poses.unsqueeze(0).unsqueeze(2).unsqueeze(2).to(device) | |
| body_poses = poses_tokenizer.decode_tokens(body_poses).squeeze(2).squeeze().reshape(1,23,3,3).cpu() | |
| all_params['pred_smpl_params']['body_pose'] = body_poses | |
| smpl_params = (all_params['pred_smpl_params']) | |
| smpl_output = smpl(**{k: v.float().cpu() for k,v in smpl_params.items()}, pose2rot=False) | |
| for n in range(smpl_output.vertices.size(0)): | |
| # Add all verts and cams to list | |
| verts = smpl_output.vertices[n].detach().cpu().numpy() | |
| img_size = all_params["img_size"].float() | |
| pred_cam = all_params['pred_cam'] | |
| box_center = all_params["box_center"].float() | |
| box_size = all_params["box_size"].float() | |
| scaled_focal_length = pose_model_cfg.EXTRA.FOCAL_LENGTH / pose_model_cfg.MODEL.IMAGE_SIZE * img_size.max() | |
| pred_cam_t_full = cam_crop_to_full(pred_cam, box_center, box_size, img_size, scaled_focal_length).detach().cpu().numpy() | |
| cam_t = pred_cam_t_full[n] | |
| all_verts.append(verts) | |
| all_cam_t.append(cam_t) | |
| except Exception as e: | |
| print('Error in decoding human poses: ', end='') | |
| print(e) | |
| continue | |
| try: | |
| input_img = denormalize(mod_dict['rgb@224']['tensor'].squeeze(), mean=(IMAGENET_DEFAULT_MEAN), std=IMAGENET_DEFAULT_STD).permute(1,2,0).cpu() | |
| except Exception as e: | |
| print(e) | |
| input_img = 1. | |
| if 'tok_rgb' in mod_dict: | |
| input_img = decode_tok_rgb(mod_dict, toks, key='tok_rgb') | |
| # Render front view | |
| input_img_overlay = 0.5* input_img[:,:,:3] | |
| if len(all_verts) > 0: | |
| misc_args = dict( | |
| mesh_base_color=LIGHT_BLUE, | |
| scene_bg_color=(1, 1, 1), | |
| focal_length=scaled_focal_length, | |
| ) | |
| cam_view = renderer.render_rgba_multiple(all_verts, cam_t=all_cam_t, render_res=img_size[n], **misc_args) | |
| mask = (cam_view[:,:,0]<1.).astype(int)[:,:,None] | |
| input_img_overlay = 0.5* input_img[:,:,:3] * (1-mask) + cam_view[:,:,:3] * mask | |
| return input_img_overlay | |
| def visualize_bboxes(img, bboxes_str, color=BOX_COLOR, thickness=2): | |
| """ | |
| Visualizes bounding boxes on the image. | |
| Args: | |
| img (np.array): Image to draw bounding boxes on. | |
| bboxes_str (str): String containing bounding boxes in the format: | |
| v0=1 v1=2 v2=3 v3=4 class_name ..., where | |
| v0 is xmin, v1 is ymin, v2 is xmax, v3 is ymax | |
| color (tuple): Color of the bounding box. | |
| thickness (int): Thickness of the bounding box. | |
| """ | |
| if img is None: | |
| img = 255 * np.ones((256,256,3), dtype=np.int32) | |
| img = img.copy() | |
| bboxes_str = bboxes_str.replace('[PAD]', '') | |
| if len(bboxes_str.replace('[EOS]', '')) == 0: | |
| return img | |
| try: | |
| bboxes = convert_string_to_bboxes(bboxes_str.replace(' [EOS]', '')) | |
| except: | |
| return img | |
| for bbox in bboxes: | |
| x_min, y_min, x_max, y_max, class_name = bbox | |
| img_h, img_w = img.shape[0], img.shape[1] | |
| x_min, x_max, y_min, y_max = int(x_min * img_w), int(x_max * img_w), int(y_min * img_h), int(y_max * img_h) | |
| cv2.rectangle(img, (x_min, y_min), (x_max, y_max), color=color, thickness=thickness) | |
| ((text_width, text_height), _) = cv2.getTextSize(class_name.rstrip(), cv2.FONT_HERSHEY_SIMPLEX, 0.35, 1) | |
| cv2.rectangle(img, (x_min, y_min - int(1.3 * text_height)), (x_min + text_width, y_min), BOX_COLOR, -1) | |
| cv2.putText( | |
| img, | |
| text=f"{class_name}", | |
| org=(x_min, y_min - int(0.3 * text_height)), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, | |
| fontScale=0.35, | |
| color=TEXT_COLOR, | |
| lineType=cv2.LINE_AA, | |
| ) | |
| return img | |
| def plot_text_in_square(ax, text, padding=0.5, fontsize=14, wrap_width=50): | |
| """ | |
| Plots text in a square. | |
| Args: | |
| ax (matplotlib.axes.Axes): Matplotlib axis to plot on | |
| text (str): Text to plot | |
| padding (float): Padding around the text | |
| fontsize (int): Font size of the text | |
| wrap_width (int): Width of the text to wrap | |
| """ | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| if isinstance(text, list): | |
| text = text[0] | |
| text = text.replace('[PAD]', '') | |
| # Wrap the text if necessary | |
| wrapped_text = textwrap.fill(text, int(wrap_width)) | |
| # Add the padding | |
| bbox_props = dict(boxstyle="square,pad=" + str(padding), facecolor="white", edgecolor="black") | |
| # Add the text to the plot | |
| ax.text(0.5, 0.5, wrapped_text, ha='center', va='center', fontsize=fontsize, bbox=bbox_props) | |
| remove_ticks_and_labels(ax) | |
| remove_spines(ax) | |
| def text_to_pil_image(text, padding=0.5, fontsize=14, wrap_width=40, image_size=(512, 512)): | |
| """ | |
| Converts text to a PIL image. | |
| Args: | |
| text (str): Text to convert to image | |
| padding (float): Padding around the text | |
| fontsize (int): Font size of the text | |
| wrap_width (int): Width of the text to wrap | |
| image_size (tuple): Size of the output image (width, height) | |
| Returns: | |
| PIL.Image.Image: Generated image with the text | |
| """ | |
| fig, ax = plt.subplots(figsize=(image_size[0] / 100, image_size[1] / 100), dpi=100) | |
| ax.set_xlim(0, 1) | |
| ax.set_ylim(0, 1) | |
| if isinstance(text, list): | |
| text = text[0] | |
| text = text.replace('[PAD]', '') | |
| # Wrap the text if necessary | |
| wrapped_text = textwrap.fill(text, wrap_width) | |
| # Add the padding | |
| bbox_props = dict(boxstyle="square,pad=" + str(padding), facecolor="white", edgecolor="black") | |
| # Add the text to the plot | |
| ax.text(0.5, 0.5, wrapped_text, ha='center', va='center', fontsize=fontsize, bbox=bbox_props) | |
| # Remove ticks, labels, and spines | |
| ax.set_xticks([]) | |
| ax.set_yticks([]) | |
| for spine in ax.spines.values(): | |
| spine.set_visible(False) | |
| # Convert the plot to a PIL image | |
| fig.canvas.draw() | |
| image = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| image = image.reshape(fig.canvas.get_width_height()[::-1] + (3,)) | |
| plt.close(fig) | |
| return Image.fromarray(image) | |
| def plot_modality(dec_dict, key, ax, figscale=4.0): | |
| """ | |
| Plots a single modality. Function name has a typo because of legacy reasons. | |
| Args: | |
| dec_dict (dict): Dictionary of decoded modalities | |
| key (str): Key of the modality to plot | |
| ax (matplotlib.axes.Axes): Matplotlib axis to plot on | |
| figscale (float): Scaling factor for the figure (used to scale the caption box) | |
| """ | |
| modality = dec_dict[key] | |
| k = get_transform_key(key) | |
| if 'tok' in k or k == 'rgb' or k == 'human_poses' or k == 'color_palette': | |
| ax.imshow(modality.clip(0,1)) | |
| elif k == 'caption': | |
| plot_text_in_square(ax, modality, wrap_width=max(1,int(7*figscale))) # 7*figscale turns out to make caption box fit nicely | |
| elif k == 't5_caption': | |
| plot_text_in_square(ax, modality, wrap_width=max(1,int(7*figscale))) # 7*figscale turns out to make caption box fit nicely | |
| elif k == 'metadata': | |
| modality = ',\n'.join([f'{k}: {v:.2f}' if isinstance(v, float) else f'{k}: {v}' for k, v in modality.items()]) | |
| plot_text_in_square(ax, modality, wrap_width=max(1,int(7*figscale)), fontsize=11) | |
| elif k == 'det': | |
| bbox_img = visualize_bboxes(np.ones((224,224,3)), modality, thickness=2) | |
| ax.imshow(bbox_img.clip(0,1)) | |
| def plot_conds_and_targets(cond_domains, target_domains, dec_dicts, save_path=None, fs_titles=15, figscale=4.0, dpi=100): | |
| """ | |
| Plots the conditioning and target modalities for a batch of samples. | |
| Args: | |
| cond_domains (list of str): List of conditioning domains | |
| target_domains (list of str): List of target domains | |
| dec_dicts (list of dicts): List of dictionaries containing the decoded conditioning and target modalities | |
| save_path (str): Path to save the figure. If None, the figure is not saved but plotted instead. | |
| fs_titles (int): Font size of the titles | |
| figscale (float): Scaling factor for the figure size (minimum 4.0 for good results) | |
| dpi (float): Dots per inch for the saved figure | |
| """ | |
| n_cond = len(cond_domains) | |
| n_target = len(target_domains) | |
| n_samples = len(dec_dicts) | |
| ncols = n_samples + 1 if n_cond > 0 else n_samples | |
| nrows = max(n_cond, n_target) | |
| fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(ncols*figscale, nrows*figscale), facecolor='white') | |
| if nrows == 1 and ncols == 1: | |
| ax = np.array([[ax]]) | |
| elif nrows == 1: | |
| ax = np.expand_dims(ax, axis=0) | |
| elif ncols == 1: | |
| ax = np.expand_dims(ax, axis=1) | |
| for cond_idx, cond_domain in enumerate(cond_domains): | |
| axi = ax[cond_idx, 0] | |
| plot_modality(dec_dicts[0], key=cond_domain, ax=axi) | |
| axi.set_title(f'Conditioning: {MOD_PRINT_NAMES[cond_domain]}', fontsize=fs_titles) | |
| # Remove spines that are not needed | |
| if n_cond > 0: | |
| for i in range(n_cond, nrows, 1): | |
| remove_spines(ax[i, 0]) | |
| offset = 0 if n_cond == 0 else 1 | |
| for sample_idx, dec_dict in enumerate(dec_dicts): | |
| for target_idx, target_domain in enumerate(target_domains): | |
| axi = ax[target_idx, sample_idx+offset] | |
| plot_modality(dec_dict, key=target_domain, ax=axi) | |
| axi.set_title(f'{sample_idx+1}.{target_idx+1}: {MOD_PRINT_NAMES[target_domain]}', fontsize=fs_titles) | |
| # Remove spines that are not needed | |
| for i in range(n_target, nrows, 1): | |
| remove_spines(ax[i, sample_idx+offset]) | |
| for ax in fig.axes: | |
| remove_ticks_and_labels(ax) | |
| plt.tight_layout() | |
| if save_path is not None: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| plt.savefig(save_path, bbox_inches='tight', dpi=dpi) #, pil_kwargs={'quality': 30}) | |
| plt.close() | |
| else: | |
| plt.show() | |
| def save_conds_and_targets(cond_domains, target_domains, dec_dicts, save_dir, sample_idx, suffix=None, vis_det=False): | |
| """ | |
| Saves the conditioning and target modalities for a batch of samples. | |
| Args: | |
| cond_domains (list of str): List of conditioning domains | |
| target_domains (list of str): List of target domains | |
| dec_dicts (list of dicts): List of dictionaries containing the decoded conditioning and target modalities | |
| save_dir (str): Path to save the modalities | |
| sample_idx (int): Unique index of the dataset sample | |
| suffix (str): Suffix to append to the saved file names | |
| vis_det (bool): Whether to visualize detection | |
| """ | |
| for variant_idx, dec_dict in enumerate(dec_dicts): | |
| for domain in cond_domains + target_domains: | |
| if variant_idx != 0 and domain in cond_domains: | |
| continue | |
| variant_suffix = f'_{variant_idx}' if domain in target_domains else '' | |
| if suffix is not None: | |
| variant_suffix += f'_{suffix}' | |
| domain_save_dir = os.path.join(save_dir, 'conds' if domain in cond_domains else 'targets', domain) | |
| os.makedirs(domain_save_dir, exist_ok=True) | |
| if 'tok' in domain or domain in ['rgb', 'human_poses', 'color_palette']: | |
| img = Image.fromarray((255 * dec_dict[domain]).astype(np.uint8)) | |
| if domain in ['tok_clip', 'tok_dinov2', 'tok_imagebind']: | |
| img = img.resize((224,224), resample=Image.NEAREST) | |
| save_path = os.path.join(domain_save_dir, f'{sample_idx:06d}{variant_suffix}.png') | |
| img.save(save_path) | |
| elif domain in ['caption', 'det', 'metadata']: | |
| if vis_det: | |
| save_path = os.path.join(domain_save_dir, f'{sample_idx:06d}{variant_suffix}.png') | |
| bbox_img = visualize_bboxes(np.ones((512,512,3)), dec_dict[domain], thickness=2) | |
| bbox_img = Image.fromarray((255 * bbox_img.clip(0,1)).astype(np.uint8)) | |
| bbox_img.save(save_path) | |
| else: | |
| # Save caption as text file | |
| save_path = os.path.join(domain_save_dir, f'{sample_idx:06d}{variant_suffix}.txt') | |
| with open(save_path, 'w') as f: | |
| f.write(dec_dict[domain]) | |
| def plot_images_with_captions(images, captions, save_path=None, dpi=100, wrap_length=40, figscale=4.0): | |
| """ | |
| Plots images with their corresponding captions. | |
| Parameters: | |
| - images (torch.Tensor): A tensor of shape Bx3xHxW with images. | |
| - captions (list): A list of B captions. | |
| """ | |
| assert len(images) == len(captions), "Number of images must match number of captions!" | |
| B = len(images) | |
| sqrt_B = int(B**0.5) | |
| # Determine the number of rows and columns for subplots | |
| nrows = sqrt_B | |
| ncols = (B + nrows - 1) // nrows | |
| fig, axarr = plt.subplots(nrows=nrows, ncols=ncols, figsize=(figscale*ncols, figscale*nrows)) | |
| axarr = np.array([axarr]) if nrows == 1 and ncols == 1 else axarr.ravel() | |
| for i, ax in enumerate(axarr): | |
| if i < B: | |
| # Convert tensor image to numpy | |
| image_np = images[i].permute(1, 2, 0).cpu().float().numpy() | |
| ax.imshow(image_np) | |
| # Place caption below the image | |
| caption_wrapped = textwrap.fill(captions[i], width=wrap_length) | |
| ax.text(0.5, -0.1, caption_wrapped, ha='center', va='top', transform=ax.transAxes, wrap=True) | |
| ax.axis("off") | |
| else: | |
| ax.axis("off") # Hide any additional subplots | |
| plt.subplots_adjust(hspace=0.6) | |
| plt.tight_layout() | |
| if save_path is not None: | |
| os.makedirs(os.path.dirname(save_path), exist_ok=True) | |
| plt.savefig(save_path, bbox_inches='tight', dpi=dpi) | |
| plt.close() | |
| else: | |
| plt.show() |