UVD / uvd /models /preprocessors /resnet_preprocessor.py
ryanhoangt's picture
Upload folder using huggingface_hub
c456c14 verified
from __future__ import annotations
from functools import partial
from typing import Callable
import torch
import torchvision.models
from torchvision.models._api import WeightsEnum
from torchvision.transforms._presets import ImageClassification
import uvd.utils as U
from uvd.models.preprocessors.base import Preprocessor
class ResNetPreprocessor(Preprocessor):
def __init__(
self,
model_type: str = "resnet50",
from_pretrained: bool = True,
device: torch.device | str | None = None,
random_crop: bool = False,
remove_bn: bool = False,
bn_to_gn: bool = False,
remove_pool: bool = False,
):
self.random_crop = random_crop
self.from_pretrained = from_pretrained
super().__init__(
model_type=model_type,
device=device,
remove_bn=remove_bn,
bn_to_gn=bn_to_gn,
remove_pool=remove_pool,
preprocess_with_fc=False,
save_fc=False,
)
def _get_model_and_transform(
self, model_type: str
) -> tuple[torch.nn.Module, ImageClassification]:
model_fn, weights_enum = get_resnet_builder_and_weight(model_type=model_type)
weights = weights_enum.DEFAULT if self.from_pretrained else None # type: ignore
model = model_fn(weights=weights).to(self.device)
if self.remove_pool:
self._pool = U.freeze_module(model.avgpool)
model = torch.nn.Sequential(*(list(model.children())[:-2]))
transforms = (
weights.transforms()
if self.from_pretrained
else ImageClassification(crop_size=224)
)
return model, transforms
def get_resnet_builder_and_weight(
model_type: str,
) -> tuple[Callable[..., torch.nn.Module], WeightsEnum]:
models = [
"resnet18",
"resnet34",
"resnet50",
"resnet101",
"resnet152",
"resnext50_32x4d",
"resnext101_32x8d",
"resnext101_64x4d",
"wide_resnet50_2",
"wide_resnet101_2",
]
assert model_type in models, f"{model_type} not in {models}"
weights = [
"ResNet18_Weights",
"ResNet34_Weights",
"ResNet50_Weights",
"ResNet101_Weights",
"ResNet152_Weights",
"ResNeXt50_32X4D_Weights",
"ResNeXt101_32X8D_Weights",
"ResNeXt101_64X4D_Weights",
"Wide_ResNet50_2_Weights",
"Wide_ResNet101_2_Weights",
]
fn = getattr(torchvision.models, model_type)
weight = getattr(torchvision.models, weights[models.index(model_type)])
return fn, weight