Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from typing import Optional | |
| import hydra | |
| import omegaconf | |
| import torch | |
| _VIP_IMPORT_ERROR = None | |
| try: | |
| import vip | |
| except ImportError as e: | |
| _VIP_IMPORT_ERROR = e | |
| from torch import nn | |
| from torchvision import transforms as T | |
| import uvd.utils as U | |
| from uvd.models.preprocessors.base import Preprocessor | |
| class VipPreprocessor(Preprocessor): | |
| def __init__( | |
| self, | |
| model_type: str | None = None, | |
| device: torch.device | str | None = None, | |
| remove_bn: bool = False, | |
| bn_to_gn: bool = False, | |
| remove_pool: bool = False, | |
| preprocess_with_fc: bool = False, | |
| save_fc: bool = False, | |
| random_crop: bool = False, | |
| ckpt: str | None = None, | |
| **kwargs, | |
| ): | |
| if _VIP_IMPORT_ERROR is not None: | |
| raise ImportError(_VIP_IMPORT_ERROR) | |
| model_type = model_type or "resnet50" | |
| self.random_crop = random_crop | |
| self.ckpt = ckpt | |
| super().__init__( | |
| model_type=model_type, | |
| device=device, | |
| remove_bn=remove_bn, | |
| bn_to_gn=bn_to_gn, | |
| remove_pool=remove_pool, | |
| preprocess_with_fc=preprocess_with_fc, | |
| save_fc=save_fc, | |
| **kwargs, | |
| ) | |
| def _get_model_and_transform( | |
| self, model_type: str | None = None | |
| ) -> tuple[vip.VIP, Optional[T]]: | |
| if model_type is not None: | |
| assert model_type == "resnet50", f"{model_type} not support" | |
| vip.device = self.device | |
| vip_ = load_vip(modelid="resnet50", ckpt_path=self.ckpt).module | |
| resnet = vip_.convnet.to(self.device) | |
| if self.remove_pool: | |
| # if self.save_fc: | |
| self._pool = U.freeze_module(resnet.avgpool) | |
| self._fc = U.freeze_module(resnet.fc) | |
| model = nn.Sequential(*(list(resnet.children())[:-2])) | |
| else: | |
| model = resnet | |
| # crop_transform = T.RandomCrop(224) if self.random_crop else T.CenterCrop(224) | |
| transform = ( | |
| # nn.Sequential(T.Resize(224), vip_.normlayer) | |
| T.Compose([T.Resize(224), vip_.normlayer]) | |
| if not self.random_crop | |
| # else nn.Sequential(T.Resize(232), T.RandomCrop(224), vip_.normlayer) | |
| else T.Compose([T.Resize(232), T.RandomCrop(224), vip_.normlayer]) | |
| ) | |
| return model, transform | |
| def load_vip(modelid: str = "resnet50", ckpt_path: str | None = None): | |
| if ckpt_path is None: | |
| return vip.load_vip(modelid) | |
| home = U.f_join("~/.vip") | |
| folderpath = U.f_mkdir(home, modelid) | |
| configpath = U.f_join(home, modelid, "config.yaml") | |
| if not U.f_exists(configpath): | |
| try: | |
| configurl = "https://pytorch.s3.amazonaws.com/models/rl/vip/config.yaml" | |
| vip.load_state_dict_from_url(configurl, folderpath) | |
| except: | |
| configurl = ( | |
| "https://drive.google.com/uc?id=1XSQE0gYm-djgueo8vwcNgAiYjwS43EG-" | |
| ) | |
| vip.gdown.download(configurl, configpath, quiet=False) | |
| modelcfg = omegaconf.OmegaConf.load(configpath) | |
| cleancfg = vip.cleanup_config(modelcfg) | |
| rep = hydra.utils.instantiate(cleancfg) | |
| rep = torch.nn.DataParallel(rep) | |
| vip_state_dict = torch.load(ckpt_path, map_location="cpu")["vip"] | |
| rep.load_state_dict(vip_state_dict) | |
| return rep | |