Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from functools import partial | |
| from typing import Callable | |
| import torch | |
| import torchvision.models | |
| from torchvision.models._api import WeightsEnum | |
| from torchvision.transforms._presets import ImageClassification | |
| import uvd.utils as U | |
| from uvd.models.preprocessors.base import Preprocessor | |
| class ResNetPreprocessor(Preprocessor): | |
| def __init__( | |
| self, | |
| model_type: str = "resnet50", | |
| from_pretrained: bool = True, | |
| device: torch.device | str | None = None, | |
| random_crop: bool = False, | |
| remove_bn: bool = False, | |
| bn_to_gn: bool = False, | |
| remove_pool: bool = False, | |
| ): | |
| self.random_crop = random_crop | |
| self.from_pretrained = from_pretrained | |
| super().__init__( | |
| model_type=model_type, | |
| device=device, | |
| remove_bn=remove_bn, | |
| bn_to_gn=bn_to_gn, | |
| remove_pool=remove_pool, | |
| preprocess_with_fc=False, | |
| save_fc=False, | |
| ) | |
| def _get_model_and_transform( | |
| self, model_type: str | |
| ) -> tuple[torch.nn.Module, ImageClassification]: | |
| model_fn, weights_enum = get_resnet_builder_and_weight(model_type=model_type) | |
| weights = weights_enum.DEFAULT if self.from_pretrained else None # type: ignore | |
| model = model_fn(weights=weights).to(self.device) | |
| if self.remove_pool: | |
| self._pool = U.freeze_module(model.avgpool) | |
| model = torch.nn.Sequential(*(list(model.children())[:-2])) | |
| transforms = ( | |
| weights.transforms() | |
| if self.from_pretrained | |
| else ImageClassification(crop_size=224) | |
| ) | |
| return model, transforms | |
| def get_resnet_builder_and_weight( | |
| model_type: str, | |
| ) -> tuple[Callable[..., torch.nn.Module], WeightsEnum]: | |
| models = [ | |
| "resnet18", | |
| "resnet34", | |
| "resnet50", | |
| "resnet101", | |
| "resnet152", | |
| "resnext50_32x4d", | |
| "resnext101_32x8d", | |
| "resnext101_64x4d", | |
| "wide_resnet50_2", | |
| "wide_resnet101_2", | |
| ] | |
| assert model_type in models, f"{model_type} not in {models}" | |
| weights = [ | |
| "ResNet18_Weights", | |
| "ResNet34_Weights", | |
| "ResNet50_Weights", | |
| "ResNet101_Weights", | |
| "ResNet152_Weights", | |
| "ResNeXt50_32X4D_Weights", | |
| "ResNeXt101_32X8D_Weights", | |
| "ResNeXt101_64X4D_Weights", | |
| "Wide_ResNet50_2_Weights", | |
| "Wide_ResNet101_2_Weights", | |
| ] | |
| fn = getattr(torchvision.models, model_type) | |
| weight = getattr(torchvision.models, weights[models.index(model_type)]) | |
| return fn, weight | |