import torch import torch.nn as nn import torch.utils.checkpoint as checkpoint from timm.models.layers import DropPath, to_2tuple, trunc_normal_ import torch.nn.functional as F from einops import rearrange, repeat from einops.layers.torch import Rearrange import math import numpy as np import time from torch import einsum import json import os import argparse from transformers import PretrainedConfig, PreTrainedModel ################################################################################# # # # PART 1: 您的模型定义 (From the file you provided) # # # ################################################################################# def conv(in_channels, out_channels, kernel_size, bias=False, stride=1): return nn.Conv2d( in_channels, out_channels, kernel_size, padding=(kernel_size // 2), bias=bias, stride=stride) class ConvBlock(nn.Module): def __init__(self, in_channel, out_channel, strides=1): super(ConvBlock, self).__init__() self.strides = strides self.in_channel = in_channel self.out_channel = out_channel self.block = nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=strides, padding=1), nn.LeakyReLU(inplace=True), nn.Conv2d(out_channel, out_channel, kernel_size=3, stride=strides, padding=1), nn.LeakyReLU(inplace=True), ) self.conv11 = nn.Conv2d(in_channel, out_channel, kernel_size=1, stride=strides, padding=0) def forward(self, x): out1 = self.block(x) out2 = self.conv11(x) out = out1 + out2 return out class LinearProjection(nn.Module): def __init__(self, dim, heads=8, dim_head=64, dropout=0., bias=True): super().__init__() inner_dim = dim_head * heads self.heads = heads self.to_q = nn.Linear(dim, inner_dim, bias=bias) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=bias) self.dim = dim self.inner_dim = inner_dim def forward(self, x, attn_kv=None): B_, N, C = x.shape if attn_kv is not None: attn_kv = attn_kv.unsqueeze(0).repeat(B_, 1, 1) else: attn_kv = x N_kv = attn_kv.size(1) q = self.to_q(x).reshape(B_, N, 1, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) kv = self.to_kv(attn_kv).reshape(B_, N_kv, 2, self.heads, C // self.heads).permute(2, 0, 3, 1, 4) q = q[0] k, v = kv[0], kv[1] return q, k, v class WindowAttention(nn.Module): def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.win_size = win_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) coords_h = torch.arange(self.win_size[0]) coords_w = torch.arange(self.win_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.win_size[0] - 1 relative_coords[:, :, 1] += self.win_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) trunc_normal_(self.relative_position_bias_table, std=.02) if token_projection == 'linear': self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) else: raise Exception("Projection error!") self.token_projection = token_projection self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) def forward(self, x, attn_kv=None, mask=None): B_, N, C = x.shape q, k, v = self.qkv(x, attn_kv) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() ratio = attn.size(-1) // relative_position_bias.size(-1) relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio) attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N * ratio) attn = self.softmax(attn) else: attn = self.softmax(attn) attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class WindowAttention_sparse(nn.Module): def __init__(self, dim, win_size, num_heads, token_projection='linear', qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.): super().__init__() self.dim = dim self.win_size = win_size self.num_heads = num_heads head_dim = dim // num_heads self.scale = qk_scale or head_dim ** -0.5 self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * win_size[0] - 1) * (2 * win_size[1] - 1), num_heads)) coords_h = torch.arange(self.win_size[0]) coords_w = torch.arange(self.win_size[1]) coords = torch.stack(torch.meshgrid([coords_h, coords_w], indexing="ij")) coords_flatten = torch.flatten(coords, 1) relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] relative_coords = relative_coords.permute(1, 2, 0).contiguous() relative_coords[:, :, 0] += self.win_size[0] - 1 relative_coords[:, :, 1] += self.win_size[1] - 1 relative_coords[:, :, 0] *= 2 * self.win_size[1] - 1 relative_position_index = relative_coords.sum(-1) self.register_buffer("relative_position_index", relative_position_index) trunc_normal_(self.relative_position_bias_table, std=.02) if token_projection == 'linear': self.qkv = LinearProjection(dim, num_heads, dim // num_heads, bias=qkv_bias) else: raise Exception("Projection error!") self.token_projection = token_projection self.attn_drop = nn.Dropout(attn_drop) self.proj = nn.Linear(dim, dim) self.proj_drop = nn.Dropout(proj_drop) self.softmax = nn.Softmax(dim=-1) self.relu = nn.ReLU() self.w = nn.Parameter(torch.ones(2)) def forward(self, x, attn_kv=None, mask=None): B_, N, C = x.shape q, k, v = self.qkv(x, attn_kv) q = q * self.scale attn = (q @ k.transpose(-2, -1)) relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view( self.win_size[0] * self.win_size[1], self.win_size[0] * self.win_size[1], -1) relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() ratio = attn.size(-1) // relative_position_bias.size(-1) relative_position_bias = repeat(relative_position_bias, 'nH l c -> nH l (c d)', d=ratio) attn = attn + relative_position_bias.unsqueeze(0) if mask is not None: nW = mask.shape[0] mask = repeat(mask, 'nW m n -> nW m (n d)', d=ratio) attn = attn.view(B_ // nW, nW, self.num_heads, N, N * ratio) + mask.unsqueeze(1).unsqueeze(0) attn = attn.view(-1, self.num_heads, N, N * ratio) attn0 = self.softmax(attn) attn1 = self.relu(attn) ** 2 else: attn0 = self.softmax(attn) attn1 = self.relu(attn) ** 2 w1 = torch.exp(self.w[0]) / torch.sum(torch.exp(self.w)) w2 = torch.exp(self.w[1]) / torch.sum(torch.exp(self.w)) attn = attn0 * w1 + attn1 * w2 attn = self.attn_drop(attn) x = (attn @ v).transpose(1, 2).reshape(B_, N, C) x = self.proj(x) x = self.proj_drop(x) return x class Mlp(nn.Module): def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.): super().__init__() out_features = out_features or in_features hidden_features = hidden_features or in_features self.fc1 = nn.Linear(in_features, hidden_features) self.act = act_layer() self.fc2 = nn.Linear(hidden_features, out_features) self.drop = nn.Dropout(drop) def forward(self, x): x = self.fc1(x) x = self.act(x) x = self.drop(x) x = self.fc2(x) x = self.drop(x) return x class LeFF(nn.Module): def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False): super().__init__() self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim), act_layer()) self.dwconv = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), act_layer()) self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) self.eca = nn.Identity() def forward(self, x): bs, hw, c = x.size() hh = int(math.sqrt(hw)) x = self.linear1(x) x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh) x = self.dwconv(x) x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh) x = self.linear2(x) x = self.eca(x) return x class FRFN(nn.Module): def __init__(self, dim=32, hidden_dim=128, act_layer=nn.GELU, drop=0., use_eca=False): super().__init__() self.linear1 = nn.Sequential(nn.Linear(dim, hidden_dim * 2), act_layer()) self.dwconv = nn.Sequential( nn.Conv2d(hidden_dim, hidden_dim, groups=hidden_dim, kernel_size=3, stride=1, padding=1), act_layer()) self.linear2 = nn.Sequential(nn.Linear(hidden_dim, dim)) self.dim = dim self.hidden_dim = hidden_dim self.dim_conv = self.dim // 4 self.dim_untouched = self.dim - self.dim_conv self.partial_conv3 = nn.Conv2d(self.dim_conv, self.dim_conv, 3, 1, 1, bias=False) def forward(self, x): bs, hw, c = x.size() hh = int(math.sqrt(hw)) x = rearrange(x, ' b (h w) (c) -> b c h w ', h=hh, w=hh) x1, x2, = torch.split(x, [self.dim_conv, self.dim_untouched], dim=1) x1 = self.partial_conv3(x1) x = torch.cat((x1, x2), 1) x = rearrange(x, ' b c h w -> b (h w) c', h=hh, w=hh) x = self.linear1(x) x_1, x_2 = x.chunk(2, dim=-1) x_1 = rearrange(x_1, ' b (h w) (c) -> b c h w ', h=hh, w=hh) x_1 = self.dwconv(x_1) x_1 = rearrange(x_1, ' b c h w -> b (h w) c', h=hh, w=hh) x = x_1 * x_2 x = self.linear2(x) return x def window_partition(x, win_size, dilation_rate=1): B, H, W, C = x.shape if dilation_rate != 1: x = x.permute(0, 3, 1, 2) assert type(dilation_rate) is int, 'dilation_rate should be a int' x = F.unfold(x, kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size) windows = x.permute(0, 2, 1).contiguous().view(-1, C, win_size, win_size) windows = windows.permute(0, 2, 3, 1).contiguous() else: x = x.view(B, H // win_size, win_size, W // win_size, win_size, C) windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, win_size, win_size, C) return windows def window_reverse(windows, win_size, H, W, dilation_rate=1): B = int(windows.shape[0] / (H * W / win_size / win_size)) x = windows.view(B, H // win_size, W // win_size, win_size, win_size, -1) if dilation_rate != 1: x = windows.permute(0, 5, 3, 4, 1, 2).contiguous() x = F.fold(x, (H, W), kernel_size=win_size, dilation=dilation_rate, padding=4 * (dilation_rate - 1), stride=win_size) else: x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) return x class Downsample(nn.Module): def __init__(self, in_channel, out_channel): super(Downsample, self).__init__() self.conv = nn.Sequential(nn.Conv2d(in_channel, out_channel, kernel_size=4, stride=2, padding=1)) def forward(self, x): B, L, C = x.shape H = int(math.sqrt(L)) W = int(math.sqrt(L)) x = x.transpose(1, 2).contiguous().view(B, C, H, W) out = self.conv(x).flatten(2).transpose(1, 2).contiguous() return out class Upsample(nn.Module): def __init__(self, in_channel, out_channel): super(Upsample, self).__init__() self.deconv = nn.Sequential(nn.ConvTranspose2d(in_channel, out_channel, kernel_size=2, stride=2)) def forward(self, x): B, L, C = x.shape H = int(math.sqrt(L)) W = int(math.sqrt(L)) x = x.transpose(1, 2).contiguous().view(B, C, H, W) out = self.deconv(x).flatten(2).transpose(1, 2).contiguous() return out class InputProj(nn.Module): def __init__(self, in_channel=3, out_channel=64, kernel_size=3, stride=1, norm_layer=None, act_layer=nn.LeakyReLU): super().__init__() self.proj = nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2), act_layer(inplace=True)) self.norm = norm_layer(out_channel) if norm_layer is not None else None def forward(self, x): B, C, H, W = x.shape x = self.proj(x).flatten(2).transpose(1, 2).contiguous() if self.norm is not None: x = self.norm(x) return x class OutputProj(nn.Module): def __init__(self, in_channel=64, out_channel=3, kernel_size=3, stride=1, norm_layer=None, act_layer=None): super().__init__() self.proj = nn.Sequential( nn.Conv2d(in_channel, out_channel, kernel_size=3, stride=stride, padding=kernel_size // 2)) if act_layer is not None: self.proj.add_module(str(len(self.proj)), act_layer(inplace=True)) self.norm = norm_layer(out_channel) if norm_layer is not None else None def forward(self, x): B, L, C = x.shape H = int(math.sqrt(L)) W = int(math.sqrt(L)) x = x.transpose(1, 2).view(B, C, H, W) x = self.proj(x) if self.norm is not None: x = self.norm(x) return x class TransformerBlock(nn.Module): def __init__(self, dim, input_resolution, num_heads, win_size=8, shift_size=0, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, token_projection='linear', token_mlp='leff', att=True, sparseAtt=False): super().__init__() self.att = att self.sparseAtt = sparseAtt self.dim = dim self.input_resolution = input_resolution self.num_heads = num_heads self.win_size = win_size self.shift_size = shift_size self.mlp_ratio = mlp_ratio if min(self.input_resolution) <= self.win_size: self.shift_size = 0 self.win_size = min(self.input_resolution) assert 0 <= self.shift_size < self.win_size, "shift_size must in 0-win_size" if self.att: self.norm1 = norm_layer(dim) if self.sparseAtt: self.attn = WindowAttention_sparse(dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, token_projection=token_projection) else: self.attn = WindowAttention(dim, win_size=to_2tuple(self.win_size), num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, token_projection=token_projection) self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity() self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) if token_mlp in ['ffn', 'mlp']: self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop) elif token_mlp == 'leff': self.mlp = LeFF(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop) elif token_mlp == 'frfn': self.mlp = FRFN(dim, mlp_hidden_dim, act_layer=act_layer, drop=drop) else: raise Exception("FFN error!") def forward(self, x, mask=None): B, L, C = x.shape H = int(math.sqrt(L)) W = int(math.sqrt(L)) attn_mask = None if self.shift_size > 0: shift_mask = torch.zeros((1, H, W, 1), device=x.device) h_slices = (slice(0, -self.win_size), slice(-self.win_size, -self.shift_size), slice(-self.shift_size, None)) w_slices = (slice(0, -self.win_size), slice(-self.win_size, -self.shift_size), slice(-self.shift_size, None)) cnt = 0 for h in h_slices: for w in w_slices: shift_mask[:, h, w, :] = cnt cnt += 1 shift_mask_windows = window_partition(shift_mask, self.win_size) shift_mask_windows = shift_mask_windows.view(-1, self.win_size * self.win_size) attn_mask = shift_mask_windows.unsqueeze(1) - shift_mask_windows.unsqueeze(2) attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) shortcut = x if self.att: x = self.norm1(x) x = x.view(B, H, W, C) if self.shift_size > 0: shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)) else: shifted_x = x x_windows = window_partition(shifted_x, self.win_size) x_windows = x_windows.view(-1, self.win_size * self.win_size, C) attn_windows = self.attn(x_windows, mask=attn_mask) attn_windows = attn_windows.view(-1, self.win_size, self.win_size, C) shifted_x = window_reverse(attn_windows, self.win_size, H, W) if self.shift_size > 0: x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)) else: x = shifted_x x = x.view(B, H * W, C) x = shortcut + self.drop_path(x) x = x + self.drop_path(self.mlp(self.norm2(x))) return x class BasicASTLayer(nn.Module): def __init__(self, dim, output_dim, input_resolution, depth, num_heads, win_size, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0., norm_layer=nn.LayerNorm, use_checkpoint=False, token_projection='linear', token_mlp='ffn', shift_flag=True, att=False, sparseAtt=False): super().__init__() self.att = att self.sparseAtt = sparseAtt self.depth = depth self.use_checkpoint = use_checkpoint if shift_flag: self.blocks = nn.ModuleList([ TransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, win_size=win_size, shift_size=0 if (i % 2 == 0) else win_size // 2, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp, att=self.att, sparseAtt=self.sparseAtt) for i in range(depth)]) else: self.blocks = nn.ModuleList([ TransformerBlock(dim=dim, input_resolution=input_resolution, num_heads=num_heads, win_size=win_size, shift_size=0, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop, attn_drop=attn_drop, drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path, norm_layer=norm_layer, token_projection=token_projection, token_mlp=token_mlp, att=self.att, sparseAtt=self.sparseAtt) for i in range(depth)]) def forward(self, x, mask=None): for blk in self.blocks: if self.use_checkpoint: # Note: checkpoint doesn't support mask argument, so we pass it as None x = checkpoint.checkpoint(blk, x, None) else: x = blk(x, mask) return x class AST(nn.Module): def __init__(self, img_size=256, in_chans=3, dd_in=3, embed_dim=32, depths=[2, 2, 2, 2, 2, 2, 2, 2, 2], num_heads=[1, 2, 4, 8, 16, 16, 8, 4, 2], win_size=8, mlp_ratio=4., qkv_bias=True, qk_scale=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.1, norm_layer=nn.LayerNorm, patch_norm=True, use_checkpoint=False, token_projection='linear', token_mlp='leff', dowsample=Downsample, upsample=Upsample, shift_flag=True, **kwargs): super().__init__() self.num_enc_layers = len(depths) // 2 self.num_dec_layers = len(depths) // 2 self.embed_dim = embed_dim self.patch_norm = patch_norm self.mlp_ratio = mlp_ratio self.token_projection = token_projection self.mlp = token_mlp self.win_size = win_size self.reso = img_size self.pos_drop = nn.Dropout(p=drop_rate) self.dd_in = dd_in enc_dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths[:self.num_enc_layers]))] conv_dpr = [drop_path_rate] * depths[4] dec_dpr = enc_dpr[::-1] self.input_proj = InputProj(in_channel=dd_in, out_channel=embed_dim, kernel_size=3, stride=1, act_layer=nn.LeakyReLU) self.output_proj = OutputProj(in_channel=2 * embed_dim, out_channel=in_chans, kernel_size=3, stride=1) # Encoder self.encoderlayer_0 = BasicASTLayer(dim=embed_dim, output_dim=embed_dim, input_resolution=(img_size, img_size), depth=depths[0], num_heads=num_heads[0], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=enc_dpr[sum(depths[:0]):sum(depths[:1])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False) self.dowsample_0 = dowsample(embed_dim, embed_dim * 2) self.encoderlayer_1 = BasicASTLayer(dim=embed_dim * 2, output_dim=embed_dim * 2, input_resolution=(img_size // 2, img_size // 2), depth=depths[1], num_heads=num_heads[1], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=enc_dpr[sum(depths[:1]):sum(depths[:2])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False) self.dowsample_1 = dowsample(embed_dim * 2, embed_dim * 4) self.encoderlayer_2 = BasicASTLayer(dim=embed_dim * 4, output_dim=embed_dim * 4, input_resolution=(img_size // (2 ** 2), img_size // (2 ** 2)), depth=depths[2], num_heads=num_heads[2], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=enc_dpr[sum(depths[:2]):sum(depths[:3])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False) self.dowsample_2 = dowsample(embed_dim * 4, embed_dim * 8) self.encoderlayer_3 = BasicASTLayer(dim=embed_dim * 8, output_dim=embed_dim * 8, input_resolution=(img_size // (2 ** 3), img_size // (2 ** 3)), depth=depths[3], num_heads=num_heads[3], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=enc_dpr[sum(depths[:3]):sum(depths[:4])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=False, sparseAtt=False) self.dowsample_3 = dowsample(embed_dim * 8, embed_dim * 16) # Bottleneck self.conv = BasicASTLayer(dim=embed_dim * 16, output_dim=embed_dim * 16, input_resolution=(img_size // (2 ** 4), img_size // (2 ** 4)), depth=depths[4], num_heads=num_heads[4], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=conv_dpr, norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True) # Decoder self.upsample_0 = upsample(embed_dim * 16, embed_dim * 8) self.decoderlayer_0 = BasicASTLayer(dim=embed_dim * 16, output_dim=embed_dim * 16, input_resolution=(img_size // (2 ** 3), img_size // (2 ** 3)), depth=depths[5], num_heads=num_heads[5], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dec_dpr[:depths[5]], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True) self.upsample_1 = upsample(embed_dim * 16, embed_dim * 4) self.decoderlayer_1 = BasicASTLayer(dim=embed_dim * 8, output_dim=embed_dim * 8, input_resolution=(img_size // (2 ** 2), img_size // (2 ** 2)), depth=depths[6], num_heads=num_heads[6], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dec_dpr[sum(depths[5:6]):sum(depths[5:7])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True) self.upsample_2 = upsample(embed_dim * 8, embed_dim * 2) self.decoderlayer_2 = BasicASTLayer(dim=embed_dim * 4, output_dim=embed_dim * 4, input_resolution=(img_size // 2, img_size // 2), depth=depths[7], num_heads=num_heads[7], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dec_dpr[sum(depths[5:7]):sum(depths[5:8])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True) self.upsample_3 = upsample(embed_dim * 4, embed_dim) self.decoderlayer_3 = BasicASTLayer(dim=embed_dim * 2, output_dim=embed_dim * 2, input_resolution=(img_size, img_size), depth=depths[8], num_heads=num_heads[8], win_size=win_size, mlp_ratio=self.mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dec_dpr[sum(depths[5:8]):sum(depths[5:9])], norm_layer=norm_layer, use_checkpoint=use_checkpoint, token_projection=token_projection, token_mlp=token_mlp, shift_flag=shift_flag, att=True, sparseAtt=True) self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): trunc_normal_(m.weight, std=.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def forward(self, x, mask=None): y = self.input_proj(x) y = self.pos_drop(y) conv0 = self.encoderlayer_0(y, mask=mask) pool0 = self.dowsample_0(conv0) conv1 = self.encoderlayer_1(pool0, mask=mask) pool1 = self.dowsample_1(conv1) conv2 = self.encoderlayer_2(pool1, mask=mask) pool2 = self.dowsample_2(conv2) conv3 = self.encoderlayer_3(pool2, mask=mask) pool3 = self.dowsample_3(conv3) conv4 = self.conv(pool3, mask=mask) up0 = self.upsample_0(conv4) deconv0 = torch.cat([up0, conv3], -1) deconv0 = self.decoderlayer_0(deconv0, mask=mask) up1 = self.upsample_1(deconv0) deconv1 = torch.cat([up1, conv2], -1) deconv1 = self.decoderlayer_1(deconv1, mask=mask) up2 = self.upsample_2(deconv1) deconv2 = torch.cat([up2, conv1], -1) deconv2 = self.decoderlayer_2(deconv2, mask=mask) up3 = self.upsample_3(deconv2) deconv3 = torch.cat([up3, conv0], -1) deconv3 = self.decoderlayer_3(deconv3, mask=mask) y = self.output_proj(deconv3) return x + y if self.dd_in == 3 else y ################################################################################# # # # PART 2: Hugging Face 包装类 (The Hugging Face Wrapper Classes) # # # ################################################################################# class ASTConfig(PretrainedConfig): """ This is the configuration class to store the configuration of an `AST` model. """ model_type = "ast" def __init__(self, **kwargs): super().__init__(**kwargs) class ASTForRestoration(PreTrainedModel): """ This is the main model class that will be loaded by Hugging Face. """ config_class = ASTConfig def __init__(self, config: ASTConfig): super().__init__(config) self.model = AST(**config.to_dict()) def forward(self, pixel_values): """ The forward pass of the model. """ return self.model(pixel_values) ################################################################################# # # # PART 3: 主转换逻辑 (Main Conversion Logic) # # # ################################################################################# if __name__ == '__main__': # --- 使用 argparse 使脚本可重用 --- parser = argparse.ArgumentParser(description="Convert AST model .pth files to Hugging Face format.") parser.add_argument("--pth_path", type=str, required=True, help="Path to the input .pth weight file.") parser.add_argument("--output_dir", type=str, required=True, help="Directory to save the Hugging Face model.") parser.add_argument("--task_name", type=str, default="restoration", help="Name of the task (e.g., 'dehazing', 'desnowing') for logging.") args = parser.parse_args() # --- 模型架构参数 (最终修正版) --- model_params = { "img_size": 256, "in_chans": 3, "dd_in": 3, "embed_dim": 32, "depths": [1, 2, 8, 8, 2, 8, 8, 2, 1], # <--- 最终的关键修正! "num_heads": [1, 2, 4, 8, 16, 16, 8, 4, 2], "win_size": 8, "mlp_ratio": 4.0, "qkv_bias": True, "qk_scale": None, "drop_rate": 0.0, "attn_drop_rate": 0.0, "drop_path_rate": 0.1, "patch_norm": True, "use_checkpoint": False, "token_projection": "linear", "token_mlp": "frfn", "shift_flag": True } # --- 执行转换 --- print(f" 任务: {args.task_name.upper()} | 步骤 1/5: 正在创建 Hugging Face 模型实例 (AST)...") hf_config = ASTConfig(**model_params) hf_model = ASTForRestoration(hf_config) print("模型实例创建成功!") print(f"步骤 2/5: 正在从 '{args.pth_path}' 加载权重...") if not os.path.exists(args.pth_path): raise FileNotFoundError(f"错误: 找不到权重文件 '{args.pth_path}'。请检查路径是否正确。") state_dict = torch.load(args.pth_path, map_location='cpu') print("权重文件加载成功!") print("步骤 3/5: 正在处理权重字典...") # 检查权重是否嵌套在某个通用键下 if 'state_dict' in state_dict: state_dict = state_dict['state_dict'] elif 'params_ema' in state_dict: state_dict = state_dict['params_ema'] elif 'params' in state_dict: state_dict = state_dict['params'] # 移除 'module.' 前缀 new_state_dict = {k.replace('module.', '', 1): v for k, v in state_dict.items()} # 加载权重 hf_model.model.load_state_dict(new_state_dict) hf_model.eval() print("权重成功加载到模型中!") print(f"步骤 4/5: 正在将模型保存到 '{args.output_dir}'...") if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) hf_model.save_pretrained(args.output_dir) print(f"模型和 config.json 已保存!") # 创建并保存图像处理器配置 image_processor_config = { "do_normalize": True, "image_mean": [0.5, 0.5, 0.5], "image_std": [0.5, 0.5, 0.5], "data_format": "channels_first" } with open(os.path.join(args.output_dir, 'preprocessor_config.json'), 'w') as f: json.dump(image_processor_config, f) print(f"图像处理器配置 (preprocessor_config.json) 已保存!") print(f"\n任务 '{args.task_name.upper()}' 转换完成!") print(f"查看输出目录: {args.output_dir}") print("\n下一步操作:") print(f"1. 将此脚本文件本身复制到输出目录 '{args.output_dir}' 中,并重命名为 `modeling_ast.py`。") print("2. 将整个输出目录上传到您的 Hugging Face 仓库。") print("3. 在 Hub 上加载模型时,请确保使用 `trust_remote_code=True`。")