Spaces:
Runtime error
Runtime error
| import logging | |
| from typing import Any, Union, List, Optional, Tuple, Dict | |
| import open_clip | |
| from open_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD | |
| import torch | |
| from torchvision import transforms | |
| import matplotlib.pyplot as plt | |
| from PIL import Image | |
| import numpy as np | |
| import cv2 as cv2 | |
| from .gem_wrapper import GEMWrapper | |
| _MODELS = { | |
| # B/32 | |
| "ViT-B/32": [ | |
| "openai", | |
| "laion400m_e31", | |
| "laion400m_e32", | |
| "laion2b_e16", | |
| "laion2b_s34b_b79k", | |
| ], | |
| "ViT-B/32-quickgelu": [ | |
| "metaclip_400m", | |
| "metaclip_fullcc" | |
| ], | |
| # B/16 | |
| "ViT-B/16": [ | |
| "openai", | |
| "laion400m_e31", | |
| "laion400m_e32", | |
| "laion2b_s34b_b88k", | |
| ], | |
| "ViT-B/16-quickgelu": [ | |
| "metaclip_400m", | |
| "metaclip_fullcc", | |
| ], | |
| "ViT-B/16-plus-240": [ | |
| "laion400m_e31", | |
| "laion400m_e32" | |
| ], | |
| # L/14 | |
| "ViT-L/14": [ | |
| "openai", | |
| "laion400m_e31", | |
| "laion400m_e32", | |
| "laion2b_s32b_b82k", | |
| ], | |
| "ViT-L/14-quickgelu": [ | |
| "metaclip_400m", | |
| "metaclip_fullcc" | |
| ], | |
| "ViT-L/14-336": [ | |
| "openai", | |
| ] | |
| } | |
| def available_models() -> List[str]: | |
| """Returns the names of available GEM-VL models""" | |
| # _str = "".join([": ".join([key, value]) + "\n" for key, values in _MODELS2.items() for value in values]) | |
| _str = "".join([": ".join([key + " "*(20 - len(key)), value]) + "\n" for key, values in _MODELS.items() for value in values]) | |
| return _str | |
| def get_tokenizer( | |
| model_name: str = '', | |
| context_length: Optional[int] = None, | |
| **kwargs, | |
| ): | |
| """ Wrapper around openclip get_tokenizer function """ | |
| return open_clip.get_tokenizer(model_name=model_name, context_length=context_length, **kwargs) | |
| def get_gem_img_transform( | |
| img_size: Union[int, Tuple[int, int]] = (448, 448), | |
| mean: Optional[Tuple[float, ...]] = None, | |
| std: Optional[Tuple[float, ...]] = None, | |
| ): | |
| mean = mean or OPENAI_DATASET_MEAN | |
| std = std or OPENAI_DATASET_STD | |
| transform = transforms.Compose([ | |
| transforms.Resize(size=img_size, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean, std), | |
| ]) | |
| return transform | |
| def create_gem_model( | |
| model_name: str, | |
| pretrained: Optional[str] = None, | |
| gem_depth: int = 7, | |
| ss_attn_iter: int = 1, | |
| ss_attn_temp: Optional[float] = None, | |
| precision: str = 'fp32', | |
| device: Union[str, torch.device] = 'cpu', | |
| jit: bool = False, | |
| force_quick_gelu: bool = False, | |
| force_custom_text: bool = False, | |
| force_patch_dropout: Optional[float] = None, | |
| force_image_size: Optional[Union[int, Tuple[int, int]]] = None, | |
| force_preprocess_cfg: Optional[Dict[str, Any]] = None, | |
| pretrained_image: bool = False, | |
| pretrained_hf: bool = True, | |
| cache_dir: Optional[str] = None, | |
| output_dict: Optional[bool] = None, | |
| require_pretrained: bool = False, | |
| **model_kwargs, | |
| ): | |
| model_name = model_name.replace("/", "-") | |
| logging.info(f'Loading pretrained {model_name} from pretrained weights {pretrained}...') | |
| open_clip_model = open_clip.create_model(model_name, pretrained, precision, device, jit, force_quick_gelu, force_custom_text, | |
| force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image, | |
| pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs) | |
| tokenizer = open_clip.get_tokenizer(model_name=model_name) | |
| gem_model = GEMWrapper(model=open_clip_model, tokenizer=tokenizer, depth=gem_depth, | |
| ss_attn_iter=ss_attn_iter, ss_attn_temp=ss_attn_temp) | |
| logging.info(f'Loaded GEM-{model_name} from pretrained weights {pretrained}!') | |
| return gem_model | |
| def create_model_and_transforms( | |
| model_name: str, | |
| pretrained: Optional[str] = None, | |
| gem_depth: int = 7, | |
| precision: str = 'fp32', | |
| device: Union[str, torch.device] = 'cpu', | |
| jit: bool = False, | |
| force_quick_gelu: bool = False, | |
| force_custom_text: bool = False, | |
| force_patch_dropout: Optional[float] = None, | |
| force_image_size: Optional[Union[int, Tuple[int, int]]] = None, | |
| force_preprocess_cfg: Optional[Dict[str, Any]] = None, | |
| pretrained_image: bool = False, | |
| pretrained_hf: bool = True, | |
| cache_dir: Optional[str] = None, | |
| output_dict: Optional[bool] = None, | |
| require_pretrained: bool = False, | |
| **model_kwargs, | |
| ): | |
| gem_model = create_gem_model(model_name, pretrained, gem_depth, precision, device, jit, force_quick_gelu, force_custom_text, | |
| force_patch_dropout, force_image_size, force_preprocess_cfg, pretrained_image, | |
| pretrained_hf, cache_dir, output_dict, require_pretrained, **model_kwargs) | |
| transform = get_gem_img_transform(**model_kwargs) | |
| return gem_model, transform | |
| def visualize(image, text, logits, alpha=0.6, save_path=None): | |
| W, H = logits.shape[-2:] | |
| if isinstance(image, Image.Image): | |
| image = image.resize((W, H)) | |
| elif isinstance(image, torch.Tensor): | |
| if image.ndim > 3: | |
| image = image.squeeze(0) | |
| image_unormed = (image.detach().cpu() * torch.Tensor(OPENAI_DATASET_STD)[:, None, None]) \ | |
| + torch.Tensor(OPENAI_DATASET_MEAN)[:, None, None] # undo the normalization | |
| image = Image.fromarray((image_unormed.permute(1, 2, 0).numpy() * 255).astype('uint8')) # convert to PIL | |
| else: | |
| raise f'image should be either of type PIL.Image.Image or torch.Tensor but found {type(image)}' | |
| # plot image | |
| plt.imshow(image) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| if logits.ndim > 3: | |
| logits = logits.squeeze(0) | |
| logits = logits.detach().cpu().numpy() | |
| img_cv = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR) | |
| logits = (logits * 255).astype('uint8') | |
| heat_maps = [cv2.applyColorMap(logit, cv2.COLORMAP_JET) for logit in logits] | |
| vizs = [(1 - alpha) * img_cv + alpha * heat_map for heat_map in heat_maps] | |
| for viz, cls_name in zip(vizs, text): | |
| viz = cv2.cvtColor(viz.astype('uint8'), cv2.COLOR_BGR2RGB) | |
| plt.imshow(viz) | |
| plt.title(cls_name) | |
| plt.axis('off') | |
| plt.tight_layout() | |
| plt.show() | |
| if save_path is not None: | |
| plt.savefig(f'heatmap_{cls_name}.png') | |