resnet-test4 / modeling_resnet3d.py
YoussefMoNader's picture
Upload model
7b59bd3 verified
from transformers import PreTrainedModel
import torch
import torch.nn as nn
from .resnetall import generate_model
from .configuration_resnet3d import Resnet3DScrollprizeConfig
import torch.nn.functional as F
class Decoder(nn.Module):
def __init__(self, encoder_dims, upscale):
super().__init__()
self.convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(encoder_dims[i]+encoder_dims[i-1], encoder_dims[i-1], 3, 1, 1, bias=False),
nn.BatchNorm2d(encoder_dims[i-1]),
nn.ReLU(inplace=True)
) for i in range(1, len(encoder_dims))])
self.logit = nn.Conv2d(encoder_dims[0], 1, 1, 1, 0)
self.up = nn.Upsample(scale_factor=upscale, mode="bilinear")
def forward(self, feature_maps):
for i in range(len(feature_maps)-1, 0, -1):
f_up = F.interpolate(feature_maps[i], scale_factor=2, mode="bilinear")
f = torch.cat([feature_maps[i-1], f_up], dim=1)
f_down = self.convs[i-1](f)
feature_maps[i-1] = f_down
x = self.logit(feature_maps[0])
mask = self.up(x)
return mask
class Resnet3DScrollprizeModel(PreTrainedModel):
config_class = Resnet3DScrollprizeConfig
def __init__(self, config):
super().__init__(config)
self.backbone= generate_model(model_depth=config.model_depth, n_input_channels=1,forward_features=True,n_classes=config.n_classes)
self.decoder = Decoder(encoder_dims=[x.size(1) for x in self.backbone(torch.rand(1,1,20,256,256))], upscale=1)
def forward(self, tensor):
feat_maps = self.backbone(tensor)
feat_maps_pooled = [torch.max(f, dim=2)[0] for f in feat_maps]
pred_mask = self.decoder(feat_maps_pooled)
return pred_mask