{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "617aaf35-0071-42dc-be81-cd286853029a", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([1, 1, 1, 128, 128])\n", "Output shape: torch.Size([1, 1, 1, 128, 128])\n" ] } ], "source": [ "import torch\n", "import torch.nn as nn\n", "from torch.nn import LeakyReLU as LReLu\n", "\n", "class CNOBlock(nn.Module):\n", " def __init__(self,\n", " in_channels,\n", " out_channels,\n", " in_size_h,\n", " in_size_w,\n", " out_size_h,\n", " out_size_w,\n", " cutoff_den = 2.0001,\n", " conv_kernel = 3,\n", " filter_size = 6,\n", " lrelu_upsampling = 2,\n", " half_width_mult = 0.8,\n", " radial = False,\n", " batch_norm = True,\n", " activation = 'cno_lrelu'\n", " ):\n", " super(CNOBlock, self).__init__()\n", " \n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.in_size_h = in_size_h\n", " self.in_size_w = in_size_w\n", " self.out_size_h = out_size_h\n", " self.out_size_w = out_size_w\n", " self.conv_kernel = conv_kernel\n", " self.batch_norm_flag = batch_norm\n", " \n", " #---------- Filter properties -----------\n", " self.critically_sampled = False # We use w_c = s/2.0001 --> NOT critically sampled\n", "\n", " if cutoff_den == 2.0:\n", " self.critically_sampled = True\n", " self.in_cutoff_h = self.in_size_h / cutoff_den\n", " self.in_cutoff_w = self.in_size_w / cutoff_den\n", " self.out_cutoff_h = self.out_size_h / cutoff_den\n", " self.out_cutoff_w = self.out_size_w / cutoff_den\n", "\n", " self.in_halfwidth_h = half_width_mult*self.in_size_h - self.in_size_h / cutoff_den\n", " self.in_halfwidth_w = half_width_mult*self.in_size_w - self.in_size_w / cutoff_den\n", " self.out_halfwidth_h = half_width_mult*self.out_size_h - self.out_size_h / cutoff_den\n", " self.out_halfwidth_w = half_width_mult*self.out_size_w - self.out_size_w / cutoff_den\n", "\n", "\n", "\n", " pad = (self.conv_kernel - 1) // 2\n", " self.convolution = torch.nn.Conv2d(in_channels=self.in_channels, out_channels=self.out_channels,\n", " kernel_size=self.conv_kernel,\n", " padding=pad)\n", " \n", " if self.batch_norm_flag:\n", " self.batch_norm = nn.BatchNorm2d(self.out_channels)\n", " else:\n", " self.batch_norm = None\n", " self.activation = LReLu() #\n", " \n", " def forward(self, x):\n", " x = self.convolution(x)\n", " if self.batch_norm_flag:\n", " x = self.batch_norm(x)\n", " x = self.activation(x)\n", " return x\n", "\n", "class LiftProjectBlock(nn.Module):\n", " def __init__(self,\n", " in_channels,\n", " out_channels,\n", " in_size_h,\n", " in_size_w,\n", " out_size_h,\n", " out_size_w,\n", " latent_dim = 64,\n", " cutoff_den = 2.0001,\n", " conv_kernel = 3,\n", " filter_size = 6,\n", " lrelu_upsampling = 2,\n", " half_width_mult = 0.8,\n", " radial = False,\n", " batch_norm = True,\n", " activation = 'cno_lrelu'\n", " ):\n", " super(LiftProjectBlock, self).__init__()\n", " \n", " self.inter_CNOBlock = CNOBlock(in_channels=in_channels,\n", " out_channels=latent_dim,\n", " in_size_h=in_size_h,\n", " in_size_w=in_size_w,\n", " out_size_h=out_size_h,\n", " out_size_w=out_size_w,\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=batch_norm,\n", " activation=activation)\n", " \n", " pad = (conv_kernel - 1) // 2\n", " self.convolution = torch.nn.Conv2d(in_channels=latent_dim, out_channels=out_channels,\n", " kernel_size=conv_kernel, stride=1,\n", " padding=pad)\n", " \n", " self.batch_norm_flag = batch_norm\n", " if self.batch_norm_flag:\n", " self.batch_norm = nn.BatchNorm2d(out_channels)\n", " else:\n", " self.batch_norm = None\n", " \n", " def forward(self, x):\n", " x = self.inter_CNOBlock(x)\n", " \n", " x = self.convolution(x)\n", " if self.batch_norm_flag:\n", " x = self.batch_norm(x)\n", " return x\n", "\n", "class ResidualBlock(nn.Module):\n", " def __init__(self,\n", " channels,\n", " size_h,\n", " size_w,\n", " cutoff_den = 2.0001,\n", " conv_kernel = 3,\n", " filter_size = 6,\n", " lrelu_upsampling = 2,\n", " half_width_mult = 0.8,\n", " radial = False,\n", " batch_norm = True,\n", " activation = 'cno_lrelu'\n", " ):\n", " super(ResidualBlock, self).__init__()\n", "\n", " self.channels = channels\n", " self.size_h = size_h\n", " self.size_w = size_w\n", " self.conv_kernel = conv_kernel\n", " self.batch_norm_flag = batch_norm\n", "\n", " #---------- Filter properties -----------\n", " self.critically_sampled = False # We use w_c = s/2.0001 --> NOT critically sampled\n", "\n", " if cutoff_den == 2.0:\n", " self.critically_sampled = True\n", " self.cutoff_h = self.size_h / cutoff_den \n", " self.cutoff_w = self.size_w / cutoff_den \n", " self.halfwidth_h = half_width_mult*self.size_h - self.size_h / cutoff_den\n", " self.halfwidth_w = half_width_mult*self.size_w - self.size_w / cutoff_den\n", "\n", " #-----------------------------------------\n", " \n", " pad = (self.conv_kernel - 1) // 2\n", " self.convolution1 = torch.nn.Conv2d(in_channels=self.channels, out_channels=self.channels,\n", " kernel_size=self.conv_kernel, stride=1,\n", " padding=pad)\n", " self.convolution2 = torch.nn.Conv2d(in_channels=self.channels, out_channels=self.channels,\n", " kernel_size=self.conv_kernel, stride=1,\n", " padding=pad)\n", " \n", " if self.batch_norm_flag:\n", " self.batch_norm1 = nn.BatchNorm2d(self.channels)\n", " self.batch_norm2 = nn.BatchNorm2d(self.channels)\n", " else:\n", " self.batch_norm1 = self.batch_norm2 = None\n", " self.activation = LReLu()\n", "\n", " def forward(self, x):\n", " out = self.convolution1(x)\n", " if self.batch_norm_flag:\n", " out = self.batch_norm1(out)\n", " out = self.activation(out)\n", " out = self.convolution2(out)\n", " if self.batch_norm_flag:\n", " out = self.batch_norm2(out)\n", " \n", " return x + out\n", "\n", "class CNO(nn.Module):\n", " def __init__(self, \n", " in_dim, \n", " in_size_h, \n", " in_size_w, \n", " N_layers, \n", " N_res = 1, \n", " N_res_neck = 6, \n", " channel_multiplier = 32, \n", " conv_kernel=3, \n", " cutoff_den = 2.0001, \n", " filter_size=6, \n", " lrelu_upsampling = 2, \n", " half_width_mult = 0.8, \n", " radial = False, \n", " batch_norm = True, \n", " out_dim = 10, \n", " out_size_h = 1, \n", " out_size_w = 1, \n", " expand_input = False, \n", " latent_lift_proj_dim = 64, \n", " add_inv = True, \n", " activation = 'cno_lrelu' \n", " ):\n", " \n", " super(CNO, self).__init__()\n", "\n", "\n", " self.N_layers = int(N_layers)\n", " \n", " self.lift_dim = channel_multiplier // 2 \n", " self.out_dim = out_dim\n", " \n", " self.add_inv = add_inv\n", " \n", " self.channel_multiplier = channel_multiplier \n", " \n", " if radial == 0:\n", " self.radial = False\n", " else:\n", " self.radial = True\n", " \n", "\n", " self.encoder_features = [self.lift_dim]\n", " for i in range(self.N_layers):\n", " self.encoder_features.append(2 ** i * self.channel_multiplier)\n", " \n", " self.decoder_features_in = self.encoder_features[1:]\n", " self.decoder_features_in.reverse()\n", " self.decoder_features_out = self.encoder_features[:-1]\n", " self.decoder_features_out.reverse()\n", "\n", " for i in range(1, self.N_layers):\n", " self.decoder_features_in[i] = 2 * self.decoder_features_in[i] \n", "\n", " self.inv_features = self.decoder_features_in.copy()\n", " self.inv_features.append(self.encoder_features[0] + self.decoder_features_out[-1]) \n", "\n", " \n", " if not expand_input:\n", " latent_size_h = in_size_h \n", " latent_size_w = in_size_w \n", " else:\n", " down_exponent = 2 ** N_layers\n", " latent_size_h = in_size_h - (in_size_h % down_exponent) + down_exponent\n", " latent_size_w = in_size_w - (in_size_w % down_exponent) + down_exponent\n", " \n", " if out_size_h == 1:\n", " latent_size_out_h = latent_size_h\n", " else:\n", " if not expand_input:\n", " latent_size_out_h = out_size_h \n", " else:\n", " down_exponent = 2 ** N_layers\n", " latent_size_out_h = out_size_h - (out_size_h % down_exponent) + down_exponent\n", "\n", " if out_size_w == 1:\n", " latent_size_out_w = latent_size_w\n", " else:\n", " if not expand_input:\n", " latent_size_out_w = out_size_w \n", " else:\n", " down_exponent = 2 ** N_layers\n", " latent_size_out_w = out_size_w - (out_size_w % down_exponent) + down_exponent\n", " \n", " self.encoder_sizes_h = []\n", " self.encoder_sizes_w = []\n", " self.decoder_sizes_h = []\n", " self.decoder_sizes_w = []\n", " for i in range(self.N_layers + 1):\n", " self.encoder_sizes_h.append(latent_size_h // (2 ** i))\n", " self.encoder_sizes_w.append(latent_size_w // (2 ** i))\n", " self.decoder_sizes_h.append(latent_size_out_h // 2 ** (self.N_layers - i))\n", " self.decoder_sizes_w.append(latent_size_out_w // 2 ** (self.N_layers - i))\n", " \n", " \n", " self.lift = LiftProjectBlock(in_channels=in_dim,\n", " out_channels=self.encoder_features[0],\n", " in_size_h=in_size_h,\n", " in_size_w=in_size_w,\n", " out_size_h=self.encoder_sizes_h[0],\n", " out_size_w=self.encoder_sizes_w[0],\n", " latent_dim=latent_lift_proj_dim,\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=False,\n", " activation=activation)\n", " _out_size_h = out_size_h\n", " _out_size_w = out_size_w\n", " if out_size_h == 1:\n", " _out_size_h = in_size_h\n", " if out_size_w == 1:\n", " _out_size_w = in_size_w\n", " \n", " self.project = LiftProjectBlock(in_channels=self.encoder_features[0] + self.decoder_features_out[-1],\n", " out_channels=out_dim,\n", " in_size_h=self.decoder_sizes_h[-1],\n", " in_size_w=self.decoder_sizes_w[-1],\n", " out_size_h=_out_size_h,\n", " out_size_w=_out_size_w,\n", " latent_dim=latent_lift_proj_dim,\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=False,\n", " activation=activation)\n", "\n", "\n", " self.encoder = nn.ModuleList([\n", " CNOBlock(\n", " in_channels=self.encoder_features[i],\n", " out_channels=self.encoder_features[i + 1],\n", " in_size_h=self.encoder_sizes_h[i],\n", " in_size_w=self.encoder_sizes_w[i],\n", " out_size_h=self.encoder_sizes_h[i + 1],\n", " out_size_w=self.encoder_sizes_w[i + 1],\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=batch_norm,\n", " activation=activation\n", " )\n", " for i in range(self.N_layers)\n", " ])\n", " \n", "\n", " self.ED_expansion = nn.ModuleList([\n", " CNOBlock(\n", " in_channels=self.encoder_features[i],\n", " out_channels=self.encoder_features[i],\n", " in_size_h=self.encoder_sizes_h[i],\n", " in_size_w=self.encoder_sizes_w[i],\n", " out_size_h=self.decoder_sizes_h[self.N_layers - i],\n", " out_size_w=self.decoder_sizes_w[self.N_layers - i],\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=batch_norm,\n", " activation=activation\n", " )\n", " for i in range(self.N_layers + 1)\n", " ])\n", " \n", " self.decoder = nn.ModuleList([\n", " CNOBlock(\n", " in_channels=self.decoder_features_in[i],\n", " out_channels=self.decoder_features_out[i],\n", " in_size_h=self.decoder_sizes_h[i],\n", " in_size_w=self.decoder_sizes_w[i],\n", " out_size_h=self.decoder_sizes_h[i + 1],\n", " out_size_w=self.decoder_sizes_w[i + 1],\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=batch_norm,\n", " activation=activation\n", " )\n", " for i in range(self.N_layers)\n", " ])\n", " \n", " self.decoder_inv = nn.ModuleList([\n", " CNOBlock(\n", " in_channels=self.inv_features[i],\n", " out_channels=self.inv_features[i],\n", " in_size_h=self.decoder_sizes_h[i],\n", " in_size_w=self.decoder_sizes_w[i],\n", " out_size_h=self.decoder_sizes_h[i],\n", " out_size_w=self.decoder_sizes_w[i],\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=batch_norm,\n", " activation=activation\n", " )\n", " for i in range(self.N_layers + 1)\n", " ])\n", " \n", "\n", " self.res_nets = []\n", " self.N_res = int(N_res)\n", " self.N_res_neck = int(N_res_neck)\n", "\n", " for l in range(self.N_layers):\n", " for i in range(self.N_res):\n", " self.res_nets.append(\n", " ResidualBlock(\n", " channels=self.encoder_features[l],\n", " size_h=self.encoder_sizes_h[l],\n", " size_w=self.encoder_sizes_w[l],\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=batch_norm,\n", " activation=activation\n", " )\n", " )\n", " for i in range(self.N_res_neck):\n", " self.res_nets.append(\n", " ResidualBlock(\n", " channels=self.encoder_features[self.N_layers],\n", " size_h=self.encoder_sizes_h[self.N_layers],\n", " size_w=self.encoder_sizes_w[self.N_layers],\n", " cutoff_den=cutoff_den,\n", " conv_kernel=conv_kernel,\n", " filter_size=filter_size,\n", " lrelu_upsampling=lrelu_upsampling,\n", " half_width_mult=half_width_mult,\n", " radial=radial,\n", " batch_norm=batch_norm,\n", " activation=activation\n", " )\n", " )\n", " \n", " self.res_nets = torch.nn.Sequential(*self.res_nets)\n", "\n", " def forward(self, x):\n", " b, t, c, h, w = x.shape\n", " x = x.reshape(b * t, c, h, w)\n", " x = self.lift(x)\n", " skip = []\n", " \n", " res_nets_idx = 0 \n", " for i in range(self.N_layers):\n", " \n", " y = x\n", " for j in range(self.N_res):\n", " y = self.res_nets[res_nets_idx](y)\n", " res_nets_idx += 1\n", " skip.append(y)\n", " \n", " x = self.encoder[i](x) \n", " \n", " #----------------------------------------------------------------------\n", " \n", " for j in range(self.N_res_neck):\n", " x = self.res_nets[res_nets_idx](x)\n", " res_nets_idx += 1\n", "\n", " for i in range(self.N_layers):\n", " \n", " if i == 0:\n", " x = self.ED_expansion[self.N_layers - i](x) \n", " else:\n", " x = torch.cat((x, self.ED_expansion[self.N_layers - i](skip[-i])), 1)\n", " \n", " if self.add_inv:\n", " x = self.decoder_inv[i](x)\n", " x = self.decoder[i](x)\n", " \n", " x = torch.cat((x, self.ED_expansion[0](skip[0])), 1)\n", " x = self.project(x)\n", " x = x.reshape(b, t, -1, x.shape[-2], x.shape[-1])\n", " \n", " del skip\n", " del y\n", " \n", " return x\n", "\n", " def get_n_params(self):\n", " pp = 0\n", " \n", " for p in list(self.parameters()):\n", " nn = 1\n", " for s in list(p.size()):\n", " nn = nn * s\n", " pp += nn\n", " return pp\n", "\n", " def print_size(self):\n", " nparams = 0\n", " nbytes = 0\n", "\n", " for param in self.parameters():\n", " nparams += param.numel()\n", " nbytes += param.data.element_size() * param.numel()\n", "\n", " print(f'{nparams} (~{nbytes / 1e6:.2f} MB)')\n", "\n", " return nparams\n", "\n", "if __name__ == '__main__':\n", "\n", " in_dim = 1\n", " in_size_h = 128 \n", " in_size_w = 128 \n", " N_layers = 4\n", " \n", " model = CNO(\n", " in_dim=in_dim,\n", " in_size_h=in_size_h,\n", " in_size_w=in_size_w,\n", " N_layers=N_layers,\n", " N_res=1,\n", " N_res_neck=6,\n", " channel_multiplier=32,\n", " conv_kernel=3,\n", " cutoff_den=2.0001,\n", " filter_size=6,\n", " lrelu_upsampling=2,\n", " half_width_mult=0.8,\n", " radial=False,\n", " batch_norm=True,\n", " out_dim=1,\n", " out_size_h=1,\n", " out_size_w=1,\n", " expand_input=False,\n", " latent_lift_proj_dim=64,\n", " add_inv=True,\n", " activation='cno_lrelu'\n", " )\n", " \n", " batch_size = 1\n", " time_steps = 1\n", " channels = in_dim\n", " height = in_size_h\n", " width = in_size_w\n", " \n", " x = torch.randn(batch_size, time_steps, channels, height, width)\n", " print(x.shape)\n", " output = model(x)\n", " \n", " print(f\"Output shape: {output.shape}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "7dbd36a2-4572-499b-92c1-d3f1e83437d9", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "e24f9871-d0e5-4bd8-8c6e-728d3b2aa259", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "be9d2b68-efc1-4a82-a3e9-5604b4d9abff", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.16" } }, "nbformat": 4, "nbformat_minor": 5 }