| from torchvision.models import resnet50, ResNet50_Weights |
| import torch.nn as nn |
|
|
|
|
| class Encoder(nn.Module): |
| def __init__(self): |
| super(Encoder, self).__init__() |
| resnet = resnet50(weights=ResNet50_Weights.IMAGENET1K_V1) |
|
|
| for param in resnet.parameters(): |
| param.requires_grad_(False) |
|
|
| self.stages = nn.ModuleDict( |
| { |
| "block1": nn.Sequential(resnet.conv1, resnet.bn1, resnet.relu), |
| "block2": nn.Sequential(resnet.maxpool, resnet.layer1), |
| "block3": resnet.layer2, |
| "block4": resnet.layer3, |
| "block5": resnet.layer4, |
| } |
| ) |
|
|
| def forward(self, x): |
| stages = {} |
|
|
| for name, stage in self.stages.items(): |
| x = stage(x) |
| stages[name] = x |
|
|
| return x, stages |
|
|