from transformers import PreTrainedModel import torch import torch.nn as nn from .resnetall import generate_model from .configuration_resnet3d import Resnet3DScrollprizeConfig import torch.nn.functional as F class Decoder(nn.Module): def __init__(self, encoder_dims, upscale): super().__init__() self.convs = nn.ModuleList([ nn.Sequential( nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False), nn.BatchNorm2d(encoder_dims[i-1]), nn.ReLU(inplace=True) ) for i in range(1, len(encoder_dims))]) self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0) self.up = nn.Upsample(scale_factor=upscale, mode="bilinear") def forward(self, feature_maps): for i in range(len(feature_maps)-1, 0, -1): f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear") f = torch.cat([feature_maps[i-1], f_up], dim=1) f_down = self.convs[i-1](f) feature_maps[i-1] = f_down x = self.logit(feature_maps[0]) mask = self.up(x) return mask class Resnet3DScrollprizeModel(PreTrainedModel): config_class = Resnet3DScrollprizeConfig def __init__(self, config): super().__init__(config) self.backbone= generate_model(model_depth=config.model_depth, n_input_channels=1,forward_features=True,n_classes=config.n_classes) self.decoder = Decoder(encoder_dims=[x.size(1) for x in self.backbone(torch.rand(1,1,20,256,256))], upscale=1) def forward(self, tensor): feat_maps = self.backbone(tensor) feat_maps_pooled = [torch.max(f, dim=2)[0] for f in feat_maps] pred_mask = self.decoder(feat_maps_pooled) return pred_mask