AST_DeRainDrop / modeling_ast.py
孙聪聪
Initial upload of AST deraindrop model
be36716
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
import torch.nn.functional as F
from einops import rearrange, repeat
from einops.layers.torch import Rearrange
import math
import numpy as np
import time
from torch import einsum
import json
import os
import argparse
from transformers import PretrainedConfig, PreTrainedModel
#################################################################################
# #
# PART 1: 您的模型定义 (From the file you provided) #
# #
#################################################################################
def conv(in_channels, out_channels, kernel_size, bias=False, stride=1):
return nn.Conv2d(
in_channels, out_channels, kernel_size,
padding=(kernel_size // 2), bias=bias, stride=stride)
class ConvBlock(nn.Module):
def __init__(self, in_channel, out_channel, strides=1):
super(ConvBlock, self).__init__()
self.strides = strides
self.in_channel = in_channel
self.out_channel = out_channel
self.block = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1),
nn.LeakyReLU(inplace=True),
nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1),
nn.LeakyReLU(inplace=True),
)
self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0)
def forward(self, x):
out1 = self.block(x)
out2 = self.conv11(x)
out = out1 + out2
return out
class LinearProjection(nn.Module):
def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.to_q = nn.Linear(dim, inner_dim, bias=bias)
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias)
self.dim = dim
self.inner_dim = inner_dim
def forward(self, x, attn_kv=None):
B_, N, C = x.shape
if attn_kv is not None:
attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1)
else:
attn_kv = x
N_kv = attn_kv.size(1)
q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4)
q = q[0]
k, v = kv[0], kv[1]
return q, k, v
class WindowAttention(nn.Module):
def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
proj_drop=0.):
super().__init__()
self.dim = dim
self.win_size = win_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads))
coords_h = torch.arange(self.win_size[0])
coords_w = torch.arange(self.win_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.win_size[0] - 1
relative_coords[:, :, 1] += self.win_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
if token_projection == 'linear':
self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
else:
raise Exception("Projection error!")
self.token_projection = token_projection
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, attn_kv=None, mask=None):
B_, N, C = x.shape
q, k, v = self.qkv(x, attn_kv)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
ratio = attn.size(-1) // relative_position_bias.size(-1)
relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N * ratio)
attn = self.softmax(attn)
else:
attn = self.softmax(attn)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class WindowAttention_sparse(nn.Module):
def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0.,
proj_drop=0.):
super().__init__()
self.dim = dim
self.win_size = win_size
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
self.relative_position_bias_table = nn.Parameter(
torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads))
coords_h = torch.arange(self.win_size[0])
coords_w = torch.arange(self.win_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij"))
coords_flatten = torch.flatten(coords, 1)
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += self.win_size[0] - 1
relative_coords[:, :, 1] += self.win_size[1] - 1
relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1
relative_position_index = relative_coords.sum(-1)
self.register_buffer("relative_position_index", relative_position_index)
trunc_normal_(self.relative_position_bias_table, std=.02)
if token_projection == 'linear':
self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias)
else:
raise Exception("Projection error!")
self.token_projection = token_projection
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.softmax = nn.Softmax(dim=-1)
self.relu = nn.ReLU()
self.w = nn.Parameter(torch.ones(2))
def forward(self, x, attn_kv=None, mask=None):
B_, N, C = x.shape
q, k, v = self.qkv(x, attn_kv)
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1)
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()
ratio = attn.size(-1) // relative_position_bias.size(-1)
relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio)
attn = attn + relative_position_bias.unsqueeze(0)
if mask is not None:
nW = mask.shape[0]
mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio)
attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0)
attn = attn.view(-1, self.num_heads, N, N * ratio)
attn0 = self.softmax(attn)
attn1 = self.relu(attn) ** 2
else:
attn0 = self.softmax(attn)
attn1 = self.relu(attn) ** 2
w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w))
w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w))
attn = attn0 * w1 + attn1 * w2
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop(x)
x = self.fc2(x)
x = self.drop(x)
return x
class LeFF(nn.Module):
def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False):
super().__init__()
self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer())
self.dwconv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), act_layer())
self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
self.eca = nn.Identity()
def forward(self, x):
bs, hw, c = x.size()
hh = int(math.sqrt(hw))
x = self.linear1(x)
x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
x = self.dwconv(x)
x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh)
x = self.linear2(x)
x = self.eca(x)
return x
class FRFN(nn.Module):
def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False):
super().__init__()
self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim * 2),
act_layer())
self.dwconv = nn.Sequential(
nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1),
act_layer())
self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim))
self.dim = dim
self.hidden_dim = hidden_dim
self.dim_conv = self.dim // 4
self.dim_untouched = self.dim - self.dim_conv
self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False)
def forward(self, x):
bs, hw, c = x.size()
hh = int(math.sqrt(hw))
x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
x1, x2, = torch.split(x, [self.dim_conv, self.dim_untouched], dim=1)
x1 = self.partial_conv3(x1)
x = torch.cat((x1, x2), 1)
x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh)
x = self.linear1(x)
x_1, x_2 = x.chunk(2, dim=-1)
x_1 = rearrange(x_1, ' b (h w) (c) -> b c h w ', h=hh, w=hh)
x_1 = self.dwconv(x_1)
x_1 = rearrange(x_1, ' b c h w -> b (h w) c', h=hh, w=hh)
x = x_1 * x_2
x = self.linear2(x)
return x
def window_partition(x, win_size, dilation_rate=1):
B, H, W, C = x.shape
if dilation_rate != 1:
x = x.permute(0, 3, 1, 2)
assert type(dilation_rate) is int, 'dilation_rate should be a int'
x = F.unfold(x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size)
windows = x.permute(0, 2, 1).contiguous().view(-1, C, win_size, win_size)
windows = windows.permute(0, 2, 3, 1).contiguous()
else:
x = x.view(B, H // win_size, win_size, W // win_size, win_size, C)
windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C)
return windows
def window_reverse(windows, win_size, H, W, dilation_rate=1):
B = int(windows.shape[0] / (H * W / win_size / win_size))
x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1)
if dilation_rate != 1:
x = windows.permute(0, 5, 3, 4, 1, 2).contiguous()
x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1),
stride=win_size)
else:
x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
return x
class Downsample(nn.Module):
def __init__(self, in_channel, out_channel):
super(Downsample, self).__init__()
self.conv = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1))
def forward(self, x):
B, L, C = x.shape
H = int(math.sqrt(L))
W = int(math.sqrt(L))
x = x.transpose(1, 2).contiguous().view(B, C, H, W)
out = self.conv(x).flatten(2).transpose(1, 2).contiguous()
return out
class Upsample(nn.Module):
def __init__(self, in_channel, out_channel):
super(Upsample, self).__init__()
self.deconv = nn.Sequential(nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2))
def forward(self, x):
B, L, C = x.shape
H = int(math.sqrt(L))
W = int(math.sqrt(L))
x = x.transpose(1, 2).contiguous().view(B, C, H, W)
out = self.deconv(x).flatten(2).transpose(1, 2).contiguous()
return out
class InputProj(nn.Module):
def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None, act_layer=nn.LeakyReLU):
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2),
act_layer(inplace=True))
self.norm = norm_layer(out_channel) if norm_layer is not None else None
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x).flatten(2).transpose(1, 2).contiguous()
if self.norm is not None:
x = self.norm(x)
return x
class OutputProj(nn.Module):
def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None, act_layer=None):
super().__init__()
self.proj = nn.Sequential(
nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2))
if act_layer is not None:
self.proj.add_module(str(len(self.proj)), act_layer(inplace=True))
self.norm = norm_layer(out_channel) if norm_layer is not None else None
def forward(self, x):
B, L, C = x.shape
H = int(math.sqrt(L))
W = int(math.sqrt(L))
x = x.transpose(1, 2).view(B, C, H, W)
x = self.proj(x)
if self.norm is not None:
x = self.norm(x)
return x
class TransformerBlock(nn.Module):
def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0, mlp_ratio=4., qkv_bias=True,
qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,
token_projection='linear', token_mlp='leff', att=True, sparseAtt=False):
super().__init__()
self.att = att
self.sparseAtt = sparseAtt
self.dim = dim
self.input_resolution = input_resolution
self.num_heads = num_heads
self.win_size = win_size
self.shift_size = shift_size
self.mlp_ratio = mlp_ratio
if min(self.input_resolution) <= self.win_size:
self.shift_size = 0
self.win_size = min(self.input_resolution)
assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size"
if self.att:
self.norm1 = norm_layer(dim)
if self.sparseAtt:
self.attn = WindowAttention_sparse(dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop,
proj_drop=drop, token_projection=token_projection)
else:
self.attn = WindowAttention(dim, win_size=to_2tuple(self.win_size), num_heads=num_heads,
qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop,
token_projection=token_projection)
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
if token_mlp in ['ffn', 'mlp']:
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
elif token_mlp == 'leff':
self.mlp = LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
elif token_mlp == 'frfn':
self.mlp = FRFN(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop)
else:
raise Exception("FFN error!")
def forward(self, x, mask=None):
B, L, C = x.shape
H = int(math.sqrt(L))
W = int(math.sqrt(L))
attn_mask = None
if self.shift_size > 0:
shift_mask = torch.zeros((1, H, W, 1), device=x.device)
h_slices = (slice(0, -self.win_size), slice(-self.win_size, -self.shift_size),
slice(-self.shift_size, None))
w_slices = (slice(0, -self.win_size), slice(-self.win_size, -self.shift_size),
slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
for w in w_slices:
shift_mask[:, h, w, :] = cnt
cnt += 1
shift_mask_windows = window_partition(shift_mask, self.win_size)
shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size)
attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
shortcut = x
if self.att:
x = self.norm1(x)
x = x.view(B, H, W, C)
if self.shift_size > 0:
shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
else:
shifted_x = x
x_windows = window_partition(shifted_x, self.win_size)
x_windows = x_windows.view(-1, self.win_size * self.win_size, C)
attn_windows = self.attn(x_windows, mask=attn_mask)
attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C)
shifted_x = window_reverse(attn_windows, self.win_size, H, W)
if self.shift_size > 0:
x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
else:
x = shifted_x
x = x.view(B, H * W, C)
x = shortcut + self.drop_path(x)
x = x + self.drop_path(self.mlp(self.norm2(x)))
return x
class BasicASTLayer(nn.Module):
def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size, mlp_ratio=4., qkv_bias=True,
qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False,
token_projection='linear', token_mlp='ffn', shift_flag=True, att=False, sparseAtt=False):
super().__init__()
self.att = att
self.sparseAtt = sparseAtt
self.depth = depth
self.use_checkpoint = use_checkpoint
if shift_flag:
self.blocks = nn.ModuleList([
TransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, win_size=win_size,
shift_size=0 if (i % 2 == 0) else win_size // 2, mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp,
att=self.att, sparseAtt=self.sparseAtt)
for i in range(depth)])
else:
self.blocks = nn.ModuleList([
TransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, win_size=win_size,
shift_size=0, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop,
attn_drop=attn_drop,
drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp,
att=self.att, sparseAtt=self.sparseAtt)
for i in range(depth)])
def forward(self, x, mask=None):
for blk in self.blocks:
if self.use_checkpoint:
# Note: checkpoint doesn't support mask argument, so we pass it as None
x = checkpoint.checkpoint(blk, x, None)
else:
x = blk(x, mask)
return x
class AST(nn.Module):
def __init__(self, img_size=256, in_chans=3, dd_in=3, embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2],
num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None,
drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True,
use_checkpoint=False, token_projection='linear', token_mlp='leff', dowsample=Downsample,
upsample=Upsample, shift_flag=True, **kwargs):
super().__init__()
self.num_enc_layers = len(depths) // 2
self.num_dec_layers = len(depths) // 2
self.embed_dim = embed_dim
self.patch_norm = patch_norm
self.mlp_ratio = mlp_ratio
self.token_projection = token_projection
self.mlp = token_mlp
self.win_size = win_size
self.reso = img_size
self.pos_drop = nn.Dropout(p=drop_rate)
self.dd_in = dd_in
enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))]
conv_dpr = [drop_path_rate] * depths[4]
dec_dpr = enc_dpr[::-1]
self.input_proj = InputProj(in_channel=dd_in, out_channel=embed_dim, kernel_size=3, stride=1,
act_layer=nn.LeakyReLU)
self.output_proj = OutputProj(in_channel=2 * embed_dim, out_channel=in_chans, kernel_size=3, stride=1)
# Encoder
self.encoderlayer_0 = BasicASTLayer(dim=embed_dim, output_dim=embed_dim, input_resolution=(img_size, img_size),
depth=depths[0], num_heads=num_heads[0], win_size=win_size,
mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])], norm_layer=norm_layer,
use_checkpoint=use_checkpoint, token_projection=token_projection,
token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False)
self.dowsample_0 = dowsample(embed_dim, embed_dim * 2)
self.encoderlayer_1 = BasicASTLayer(dim=embed_dim * 2, output_dim=embed_dim * 2,
input_resolution=(img_size // 2, img_size // 2), depth=depths[1],
num_heads=num_heads[1], win_size=win_size, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], norm_layer=norm_layer,
use_checkpoint=use_checkpoint, token_projection=token_projection,
token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False)
self.dowsample_1 = dowsample(embed_dim * 2, embed_dim * 4)
self.encoderlayer_2 = BasicASTLayer(dim=embed_dim * 4, output_dim=embed_dim * 4,
input_resolution=(img_size // (2 ** 2), img_size // (2 ** 2)),
depth=depths[2], num_heads=num_heads[2], win_size=win_size,
mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], norm_layer=norm_layer,
use_checkpoint=use_checkpoint, token_projection=token_projection,
token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False)
self.dowsample_2 = dowsample(embed_dim * 4, embed_dim * 8)
self.encoderlayer_3 = BasicASTLayer(dim=embed_dim * 8, output_dim=embed_dim * 8,
input_resolution=(img_size // (2 ** 3), img_size // (2 ** 3)),
depth=depths[3], num_heads=num_heads[3], win_size=win_size,
mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])], norm_layer=norm_layer,
use_checkpoint=use_checkpoint, token_projection=token_projection,
token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False)
self.dowsample_3 = dowsample(embed_dim * 8, embed_dim * 16)
# Bottleneck
self.conv = BasicASTLayer(dim=embed_dim * 16, output_dim=embed_dim * 16,
input_resolution=(img_size // (2 ** 4), img_size // (2 ** 4)), depth=depths[4],
num_heads=num_heads[4], win_size=win_size, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=conv_dpr, norm_layer=norm_layer, use_checkpoint=use_checkpoint,
token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag,
att=True, sparseAtt=True)
# Decoder
self.upsample_0 = upsample(embed_dim * 16, embed_dim * 8)
self.decoderlayer_0 = BasicASTLayer(dim=embed_dim * 16, output_dim=embed_dim * 16,
input_resolution=(img_size // (2 ** 3), img_size // (2 ** 3)),
depth=depths[5], num_heads=num_heads[5], win_size=win_size,
mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dec_dpr[:depths[5]],
norm_layer=norm_layer, use_checkpoint=use_checkpoint,
token_projection=token_projection, token_mlp=token_mlp,
shift_flag=shift_flag, att=True, sparseAtt=True)
self.upsample_1 = upsample(embed_dim * 16, embed_dim * 4)
self.decoderlayer_1 = BasicASTLayer(dim=embed_dim * 8, output_dim=embed_dim * 8,
input_resolution=(img_size // (2 ** 2), img_size // (2 ** 2)),
depth=depths[6], num_heads=num_heads[6], win_size=win_size,
mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
drop=drop_rate, attn_drop=attn_drop_rate,
drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], norm_layer=norm_layer,
use_checkpoint=use_checkpoint, token_projection=token_projection,
token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True)
self.upsample_2 = upsample(embed_dim * 8, embed_dim * 2)
self.decoderlayer_2 = BasicASTLayer(dim=embed_dim * 4, output_dim=embed_dim * 4,
input_resolution=(img_size // 2, img_size // 2), depth=depths[7],
num_heads=num_heads[7], win_size=win_size, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], norm_layer=norm_layer,
use_checkpoint=use_checkpoint, token_projection=token_projection,
token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True)
self.upsample_3 = upsample(embed_dim * 4, embed_dim)
self.decoderlayer_3 = BasicASTLayer(dim=embed_dim * 2, output_dim=embed_dim * 2,
input_resolution=(img_size, img_size), depth=depths[8],
num_heads=num_heads[8], win_size=win_size, mlp_ratio=self.mlp_ratio,
qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], norm_layer=norm_layer,
use_checkpoint=use_checkpoint, token_projection=token_projection,
token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, mask=None):
y = self.input_proj(x)
y = self.pos_drop(y)
conv0 = self.encoderlayer_0(y, mask=mask)
pool0 = self.dowsample_0(conv0)
conv1 = self.encoderlayer_1(pool0, mask=mask)
pool1 = self.dowsample_1(conv1)
conv2 = self.encoderlayer_2(pool1, mask=mask)
pool2 = self.dowsample_2(conv2)
conv3 = self.encoderlayer_3(pool2, mask=mask)
pool3 = self.dowsample_3(conv3)
conv4 = self.conv(pool3, mask=mask)
up0 = self.upsample_0(conv4)
deconv0 = torch.cat([up0, conv3], -1)
deconv0 = self.decoderlayer_0(deconv0, mask=mask)
up1 = self.upsample_1(deconv0)
deconv1 = torch.cat([up1, conv2], -1)
deconv1 = self.decoderlayer_1(deconv1, mask=mask)
up2 = self.upsample_2(deconv1)
deconv2 = torch.cat([up2, conv1], -1)
deconv2 = self.decoderlayer_2(deconv2, mask=mask)
up3 = self.upsample_3(deconv2)
deconv3 = torch.cat([up3, conv0], -1)
deconv3 = self.decoderlayer_3(deconv3, mask=mask)
y = self.output_proj(deconv3)
return x + y if self.dd_in == 3 else y
#################################################################################
# #
# PART 2: Hugging Face 包装类 (The Hugging Face Wrapper Classes) #
# #
#################################################################################
class ASTConfig(PretrainedConfig):
"""
This is the configuration class to store the configuration of an `AST` model.
"""
model_type = "ast"
def __init__(self, **kwargs):
super().__init__(**kwargs)
class ASTForRestoration(PreTrainedModel):
"""
This is the main model class that will be loaded by Hugging Face.
"""
config_class = ASTConfig
def __init__(self, config: ASTConfig):
super().__init__(config)
self.model = AST(**config.to_dict())
def forward(self, pixel_values):
"""
The forward pass of the model.
"""
return self.model(pixel_values)
#################################################################################
# #
# PART 3: 主转换逻辑 (Main Conversion Logic) #
# #
#################################################################################
if __name__ == '__main__':
# --- 使用 argparse 使脚本可重用 ---
parser = argparse.ArgumentParser(description="Convert AST model .pth files to Hugging Face format.")
parser.add_argument("--pth_path", type=str, required=True, help="Path to the input .pth weight file.")
parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the Hugging Face model.")
parser.add_argument("--task_name", type=str, default="restoration",
help="Name of the task (e.g., 'dehazing', 'desnowing') for logging.")
args = parser.parse_args()
# --- 模型架构参数 (最终修正版) ---
model_params = {
"img_size": 256,
"in_chans": 3,
"dd_in": 3,
"embed_dim": 32,
"depths": [1, 2, 8, 8, 2, 8, 8, 2, 1], # <--- 最终的关键修正!
"num_heads": [1, 2, 4, 8, 16, 16, 8, 4, 2],
"win_size": 8,
"mlp_ratio": 4.0,
"qkv_bias": True,
"qk_scale": None,
"drop_rate": 0.0,
"attn_drop_rate": 0.0,
"drop_path_rate": 0.1,
"patch_norm": True,
"use_checkpoint": False,
"token_projection": "linear",
"token_mlp": "frfn",
"shift_flag": True
}
# --- 执行转换 ---
print(f" 任务: {args.task_name.upper()} | 步骤 1/5: 正在创建 Hugging Face 模型实例 (AST)...")
hf_config = ASTConfig(**model_params)
hf_model = ASTForRestoration(hf_config)
print("模型实例创建成功!")
print(f"步骤 2/5: 正在从 '{args.pth_path}' 加载权重...")
if not os.path.exists(args.pth_path):
raise FileNotFoundError(f"错误: 找不到权重文件 '{args.pth_path}'。请检查路径是否正确。")
state_dict = torch.load(args.pth_path, map_location='cpu')
print("权重文件加载成功!")
print("步骤 3/5: 正在处理权重字典...")
# 检查权重是否嵌套在某个通用键下
if 'state_dict' in state_dict:
state_dict = state_dict['state_dict']
elif 'params_ema' in state_dict:
state_dict = state_dict['params_ema']
elif 'params' in state_dict:
state_dict = state_dict['params']
# 移除 'module.' 前缀
new_state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()}
# 加载权重
hf_model.model.load_state_dict(new_state_dict)
hf_model.eval()
print("权重成功加载到模型中!")
print(f"步骤 4/5: 正在将模型保存到 '{args.output_dir}'...")
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
hf_model.save_pretrained(args.output_dir)
print(f"模型和 config.json 已保存!")
# 创建并保存图像处理器配置
image_processor_config = {
"do_normalize": True,
"image_mean": [0.5, 0.5, 0.5],
"image_std": [0.5, 0.5, 0.5],
"data_format": "channels_first"
}
with open(os.path.join(args.output_dir, 'preprocessor_config.json'), 'w') as f:
json.dump(image_processor_config, f)
print(f"图像处理器配置 (preprocessor_config.json) 已保存!")
print(f"\n任务 '{args.task_name.upper()}' 转换完成!")
print(f"查看输出目录: {args.output_dir}")
print("\n下一步操作:")
print(f"1. 将此脚本文件本身复制到输出目录 '{args.output_dir}' 中,并重命名为 `modeling_ast.py`。")
print("2. 将整个输出目录上传到您的 Hugging Face 仓库。")
print("3. 在 Hub 上加载模型时,请确保使用 `trust_remote_code=True`。")