astrosbd's picture
Update models/dtd.py
e1e9631 verified
import os
# Import compatibility fixes - handle both relative and absolute imports
try:
from . import fix_imports
except ImportError:
import fix_imports
import cv2
import lmdb
import torch
# import jpegio # Removed - using scipy instead
import numpy as np
import torch.nn as nn
import gc
import math
import time
import copy
import logging
import torch.optim as optim
import torch.distributed as dist
import random
import pickle
import six
from glob import glob
from PIL import Image
from tqdm import tqdm
from torch.autograd import Variable
from torch.cuda.amp import autocast
import segmentation_models_pytorch as smp
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler#need pytorch>1.6
# from losses import DiceLoss,FocalLoss,SoftCrossEntropyLoss,LovaszLoss # Only needed for training
# Import models - handle both relative and absolute imports
try:
from models.fph import FPH
from models.swins import *
except ImportError:
from fph import FPH
from swins import *
import albumentations as A
from albumentations.pytorch import ToTensorV2
import torchvision
import torch.nn.functional as F
try:
from timm.models.layers import trunc_normal_, DropPath
except ImportError:
from timm.layers import trunc_normal_, DropPath
from functools import partial
from segmentation_models_pytorch.base import modules as md
from typing import Optional, Union, List
from segmentation_models_pytorch.base import SegmentationModel
# Custom GELU for compatibility
class GELU(nn.Module):
def forward(self, x):
return F.gelu(x)
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in ["channels_last", "channels_first"]:
raise NotImplementedError
self.normalized_shape = (normalized_shape, )
def forward(self, x):
if self.data_format == "channels_last":
return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
elif self.data_format == "channels_first":
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class SCSEModule(nn.Module):
def __init__(self, in_channels, reduction=16):
super().__init__()
self.cSE = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(in_channels, in_channels // reduction, 1),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels // reduction, in_channels, 1),
nn.Sigmoid(),
)
self.sSE = nn.Sequential(nn.Conv2d(in_channels, 1, 1), nn.Sigmoid())
def forward(self, x):
return x * self.cSE(x) + x * self.sSE(x)
class ConvBlock(nn.Module):
def __init__(self, dim, drop_path=0., layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(dim, dim, kernel_size=7, padding=3, groups=dim)
self.norm = LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, 4 * dim)
self.act = GELU()
self.pwconv2 = nn.Linear(4 * dim, dim)
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), requires_grad=True) if layer_scale_init_value > 0 else None
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
def forward(self, x):
ipt = x
x = self.dwconv(x)
x = x.permute(0, 2, 3, 1)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * x
x = x.permute(0, 3, 1, 2)
x = ipt + self.drop_path(x)
return x
class AddCoords(nn.Module):
def __init__(self, with_r=True):
super().__init__()
self.with_r = with_r
def forward(self, input_tensor):
batch_size, _, x_dim, y_dim = input_tensor.size()
xx_c, yy_c = torch.meshgrid(torch.arange(x_dim,dtype=input_tensor.dtype), torch.arange(y_dim,dtype=input_tensor.dtype))
xx_c = xx_c.to(input_tensor.device) / (x_dim - 1) * 2 - 1
yy_c = yy_c.to(input_tensor.device) / (y_dim - 1) * 2 - 1
xx_c = xx_c.expand(batch_size,1,x_dim,y_dim)
yy_c = yy_c.expand(batch_size,1,x_dim,y_dim)
ret = torch.cat((input_tensor,xx_c,yy_c), dim=1)
if self.with_r:
rr = torch.sqrt(torch.pow(xx_c - 0.5, 2) + torch.pow(yy_c - 0.5, 2))
ret = torch.cat([ret, rr], dim=1)
return ret
class VPH(nn.Module):
def __init__(self, dims=[96, 192], drop_path_rate=0.4, layer_scale_init_value=1e-6):
super().__init__()
dp_rates=[x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
self.downsample_layers = nn.ModuleList([nn.Sequential(nn.Conv2d(6, dims[0], kernel_size=4, stride=4), LayerNorm(dims[0], eps=1e-6, data_format="channels_first")), nn.Sequential(LayerNorm(dims[1], eps=1e-6, data_format="channels_first"),nn.Conv2d(dims[1], dims[2], kernel_size=2, stride=2))])
self.stages = nn.ModuleList([nn.Sequential(*[ConvBlock(dim=dims[0], drop_path=dp_rates[j],layer_scale_init_value=layer_scale_init_value) for j in range(3)]), nn.Sequential(*[ConvBlock(dim=dims[1], drop_path=dp_rates[3 + j],layer_scale_init_value=layer_scale_init_value) for j in range(3)])])
self.apply(self._init_weights)
def initnorm(self):
norm_layer = partial(LayerNorm, eps=1e-6, data_format="channels_first")
for i_layer in range(4):
layer = norm_layer(self.dims[i_layer])
layer_name = f'norm{i_layer}'
self.add_module(layer_name, layer)
def _init_weights(self, m):
if isinstance(m, (nn.Conv2d, nn.Linear)):
trunc_normal_(m.weight, std=.02)
nn.init.constant_(m.bias, 0)
def init_weights(self, pretrained=None):
def _init_weights(m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
self.apply(_init_weights)
def forward(self, x):
outs = []
x = self.stages[0](self.downsample_layers[0](x))
outs = [self.norm0(x)]
x = self.stages[1](self.downsample_layers[1](x))
outs.append(self.norm1(x))
return outs
class SegmentationHead(nn.Sequential):
def __init__(self, in_channels, out_channels, kernel_size=3, activation=None, upsampling=1):
upsampling = nn.UpsamplingBilinear2d(scale_factor=upsampling) if upsampling > 1 else nn.Identity()
conv2d = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2)
activation = md.Activation(activation)
super().__init__(conv2d, upsampling, activation)
class DecoderBlock(nn.Module):
def __init__(self,cin,cadd,cout,):
super().__init__()
self.cin = (cin + cadd)
self.cout = cout
self.conv1 = nn.Sequential(
nn.Conv2d(self.cin, self.cout, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(self.cout),
nn.ReLU(inplace=True)
)
self.conv2 = nn.Sequential(
nn.Conv2d(self.cout, self.cout, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(self.cout),
nn.ReLU(inplace=True)
)
def forward(self, x1, x2=None):
x1 = F.interpolate(x1, scale_factor=2.0, mode="nearest")
if x2 is not None:
x1 = torch.cat([x1, x2], dim=1)
x1 = self.conv1(x1[:,:self.cin])
x1 = self.conv2(x1)
return x1
class ConvBNReLU(nn.Module):
def __init__(self,in_c,out_c,ks,stride=1,norm=True,res=False):
super(ConvBNReLU, self).__init__()
if norm:
self.conv = nn.Sequential(nn.Conv2d(in_c, out_c, kernel_size=ks, padding = ks//2, stride=stride,bias=False),nn.BatchNorm2d(out_c),nn.ReLU(True))
else:
self.conv = nn.Conv2d(in_c, out_c, kernel_size=ks, padding = ks//2, stride=stride,bias=False)
self.res = res
def forward(self,x):
if self.res:
return (x + self.conv(x))
else:
return self.conv(x)
class FUSE1(nn.Module):
def __init__(self,in_channels_list=(96,192,384,768)):
super(FUSE1, self).__init__()
self.c31 = ConvBNReLU(in_channels_list[2],in_channels_list[2],1)
self.c32 = ConvBNReLU(in_channels_list[3],in_channels_list[2],1)
self.c33 = ConvBNReLU(in_channels_list[2],in_channels_list[2],3)
self.c21 = ConvBNReLU(in_channels_list[1],in_channels_list[1],1)
self.c22 = ConvBNReLU(in_channels_list[2],in_channels_list[1],1)
self.c23 = ConvBNReLU(in_channels_list[1],in_channels_list[1],3)
self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
def forward(self,x):
x,x1,x2,x3 = x
h,w = x2.shape[-2:]
x2 = self.c33(F.interpolate(self.c32(x3),size=(h,w))+self.c31(x2))
h,w = x1.shape[-2:]
x1 = self.c23(F.interpolate(self.c22(x2),size=(h,w))+self.c21(x1))
h,w = x.shape[-2:]
x = self.c13(F.interpolate(self.c12(x1),size=(h,w))+self.c11(x))
return x,x1,x2,x3
class FUSE2(nn.Module):
def __init__(self,in_channels_list=(96,192,384)):
super(FUSE2, self).__init__()
self.c21 = ConvBNReLU(in_channels_list[1],in_channels_list[1],1)
self.c22 = ConvBNReLU(in_channels_list[2],in_channels_list[1],1)
self.c23 = ConvBNReLU(in_channels_list[1],in_channels_list[1],3)
self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
def forward(self,x):
x,x1,x2 = x
h,w = x1.shape[-2:]
x1 = self.c23(F.interpolate(self.c22(x2),size=(h,w),mode='bilinear',align_corners=True)+self.c21(x1))
h,w = x.shape[-2:]
x = self.c13(F.interpolate(self.c12(x1),size=(h,w),mode='bilinear',align_corners=True)+self.c11(x))
return x,x1,x2
class FUSE3(nn.Module):
def __init__(self,in_channels_list=(96,192)):
super(FUSE3, self).__init__()
self.c11 = ConvBNReLU(in_channels_list[0],in_channels_list[0],1)
self.c12 = ConvBNReLU(in_channels_list[1],in_channels_list[0],1)
self.c13 = ConvBNReLU(in_channels_list[0],in_channels_list[0],3)
def forward(self,x):
x,x1 = x
h,w = x.shape[-2:]
x = self.c13(F.interpolate(self.c12(x1),size=(h,w),mode='bilinear',align_corners=True)+self.c11(x))
return x,x1
class MID(nn.Module):
def __init__(self, encoder_channels, decoder_channels):
super().__init__()
encoder_channels = encoder_channels[1:][::-1]
self.in_channels = [encoder_channels[0]] + list(decoder_channels[:-1])
self.add_channels = list(encoder_channels[1:]) + [96]
self.out_channels = decoder_channels
self.fuse1 = FUSE1()
self.fuse2 = FUSE2()
self.fuse3 = FUSE3()
decoder_convs = {}
for layer_idx in range(len(self.in_channels) - 1):
for depth_idx in range(layer_idx + 1):
if depth_idx == 0:
in_ch = self.in_channels[layer_idx]
skip_ch = self.add_channels[layer_idx] * (layer_idx + 1)
out_ch = self.out_channels[layer_idx]
else:
out_ch = self.add_channels[layer_idx]
skip_ch = self.add_channels[layer_idx] * (layer_idx + 1 - depth_idx)
in_ch = self.add_channels[layer_idx - 1]
decoder_convs[f"x_{depth_idx}_{layer_idx}"] = DecoderBlock(in_ch, skip_ch, out_ch)
decoder_convs[f"x_{0}_{len(self.in_channels)-1}"] = DecoderBlock(self.in_channels[-1], 0, self.out_channels[-1])
self.decoder_convs = nn.ModuleDict(decoder_convs)
def forward(self, *features):
decoder_features = {}
features = self.fuse1(features)[::-1]
decoder_features["x_0_0"] = self.decoder_convs["x_0_0"](features[0],features[1])
decoder_features["x_1_1"] = self.decoder_convs["x_1_1"](features[1],features[2])
decoder_features["x_2_2"] = self.decoder_convs["x_2_2"](features[2],features[3])
decoder_features["x_2_2"], decoder_features["x_1_1"], decoder_features["x_0_0"] = self.fuse2((decoder_features["x_2_2"], decoder_features["x_1_1"], decoder_features["x_0_0"]))
decoder_features["x_0_1"] = self.decoder_convs["x_0_1"](decoder_features["x_0_0"], torch.cat((decoder_features["x_1_1"], features[2]),1))
decoder_features["x_1_2"] = self.decoder_convs["x_1_2"](decoder_features["x_1_1"], torch.cat((decoder_features["x_2_2"], features[3]),1))
decoder_features["x_1_2"], decoder_features["x_0_1"] = self.fuse3((decoder_features["x_1_2"], decoder_features["x_0_1"]))
decoder_features["x_0_2"] = self.decoder_convs["x_0_2"](decoder_features["x_0_1"], torch.cat((decoder_features["x_1_2"], decoder_features["x_2_2"], features[3]),1))
return self.decoder_convs["x_0_3"](torch.cat((decoder_features["x_0_2"], decoder_features["x_1_2"], decoder_features["x_2_2"]),1))
class DTD(SegmentationModel):
def __init__(self, encoder_name = "resnet18", decoder_channels = (384, 192, 96, 64), classes = 1, device='cpu'):
super().__init__()
# Load models with proper device mapping
import os
import sys
# Create module alias for loading old checkpoints
sys.modules['dtd'] = sys.modules['models.dtd']
sys.modules['fph'] = sys.modules['models.fph']
sys.modules['swins'] = sys.modules['models.swins']
model_dir = os.path.dirname(os.path.abspath(__file__))
vph_path = os.path.join(model_dir, '..', 'checkpoints', 'vph_imagenet.pt')
swin_path = os.path.join(model_dir, '..', 'checkpoints', 'swin_imagenet.pt')
if device == 'mps':
self.vph = torch.load(vph_path, map_location=torch.device('cpu'))
self.swin = torch.load(swin_path, map_location=torch.device('cpu'))
else:
self.vph = torch.load(vph_path, map_location=device)
self.swin = torch.load(swin_path, map_location=device)
self.fph = FPH()
self.decoder = MID(encoder_channels=(96, 192, 384, 768), decoder_channels=decoder_channels)
self.segmentation_head = SegmentationHead(in_channels=decoder_channels[-1], out_channels=classes, upsampling=2.0)
self.addcoords = AddCoords()
self.FU = nn.Sequential(SCSEModule(448),nn.Conv2d(448,192,3,1,1),nn.BatchNorm2d(192),nn.ReLU(True))
self.classification_head = None
self.initialize()
def forward(self,x,dct,qt):
features = self.vph(self.addcoords(x))
features[1] = self.FU(torch.cat((features[1],self.fph(dct,qt)),1))
rst = self.swin[0](features[1].flatten(2).transpose(1,2).contiguous())
N,L,C = rst.shape
H = W = int(L**(1/2))
features.append(self.vph.norm2(rst.transpose(1,2).contiguous().view(N,C,H,W)))
features.append(self.vph.norm3(self.swin[2](self.swin[1](rst)).transpose(1,2).contiguous().view(N,C*2,H//2,W//2)))
decoder_output = self.decoder(*features)
return self.segmentation_head(decoder_output)
class seg_dtd(nn.Module):
def __init__(self, model_name='resnet18', n_class=1, device='cpu'):
super().__init__()
self.model = DTD(encoder_name=model_name, classes=n_class, device=device)
self.device = device
def forward(self, x, dct, qt):
# Use autocast only for CUDA, not for MPS
if self.device == 'cuda':
with autocast():
x = self.model(x, dct, qt)
else:
x = self.model(x, dct, qt)
return x