{ "cells": [ { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torch import nn\n", "import math\n", "from timm.layers import DropPath, trunc_normal_\n", "\n", "def stride_generator(N, reverse=False):\n", " strides = [1, 2] * 10\n", " if reverse:\n", " return list(reversed(strides[:N]))\n", " else:\n", " return strides[:N]\n", " \n", "class MLP(nn.Module):\n", " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n", " super(MLP, self).__init__()\n", " out_features = out_features or in_features\n", " hidden_features = hidden_features or in_features\n", " self.fc1 = nn.Linear(in_features, hidden_features)\n", " self.act = act_layer()\n", " self.fc2 = nn.Linear(hidden_features, out_features)\n", " self.drop = nn.Dropout(drop)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = self.act(x)\n", " x = self.drop(x)\n", " x = self.fc2(x)\n", " x = self.drop(x)\n", " return x\n", "\n", "class ConvMLP(nn.Module):\n", " def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):\n", " super(ConvMLP, self).__init__()\n", " out_features = out_features or in_features\n", " hidden_features = hidden_features or in_features\n", " self.fc1 = nn.Conv2d(in_features, hidden_features, 1)\n", " self.act = act_layer()\n", " self.fc2 = nn.Conv2d(hidden_features, out_features, 1)\n", " self.drop = nn.Dropout(drop)\n", "\n", " def forward(self, x):\n", " x = self.fc1(x)\n", " x = self.act(x)\n", " x = self.drop(x)\n", " x = self.fc2(x)\n", " x = self.drop(x)\n", " return x\n", "\n", "class Attention(nn.Module):\n", " def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):\n", " super(Attention, self).__init__()\n", " self.num_heads = num_heads\n", " head_dim = dim // num_heads\n", " self.scale = qk_scale or head_dim ** -0.5\n", "\n", " self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)\n", " self.attn_drop = nn.Dropout(attn_drop)\n", " self.proj = nn.Linear(dim, dim)\n", " self.proj_drop = nn.Dropout(proj_drop)\n", "\n", " def forward(self, x):\n", " B, N, C = x.shape\n", " qkv = (\n", " self.qkv(x)\n", " .reshape(B, N, 3, self.num_heads, C // self.num_heads)\n", " .permute(2, 0, 3, 1, 4)\n", " )\n", " q, k, v = qkv[0], qkv[1], qkv[2]\n", "\n", " attn = (q @ k.transpose(-2, -1)) * self.scale\n", " attn = attn.softmax(dim=-1)\n", " attn = self.attn_drop(attn)\n", "\n", " x = (attn @ v).transpose(1, 2).reshape(B, N, C)\n", " x = self.proj(x)\n", " x = self.proj_drop(x)\n", " return x\n", "\n", "class ConvBlock(nn.Module):\n", " def __init__(\n", " self,\n", " dim,\n", " num_heads=4,\n", " mlp_ratio=4.,\n", " qkv_bias=False,\n", " qk_scale=None,\n", " drop=0.,\n", " attn_drop=0.,\n", " drop_path=0.,\n", " act_layer=nn.GELU,\n", " norm_layer=nn.LayerNorm\n", " ):\n", " super(ConvBlock, self).__init__()\n", " self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n", " self.norm1 = nn.BatchNorm2d(dim)\n", " self.conv1 = nn.Conv2d(dim, dim, 1)\n", " self.conv2 = nn.Conv2d(dim, dim, 1)\n", " self.attn = nn.Conv2d(dim, dim, 5, padding=2, groups=dim)\n", " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n", " self.norm2 = nn.BatchNorm2d(dim)\n", " mlp_hidden_dim = int(dim * mlp_ratio)\n", " self.mlp = ConvMLP(\n", " in_features=dim,\n", " hidden_features=mlp_hidden_dim,\n", " act_layer=act_layer,\n", " drop=drop\n", " )\n", "\n", " self.apply(self._init_weights)\n", "\n", " def _init_weights(self, m):\n", " if isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):\n", " nn.init.constant_(m.bias, 0)\n", " nn.init.constant_(m.weight, 1.0)\n", " elif isinstance(m, nn.Conv2d):\n", " fan_out = (\n", " m.kernel_size[0] * m.kernel_size[1] * m.out_channels\n", " )\n", " fan_out //= m.groups\n", " m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))\n", " if m.bias is not None:\n", " m.bias.data.zero_()\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay(self):\n", " return {}\n", "\n", " def forward(self, x):\n", " x = x + self.pos_embed(x)\n", " x = x + self.drop_path(\n", " self.conv2(self.attn(self.conv1(self.norm1(x))))\n", " )\n", " x = x + self.drop_path(self.mlp(self.norm2(x)))\n", " return x\n", "\n", "class SelfAttentionBlock(nn.Module):\n", " def __init__(\n", " self,\n", " dim,\n", " num_heads,\n", " mlp_ratio=4.,\n", " qkv_bias=False,\n", " qk_scale=None,\n", " drop=0.,\n", " attn_drop=0.,\n", " drop_path=0.,\n", " init_value=1e-6,\n", " act_layer=nn.GELU,\n", " norm_layer=nn.LayerNorm\n", " ):\n", " super(SelfAttentionBlock, self).__init__()\n", " self.pos_embed = nn.Conv2d(dim, dim, 3, padding=1, groups=dim)\n", " self.norm1 = norm_layer(dim)\n", " self.attn = Attention(\n", " dim,\n", " num_heads=num_heads,\n", " qkv_bias=qkv_bias,\n", " qk_scale=qk_scale,\n", " attn_drop=attn_drop,\n", " proj_drop=drop\n", " )\n", " self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()\n", " self.norm2 = norm_layer(dim)\n", " mlp_hidden_dim = int(dim * mlp_ratio)\n", " self.mlp = MLP(\n", " in_features=dim,\n", " hidden_features=mlp_hidden_dim,\n", " act_layer=act_layer,\n", " drop=drop\n", " )\n", " self.gamma_1 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)\n", " self.gamma_2 = nn.Parameter(init_value * torch.ones((dim)), requires_grad=True)\n", "\n", " self.apply(self._init_weights)\n", "\n", " def _init_weights(self, m):\n", " if isinstance(m, nn.Linear):\n", " trunc_normal_(m.weight, std=.02)\n", " if isinstance(m, nn.Linear) and m.bias is not None:\n", " nn.init.constant_(m.bias, 0)\n", " elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):\n", " nn.init.constant_(m.bias, 0)\n", " nn.init.constant_(m.weight, 1.0)\n", "\n", " @torch.jit.ignore\n", " def no_weight_decay(self):\n", " return {'gamma_1', 'gamma_2'}\n", "\n", " def forward(self, x):\n", " x = x + self.pos_embed(x)\n", " B, N, H, W = x.shape\n", " x = x.flatten(2).transpose(1, 2)\n", " x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))\n", " x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))\n", " x = x.transpose(1, 2).reshape(B, N, H, W)\n", " return x\n", "\n", "def UniformerSubBlock(\n", " embed_dims,\n", " mlp_ratio=4.,\n", " drop=0.,\n", " drop_path=0.,\n", " init_value=1e-6,\n", " block_type='Conv'\n", "):\n", " assert block_type in ['Conv', 'MHSA']\n", " if block_type == 'Conv':\n", " return ConvBlock(dim=embed_dims, mlp_ratio=mlp_ratio, drop=drop, drop_path=drop_path)\n", " else:\n", " return SelfAttentionBlock(\n", " dim=embed_dims,\n", " num_heads=8,\n", " mlp_ratio=mlp_ratio,\n", " qkv_bias=True,\n", " drop=drop,\n", " drop_path=drop_path,\n", " init_value=init_value\n", " )\n", "\n", "class SpatioTemporalEvolutionBlock(nn.Module):\n", " def __init__(\n", " self,\n", " in_channels,\n", " out_channels,\n", " input_resolution=None,\n", " mlp_ratio=8.,\n", " drop=0.0,\n", " drop_path=0.0,\n", " layer_i=0\n", " ):\n", " super(SpatioTemporalEvolutionBlock, self).__init__()\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " block_type = 'MHSA' if in_channels == out_channels and layer_i > 0 else 'Conv'\n", " self.block = UniformerSubBlock(\n", " in_channels,\n", " mlp_ratio=mlp_ratio,\n", " drop=drop,\n", " drop_path=drop_path,\n", " block_type=block_type\n", " )\n", "\n", " if in_channels != out_channels:\n", " self.reduction = nn.Conv2d(\n", " in_channels,\n", " out_channels,\n", " kernel_size=1,\n", " stride=1,\n", " padding=0\n", " )\n", "\n", " def forward(self, x):\n", " z = self.block(x)\n", " if self.in_channels != self.out_channels:\n", " z = self.reduction(z)\n", " return z\n", "\n", "class SpatioTemporalEvolution(nn.Module):\n", " def __init__(\n", " self,\n", " channel_in,\n", " channel_hid,\n", " N2,\n", " input_resolution=None,\n", " mlp_ratio=4.,\n", " drop=0.0,\n", " drop_path=0.1\n", " ):\n", " super(SpatioTemporalEvolution, self).__init__()\n", " assert N2 >= 2 and mlp_ratio > 1\n", " self.N2 = N2\n", " dpr = [x.item() for x in torch.linspace(1e-2, drop_path, self.N2)]\n", "\n", " evolution_layers = [SpatioTemporalEvolutionBlock(\n", " channel_in,\n", " channel_hid,\n", " input_resolution,\n", " mlp_ratio=mlp_ratio,\n", " drop=drop,\n", " drop_path=dpr[0],\n", " layer_i=0\n", " )]\n", "\n", " for i in range(1, N2 - 1):\n", " evolution_layers.append(SpatioTemporalEvolutionBlock(\n", " channel_hid,\n", " channel_hid,\n", " input_resolution,\n", " mlp_ratio=mlp_ratio,\n", " drop=drop,\n", " drop_path=dpr[i],\n", " layer_i=i\n", " ))\n", "\n", " evolution_layers.append(SpatioTemporalEvolutionBlock(\n", " channel_hid,\n", " channel_in,\n", " input_resolution,\n", " mlp_ratio=mlp_ratio,\n", " drop=drop,\n", " drop_path=drop_path,\n", " layer_i=N2 - 1\n", " ))\n", " self.enc = nn.Sequential(*evolution_layers)\n", "\n", " def forward(self, x):\n", " B, T, C, H, W = x.shape\n", " x = x.reshape(B, T * C, H, W)\n", " z = x\n", " for i in range(self.N2):\n", " z = self.enc[i](z)\n", " y = z.reshape(B, T, C, H, W)\n", " return y\n", "\n", "class BasicConv2d(nn.Module):\n", " def __init__(\n", " self,\n", " in_channels,\n", " out_channels,\n", " kernel_size,\n", " stride,\n", " padding,\n", " transpose=False,\n", " act_norm=False\n", " ):\n", " super(BasicConv2d, self).__init__()\n", " self.act_norm = act_norm\n", " if not transpose:\n", " self.conv = nn.Conv2d(\n", " in_channels,\n", " out_channels,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding\n", " )\n", " else:\n", " self.conv = nn.ConvTranspose2d(\n", " in_channels,\n", " out_channels,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding,\n", " output_padding=stride // 2\n", " )\n", " self.norm = nn.GroupNorm(2, out_channels)\n", " self.act = nn.LeakyReLU(0.2, inplace=True)\n", "\n", " def forward(self, x):\n", " y = self.conv(x)\n", " if self.act_norm:\n", " y = self.act(self.norm(y))\n", " return y\n", "\n", "class ConvDynamicsLayer(nn.Module):\n", " def __init__(self, C_in, C_out, stride, transpose=False, act_norm=True):\n", " super(ConvDynamicsLayer, self).__init__()\n", " if stride == 1:\n", " transpose = False\n", " self.conv = BasicConv2d(\n", " C_in,\n", " C_out,\n", " kernel_size=3,\n", " stride=stride,\n", " padding=1,\n", " transpose=transpose,\n", " act_norm=act_norm\n", " )\n", "\n", " def forward(self, x):\n", " y = self.conv(x)\n", " return y\n", "\n", "class MultiGroupConv2d(nn.Module):\n", " def __init__(\n", " self,\n", " in_channels,\n", " out_channels,\n", " kernel_size,\n", " stride,\n", " padding,\n", " groups,\n", " act_norm=False\n", " ):\n", " super(MultiGroupConv2d, self).__init__()\n", " self.act_norm = act_norm\n", " if in_channels % groups != 0:\n", " groups = 1\n", " self.conv = nn.Conv2d(\n", " in_channels,\n", " out_channels,\n", " kernel_size=kernel_size,\n", " stride=stride,\n", " padding=padding,\n", " groups=groups\n", " )\n", " self.norm = nn.GroupNorm(groups, out_channels)\n", " self.activate = nn.LeakyReLU(0.2, inplace=True)\n", "\n", " def forward(self, x):\n", " y = self.conv(x)\n", " if self.act_norm:\n", " y = self.activate(self.norm(y))\n", " return y\n", "\n", "\n", "class AtmosphericEncoder(nn.Module):\n", " def __init__(self, C_in, spatial_hidden_dim, num_spatial_layers):\n", " super(AtmosphericEncoder, self).__init__()\n", " strides = stride_generator(num_spatial_layers)\n", " self.enc = nn.Sequential(\n", " ConvDynamicsLayer(C_in, spatial_hidden_dim, stride=strides[0]),\n", " *[ConvDynamicsLayer(spatial_hidden_dim, spatial_hidden_dim, stride=s) for s in strides[1:]]\n", " )\n", "\n", " def forward(self, x):\n", " enc1 = self.enc[0](x)\n", " latent = enc1\n", " for i in range(1, len(self.enc)):\n", " latent = self.enc[i](latent)\n", " return latent, enc1\n", "\n", "class AtmosphericDecoder(nn.Module):\n", " def __init__(self, spatial_hidden_dim, C_out, num_spatial_layers):\n", " super(AtmosphericDecoder, self).__init__()\n", " strides = stride_generator(num_spatial_layers, reverse=True)\n", " self.dec = nn.Sequential(\n", " *[ConvDynamicsLayer(spatial_hidden_dim, spatial_hidden_dim, stride=s, transpose=True) for s in strides[:-1]],\n", " ConvDynamicsLayer(2 * spatial_hidden_dim, spatial_hidden_dim, stride=strides[-1], transpose=True)\n", " )\n", " self.readout = nn.Conv2d(spatial_hidden_dim, C_out, 1)\n", "\n", " def forward(self, hid, enc1=None):\n", " for i in range(0, len(self.dec) - 1):\n", " hid = self.dec[i](hid)\n", " Y = self.dec[-1](torch.cat([hid, enc1], dim=1))\n", " Y = self.readout(Y)\n", " return Y\n", "\n", "class Triton(nn.Module):\n", " def __init__(\n", " self,\n", " shape_in,\n", " spatial_hidden_dim=64,\n", " output_channels=4,\n", " temporal_hidden_dim=128,\n", " num_spatial_layers=4,\n", " num_temporal_layers=8,\n", " in_time_seq_length=10,\n", " out_time_seq_length=10\n", " ):\n", " super(Triton, self).__init__()\n", " T, C, H, W = shape_in\n", " self.H1 = int(H / 2 ** (num_spatial_layers / 2)) + 1 if H % 3 == 0 else int(H / 2 ** (num_spatial_layers / 2))\n", " self.W1 = int(W / 2 ** (num_spatial_layers / 2))\n", " self.output_dim = output_channels\n", " self.input_time_seq_length = in_time_seq_length\n", " self.output_time_seq_length = out_time_seq_length\n", " \n", " self.atmospheric_encoder = AtmosphericEncoder(C, spatial_hidden_dim, num_spatial_layers)\n", " self.temporal_evolution = SpatioTemporalEvolution(\n", " T * spatial_hidden_dim,\n", " temporal_hidden_dim,\n", " num_temporal_layers,\n", " input_resolution=[self.H1, self.W1],\n", " mlp_ratio=4.0,\n", " drop_path=0.1\n", " )\n", " self.atmospheric_decoder = AtmosphericDecoder(spatial_hidden_dim, self.output_dim, num_spatial_layers)\n", "\n", " def forward(self, input_state):\n", " \"\"\"\n", " 1. Reshape the input state to match the encoder's input requirements.\n", " 2. Extract features using the Atmospheric Encoder and obtain skip connections.\n", " 3. Perform spatio-temporal evolution on the encoded features.\n", " 4. Decode the evolved features to generate the final output.\n", " \"\"\"\n", " batch_size, temporal_length, channels, height, width = input_state.shape\n", " reshaped_input = input_state.view(batch_size * temporal_length, channels, height, width)\n", " \n", " encoded_features, skip_connection = self.atmospheric_encoder(reshaped_input)\n", " _, encoded_channels, encoded_height, encoded_width = encoded_features.shape\n", " encoded_features = encoded_features.view(batch_size, temporal_length, encoded_channels, encoded_height, encoded_width)\n", " \n", " temporal_bias = encoded_features\n", " temporal_hidden = self.temporal_evolution(temporal_bias)\n", " reshaped_hidden = temporal_hidden.view(batch_size * temporal_length, encoded_channels, encoded_height, encoded_width)\n", "\n", " decoded_output = self.atmospheric_decoder(reshaped_hidden, skip_connection)\n", " final_output = decoded_output.view(batch_size, temporal_length, -1, height, width)\n", " \n", " return final_output\n", "\n", "\n", "def count_parameters(model):\n", " return sum(p.numel() for p in model.parameters() if p.requires_grad)\n", "\n", "if __name__ == '__main__':\n", " inputs = torch.randn(1, 10, 7, 720, 1440)\n", " model = Triton(\n", " shape_in=(10, 7, 720, 1440),\n", " spatial_hidden_dim=32,\n", " output_channels=1,\n", " temporal_hidden_dim=64,\n", " num_spatial_layers=4,\n", " num_temporal_layers=8)\n", " output = model(inputs)\n", " print(output.shape)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "envwu", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.19" } }, "nbformat": 4, "nbformat_minor": 2 }