UVD / uvd /models /preprocessors /clip_preprocessor.py
ryanhoangt's picture
Upload folder using huggingface_hub
c456c14 verified
from __future__ import annotations
_CLIP_IMPORT_ERROR = None
try:
import clip
from clip.model import CLIP
except ImportError as e:
_CLIP_IMPORT_ERROR = e
import torch
from torchvision import transforms as T
from uvd.models.preprocessors.base import Preprocessor
import uvd.utils as U
class ClipPreprocessor(Preprocessor):
def __init__(
self,
model_type: str = "RN50",
device: torch.device | str | None = None,
random_crop: bool = False,
remove_bn: bool = False,
bn_to_gn: bool = False,
remove_pool: bool = False,
**kwargs,
):
if _CLIP_IMPORT_ERROR is not None:
raise ImportError(_CLIP_IMPORT_ERROR)
self.random_crop = random_crop
model_type = model_type.replace("resnet", "RN")
kwargs.pop("preprocess_with_fc", None)
kwargs.pop("save_fc", None)
super().__init__(
model_type=model_type,
device=device,
remove_bn=remove_bn,
bn_to_gn=bn_to_gn,
remove_pool=remove_pool,
preprocess_with_fc=False,
save_fc=False,
**kwargs,
)
def _get_model_and_transform(self, model_type: str) -> tuple[CLIP, T.Compose]:
model, transform = clip.load(model_type, device=self.device)
model = model.visual
if self.remove_pool:
self._pool = U.freeze_module(model.attnpool)
model = torch.nn.Sequential(*(list(model.children())[:-1]))
normlayer = T.Normalize(
mean=(0.48145466, 0.4578275, 0.40821073),
std=(0.26862954, 0.26130258, 0.27577711),
)
transform = (
T.Compose([T.Resize(224), normlayer])
if not self.random_crop
else T.Compose([T.Resize(232), T.RandomCrop(224), normlayer])
)
return model, transform