{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#creating a simple sample of points\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "import math\n", "import tqdm\n", "import torch\n", "from torch import nn\n", "from matplotlib.colors import ListedColormap\n", "\n", "N = 1000 #number of points to sample\n", "x_min, x_max = -4, 4\n", "y_min, y_max = -4, 4\n", "resolution = 100 #resolution of the grid\n", "\n", "x = np.linspace(x_min, x_max, resolution)\n", "y = np.linspace(y_min, y_max, resolution)\n", "X, Y = np.meshgrid(x, y)\n", "\n", "length = 4\n", "checkerboard = np.indices((length, length)).sum(axis=0) % 2\n", "\n", "sampled_points = []\n", "while len(sampled_points) < N:\n", " x_sample = np.random.uniform(x_min, x_max)\n", " y_sample = np.random.uniform(y_min, y_max)\n", "\n", " i = int((x_sample - x_min) / (x_max - x_min) * length)\n", " j = int((y_sample - y_min) / (y_max - y_min) * length)\n", "\n", " if checkerboard[j, i] == 1:\n", " sampled_points.append((x_sample, y_sample))\n", "sampled_points = np.array(sampled_points) #sampled points is our x1" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "t = 0.5\n", "noise = np.random.randn(N, 2)\n", "plt.figure(figsize=(6, 6))\n", "plt.scatter(sampled_points[:, 0], sampled_points[:, 1], color=\"red\", marker=\"o\")\n", "plt.scatter(noise[:, 0], noise[:, 1], color=\"blue\", marker=\"o\")\n", "plt.scatter((1 - t) * noise[:, 0] + t * sampled_points[:, 0], (1 - t) * noise[:, 1] + t * sampled_points[:, 1], color=\"green\", marker=\"o\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Model\n", "class Block(nn.Module):\n", " def __init__(self, channels=512):\n", " super().__init__()\n", " self.ff = nn.Linear(channels, channels)\n", " self.act = nn.ReLU()\n", "\n", " def forward(self, x):\n", " return self.act(self.ff(x))\n", "\n", "class MLP(nn.Module):\n", " def __init__(self, channels_data=2, layers=5, channels=512, channels_t=512):\n", " super().__init__()\n", " self.channels_t = channels_t\n", " self.in_projection = nn.Linear(channels_data, channels)\n", " self.t_projection = nn.Linear(channels_t, channels)\n", " self.blocks = nn.Sequential(*[\n", " Block(channels) for _ in range(layers)\n", " ])\n", " self.out_projection = nn.Linear(channels, channels_data)\n", "\n", " def gen_t_embedding(self, t, max_positions=10000):\n", " t = t * max_positions\n", " half_dim = self.channels_t // 2\n", " emb = math.log(max_positions) / (half_dim - 1)\n", " emb = torch.arange(half_dim, device=t.device).float().mul(-emb).exp()\n", " emb = t[:, None] * emb[None, :]\n", " emb = torch.cat([emb.sin(), emb.cos()], dim=1)\n", " if self.channels_t % 2 == 1: # zero pad\n", " emb = nn.functional.pad(emb, (0, 1), mode='constant')\n", " return emb\n", "\n", " def forward(self, x, t):\n", " x = self.in_projection(x)\n", " t = self.gen_t_embedding(t)\n", " t = self.t_projection(t)\n", " x = x + t \n", " x = self.blocks(x)\n", " x = self.out_projection(x)\n", " return x" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "model = MLP(layers=5, channels=512)\n", "optim = torch.optim.AdamW(model.parameters(), lr=1e-4)\n", "\n", "data = torch.Tensor(sampled_points)\n", "training_steps = 100_000\n", "batch_size = 64\n", "pbar = tqdm.tqdm(range(training_steps))\n", "losses = []\n", "for i in pbar:\n", " x1 = data[torch.randint(data.size(0), (batch_size,))]\n", " x0 = torch.randn_like(x1)\n", " target = x1 - x0\n", " t = torch.rand(x1.size(0))\n", " xt = (1 - t[:, None]) * x0 + t[:, None] * x1\n", " pred = model(xt, t) # also add t here\n", " loss = ((target - pred)**2).mean()\n", " loss.backward()\n", " optim.step()\n", " optim.zero_grad()\n", " pbar.set_postfix(loss=loss.item())\n", " losses.append(loss.item())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "#Sampling\n", "torch.manual_seed(42)\n", "model.eval().requires_grad_(False)\n", "### from here\n", "xt = torch.randn(1000, 2)\n", "steps = 1000\n", "plot_every = 100\n", "for i, t in enumerate(torch.linspace(0, 1, steps), start=1):\n", " pred = model(xt, t.expand(xt.size(0)))\n", " xt = xt + (1 / steps) * pred\n", "## to here, this is the sampling logic, and it in this case its moving random noise points into an organized checkerboard\n", "##BUT, this sampling is literally applied anywhere from images to videos, because the goal is to move each noise sample to the specific location and modification\n", " if i % plot_every == 0:\n", " plt.figure(figsize=(6, 6))\n", " plt.scatter(sampled_points[:, 0], sampled_points[:, 1], color=\"red\", marker=\"o\")\n", " plt.scatter(xt[:, 0], xt[:, 1], color=\"green\", marker=\"o\")\n", " plt.show()\n", "model.train().requires_grad_(True)" ] } ], "metadata": { "language_info": { "name": "python" }, "orig_nbformat": 4 }, "nbformat": 4, "nbformat_minor": 2 }