Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| _CLIP_IMPORT_ERROR = None | |
| try: | |
| import clip | |
| from clip.model import CLIP | |
| except ImportError as e: | |
| _CLIP_IMPORT_ERROR = e | |
| import torch | |
| from torchvision import transforms as T | |
| from uvd.models.preprocessors.base import Preprocessor | |
| import uvd.utils as U | |
| class ClipPreprocessor(Preprocessor): | |
| def __init__( | |
| self, | |
| model_type: str = "RN50", | |
| device: torch.device | str | None = None, | |
| random_crop: bool = False, | |
| remove_bn: bool = False, | |
| bn_to_gn: bool = False, | |
| remove_pool: bool = False, | |
| **kwargs, | |
| ): | |
| if _CLIP_IMPORT_ERROR is not None: | |
| raise ImportError(_CLIP_IMPORT_ERROR) | |
| self.random_crop = random_crop | |
| model_type = model_type.replace("resnet", "RN") | |
| kwargs.pop("preprocess_with_fc", None) | |
| kwargs.pop("save_fc", None) | |
| super().__init__( | |
| model_type=model_type, | |
| device=device, | |
| remove_bn=remove_bn, | |
| bn_to_gn=bn_to_gn, | |
| remove_pool=remove_pool, | |
| preprocess_with_fc=False, | |
| save_fc=False, | |
| **kwargs, | |
| ) | |
| def _get_model_and_transform(self, model_type: str) -> tuple[CLIP, T.Compose]: | |
| model, transform = clip.load(model_type, device=self.device) | |
| model = model.visual | |
| if self.remove_pool: | |
| self._pool = U.freeze_module(model.attnpool) | |
| model = torch.nn.Sequential(*(list(model.children())[:-1])) | |
| normlayer = T.Normalize( | |
| mean=(0.48145466, 0.4578275, 0.40821073), | |
| std=(0.26862954, 0.26130258, 0.27577711), | |
| ) | |
| transform = ( | |
| T.Compose([T.Resize(224), normlayer]) | |
| if not self.random_crop | |
| else T.Compose([T.Resize(232), T.RandomCrop(224), normlayer]) | |
| ) | |
| return model, transform | |