{ "cells": [ { "cell_type": "markdown", "id": "c18fdd15", "metadata": {}, "source": [ "# DriveTiTok" ] }, { "cell_type": "markdown", "id": "99fd1987", "metadata": {}, "source": [ "# Setup\n", "\n", "```\n", "uv sync --all-extras\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "44bca1c4", "metadata": {}, "outputs": [], "source": [ "from collections import OrderedDict\n", "from pathlib import Path\n", "from typing import Optional\n", "import math\n", "\n", "import einops\n", "import matplotlib.pyplot as plt\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from PIL import Image\n", "from torchvision import transforms" ] }, { "cell_type": "code", "execution_count": null, "id": "7e411458", "metadata": {}, "outputs": [], "source": [ "class Conv2dSame(nn.Conv2d):\n", " def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:\n", " return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)\n", "\n", " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", " ih, iw = x.size()[-2:]\n", " pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])\n", " pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])\n", " if pad_h > 0 or pad_w > 0:\n", " x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])\n", " return super().forward(x)\n", "\n", "\n", "class ResnetBlock(nn.Module):\n", " def __init__(self, in_channels: int, out_channels: int = None, dropout_prob: float = 0.0):\n", " super().__init__()\n", " self.in_channels = in_channels\n", " self.out_channels = out_channels\n", " self.out_channels_ = self.in_channels if self.out_channels is None else self.out_channels\n", "\n", " self.norm1 = nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)\n", " self.conv1 = Conv2dSame(self.in_channels, self.out_channels_, kernel_size=3, bias=False)\n", "\n", " self.norm2 = nn.GroupNorm(num_groups=32, num_channels=self.out_channels_, eps=1e-6, affine=True)\n", " self.dropout = nn.Dropout(dropout_prob)\n", " self.conv2 = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=3, bias=False)\n", "\n", " if self.in_channels != self.out_channels_:\n", " self.nin_shortcut = Conv2dSame(self.out_channels_, self.out_channels_, kernel_size=1, bias=False)\n", "\n", " def forward(self, hidden_states):\n", " residual = hidden_states\n", " hidden_states = self.norm1(hidden_states)\n", " hidden_states = F.silu(hidden_states)\n", " hidden_states = self.conv1(hidden_states)\n", "\n", " hidden_states = self.norm2(hidden_states)\n", " hidden_states = F.silu(hidden_states)\n", " hidden_states = self.dropout(hidden_states)\n", " hidden_states = self.conv2(hidden_states)\n", "\n", " if self.in_channels != self.out_channels_:\n", " residual = self.nin_shortcut(hidden_states)\n", "\n", " return hidden_states + residual\n", "\n", "\n", "class UpsamplingBlock(nn.Module):\n", " def __init__(self, config, block_idx: int):\n", " super().__init__()\n", " self.config = config\n", " self.block_idx = block_idx\n", "\n", " if self.block_idx == self.config['num_resolutions'] - 1:\n", " block_in = self.config['hidden_channels'] * self.config['channel_mult'][-1]\n", " else:\n", " block_in = self.config['hidden_channels'] * self.config['channel_mult'][self.block_idx + 1]\n", "\n", " block_out = self.config['hidden_channels'] * self.config['channel_mult'][self.block_idx]\n", "\n", " res_blocks = []\n", " for _ in range(self.config['num_res_blocks']):\n", " res_blocks.append(ResnetBlock(block_in, block_out, dropout_prob=self.config['dropout']))\n", " block_in = block_out\n", " self.block = nn.ModuleList(res_blocks)\n", "\n", " self.add_upsample = self.block_idx != 0\n", " if self.add_upsample:\n", " self.upsample_conv = Conv2dSame(block_out, block_out, kernel_size=3)\n", "\n", " def forward(self, hidden_states):\n", " for res_block in self.block:\n", " hidden_states = res_block(hidden_states)\n", "\n", " if self.add_upsample:\n", " hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode='nearest')\n", " hidden_states = self.upsample_conv(hidden_states)\n", "\n", " return hidden_states\n", "\n", "\n", "class PixelDecoder(nn.Module):\n", " def __init__(self, config):\n", " super().__init__()\n", " self.config = config\n", "\n", " block_in = self.config['hidden_channels'] * self.config['channel_mult'][self.config['num_resolutions'] - 1]\n", " self.conv_in = Conv2dSame(self.config['z_channels'], block_in, kernel_size=3)\n", "\n", " res_blocks = nn.ModuleList()\n", " for _ in range(self.config['num_res_blocks']):\n", " res_blocks.append(ResnetBlock(block_in, block_in, dropout_prob=self.config['dropout']))\n", " self.mid = res_blocks\n", "\n", " upsample_blocks = []\n", " for i_level in reversed(range(self.config['num_resolutions'])):\n", " upsample_blocks.append(UpsamplingBlock(self.config, block_idx=i_level))\n", " self.up = nn.ModuleList(list(reversed(upsample_blocks)))\n", "\n", " block_out = self.config['hidden_channels'] * self.config['channel_mult'][0]\n", " self.norm_out = nn.GroupNorm(num_groups=32, num_channels=block_out, eps=1e-6, affine=True)\n", " self.conv_out = Conv2dSame(block_out, self.config['num_channels'], kernel_size=3)\n", "\n", " def forward(self, hidden_states):\n", " hidden_states = self.conv_in(hidden_states)\n", " for block in self.mid:\n", " hidden_states = block(hidden_states)\n", " for block in reversed(self.up):\n", " hidden_states = block(hidden_states)\n", " hidden_states = self.norm_out(hidden_states)\n", " hidden_states = F.silu(hidden_states)\n", " hidden_states = self.conv_out(hidden_states)\n", " return hidden_states\n", "\n", "\n", "class PixelQuantizer(nn.Module):\n", " def __init__(self, num_embeddings, embedding_dim, commitment_cost):\n", " super().__init__()\n", " self.num_embeddings = num_embeddings\n", " self.embedding_dim = embedding_dim\n", " self.commitment_cost = commitment_cost\n", " self.embedding = nn.Embedding(num_embeddings, embedding_dim)\n", " self.embedding.weight.data.uniform_(-1.0 / num_embeddings, 1.0 / num_embeddings)\n", "\n", "\n", "class VectorQuantizer(nn.Module):\n", " def __init__(self, codebook_size=1024, token_size=256, commitment_cost=0.25, use_l2_norm=False):\n", " super().__init__()\n", " self.commitment_cost = commitment_cost\n", " self.embedding = nn.Embedding(codebook_size, token_size)\n", " self.embedding.weight.data.uniform_(-1.0 / codebook_size, 1.0 / codebook_size)\n", " self.use_l2_norm = use_l2_norm\n", "\n", " @torch.cuda.amp.autocast(enabled=False)\n", " def forward(self, z: torch.Tensor):\n", " z = z.float()\n", " z = einops.rearrange(z, 'b c h w -> b h w c').contiguous()\n", " z_flattened = einops.rearrange(z, 'b h w c -> (b h w) c')\n", "\n", " if self.use_l2_norm:\n", " z_flattened = torch.nn.functional.normalize(z_flattened, dim=-1)\n", " embedding = torch.nn.functional.normalize(self.embedding.weight, dim=-1)\n", " else:\n", " embedding = self.embedding.weight\n", "\n", " d = torch.sum(z_flattened**2, dim=1, keepdim=True) + torch.sum(embedding**2, dim=1) - 2 * torch.einsum('bd,dn->bn', z_flattened, embedding.T)\n", " min_encoding_indices = torch.argmin(d, dim=1)\n", " z_quantized = self.get_codebook_entry(min_encoding_indices).view(z.shape)\n", "\n", " if self.use_l2_norm:\n", " z = torch.nn.functional.normalize(z, dim=-1)\n", "\n", " commitment_loss = self.commitment_cost * torch.mean((z_quantized.detach() - z) ** 2)\n", " codebook_loss = torch.mean((z_quantized - z.detach()) ** 2)\n", " loss = commitment_loss + codebook_loss\n", "\n", " z_quantized = z + (z_quantized - z).detach()\n", " z_quantized = einops.rearrange(z_quantized, 'b h w c -> b c h w').contiguous()\n", "\n", " result_dict = {\n", " 'quantizer_loss': loss,\n", " 'commitment_loss': commitment_loss,\n", " 'codebook_loss': codebook_loss,\n", " 'min_encoding_indices': min_encoding_indices.view(z_quantized.shape[0], z_quantized.shape[2], z_quantized.shape[3]),\n", " }\n", " return z_quantized, result_dict\n", "\n", " def get_codebook_entry(self, indices):\n", " if len(indices.shape) == 1:\n", " z_quantized = self.embedding(indices)\n", " elif len(indices.shape) == 2:\n", " z_quantized = torch.einsum('bd,dn->bn', indices, self.embedding.weight)\n", " else:\n", " raise NotImplementedError\n", " if self.use_l2_norm:\n", " z_quantized = torch.nn.functional.normalize(z_quantized, dim=-1)\n", " return z_quantized\n", "\n", "\n", "class ResidualAttentionBlock(nn.Module):\n", " def __init__(self, d_model, n_head, mlp_ratio=4.0, act_layer=nn.GELU, norm_layer=nn.LayerNorm):\n", " super().__init__()\n", " self.ln_1 = norm_layer(d_model)\n", " self.attn = nn.MultiheadAttention(d_model, n_head)\n", " self.mlp_ratio = mlp_ratio\n", " if mlp_ratio > 0:\n", " self.ln_2 = norm_layer(d_model)\n", " mlp_width = int(d_model * mlp_ratio)\n", " self.mlp = nn.Sequential(OrderedDict([\n", " ('c_fc', nn.Linear(d_model, mlp_width)),\n", " ('gelu', act_layer()),\n", " ('c_proj', nn.Linear(mlp_width, d_model)),\n", " ]))\n", "\n", " def attention(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):\n", " return self.attn(x, x, x, attn_mask=attention_mask, need_weights=False)[0]\n", "\n", " def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):\n", " x = x + self.attention(x=self.ln_1(x), attention_mask=attention_mask)\n", " if self.mlp_ratio > 0:\n", " x = x + self.mlp(self.ln_2(x))\n", " return x\n", "\n", "\n", "def _expand_token(token, batch_size: int):\n", " return token.unsqueeze(0).expand(batch_size, -1, -1)\n", "\n", "\n", "class TiTokEncoder(nn.Module):\n", " def __init__(self, config):\n", " super().__init__()\n", " self.config = config\n", " self.image_size = config['dataset']['preprocessing']['crop_size']\n", " self.patch_size = config['model']['vq_model']['vit_enc_patch_size']\n", " self.grid_size = self.image_size // self.patch_size\n", " self.model_size = config['model']['vq_model']['vit_enc_model_size']\n", " self.num_latent_tokens = config['model']['vq_model']['num_latent_tokens']\n", " self.token_size = config['model']['vq_model']['token_size']\n", "\n", " self.width = {'small': 512, 'base': 768, 'large': 1024}[self.model_size]\n", " self.num_layers = {'small': 8, 'base': 12, 'large': 24}[self.model_size]\n", " self.num_heads = {'small': 8, 'base': 12, 'large': 16}[self.model_size]\n", "\n", " self.patch_embed = nn.Conv2d(3, self.width, kernel_size=self.patch_size, stride=self.patch_size, bias=True)\n", "\n", " scale = self.width ** -0.5\n", " self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))\n", " self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size ** 2 + 1, self.width))\n", " self.latent_token_positional_embedding = nn.Parameter(scale * torch.randn(self.num_latent_tokens, self.width))\n", "\n", " self.ln_pre = nn.LayerNorm(self.width)\n", " self.transformer = nn.ModuleList([ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) for _ in range(self.num_layers)])\n", " self.ln_post = nn.LayerNorm(self.width)\n", " self.conv_out = nn.Conv2d(self.width, self.token_size, kernel_size=1, bias=True)\n", "\n", " def encode_patches(self, pixel_values):\n", " x = self.patch_embed(pixel_values)\n", " x = x.reshape(x.shape[0], x.shape[1], -1)\n", " x = x.permute(0, 2, 1)\n", " return x\n", "\n", " def forward(self, pixel_values, latent_tokens, needs_width_reduction=True):\n", " batch_size = pixel_values.shape[0]\n", " x = self.patch_embed(pixel_values)\n", " x = x.reshape(x.shape[0], x.shape[1], -1)\n", " x = x.permute(0, 2, 1)\n", "\n", " x = torch.cat([_expand_token(self.class_embedding, x.shape[0]).to(x.dtype), x], dim=1)\n", " x = x + self.positional_embedding.to(x.dtype)\n", "\n", " latent_tokens = _expand_token(latent_tokens, x.shape[0]).to(x.dtype)\n", " latent_tokens = latent_tokens + self.latent_token_positional_embedding.to(x.dtype)\n", " x = torch.cat([x, latent_tokens], dim=1)\n", "\n", " x = self.ln_pre(x)\n", " x = x.permute(1, 0, 2)\n", " for i in range(self.num_layers):\n", " x = self.transformer[i](x)\n", " x = x.permute(1, 0, 2)\n", "\n", " latent_tokens = x[:, 1 + self.grid_size ** 2:]\n", " latent_tokens = self.ln_post(latent_tokens)\n", " if needs_width_reduction:\n", " latent_tokens = latent_tokens.reshape(batch_size, self.width, self.num_latent_tokens, 1)\n", " latent_tokens = self.conv_out(latent_tokens)\n", " latent_tokens = latent_tokens.reshape(batch_size, self.token_size, 1, self.num_latent_tokens)\n", " return latent_tokens\n", "\n", "\n", "class TiTokDecoder(nn.Module):\n", " def __init__(self, config):\n", " super().__init__()\n", " self.config = config\n", " vq = config['model']['vq_model']\n", " self.strict_length_assertion = vq.get('strict_length_assertion', True)\n", "\n", " self.image_size = config['dataset']['preprocessing']['crop_size']\n", " self.patch_size = vq['vit_dec_patch_size']\n", " self.grid_size = self.image_size // self.patch_size\n", " self.model_size = vq['vit_dec_model_size']\n", " self.num_latent_tokens = vq['num_latent_tokens']\n", " self.token_size = vq['token_size']\n", "\n", " self.width = {'small': 512, 'base': 768, 'large': 1024}[self.model_size]\n", " self.num_layers = {'small': 8, 'base': 12, 'large': 24}[self.model_size]\n", " self.num_heads = {'small': 8, 'base': 12, 'large': 16}[self.model_size]\n", "\n", " self.decoder_embed = nn.Linear(self.token_size, self.width, bias=True)\n", " scale = self.width ** -0.5\n", " self.class_embedding = nn.Parameter(scale * torch.randn(1, self.width))\n", " self.positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size ** 2 + 1, self.width))\n", " self.mask_token = nn.Parameter(scale * torch.randn(1, 1, self.width))\n", " self.mask_token.requires_grad_(False)\n", " self.latent_token_positional_embedding = nn.Parameter(scale * torch.randn(self.num_latent_tokens, self.width))\n", "\n", " enc_width = {'small': 512, 'base': 768, 'large': 1024}[vq['vit_enc_model_size']]\n", " self.context_patch_proj = nn.Linear(enc_width, self.width, bias=False)\n", " self.context_patch_positional_embedding = nn.Parameter(scale * torch.randn(self.grid_size ** 2, self.width))\n", "\n", " self.ln_pre = nn.LayerNorm(self.width)\n", " self.transformer = nn.ModuleList([ResidualAttentionBlock(self.width, self.num_heads, mlp_ratio=4.0) for _ in range(self.num_layers)])\n", " self.ln_post = nn.LayerNorm(self.width)\n", "\n", " self.ffn = nn.Sequential(\n", " nn.Conv2d(self.width, 2 * self.width, 1, padding=0, bias=True),\n", " nn.Tanh(),\n", " nn.Conv2d(2 * self.width, 1024, 1, padding=0, bias=True),\n", " )\n", " self.conv_out = nn.Identity()\n", "\n", " def forward(self, z_quantized, context_patches):\n", " if context_patches is None:\n", " raise ValueError('context_patches is required for this model')\n", "\n", " N, C, H, W = z_quantized.shape\n", " if self.strict_length_assertion:\n", " assert H == 1 and W == self.num_latent_tokens\n", " else:\n", " assert H == 1 and W <= self.num_latent_tokens\n", "\n", " x = z_quantized.reshape(N, C * H, W).permute(0, 2, 1)\n", " x = self.decoder_embed(x)\n", " batchsize, seq_len, _ = x.shape\n", "\n", " context_patch_tokens_raw = self.context_patch_proj(context_patches)\n", " mask_tokens = context_patch_tokens_raw.to(x.dtype)\n", " mask_tokens = torch.cat([\n", " _expand_token(self.class_embedding, mask_tokens.shape[0]).to(mask_tokens.dtype),\n", " mask_tokens,\n", " ], dim=1)\n", " mask_tokens = mask_tokens + self.positional_embedding.to(mask_tokens.dtype)\n", "\n", " x = x + self.latent_token_positional_embedding[:seq_len]\n", " x = torch.cat([mask_tokens, x], dim=1)\n", "\n", " x = self.ln_pre(x)\n", " x = x.permute(1, 0, 2)\n", " for i in range(self.num_layers):\n", " x = self.transformer[i](x)\n", " x = x.permute(1, 0, 2)\n", "\n", " x = x[:, 1:1 + self.grid_size ** 2]\n", " x = self.ln_post(x)\n", " x = x.permute(0, 2, 1).reshape(batchsize, self.width, self.grid_size, self.grid_size)\n", " x = self.ffn(x.contiguous())\n", " x = self.conv_out(x)\n", " return x\n", "\n", "\n", "class TiTok(nn.Module):\n", " def __init__(self, config):\n", " super().__init__()\n", " self.config = config\n", " vq = config['model']['vq_model']\n", "\n", " self.finetune_decoder = vq.get('finetune_decoder', True)\n", " self.no_freeze_encoder = vq.get('no_freeze_encoder', False)\n", "\n", " self.encoder = TiTokEncoder(config)\n", " self.decoder = TiTokDecoder(config)\n", "\n", " self.num_latent_tokens = vq['num_latent_tokens']\n", " scale = self.encoder.width ** -0.5\n", " self.latent_tokens = nn.Parameter(scale * torch.randn(self.num_latent_tokens, self.encoder.width))\n", "\n", " self.quantize = VectorQuantizer(\n", " codebook_size=vq['codebook_size'],\n", " token_size=vq['token_size'],\n", " commitment_cost=vq['commitment_cost'],\n", " use_l2_norm=vq['use_l2_norm'],\n", " )\n", "\n", " if self.finetune_decoder and not self.no_freeze_encoder:\n", " self.latent_tokens.requires_grad_(False)\n", " self.encoder.eval()\n", " self.encoder.requires_grad_(False)\n", " self.quantize.eval()\n", " self.quantize.requires_grad_(False)\n", "\n", " if self.finetune_decoder:\n", " self.pixel_quantize = PixelQuantizer(num_embeddings=1024, embedding_dim=256, commitment_cost=0.25)\n", " self.pixel_decoder = PixelDecoder({\n", " 'channel_mult': [1, 1, 2, 2, 4],\n", " 'num_resolutions': 5,\n", " 'dropout': 0.0,\n", " 'hidden_channels': 128,\n", " 'num_channels': 3,\n", " 'num_res_blocks': 2,\n", " 'resolution': 256,\n", " 'z_channels': 256,\n", " })\n", "\n", " def encode(self, x):\n", " if self.finetune_decoder and not self.no_freeze_encoder:\n", " with torch.no_grad():\n", " self.encoder.eval()\n", " self.quantize.eval()\n", " z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens)\n", " z_quantized, result_dict = self.quantize(z)\n", " result_dict['quantizer_loss'] *= 0\n", " result_dict['commitment_loss'] *= 0\n", " result_dict['codebook_loss'] *= 0\n", " else:\n", " z = self.encoder(pixel_values=x, latent_tokens=self.latent_tokens)\n", " z_quantized, result_dict = self.quantize(z)\n", " return z_quantized, result_dict\n", "\n", " def decode(self, z_quantized, length=None, context_patches=None):\n", " if length is not None:\n", " z_quantized = z_quantized[:, :, :, :length]\n", "\n", " decoded = self.decoder(z_quantized, context_patches=context_patches)\n", "\n", " if self.finetune_decoder:\n", " quantized_states = torch.einsum('nchw,cd->ndhw', decoded.softmax(1), self.pixel_quantize.embedding.weight)\n", " decoded = self.pixel_decoder(quantized_states)\n", " return decoded\n", "\n", " def forward(self, x, length=None, context_pixel_values=None):\n", " if context_pixel_values is None:\n", " raise ValueError('context_pixel_values is required for this model')\n", "\n", " z_quantized, result_dict = self.encode(x)\n", " context_patches = self.encoder.encode_patches(context_pixel_values)\n", "\n", " decoded = self.decode(z_quantized, length=length, context_patches=context_patches)\n", " return decoded, result_dict\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "id": "3437676d", "metadata": {}, "outputs": [], "source": [ "# Hard-coded config and checkpoint for titok_s640_jadd_prev_mask_query stage2\n", "CHECKPOINT_PATH = Path('drive_titok_s_640.bin')\n", "OUTPUT_IMAGE_PATH = Path('out/sample_recon.png')\n", "\n", "CONFIG = {\n", " 'dataset': {\n", " 'preprocessing': {\n", " 'crop_size': 256,\n", " }\n", " },\n", " 'model': {\n", " 'vq_model': {\n", " 'codebook_size': 4096,\n", " 'token_size': 12,\n", " 'use_l2_norm': True,\n", " 'commitment_cost': 0.25,\n", " 'vit_enc_model_size': 'small',\n", " 'vit_dec_model_size': 'small',\n", " 'vit_enc_patch_size': 16,\n", " 'vit_dec_patch_size': 16,\n", " 'num_latent_tokens': 640,\n", " 'finetune_decoder': True,\n", " 'strict_length_assertion': False,\n", " }\n", " }\n", "}\n", "\n", "if not CHECKPOINT_PATH.exists():\n", " raise FileNotFoundError(f'checkpoint not found: {CHECKPOINT_PATH}')\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "print(f'device: {device}')\n", "print(f'checkpoint: {CHECKPOINT_PATH}')\n" ] }, { "cell_type": "code", "execution_count": null, "id": "4ee9c8b9", "metadata": {}, "outputs": [], "source": [ "state_dict = torch.load(CHECKPOINT_PATH, map_location='cpu')\n", "model = TiTok(CONFIG)\n", "missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=True)\n", "print('missing keys:', missing_keys)\n", "print('unexpected keys:', unexpected_keys)\n", "\n", "model = model.to(device).eval()\n", "model.requires_grad_(False)\n" ] }, { "cell_type": "code", "execution_count": null, "id": "15e97777", "metadata": {}, "outputs": [], "source": [ "# use two images: prev frame and current frame\n", "PREV_FRAME_PATH = Path('assets/sample_prev.png')\n", "CURR_FRAME_PATH = Path('assets/sample_cur.png')\n", "\n", "for p in [PREV_FRAME_PATH, CURR_FRAME_PATH]:\n", " if not p.exists():\n", " raise FileNotFoundError(f'image not found: {p}')\n", "\n", "transform = transforms.Compose([\n", " transforms.Lambda(lambda x: x.convert('RGB')),\n", " transforms.Resize(256),\n", " transforms.CenterCrop(256),\n", " transforms.ToTensor(),\n", "])\n", "\n", "prev_frame = transform(Image.open(PREV_FRAME_PATH)).unsqueeze(0).to(device)\n", "curr_frame = transform(Image.open(CURR_FRAME_PATH)).unsqueeze(0).to(device)\n", "print(prev_frame.shape, curr_frame.shape)\n", "\n", "with torch.no_grad():\n", " reconstructed, _ = model(curr_frame, context_pixel_values=prev_frame)\n", " reconstructed = reconstructed.clamp(0, 1)\n", "\n", "prev_np = (prev_frame[0].permute(1, 2, 0).cpu().numpy() * 255).astype('uint8')\n", "curr_np = (curr_frame[0].permute(1, 2, 0).cpu().numpy() * 255).astype('uint8')\n", "recon_np = (reconstructed[0].permute(1, 2, 0).cpu().numpy() * 255).astype('uint8')\n", "\n", "OUTPUT_IMAGE_PATH.parent.mkdir(parents=True, exist_ok=True)\n", "Image.fromarray(recon_np).save(OUTPUT_IMAGE_PATH)\n", "print(f'saved: {OUTPUT_IMAGE_PATH}')\n" ] }, { "cell_type": "code", "execution_count": null, "id": "7f30ac0f", "metadata": {}, "outputs": [], "source": [ "fig, axes = plt.subplots(1, 3, figsize=(15, 5))\n", "axes[0].imshow(prev_np)\n", "axes[0].set_title('prev frame')\n", "axes[0].axis('off')\n", "\n", "axes[1].imshow(curr_np)\n", "axes[1].set_title('current frame')\n", "axes[1].axis('off')\n", "\n", "axes[2].imshow(recon_np)\n", "axes[2].set_title('reconstruction')\n", "axes[2].axis('off')\n", "\n", "plt.tight_layout()\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "f9d792b7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "drivetitok", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.9" } }, "nbformat": 4, "nbformat_minor": 5 }