File size: 1,788 Bytes
fa175ec 4910447 fa175ec 7b59bd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
from transformers import PreTrainedModel
import torch
import torch.nn as nn
from .resnetall import generate_model
from .configuration_resnet3d import Resnet3DScrollprizeConfig
import torch.nn.functional as F
class Decoder(nn.Module):
def __init__(self, encoder_dims, upscale):
super().__init__()
self.convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False),
nn.BatchNorm2d(encoder_dims[i-1]),
nn.ReLU(inplace=True)
) for i in range(1, len(encoder_dims))])
self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0)
self.up = nn.Upsample(scale_factor=upscale, mode="bilinear")
def forward(self, feature_maps):
for i in range(len(feature_maps)-1, 0, -1):
f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear")
f = torch.cat([feature_maps[i-1], f_up], dim=1)
f_down = self.convs[i-1](f)
feature_maps[i-1] = f_down
x = self.logit(feature_maps[0])
mask = self.up(x)
return mask
class Resnet3DScrollprizeModel(PreTrainedModel):
config_class = Resnet3DScrollprizeConfig
def __init__(self, config):
super().__init__(config)
self.backbone= generate_model(model_depth=config.model_depth, n_input_channels=1,forward_features=True,n_classes=config.n_classes)
self.decoder = Decoder(encoder_dims=[x.size(1) for x in self.backbone(torch.rand(1,1,20,256,256))], upscale=1)
def forward(self, tensor):
feat_maps = self.backbone(tensor)
feat_maps_pooled = [torch.max(f, dim=2)[0] for f in feat_maps]
pred_mask = self.decoder(feat_maps_pooled)
return pred_mask |