import torch import torchvision import torch.nn as nn def create_resnet50_model(num_classes: int = 2, seed: int = 42): """Creates a ResNet50 feature extractor model and transforms. Args: num_classes (int, optional): Number of classes in the classifier head. Defaults to 2. seed (int, optional): Random seed value. Defaults to 42. Returns: model (torch.nn.Module): ResNet50 feature extractor model. transforms (torchvision.transforms): ResNet50 image transforms. """ # 1. Create ResNet50 pretrained weights and transforms weights = torchvision.models.resnet50(pretrained=True) transforms = torchvision.transforms.Compose([ torchvision.transforms.Resize(256), torchvision.transforms.CenterCrop(224), torchvision.transforms.ToTensor(), torchvision.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) # 2. Create ResNet50 model with pretrained weights model = torchvision.models.resnet50(pretrained=False) # 3. Load the pretrained weights into the model model.load_state_dict(weights.state_dict()) # 4. Freeze all layers in the base model for param in model.parameters(): param.requires_grad = False # 5. Change classifier head with random seed for reproducibility torch.manual_seed(seed) num_features = model.fc.in_features model.fc = nn.Sequential( nn.Dropout(p=0.3, inplace=True), nn.Linear(in_features=num_features, out_features=num_classes), ) return model, transforms