| from transformers import PreTrainedModel | |
| import torch | |
| import torch.nn as nn | |
| from .resnetall import generate_model | |
| from .configuration_resnet3d import Resnet3DScrollprizeConfig | |
| import torch.nn.functional as F | |
| class Decoder(nn.Module): | |
| def __init__(self, encoder_dims, upscale): | |
| super().__init__() | |
| self.convs = nn.ModuleList([ | |
| nn.Sequential( | |
| nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False), | |
| nn.BatchNorm2d(encoder_dims[i-1]), | |
| nn.ReLU(inplace=True) | |
| ) for i in range(1, len(encoder_dims))]) | |
| self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0) | |
| self.up = nn.Upsample(scale_factor=upscale, mode="bilinear") | |
| def forward(self, feature_maps): | |
| for i in range(len(feature_maps)-1, 0, -1): | |
| f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear") | |
| f = torch.cat([feature_maps[i-1], f_up], dim=1) | |
| f_down = self.convs[i-1](f) | |
| feature_maps[i-1] = f_down | |
| x = self.logit(feature_maps[0]) | |
| mask = self.up(x) | |
| return mask | |
| class Resnet3DScrollprizeModel(PreTrainedModel): | |
| config_class = Resnet3DScrollprizeConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.backbone= generate_model(model_depth=config.model_depth, n_input_channels=1,forward_features=True,n_classes=config.n_classes) | |
| self.decoder = Decoder(encoder_dims=[x.size(1) for x in self.backbone(torch.rand(1,1,20,256,256))], upscale=1) | |
| def forward(self, tensor): | |
| feat_maps = self.backbone(tensor) | |
| feat_maps_pooled = [torch.max(f, dim=2)[0] for f in feat_maps] | |
| pred_mask = self.decoder(feat_maps_pooled) | |
| return pred_mask |