UVD / uvd /models /preprocessors /vip_preprocessor.py
ryanhoangt's picture
Upload folder using huggingface_hub
c456c14 verified
from __future__ import annotations
from typing import Optional
import hydra
import omegaconf
import torch
_VIP_IMPORT_ERROR = None
try:
import vip
except ImportError as e:
_VIP_IMPORT_ERROR = e
from torch import nn
from torchvision import transforms as T
import uvd.utils as U
from uvd.models.preprocessors.base import Preprocessor
class VipPreprocessor(Preprocessor):
def __init__(
self,
model_type: str | None = None,
device: torch.device | str | None = None,
remove_bn: bool = False,
bn_to_gn: bool = False,
remove_pool: bool = False,
preprocess_with_fc: bool = False,
save_fc: bool = False,
random_crop: bool = False,
ckpt: str | None = None,
**kwargs,
):
if _VIP_IMPORT_ERROR is not None:
raise ImportError(_VIP_IMPORT_ERROR)
model_type = model_type or "resnet50"
self.random_crop = random_crop
self.ckpt = ckpt
super().__init__(
model_type=model_type,
device=device,
remove_bn=remove_bn,
bn_to_gn=bn_to_gn,
remove_pool=remove_pool,
preprocess_with_fc=preprocess_with_fc,
save_fc=save_fc,
**kwargs,
)
def _get_model_and_transform(
self, model_type: str | None = None
) -> tuple[vip.VIP, Optional[T]]:
if model_type is not None:
assert model_type == "resnet50", f"{model_type} not support"
vip.device = self.device
vip_ = load_vip(modelid="resnet50", ckpt_path=self.ckpt).module
resnet = vip_.convnet.to(self.device)
if self.remove_pool:
# if self.save_fc:
self._pool = U.freeze_module(resnet.avgpool)
self._fc = U.freeze_module(resnet.fc)
model = nn.Sequential(*(list(resnet.children())[:-2]))
else:
model = resnet
# crop_transform = T.RandomCrop(224) if self.random_crop else T.CenterCrop(224)
transform = (
# nn.Sequential(T.Resize(224), vip_.normlayer)
T.Compose([T.Resize(224), vip_.normlayer])
if not self.random_crop
# else nn.Sequential(T.Resize(232), T.RandomCrop(224), vip_.normlayer)
else T.Compose([T.Resize(232), T.RandomCrop(224), vip_.normlayer])
)
return model, transform
def load_vip(modelid: str = "resnet50", ckpt_path: str | None = None):
if ckpt_path is None:
return vip.load_vip(modelid)
home = U.f_join("~/.vip")
folderpath = U.f_mkdir(home, modelid)
configpath = U.f_join(home, modelid, "config.yaml")
if not U.f_exists(configpath):
try:
configurl = "https://pytorch.s3.amazonaws.com/models/rl/vip/config.yaml"
vip.load_state_dict_from_url(configurl, folderpath)
except:
configurl = (
"https://drive.google.com/uc?id=1XSQE0gYm-djgueo8vwcNgAiYjwS43EG-"
)
vip.gdown.download(configurl, configpath, quiet=False)
modelcfg = omegaconf.OmegaConf.load(configpath)
cleancfg = vip.cleanup_config(modelcfg)
rep = hydra.utils.instantiate(cleancfg)
rep = torch.nn.DataParallel(rep)
vip_state_dict = torch.load(ckpt_path, map_location="cpu")["vip"]
rep.load_state_dict(vip_state_dict)
return rep