{ "cells": [ { "cell_type": "code", "execution_count": 19, "id": "9074d4b4", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from torch.utils.data import Dataset, DataLoader\n", "import scipy.io\n", "from torchvision import transforms, utils\n", "import os\n", "from PIL import Image" ] }, { "cell_type": "code", "execution_count": 20, "id": "c85ff158", "metadata": {}, "outputs": [], "source": [ "class GravityDataset(Dataset):\n", " def __init__(self, folder, image_size, exts=['mat']):\n", " super().__init__()\n", " self.folder = folder\n", " self.image_size = image_size\n", " self.paths = [os.path.join(folder, f) for f in os.listdir(folder) if f.endswith('.mat')]\n", "\n", " # Define transformations that are independent of scaling\n", " self.transform = transforms.Compose([\n", " transforms.Resize((int(image_size * 1.12), int(image_size * 1.12))), # Resize slightly larger\n", " transforms.RandomCrop(image_size), # Then crop to the target size\n", " transforms.RandomHorizontalFlip(), # Random horizontal flip\n", " transforms.ToTensor() # Convert to tensor\n", " ])\n", "\n", " def scale_to_minus1_1(self, tensor):\n", " \"\"\"Dynamically scale the tensor to the range [-1, 1] based on its own min and max values.\"\"\"\n", " min_val = tensor.min()\n", " max_val = tensor.max()\n", " # Avoid division by zero if min and max are the same\n", " if max_val > min_val:\n", " return 2 * ((tensor - min_val) / (max_val - min_val)) - 1\n", " else:\n", " return tensor # If min and max are the same, return tensor as is.\n", " # tensor = tensor\n", " return tensor # If min and max are the same, return tensor as is.\n", "\n", " def __len__(self):\n", " return len(self.paths)\n", "\n", " def __getitem__(self, index):\n", " file_path = self.paths[index]\n", " # Load the .mat file\n", " data_loc = scipy.io.loadmat(file_path)\n", " data_val = data_loc['d']\n", "\n", " data_val = data_val.reshape(32, 32)\n", " # Convert numpy array as a PIL image\n", " img = Image.fromarray(data_val)\n", "\n", " # Apply transformations\n", " if self.transform:\n", " img = self.transform(img)\n", " \n", " # Scale the image to [-1, 1] based on its own min and max values\n", " img = self.scale_to_minus1_1(img)\n", " return img" ] }, { "cell_type": "code", "execution_count": 21, "id": "c16dc0d1", "metadata": {}, "outputs": [], "source": [ "data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n", "dataset = GravityDataset(data, image_size=32)\n", "dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 22, "id": "a8bd15c0", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Loading pipeline components...: 100%|██████████| 7/7 [00:00<00:00, 28.56it/s]\n" ] }, { "data": { "text/plain": [ "4" ] }, "execution_count": 22, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from diffusers import StableDiffusionPipeline\n", "\n", "pipeline = StableDiffusionPipeline.from_pretrained(\"stable-diffusion-v1-5/stable-diffusion-v1-5\", use_safetensors=True)\n", "pipeline.unet.config[\"in_channels\"]\n", "4" ] }, { "cell_type": "markdown", "id": "72b20338", "metadata": {}, "source": [ "**Training Configuration**" ] }, { "cell_type": "code", "execution_count": 23, "id": "e01c3fae", "metadata": {}, "outputs": [], "source": [ "from dataclasses import dataclass\n", "\n", "@dataclass\n", "class TrainingConfig:\n", " image_size = 32 # the generated image resolution\n", " train_batch_size = 32\n", " eval_batch_size = 16 # how many images to sample during evaluation\n", " num_epochs = 50\n", " gradient_accumulation_steps = 1\n", " learning_rate = 1e-4\n", " lr_warmup_steps = 500\n", " save_image_epochs = 10\n", " save_model_epochs = 30\n", " mixed_precision = \"fp16\" # `no` for float32, `fp16` for automatic mixed precision\n", " output_dir = \"ddpm-butterflies-128\" # the model name locally and on the HF Hub\n", "\n", " push_to_hub = True # whether to upload the saved model to the HF Hub\n", " hub_model_id = \"jainadarsh/trial\" # the name of the repository to create on the HF Hub\n", " hub_private_repo = None\n", " overwrite_output_dir = True # overwrite the old model when re-running the notebook\n", " seed = 0\n", "\n", "\n", "config = TrainingConfig()" ] }, { "cell_type": "markdown", "id": "9f52a27e", "metadata": {}, "source": [ "**Load Dataset**" ] }, { "cell_type": "code", "execution_count": 74, "id": "b4a98567", "metadata": {}, "outputs": [], "source": [ "# from datasets import load_dataset\n", "\n", "# config.dataset_name = \"huggan/smithsonian_butterflies_subset\"\n", "# dataset = load_dataset(config.dataset_name, split=\"train\")\n", "\n", "data = '/mnt/drive/adarsh/DC_cold3/Experiment-14/neg_int'\n", "dataset = GravityDataset(data, image_size=32)\n", "train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True, num_workers=4)" ] }, { "cell_type": "code", "execution_count": 75, "id": "e3aed193", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAABOwAAAEhCAYAAADMCz9IAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAISNJREFUeJzt2smSI7e1MGAkZ7JY1ZNaLdmK8MZb77T2+/hF/RDeOMJ2eAi1ZXWruyYO+S/+xb03bONASlQSZH3f9qAOkCCABE+x6/u+TwAAAABAEyanHgAAAAAA8D8U7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGzE49gBb1fX/qIaSU6oyjhRwlfx+1aSXH8Xh88j7G0nXdoHiJ7XY7OEdr3r17l43/61//CnPsdrtsfDKJ/5eyWq2y8c1mE+aIPp+XL19m419++WXYxy9/+cts/Fe/+lWY45tvvhk8juhZStZqNOfz+Twbn06nYR819l10xhwOhzBH1CZawyV9RDmieEop7ff7weOI5qvkM4n27G9/+9swxzmqsV55nqyd5yu6S5+jkvdVC1rZd2OM4zk9aw3nMs7nyC/sAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIbMTj2A56rv+1MPoZroWUqetUaO4/E4KF7Szxh91FgbXdcNblOSYzJ5fjX/7Xabje/3+zDH4+NjNj6dTsMcq9UqG4/GmVJK19fX2firV6+y8ZcvX4Z93NzcZOPr9TrMsVgssvHZLH6V1VjvkaHnR8k4apyFh8MhzBGt42gNR/GUUnp4eBgUL+mnZD/WOHNL1iBckhpnJlyKc/leVzLOoXu75O/H+K4zxrOWiMbhLCXy/L5tAwAAAEDDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPYBT6Pv+1EOoosZzlOSI2gyNl7Q5Ho9hjqhNSY7D4TAoR/T3JW1qfK5d14VtJpN8vX46nQ7OcYlev36djZfMyePj4+Acm80mG99ut2GOly9fZuOvXr3Kxr/44ouwj2i+bm5uwhzr9Tobn8/nYY7ZLP+6K9kzQ411FkZnzH6/D3NEa/Th4SEbv7u7C/uI2pTkuL+/z8Z3u12YI5qvkv1YsgbhXIxxHj4n5vPyXcp3y5TiZ4nWc8lcRDlK9szQcZbkKDF0f9eYrxpaGQf/7vl92wYAAACAhinYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaMjv1AC5V3/enHkJKaZxxRH2UjOF4PA6Kl7Q5HA5hjv1+n43vdrtBf1/SpmSckckkrsXPZvntP51Owxzz+bx4TJfi7du32fhyuQxzPDw8ZOMln996vc7Gr6+vwxyvXr0aFH/z5k3Yx+vXr7Pxm5ubMMdqtcrGF4tFmCNa7yVz3nVdNl7jvI1y1DjHHh8fwxzRGr29vR0UTymlT58+DYqX9BM9R0rxnJasjWiNwliic4r/y3xRQ8n3lHNhT/yPkrmI7m2XNJ/P6Vlb4hd2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIbNTD6C2vu9PPYRqxniWkj6GjqPk74/H4+Ach8MhG9/v92GO3W6XjT88PGTjd3d3YR9RjpJxRmazeGsvl8tsfLVaDR7HJfrqq6+y8aurqzDH4+NjNt51XZgj+vyur6/DHC9fvnzSeEmb7XYb5ojmNJqLlOI9MZ1OwxzR5xLFa5yF0TmXUnyGROdcSind399n47e3t9n4x48fwz5+/PHHJ88RPUdK8ZyWnKfr9TpsAzWUvB+eC3Pxf5mP0yl5N5+LyST/e56hd6GW1Li3jfG80TjOac756fzCDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGjI7NQDOEd93596CCml8cYR9TM0XuJ4PA5uczgcwhyPj4/Z+N3dXTb+448/hn18+vRp0BhKLJfLsM12u83GS+a8lb0wpl/84hfZePT5plTnM14sFtn4ZrMJc7x48SIbv7m5ycavr6/DPqIcV1dXYY7VapWNz+fzMMdsln/dTSbx/6+6rgvbDBXtqZJ9ud/vs/GHh4cwx/39fTZ+e3ubjZfsgx9++GFQPKX4zI3O7JTiOY32Wkop7Xa7sA1ExjhjWnFJz3pJz8LPU/IdI1LjTh2txZK1Go0jylHjPlUyzjH2XY0+hs5njT5q9TN0HM7Kn8cv7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx7Ac9X3fRP91BhHlON4PIY5ojYlOQ6HQza+3+/DHLvdLhu/v7/Pxj9//hz28fHjx8E5IldXV4NzTKfTKm0uzVdffZWN397ehjkeHx8Hj2OxWGTjq9UqzBGtkyi+3W7DPjabTTa+Xq/DHMvlMhuP5iKllGaz/OtuMon/fxW1qXHeRudYFE8pPsdK1t/d3V02Hp1TP/74Y9hHdBZ++PAhzPHp06dsPDqzS5TMebS+IKWUuq479RBGcy7PapzUEH3HGOt7X7ROStZRdNeJcpQ8a5Tjkr5fjDFfXDa/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaMjv1AH6qvu9PPYRnJ5rzGp9JlON4PIY5ojaHwyHMEbXZ7/fZ+G63C/t4fHwcnCOar+l0GuZYLpfZ+Gq1CnPMZmd3hAz29u3bbPzh4SHMEa2jEtFnvFgswhzRZ7xerwf9fUmbaB2mlNJ8Ps/GS9ZhNF+TyXn8/6rkHKtxTkXr+Pb2Nhv//Plz2MenT58G9VHSpmS+os++5P3TdV3Yhst2SWvgXJ6llXG2Mo7IuYzzHEV3+5LvSiXvmkiNu07UJuqj5FmjHCXv7pLvOkPV2DPRfIy1L1sYR8nacE79u/P4hgIAAAAAz4SCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIbNTD4D/ru/7QfFaOYYaa5zH43FQvKRNNI6u68I+ptPpoHhKw8eZUkqHwyEb3+12YY6SNpfmzZs32fjj42OYI5r7EtFam8/nYY7FYjEox3K5fPI+UkppNsu/qiaT+H9PUZuSvRspOWMiNc7CGnv74eEhG7+/v8/G7+7uwj6iHCV7KZqPkrURrdHVahXmKGnDeatxRrSihWdpYQwpjTOO5/Ssz1X0Piu5H9T4Tha980q+Y0RtontbSR+Rknd3pMY4anwmNfZdje+fnC+/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoALer7/tRDaEo0H8fjcdDfl7QpyVFjHFGOruuy8el0GvYxn8+z8cViEebY7/fZ+GQS1+KjZ436KG1zaW5ubrLxkjmJ5r6G2Sw+3qP1Gq3Vkj6iNiV7JmpTst6jvVsiOkNq9DF0DCVtDodDmCNax7vdblC8ZBwl81ljja5Wq2z8+vo6zFHShraNsX/H0MpzjDGOsZ51aD/nMk5+vs+fP2fjJXfDGneM6L5U8k6M3qvRHTb6+xIl44zmo+SuE81XKzWBVu6XzpjT8As7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw/gUvV9f+ohFGllnNE4aoxzjGedTOIa+HQ6zcZns+HbsmQckePxGLbZ7/eD+zk3m80mGy9ZZyVzO1TJGui6LhuP1mK0lkvalIwzalNjvZd8bkPPqWi+S9ucg5LPZD6fZ+Ml+ySar+VyGeaI9vT19XWY48WLF2EbTudS9lVK7TzLGOOo0UcLOVoYQ2v9XJoffvghG394eAhzHA6HbLzGe7XknbharbLx9Xqdjbfy3bLkjlrj+2cLe+Zcxlmixl360viFHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaMjs1AOgfX3fn/Tva+Wooeu6QfGUUppM8nXyKF7SpmQc0Zwej8fBOS7RZrPJxmvMW8m81liLUZtW1uoY671E9NlGfdSYrxo5ptNpmGM2y18PFotFNr5arcI+DofDoD5Sip+lZBzb7XZQPKWUbm5uwjY8nZIz4By08hw1xjHGs4w1zqH9jNFHSY5W1tcl+tvf/paNf/r0Kcyx3++z8ZL3/3K5zMavrq7CHNfX19n4ixcvsvGxvj+McQ+ucb+07xjKL+wAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPQBIKaW+7588xxh9jKXGsx4Oh2x8t9uFOfb7fdjm0iwWi2y8ZO7HWEdd11VpkzOZxP/zifoYY5wlxjgfSuYrajObxa/tqE20hlNKabVaZeObzSYbj86XlOJxHo/HwTmi50gpfpbr6+swx9XVVdgGxjjLStQYx9AcLYyhNMfQ99gYfZS2GSPHc/TnP/85G3///n2Y4/HxMRsvuUOs1+tsvOR99ubNm2w8GmfJ+7/GnSuajxp3rhp3+lbu/Jwvv7ADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANmZ16AP9b3/enHsJoxnrWS5nTGs9RkiNqE8WPx2PYx+FwyMb3+32YY7fbZeMlzxqNdTKJ6/nROC7RYrHIxmuss7F0XTco3kofJYbu7RLRnppOp2GO2Sz/Wo7iKaW0XC6z8ZJ9u9lssvHoHCt51qiPks8kmo9oLlJKab1eZ+PROEvb8POMdUYM1co4xzi3W8kx1jiHjqPkPlXjWaN+xnjW5+r9+/fZ+N///vcwx93d3eBxzOfzbLzkXfX9999n458/f87Gb29vwz6++OKLbPzVq1dhjkjJvovuKiXf60r6uRTRvcz58DSezwoDAAAAgDOgYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaMjs1APg/PV9PyheK8fxeHzyHFH8cDiEfez3+2z88fExzHF/fz94HJNJvl4fjfO5ms3yx2bJOqsh6qfruicfwxh91FLjjIlEa6NEdMbM5/MwR7T/1+v14HFE58dyuQz7iM6Yks9kOp1m44vFIswRjXW1WoU5StpADTXO3RZylPx91OZcckTnZUmbVnKc03t/TNF7t+ReHt3tHx4ewhwfPnzIxt+/fx/m2G632fjbt2+z8a+//jrs45tvvsnGf/nLX4Y5vv3222y85P0f3alK9gw8NasQAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0ZHbqAfDz9X0/So7j8Ti4n6HjKBln1KbkOaI2h8NhUDyllPb7fTb+8PAQ5vj06VM2fnt7G+aIxjqbxcfDy5cvwzaXJpqXGmv1knRdNzjHGGddSR9Dn2U6nYZtxjgLS0TPOp/Ps/H1eh32EZ1BJc8RzWnJORY9y2KxCHOUtOE/q3FGjGGMcZ7LXJSInqXkWVvJMZnkf98wNJ5SfJaVvD9qnIc1noV/V7LOonde9P0hpZS+//77bPwPf/hDmCNaJ9Hd/927d2Ef3333XTb+4cOHMMevf/3rweNo4X451rnfyjj46Zy6AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx4A/13f96ceQpFonCXPEbU5Ho9hjqjNfr8Pc0RthsZL2jw8PIQ5Pn78mI2/f/8+zPHdd99l43/5y1/CHG/fvs3Gf/e734U5zs1kMvz/HOeyt1tRY77GOKda0XXdoHhKKU2n02x8sVhk4yVn4eFwCNtEomeJniOllGaz/FUoiqeU0nw+D9vAGEr2d40cQ8+ZMfooaVPyTo/aRPGxzqGozXK5DHNEZ1nJszxH//znP7Pxv/71r2GO6G5f8v3g/v4+G1+tVmGOaA1E8ZJ9Gd0RdrtdmCP63ncudzaI+IUdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0JDZqQfA5ev7fnCb4/EY5tjv99n44XB48hwl46zxrI+Pj9n4x48fwxx/+tOfsvHf//73YY63b9+GbS7NZPL0/+coWQPPSdd1g3NE+67GOVVD9Kwlc1Ejx2yWvx7M5/NsvOS8HWOdT6fTwW1q5HiuauzdsYwx1hp9jJGjlc+txjijd3bJOz1qE+3/6DxNKT5TF4tFmGO1Wg2Kp5TSer3OxqNxPlf/+Mc/svE//vGPYY67u7tsvGS9R2vxm2++CXNE6+Tq6iobv7m5Cft4+fJlNh6tw5TifVWyt52nnAO/sAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoA/1vXdWGbvu9HGAk/RfSZlHxmx+MxGz8cDmGO/X6fjT8+PoY5drvdoD5KTCb5OvlisQhzrFarbHy9Xoc5rq6usvF3796FOd68eRO2uTTR51ci2hPT6XRwH2No5TyOzo+U4vfLGO+fkj5qjDNqU7KGozmdzfLXh5LPpMb6qTFf0XyUzFeNcwFqKFnzY/TTyllWY39H7+ToPJzP52Ef0d0vuvelFN/9rq+vwxw3NzeDx/EcRXfmh4eHMEf0HaNkHUWfz2azCXNE6yjqY7vdhn1EbV6/fh3miNZqtC9TGu+8hCHcMAEAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGzE49AC5f3/dhm8PhkI3v9/swx8PDQzZ+d3cX5oja7Ha7MEdksVhk45vNJszx5s2bweOYz+fZ+Ndffx3muLm5GTyOc9N1XRM5SvbVU2vlOabT6eB+jsdjmCN63skk/z+wkmeN+iiZ86HjTCmej9ksf30omc+hY0gpfpYa81WSo2ROL1GNM2AMY4yzlXdDjX7G2Dcle2aMvVny/ojaROdhSR/RnWy1WoU5rq+vs/HXr1+HOaI22+02zPEcffvtt9n4b37zmzBHtJ6Xy2WYY71eZ+Ml3zGitRb1UTLOaBwl44zWYrSnUorPh+f6bv9vzuWdf2msQgAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0JDZqQfA+ev7Phs/Ho9hjqjNfr8Pc9zd3WXjHz58GJxjOp1m4/P5POxjsVhk47NZvC2jHMvlMsyx2Wyy8c+fP4c5VqtV2ObSTCb5/3NE+6GWrusG5xhrrDklzzHGOKO9XTKO6Byr8awlOaI20RouGcfQeGmboWrMV40cUEsra23ovil5juisKjnLauSI3g/RvS26s6UU36fW63WYY7vdZuOvXr0Kc7x79y4bf/36dZijhbvF2KJ5KfkuFK2zkjt3tE5KckTfD2p8B4najPV9aoz3vzsGQ/mFHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQEMU7AAAAACgIQp2AAAAANCQ2akH8Fx1XTc4R9/3o+SI2gyNp5TS8XjMxne7XZjj9vY2G//+++/DHHd3d9n4ZrPJxq+vr8M+lstlNj6dTsMcq9UqG4/GmVJK2+02G398fAxzLBaLsM2lqbF3GV/0uZWcU1GOGn1EZ2HJ+qsxjhrn+tA+aqgxX7X64emMMf/nsk5qrPkxnrXGOCeT+HcHUZuSO9dslv+6FMVL7kpRm/V6HeaI7nUvXrwIc3zxxRfZ+JdffhnmGONsb00099G7PaWU5vN5Nh7d/Uva1MgRrdXoOUraRHuqpE3J3o7Oh7HuEJDjF3YAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIYo2AEAAABAQxTsAAAAAKAhs1MPgKfV9/3J+zgej2GOqM3hcAhz7Ha7bPzu7i7McXt7m413XZeNr1arsI/IcrkM28xm+a1bMl/X19fZeMnnFo3jOYrWCM9XydqI2pSc6dHerbFGx3i3jMWepRWtrMUaZ9XQeEopTSb53xVE8ZRSmk6ng+IpxXedKL5YLMI+ojbr9TrMcXV1lY3f3NyEOV6+fJmNv3r1KsxxSe+HUtHc11jvJd8xou8QJd8xorUYrff5fB72EeUo2Zcl+z9S45xqwbmMk5/HL+wAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhsxOPYAWdV2Xjfd9P9JInt7xeBzc5nA4DIqnlNJ+vx+co6RNJHrW+/v7bPzu7i7sY7PZZOOzWbwt1+t1Nj6ZjFOLH6uflkTnQytKzqkWnuVcxtmKkrmoMV9jvOdK3j9DjXVGWaOMJVprY50RQ8dRY5wl+ztqM51OwxzRvWyxWAyKpxTf61arVZgjul9eXV2FOa6vr7Pxm5ubMMcYZ3trttttNl6y3qO1OJ/PwxxRm5K1GI0jylGyp6I2JTmiOS3JEZ0PNc6pGmfhGIyjXc/v2zYAAAAANEzBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQkNmpB/BTdV2Xjfd9P9JI8s5lnCWOx2M2vtvtsvG7u7uwj9vb28E5DodDNj6ZxPXpqE30uUVjKOljNou35XK5zMbn8/ngcURruCTHJSqZl6FqnA9jjLNE9CytjLOGGuf+ueSooUYfrayfVsbBzzf0M7ykNVDyLFGbKF7jTlaSYzqdZuMld66oTXTnWiwWYR+r1Sob32w2YY6ozXa7DXNEba6ursIc0feGS1QyL5ForUbxlOK1WmO91xhn1KbkDKqRI3JJ5zrn6/l92wYAAACAhinYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA2ZnXoA/Hxd143Sz+FwyMbv7++z8Y8fP4Z9/Pjjj9n47e1tmOPh4SEbn0zi+vRiscjGp9PpoL9PKaX5fD44x3K5HJwjmo+S+Sppw0831t6O9H0/OMe5PEuNcY7RR0mOFp61RCtro5Vx8J/5fOob4ywaGk+pzj0lurdF8ZTie1sUj+5sKaW0Wq2y8fV6Hea4uroaFE8ppc1mMzhH9L3hEpV8PpFoPZfsmdks/9W+ZL0P3TMl+zJ6lho5SuarhbOwRh9j5eA0fNsGAAAAgIYo2AEAAABAQxTsAAAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGjI79QB4Wl3XPXkffd9n48fjMcxR0iYym+WX83a7DXOsVqtBfVxdXYV9RG2iMaSU0mKxGBRPKaXJJF+vj+KlbS5NC3tqLGM8a4ka89HKs1wK8wk/3Rj7pkYfY+Qo6SNqU+OeEt3rStosl8tsvORet16vs/HNZhPmiO6XJXfU6K5cMo7D4RC2uTTRGigRrdWSPTPG3b7GOKfT6eAcY5wx53Jmn4vn9Kw1Pb9v2wAAAADQMAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw+gtq7rwjZ935+8j5IcQ/soMZnENdvpdJqNL5fLbHy73Q4ex3q9DnPs9/ts/HA4hDmiOR06FymldH19nY2XPOtiscjGZ7N4a0fPUrJGS9YPP12N86GGGmdMDWPMR41nbWWcrawf4Kcp2bs19vcYOYbGU4rvGCV3kBo55vN5Nh7dyUruhtHdb7PZhDmurq6y8ZL7eNQm6iOllHa7Xdjm0kRroESNPRO1KVnvQ3PUGGcrOUq0cOdqYQw8Hd+2AQAAAKAhCnYAAAAA0BAFOwAAAABoiIIdAAAAADREwQ4AAAAAGqJgBwAAAAANUbADAAAAgIbMTj0Afr6u6wa3mUzimu18Ps/G1+v14D42m002vt/vwxzH43FQvET0LLNZvKWi+Vwul4NzlMx5SZsxcpybkn0X6fu+wkieXo1nrWGM+TqXZ21lnMDlGuOcqXGHLckxnU6z8ZJ7W9Qmipfc61arVTYe3ZNL2kT39Vo5ojm/RCXraKgae6ZEdLevsS9byTG0j7FynIvn9Kxjen7ftgEAAACgYQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADVGwAwAAAICGKNgBAAAAQENmpx7AKXRdl433fX8RfZS0mUzimu10Os3Gl8tlNj6bxcvscDhk48fjMcxR0iYSfS5jzGcUL+1n6DhK1hc/zxhzW+OMaUUra3GMOW3lWYdqZf1dynzCuRlj77VyD57P54PiKcV36dVqFeaI2qzX68HjWCwWYY5Wzv8x1biX1+ijxr6LcgyN1xjDmP2MkaOFPmiXX9gBAAAAQEMU7AAAAACgIQp2AAAAANAQBTsAAAAAaIiCHQAAAAA0RMEOAAAAABqiYAcAAAAADZmdegD8d13XZeN93w/OUWI6nQ7qI/r7lFI6HA4/aUz/STQfJfM1VMl8R20mk7iOHuWoMY4SNXLwNFr5bMbYd2NpYU7PZT5bmKsxPbfnhaFrfqx7yhjjiO5tJffgqE2NHCX3y8jxeAzb1LjTn5uSzydyLuv9qf/+0nK00Ect5zTWS+IXdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCGzUw+gRV3XZeN93z95HyX9lOSoIeqnxnxNJvnacUmOGp/LGGp8bkM/kxp9QIlW1tG5nA+RVuYTeN6GnqmXdK+rIXrW4/EY5tjv99n4/f19mOPz58/Z+HQ6DXPc3d1l45vNJsxxbqLvMTW08r3vqf++Vo5W+jmXe9u5jPM58gs7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCFd3/f9qQcBAAAAAPx/fmEHAAAAAA1RsAMAAACAhijYAQAAAEBDFOwAAAAAoCEKdgAAAADQEAU7AAAAAGiIgh0AAAAANETBDgAAAAAaomAHAAAAAA35f5EOqgphU/Q8AAAAAElFTkSuQmCC", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "import torch\n", "\n", "batch = next(iter(dataloader)) # batch shape (B, C, H, W)\n", "imgs = batch[:4] # take first 4\n", "\n", "fig, axs = plt.subplots(1, 4, figsize=(16, 4))\n", "for i, img in enumerate(imgs):\n", " img = img.detach().cpu()\n", " if img.ndim == 3: # (C,H,W) -> (H,W,C)\n", " img = img.permute(1, 2, 0)\n", " img = (img + 1) / 2 # convert from [-1,1] to [0,1]\n", " arr = img.numpy().squeeze()\n", " axs[i].imshow(arr, cmap='gray' if arr.ndim == 2 else None)\n", " axs[i].axis('off')\n", "plt.show()" ] }, { "cell_type": "markdown", "id": "06c1dff6", "metadata": {}, "source": [ "**Create a Architecture**" ] }, { "cell_type": "code", "execution_count": 76, "id": "3ee088b3", "metadata": {}, "outputs": [], "source": [ "from diffusers import UNet2DModel\n", "\n", "model = UNet2DModel(\n", " sample_size=config.image_size, # the target image resolution\n", " in_channels=1, # the number of input channels, 3 for RGB images\n", " out_channels=1, # the number of output channels\n", " layers_per_block=2, # how many ResNet layers to use per UNet block\n", " block_out_channels=(128, 128, 256, 256, 512, 512), # the number of output channels for each UNet block\n", " down_block_types=(\n", " \"DownBlock2D\", # a regular ResNet downsampling block\n", " \"DownBlock2D\",\n", " \"DownBlock2D\",\n", " \"DownBlock2D\",\n", " \"AttnDownBlock2D\", # a ResNet downsampling block with spatial self-attention\n", " \"DownBlock2D\",\n", " ),\n", " up_block_types=(\n", " \"UpBlock2D\", # a regular ResNet upsampling block\n", " \"AttnUpBlock2D\", # a ResNet upsampling block with spatial self-attention\n", " \"UpBlock2D\",\n", " \"UpBlock2D\",\n", " \"UpBlock2D\",\n", " \"UpBlock2D\",\n", " ),\n", ")" ] }, { "cell_type": "code", "execution_count": 77, "id": "8a703029", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Input shape: torch.Size([1, 1, 32, 32])\n", "Output shape: torch.Size([1, 1, 32, 32])\n" ] } ], "source": [ "sample_image = dataset.__getitem__(0).unsqueeze(0) # add batch dimension\n", "print(\"Input shape:\", sample_image.shape)\n", "print(\"Output shape:\", model(sample_image, timestep=0).sample.shape)" ] }, { "cell_type": "markdown", "id": "8d3d9af1", "metadata": {}, "source": [ "**Create a scheduler**" ] }, { "cell_type": "code", "execution_count": 78, "id": "22c2ce21", "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAfcAAAGdCAYAAAAPGjobAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjcsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvTLEjVAAAAAlwSFlzAAAPYQAAD2EBqD+naQAAO6BJREFUeJzt3XtwVfW5//FPCCQh5EZIyAUChJuAXKxRMdpalJSAHUeUcbR1RrQWRw90qvQmHS9V25PWniraUjhnWqFOS72cqXjsBdQoWG2QgnKQiylgIOGScM2V3EjW7w8P+2d0Qb5PsmOyNu/XzJohO0+e/d17rb0fvt+99rOiPM/zBAAAIka/3h4AAAAIL4o7AAARhuIOAECEobgDABBhKO4AAEQYijsAABGG4g4AQIShuAMAEGH69/YAPq29vV2HDh1SYmKioqKiens4AAAjz/NUV1en7Oxs9evXc3PIpqYmtbS0dDtPTEyM4uLiwjCivqPPFfdDhw4pJyent4cBAOimiooKDR8+vEdyNzU1KTc3V5WVld3OlZmZqbKysogq8D1W3JctW6af//znqqys1LRp0/TLX/5Sl112Wad/l5iYKEl6/fXXNWjQIKf7On78uPO4Ro4c6RwrSUePHnWOTU1NNeUeNmyYc+y//vUvU+6ysjLn2IsuusiU+8MPPzTFZ2dnO8dWV1ebco8bN845tr6+3pS7rq7OOba9vd2Uu7W11RSfl5fnHLt7925T7tjY2B7LnZaW5hw7YMAAU+62tjbn2NOnT5ty9+9ve2u0vPZjYmJMuf/5z386x06bNs2U23KMDxkyxDm2vr5eeXl5offzntDS0qLKykqVl5crKSmpy3lqa2s1YsQItbS0UNw78/zzz2vx4sVasWKFpk+frqVLl6qwsFClpaUaOnToOf/2zFL8oEGDlJCQ4HR/TU1NzmOzHmyNjY09lttyQLo+F2fEx8c7x1rH7fqfrjMsY7cWPcvYrR/zWC67YC3u1qXEnjxWLG9oluPKOpaeLO7W48o6FstxaC3uPflathzjXSnUn8dHq0lJSd0q7pGqRz4MeeKJJ7RgwQLdcccdmjRpklasWKH4+Hg988wzPXF3AIDzlOd53d4iUdiLe0tLi7Zs2aKCgoL/fyf9+qmgoEAlJSWfiW9ublZtbW2HDQAAFxR3f2Ev7seOHVNbW5syMjI63J6RkeF74kNRUZGSk5NDGyfTAQBcUdz99fr33JcsWaKamprQVlFR0dtDAgAg0MJ+Ql1aWpqio6NVVVXV4faqqiplZmZ+Jj42NtZ0ti4AAGd0d/bNzN1RTEyM8vLyVFxcHLqtvb1dxcXFys/PD/fdAQDOYyzL++uRr8ItXrxY8+fP1yWXXKLLLrtMS5cuVUNDg+64446euDsAAPAJPVLcb775Zh09elQPPfSQKisrddFFF2nt2rWfOckOAIDuYFneX491qFu0aJEWLVrU5b+3NLGxdB6zdLOTbF2Zampqeiz3iRMnTLktnaq2bdtmym3pOCfZuppZG+QcPnzYOdbaQdDS+MT6LQ/rsWLZR6NGjeqxsYwePdqU29IadMKECabcBw8edI61NrGxHuOWzorWLoyWJjbWDmulpaXOsadOnXKObWhoMI2jOyju/nr9bHkAABBefe7CMQAAuGLm7o/iDgAILIq7P5blAQCIMMzcAQCBxczdH8UdABBYFHd/FHcAQGBR3P3xmTsAABGGmTsAILCYufujuAMAAovi7q/PFvddu3Y5t12cPn26c97GxkbTOCw7PikpyZT72LFjzrGWFrvWeGtL0cTERFN8//7uh9nIkSNNufft2+cca2lXah3L6dOnTbkHDBhgio+OjnaOXbdunSm3pV1tcnKyKffQoUOdY5ubm025Bw8e7Bxrvaz09u3bTfGW53D8+PGm3Js2bXKOtbweJNtz2NbW5hwbqQUzSPpscQcAoDPM3P1R3AEAgUVx98fZ8gAARBhm7gCAwGLm7o/iDgAItEgt0N3BsjwAABGGmTsAILBYlvdHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/vpscc/KylJCQoJTrKWP+vHjx03jyM7Odo6trq425T516pRz7JgxY0y5W1panGPr6upMuV17/p9h6UXf2tpqyp2VleUce/jwYVPuDz74wDn2wgsvNOW2PoeWsY8bN86UOz093Tm2tLTUlNv1NSxJBw4cMOW2vDatvfytfe4t1y2w7nuLuLg4U/yQIUOcYy3XT7C+p3QHxd0fy/IAAESYPjtzBwCgM8zc/VHcAQCBRXH3x7I8AAARhpk7ACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3R/FHQAQWBR3f322uCckJDi3rrS0LB02bJhpHJaWsj3ZUnTkyJGm3Pv27XOO7d/fdhi8+uqrpvhZs2Y5x27evNmUe/z48c6xu3btMuVOS0tzjm1razPltj7nFtZj5dChQ86xOTk5ptyWN86xY8eacldVVTnHWtvJDh8+3BTf1NTkHGttg2xpD5ybm2vKbWn5a2lt297ebhoHwq/PFncAADrDzN0fxR0AEFgUd398FQ4AgAjDzB0AEFjM3P1R3AEAgUVx98eyPAAAEYaZOwAgsJi5+6O4AwACLVILdHewLA8AQIRh5g4ACCyW5f1R3AEAgUVx99eni3tUVJRT3ODBg51ztrS0mMZg6S++e/duU+7k5GTnWGtv7NTUVOdY63Ny7bXXmuI/+OAD59hLLrnElNvSp3vmzJmm3JZ+4Tt27DDltl7joKGhwTl248aNptyWnu7WN0JLz/3nn3/elHvu3LnOsTt37jTlrqmpMcVb+q5bX29XX321c6zlmhKS7biy7MtTp06ZxtEdFHd/fOYOAECECXtx/9GPfqSoqKgO24QJE8J9NwAAhGbu3dkiUY/M3C+88EIdPnw4tL399ts9cTcAgPNcbxT3t956S9ddd52ys7MVFRWlNWvWdPo369ev18UXX6zY2FiNHTtWq1atsj9Ygx4p7v3791dmZmZos3xuDQBAX9bQ0KBp06Zp2bJlTvFlZWX66le/qquvvlpbt27Vvffeq29+85tat25dj42xR06o2717t7KzsxUXF6f8/HwVFRVpxIgRvrHNzc0dTharra3tiSEBACJQb5xQN2fOHM2ZM8c5fsWKFcrNzdUvfvELSdLEiRP19ttv68knn1RhYaH5/l2EfeY+ffp0rVq1SmvXrtXy5ctVVlamL33pS6qrq/ONLyoqUnJycmjLyckJ95AAABEqXMvytbW1HTbrN5TOpaSkRAUFBR1uKywsVElJSdju49PCXtznzJmjm266SVOnTlVhYaH++te/qrq6Wi+88IJv/JIlS1RTUxPaKioqwj0kAADOKScnp8NEs6ioKGy5KysrlZGR0eG2jIwM1dbWqrGxMWz380k9/j33lJQUjR8/Xnv27PH9fWxsrGJjY3t6GACACBSuZfmKigolJSWFbg96Xerx77nX19dr7969ysrK6um7AgCcZ8K1LJ+UlNRhC2dxz8zMVFVVVYfbqqqqlJSUpIEDB4btfj4p7MX9u9/9rjZs2KB9+/bpH//4h2644QZFR0fra1/7WrjvCgCAPi8/P1/FxcUdbnvttdeUn5/fY/cZ9mX5AwcO6Gtf+5qOHz+u9PR0ffGLX9TGjRuVnp5uyrN9+3bFx8c7xebm5jrntcRK0uuvv+4cO3ToUFPu6Oho59i9e/eacl955ZXOse+//74pt7W1pOUkyU8ui7mwPC/t7e2m3Pv373eOnTx5sim3tV3t2b5t4sf6WrM858ePHzflbm1tdY6dN2+eKbflOJw0aZIp98GDB03x/fq5z5Msr3tJZz0Z2Y+lpbX08WfBrixtcy1tbburN86Wr6+v7/BRc1lZmbZu3arU1FSNGDFCS5Ys0cGDB/Xss89Kku6++2796le/0ve//3194xvf0BtvvKEXXnhBf/nLX7o87s6Evbg/99xz4U4JAICv3ijumzdv7tDzf/HixZKk+fPna9WqVTp8+LDKy8tDv8/NzdVf/vIX3XfffXrqqac0fPhw/eY3v+mxr8FJffzCMQAAnEtvFPcZM2ac8+/8us/NmDHDvEraHVw4BgCACMPMHQAQWFzy1R/FHQAQWBR3fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7+WJYHACDCMHMHAAQWM3d/FHcAQKBFaoHujj5b3CdPnqzExESn2JMnTzrn3bRpk2kcn2wxGM5xSNLp06edY0tLS025LT3ArT3XExISTPGW/C+//LIpt6Wnu/UqT5ae64cPHzblHjdunCk+Jiamx8ZiuSZCc3OzKbelj7r1tXnBBRc4x1qPcWuPdkvfdcvrXpLzNTYknfXS2mdjef1YckdFRZnGgfDrs8UdAIDOsCzvj+IOAAgsirs/ijsAILAo7v74KhwAABGGmTsAILCYufujuAMAAovi7o9leQAAIgwzdwBAYDFz90dxBwAEFsXdH8vyAABEmD47c29qalL//m7Da2xsdM6blpZmGscrr7ziHPuFL3zBlDs7O9s51tIiVLL9b7S2ttaU2/ocWlpz5ufnm3K3trY6x+7du9eUOyMjwzl2wIABptxNTU2m+L///e/OsdbjsKamxjn21KlTptyW53zKlCmm3A0NDc6xJ06cMOUeM2aMKd71vUqSDh48aMptaT1t3T87duxwjrW0HraOozuYufvrs8UdAIDOUNz9sSwPAECEYeYOAAgsZu7+KO4AgMCiuPujuAMAAovi7o/P3AEAiDDM3AEAgcXM3R/FHQAQWBR3fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7++mxxb2lpce5JPnXqVOe8JSUlpnGMHDnSOTYlJcWU+8iRI86xlj7nkq3ffmJioim3lSW/te/2iBEjnGOtz6FFeXm5Kf706dOm+KysLOfYw4cPm3Jb+uJb3whzcnJ6LLfFoEGDTPHW/Wm53oK19/+uXbucY0ePHm3KbTkOLe9X0dHRpnEg/PpscQcAoDPM3P1R3AEAgRapBbo7OFseAIAIw8wdABBYLMv7o7gDAAKL4u6P4g4ACCyKuz8+cwcAIMIwcwcABBYzd38UdwBAYFHc/bEsDwBAhGHmDgAILGbu/vpscT958qRzb/n6+nrnvMOGDTONIzk52Tm2tbXVlPvYsWPOsXV1dabceXl5zrEfffSRKXd8fLwpvrq62jn2wIEDptyxsbHOsW1tbabcO3bscI619vTu39/20ouLi3OOtR6Hlv1pHbflebH2xLf0i1+3bp0pd1JSkinesn+s10+IiYlxjrW8p0hSQ0ODc6zlOg6W9+Tuorj7Y1keAIAIYy7ub731lq677jplZ2crKipKa9as6fB7z/P00EMPKSsrSwMHDlRBQYF2794drvECABByZubenS0SmYt7Q0ODpk2bpmXLlvn+/vHHH9fTTz+tFStW6N1339WgQYNUWFiopqambg8WAIBPorj7M3/mPmfOHM2ZM8f3d57naenSpXrggQd0/fXXS5KeffZZZWRkaM2aNbrlllu6N1oAANCpsH7mXlZWpsrKShUUFIRuS05O1vTp01VSUuL7N83Nzaqtre2wAQDggpm7v7AW98rKSklSRkZGh9szMjJCv/u0oqIiJScnh7acnJxwDgkAEMEo7v56/Wz5JUuWqKamJrRVVFT09pAAAAFBcfcX1uKemZkpSaqqqupwe1VVVeh3nxYbG6ukpKQOGwAA6LqwFvfc3FxlZmaquLg4dFttba3effdd5efnh/OuAABg5n4W5rPl6+vrtWfPntDPZWVl2rp1q1JTUzVixAjde++9+vGPf6xx48YpNzdXDz74oLKzszV37txwjhsAADrUnYW5uG/evFlXX3116OfFixdLkubPn69Vq1bp+9//vhoaGnTXXXepurpaX/ziF7V27VpTe0ZJio6OVnR0tFOsa5taSerXz7ZY8cEHHzjHWluQWlrh1tTUmHJbWopOnjzZlPvo0aOm+LS0NOfYQ4cOmXJHRUU5x1pfxJaTOy0tQiXbuKWPv1Xiytqq+MSJE86xQ4YMMeV+8803nWPHjh1rym1pjnXNNdeYclufQ0t8YmKiKbflNWF9f8vKyjLFIzjMxX3GjBnnfJOMiorSo48+qkcffbRbAwMAoDPM3P312QvHAADQGYq7v17/KhwAAAgvZu4AgMBi5u6P4g4ACCyKuz+W5QEAiDDM3AEAgRaps+/uoLgDAAKLZXl/LMsDAAKrt9rPLlu2TKNGjVJcXJymT5+uTZs2nTV21apVioqK6rBZG7tZUdwBADB4/vnntXjxYj388MN67733NG3aNBUWFurIkSNn/ZukpCQdPnw4tO3fv79Hx0hxBwAEVm/M3J944gktWLBAd9xxhyZNmqQVK1YoPj5ezzzzzFn/JioqSpmZmaEtIyOjOw+7U332M/eEhAQlJCQ4xaanpzvntfTRlmz94hsbG025XXvnS/ZxW5SWlpriY2NjTfENDQ3OsW1tbabclh7tlv7s1tzWfvvWqyRu27bNOXbw4MGm3Jbj1tqL3NLr3Pr6SU1NdY6trKw05bY6fvy4c+yAAQNMuVNSUpxje/LaGcnJyc6xltd8d4XrM/fa2toOt8fGxvq+17W0tGjLli1asmRJ6LZ+/fqpoKBAJSUlZ72f+vp6jRw5Uu3t7br44ov17//+77rwwgu7PO7OMHMHAJz3cnJylJycHNqKiop8444dO6a2trbPzLwzMjLO+p/ICy64QM8884xefvll/f73v1d7e7uuuOIKHThwIOyP44w+O3MHAKAz4Zq5V1RUKCkpKXS7dYXyXPLz8zus1l1xxRWaOHGi/vM//1OPPfZY2O7nkyjuAIDACldxT0pK6lDczyYtLU3R0dGqqqrqcHtVVZUyMzOd7nPAgAH6whe+oD179tgH7IhleQAAHMXExCgvL0/FxcWh29rb21VcXOx8Lk1bW5s++OAD8zksFszcAQCB1RtNbBYvXqz58+frkksu0WWXXaalS5eqoaFBd9xxhyTptttu07Bhw0Kf2z/66KO6/PLLNXbsWFVXV+vnP/+59u/fr29+85tdHndnKO4AgMDqjeJ+88036+jRo3rooYdUWVmpiy66SGvXrg2dZFdeXt7hmwsnT57UggULVFlZqcGDBysvL0//+Mc/NGnSpC6PuzMUdwBAYPVW+9lFixZp0aJFvr9bv359h5+ffPJJPfnkk126n67iM3cAACIMM3cAQGBx4Rh/FHcAQGBR3P312eK+c+dODRw40Cn22muvdc5rbW86btw459jt27ebch86dMg51trO8e2333aOHTZsmCn3p7/f2Zn6+voeG0tra6tzrGs74zMs7U0trTmljxtmWFha4Vrbm1r2z8mTJ025La1wa2pqTLktr2Vr++aYmBhT/KBBg5xjre9Bn26LGq5xSNKUKVOcY6urq51jLa210TP6bHEHAKAzzNz9UdwBAIFFcffH2fIAAEQYZu4AgMBi5u6P4g4ACCyKuz+W5QEAiDDM3AEAgcXM3R/FHQAQWBR3fxR3AECgRWqB7g4+cwcAIMIwcwcABBbL8v76bHGfMGGCc59kS991a8/ozZs3O8fGx8ebcqenpzvHpqWlmXK3tLQ4x/brZ1vASUpKMsVPnjzZOXbfvn2m3JY+3f372w73gwcPOsfu3r3blDs2NtYUb+n/bomVpKysLOdY6xvhgQMHnGOPHz9uyt3e3u4ce/nll5tyv/POO6Z4y2vfehzGxcU5xzY2NppyW45Dy7g/z97yFHd/LMsDABBh+uzMHQCAzjBz90dxBwAEFsXdH8vyAABEGGbuAIDAYubuj+IOAAgsirs/luUBAIgwzNwBAIHFzN0fxR0AEFgUd38UdwBAYFHc/fXZ4h4VFeXcFrW5udk5ryVWsu14a/vMMWPGOMeeOHHClNvSmtNq/Pjxpvh//etfzrGuLYfPsLbytLC0/D169Kgpd21trSl+4sSJzrGWlq+Src2utcXywIEDnWOtbZAzMjKcY63tga0tli1traOioky5U1JSnGNPnz5tym1hOWZPnTrVY+OAmz5b3AEA6Awzd38UdwBAYFHc/Zm/CvfWW2/puuuuU3Z2tqKiorRmzZoOv7/99tsVFRXVYZs9e3a4xgsAADphnrk3NDRo2rRp+sY3vqEbb7zRN2b27NlauXJl6Gfr5S0BAHDBzN2fubjPmTNHc+bMOWdMbGysMjMzuzwoAABcUNz99UiHuvXr12vo0KG64IILdM8995zzLPLm5mbV1tZ22AAAQNeFvbjPnj1bzz77rIqLi/Wzn/1MGzZs0Jw5c9TW1uYbX1RUpOTk5NCWk5MT7iEBACLUmZl7d7ZIFPaz5W+55ZbQv6dMmaKpU6dqzJgxWr9+vWbOnPmZ+CVLlmjx4sWhn2traynwAAAnLMv76/ELx4wePVppaWnas2eP7+9jY2OVlJTUYQMAAF3X499zP3DggI4fP66srKyevisAwHmGmbs/c3Gvr6/vMAsvKyvT1q1blZqaqtTUVD3yyCOaN2+eMjMztXfvXn3/+9/X2LFjVVhYGNaBAwBAcfdnLu6bN2/W1VdfHfr5zOfl8+fP1/Lly7Vt2zb97ne/U3V1tbKzszVr1iw99thj5u+6Hzt2zLk/cUNDg3Pe6Oho0zhSU1OdY62rE5Y+0OXl5abclq8iNjY2mnJbnm9JGjJkiHOstVe85Xmx5j7bSaB+rG8Q1h7gpaWlzrGVlZWm3JZrBVhfx5b9M3r0aFNuS098a9966/686KKLnGP3799vyl1WVuYcm5CQYMptOa5uuOEG59i6ujrTOLorUgt0d5iL+4wZM875RK5bt65bAwIAAN1Db3kAQGCxLO+P4g4ACCyKu78e/yocAAD4fDFzBwAEFjN3fxR3AEBgUdz9sSwPAECEYeYOAAgsZu7+KO4AgMCiuPtjWR4AgAjDzB0AEFjM3P312eI+aNAgDRo0yCk2Li7OOW98fLxpHJZ+19ae6zt37nSOtV7jfuDAgc6x1oN7wIABpnhLf/H09HRTbkuP9mHDhply/+///q9z7Fe+8hVT7rNdAvlsLH3urX29Xa/hIEl79+415bZcV2DLli2m3JZrORw/ftyUe9SoUab44uJi51jrcWh5DjMyMky5m5ubnWOPHDniHFtfX28aR3dQ3P312eIOAEBnKO7++MwdAIAIw8wdABBYzNz9UdwBAIFFcffHsjwAABGGmTsAILCYufujuAMAAovi7o9leQAAIgwzdwBAYDFz90dxBwAEFsXdX58t7vHx8c7tZysrK53zpqSkmMZhabdpbctqaed44sQJU+6JEyf2yDgkW7tfScrMzHSOtbRClaTY2Fjn2I8++siU23KsvPrqq6bcl19+uSne0pp19+7dptzjxo1zjh0/frwpt4W1tW3//u5vX8nJyabc1mM8LS3NOdZyzEpSv37un55a973lfcJyDFpfxwi/PlvcAQDoDDN3fxR3AEBgUdz9cbY8ACDQzhT4rmxdtWzZMo0aNUpxcXGaPn26Nm3adM74F198URMmTFBcXJymTJmiv/71r12+bxcUdwAADJ5//nktXrxYDz/8sN577z1NmzZNhYWFZz1/6R//+Ie+9rWv6c4779T777+vuXPnau7cudq+fXuPjZHiDgAIrO7M2rs6e3/iiSe0YMEC3XHHHZo0aZJWrFih+Ph4PfPMM77xTz31lGbPnq3vfe97mjhxoh577DFdfPHF+tWvftXdh39WFHcAQGCFq7jX1tZ22Jqbm33vr6WlRVu2bFFBQUHotn79+qmgoEAlJSW+f1NSUtIhXpIKCwvPGh8OFHcAwHkvJydHycnJoa2oqMg37tixY2pra1NGRkaH2zMyMs76tezKykpTfDhwtjwAILDCdbZ8RUWFkpKSQrdb+xH0NRR3AEBghau4JyUldSjuZ5OWlqbo6GhVVVV1uL2qquqsDbsyMzNN8eHAsjwAAI5iYmKUl5en4uLi0G3t7e0qLi5Wfn6+79/k5+d3iJek11577azx4cDMHQAQWL3RxGbx4sWaP3++LrnkEl122WVaunSpGhoadMcdd0iSbrvtNg0bNiz0uf23v/1tffnLX9YvfvELffWrX9Vzzz2nzZs367/+67+6PO7O9Nni7nme2tvbnWI/faLCuVg/R9m5c6dzrGsv/DNGjRrlHFtfX2/KbemjbuldLUk1NTWm+D179jjHWpep9u/f7xxr6f8tScOHD++x3Pv27TPFp6enO8dajivJ9pzv2rXLlHvgwIHOsRdeeKEpt+U1cfr0aVPugwcPmuJbW1udY637x3KMW44TSWpsbHSOtVyDwPp+1R29UdxvvvlmHT16VA899JAqKyt10UUXae3ataFaVF5e3uF99YorrtDq1av1wAMP6Ic//KHGjRunNWvWaPLkyV0ed2f6bHEHAKCvWrRokRYtWuT7u/Xr13/mtptuukk33XRTD4/q/6O4AwACi97y/ijuAIDAorj7o7gDAAKL4u6Pr8IBABBhmLkDAAKLmbs/ijsAILAo7v5YlgcAIMIwcwcABBYzd38UdwBAYFHc/fXZ4j527FinK/RIH7f6c2VpEylJU6ZMcY5NSEgw5baMOzs725Tb8jirq6tNuaOjo03xluewtrbWlHv69OnOsda2uaWlpaZ4iy984QumeMtYrC18LS2Wy8rKTLlTU1OdY61vsnv37nWOLSgoMOU+dOiQKd7SxtV6XFneJ5qbm025LS2z6+rqnGNPnTplGgfCr88WdwAAOsPM3Z/phLqioiJdeumlSkxM1NChQzV37tzP/C+0qalJCxcu1JAhQ5SQkKB58+Z95jq2AACEw5ni3p0tEpmK+4YNG7Rw4UJt3LhRr732mlpbWzVr1iw1NDSEYu677z698sorevHFF7VhwwYdOnRIN954Y9gHDgAA/JmW5deuXdvh51WrVmno0KHasmWLrrrqKtXU1Oi3v/2tVq9erWuuuUaStHLlSk2cOFEbN27U5ZdfHr6RAwDOeyzL++vW99zPnKB05qSZLVu2qLW1tcPJKxMmTNCIESNUUlLim6O5uVm1tbUdNgAAXLAs76/Lxb29vV333nuvrrzyytAF5ysrKxUTE6OUlJQOsRkZGaqsrPTNU1RUpOTk5NCWk5PT1SEBAM5DFPbP6nJxX7hwobZv367nnnuuWwNYsmSJampqQltFRUW38gEAcL7r0lfhFi1apD//+c966623NHz48NDtmZmZamlpUXV1dYfZe1VV1Vm/exsbG6vY2NiuDAMAcJ7jM3d/ppm753latGiRXnrpJb3xxhvKzc3t8Pu8vDwNGDBAxcXFodtKS0tVXl6u/Pz88IwYAID/w2fu/kwz94ULF2r16tV6+eWXlZiYGPocPTk5WQMHDlRycrLuvPNOLV68WKmpqUpKStK3vvUt5efnc6Y8AACfE1NxX758uSRpxowZHW5fuXKlbr/9dknSk08+qX79+mnevHlqbm5WYWGhfv3rX4dlsAAAfBLL8v5Mxd3lSYiLi9OyZcu0bNmyLg9K+rjHeE886f37204zaG9vd449duyYKbelR7u1L7q1/7uF9Tk8ffq0c2x9fb0pt6Wn94kTJ0y5ExMTnWM/2cjJxQsvvGCKz8vLc461doRMTk52jo2Li+ux3NavwVqu5XDw4EFT7g8//NAUb+nn39bWZso9ZMgQ51jr4xw3bpxzrOU9pSfffz6N4u6P67kDABBhuHAMACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/ijuAIDAorj767PFfffu3Ro0aJBT7MmTJ53zDhs2zDSOuro651hLm0hJn7mq3rns37/flLulpcU5NjU11ZTb2gr3+PHjzrGTJ0825U5LS3OO/fvf/27KbWmFm5GRYcpt1dTU5BxreT1I0uDBg51jre1nq6urnWPT09NNuS2Xit60aZMp9/jx403xFpbXpmR7Di2vByvX92P0DX22uAMA0Blm7v4o7gCAwKK4++NseQAAIgwzdwBAYDFz90dxBwAEFsXdH8vyAABEGGbuAIDAYubuj+IOAAgsirs/ijsAILAo7v74zB0AgAjDzB0AEGiROvvujj5b3OPj4xUfH+8Um5yc7Jz36NGjpnFYep1HRUWZcu/Zs8c5trGx0ZTb0nf70KFDptzWXvTt7e3OsceOHTPlPnDggHPskSNHTLmHDh3qHNvW1mbKbXlOJGnDhg3OsbNmzTLl3rFjh3Osdd9b+txb9qVk6/2fkpJiyt3a2mqKj46Odo4tKSkx5b788sudY6196xMSEpxjLa9N6/tVd7As749leQAAIkyfnbkDANAZZu7+KO4AgMCiuPtjWR4AgAjDzB0AEFjM3P1R3AEAgUVx98eyPAAAEYaZOwAgsJi5+6O4AwACi+Luj+IOAAgsiru/Plvc4+PjNWjQIKfY9PR057zDhg0zjcPScrF/f9vT6dpeV5IyMzNNuQ8ePOgcO2TIEFPuiooKU/yFF17oHLtv3z5T7tOnTzvHjhs3zpTb0mq1qqrKlNvaqnj69OnOsdb2pgMHDnSOTUxMNOWuq6tzjk1KSjLlbmpqco5taGgw5ba2qba8B02ZMsWUu7a21hRv8d577znHWlpaf57tZ+GvzxZ3AAA6w8zdH8UdABBYFHd/fBUOAIAIw8wdABBYzNz9UdwBAIFFcffHsjwAABGGmTsAILCYufujuAMAAovi7o9leQAAesCJEyd06623KikpSSkpKbrzzjtVX19/zr+ZMWOGoqKiOmx33323+b6ZuQMAAqsvz9xvvfVWHT58WK+99ppaW1t1xx136K677tLq1avP+XcLFizQo48+GvrZ0s30DIo7ACCw+mpx37Vrl9auXat//vOfuuSSSyRJv/zlL3XttdfqP/7jP5SdnX3Wv42Pjze3HP+0Plvc+/fv79yrfcCAAc55d+zYYRpHdHS0c6y1R7uFpQe0JMXFxTnHWsdt7dNtGXtubq4pt6WH/u7du025R48ebYq3qKmpMcVb/uduOWYlW597a0/81tZW59gjR46Ycre0tDjHWl4Pkq1XvCSVl5c7x1p6+UvS4MGDnWN37txpyp2fn+8ca+m3b9nv3RWu4v7pHv6xsbGmfvqfVlJSopSUlFBhl6SCggL169dP7777rm644Yaz/u0f/vAH/f73v1dmZqauu+46Pfjgg+bZe58t7gAAfF5ycnI6/Pzwww/rRz/6UZfzVVZWaujQoR1u69+/v1JTU1VZWXnWv/v617+ukSNHKjs7W9u2bdMPfvADlZaW6k9/+pPp/k0n1BUVFenSSy9VYmKihg4dqrlz56q0tLRDTLhOBgAAwMWZ2XtXtjMqKipUU1MT2pYsWeJ7X/fff/9natyntw8//LDLj+Wuu+5SYWGhpkyZoltvvVXPPvusXnrpJe3du9eUxzRz37BhgxYuXKhLL71Up0+f1g9/+EPNmjVLO3fu7HB51nCcDAAAQGfCtSyflJTkdNnh73znO7r99tvPGTN69GhlZmZ+5qOm06dP68SJE6bP089c7nnPnj0aM2aM89+ZivvatWs7/Lxq1SoNHTpUW7Zs0VVXXRW6PRwnAwAA0Nekp6c7nZORn5+v6upqbdmyRXl5eZKkN954Q+3t7aGC7WLr1q2SpKysLNM4u/U99zMnBaWmpna4/Q9/+IPS0tI0efJkLVmyRKdOnTprjubmZtXW1nbYAABw0Z0l+e7O+s9l4sSJmj17thYsWKBNmzbpnXfe0aJFi3TLLbeEzpQ/ePCgJkyYoE2bNkmS9u7dq8cee0xbtmzRvn379D//8z+67bbbdNVVV2nq1Kmm++/yCXXt7e269957deWVV2ry5Mmh260nAxQVFemRRx7p6jAAAOexvvpVOOnjie6iRYs0c+ZM9evXT/PmzdPTTz8d+n1ra6tKS0tDE+CYmBi9/vrrWrp0qRoaGpSTk6N58+bpgQceMN93l4v7woULtX37dr399tsdbr/rrrtC/54yZYqysrI0c+ZM7d271/fzgiVLlmjx4sWhn2traz9z1iIAAEGTmpp6zoY1o0aN6vCfi5ycHG3YsCEs992l4r5o0SL9+c9/1ltvvaXhw4efM7azkwG6+11CAMD5qy/P3HuTqbh7nqdvfetbeumll7R+/XqnhiNdPRkAAIDOUNz9mYr7woULtXr1ar388stKTEwMfRE/OTlZAwcO1N69e7V69Wpde+21GjJkiLZt26b77ruvSycDAACArjEV9+XLl0v6uFHNJ61cuVK33357WE8GAACgM8zc/ZmX5c8lnCcDVFRUODe/sfRqvuaaa0zjeP/9951j6+rqTLktB1V7e7spt6Uv+smTJ025J0yYYIrfs2ePc6y1J/Unmyd1Ztq0aabc1dXVPRIrSU1NTab45uZm51hLD3Dp46/suDp06JApt6XnurX/+0cffeQca2n+IUnHjx83xSckJDjHVlVVmXJbji3ra9PCpcHLGa7XBQkHirs/essDAAKL4u6vW01sAABA38PMHQAQWMzc/VHcAQCBRXH3x7I8AAARhpk7ACCwmLn7o7gDAAKL4u6PZXkAACIMM3cAQGAxc/dHcQcABBbF3V+fLe7JycnOrUUbGhqc827fvt00jtOnTzvH9utn+5Rj7NixzrGWFqGStH//fufYSZMmmXJbWopKUnp6unOstYVvYmKic2xxcbEpd2ZmpnPsqVOnTLmtx+GoUaOcY62XUH7nnXecYy2vB0mKiopyjrW28LW0OLW8HiTpww8/NMVfeOGFzrFf+tKXTLnr6+udY2NiYky5LW12U1JSnGM/z/az8MceAAAEFjN3fxR3AEBgUdz9UdwBAIFFcffHV+EAAIgwzNwBAIEWqbPv7qC4AwACi2V5fyzLAwAQYZi5AwACi5m7P4o7ACCwKO7+WJYHACDCMHMHAAQWM3d/fba4t7e3q7293Sl26NChznmTk5NN4ygtLXWObWlpMeWuqqpyjt2xY4cpd1ZWlnPsunXrTLnz8vJM8U1NTc6xgwcPNuX+6KOPnGOPHj1qyv3GG284x+bk5JhyR0dHm+Itvegt1yyQpOzsbOfYX/3qV6bco0ePdo5ta2sz5W5tbe2x3PHx8ab4kydPOsc2NzebctfW1jrHWp4TSTpy5Ihz7MiRI51jP8/e8hR3fyzLAwAQYfrszB0AgM4wc/dHcQcABBbF3R/FHQAQWBR3f3zmDgBAhGHmDgAILGbu/ijuAIDAorj7Y1keAIAIw8wdABBYzNz9UdwBAIFFcffXZ4v7yJEjlZiY6BRraft64MAB0zgsbSutrW0t487MzDTlrq+vd44dNmyYKfeJEydM8YcOHXKOtbQSlqTKykrn2FdffdWU29La9pJLLjHltrTytLK0K5VsLZaHDBliyt3Q0OAce/jwYVNuS4vlgwcPmnJbWvJK0vDhw51jra2HLS2ZrW1fLfvTUgQjtWAGSZ8t7gAAdIaZuz+KOwAgsCju/jhbHgCACMPMHQAQWMzc/VHcAQCBRXH3R3EHAAQWxd0fn7kDABBhmLkDAAItUmff3UFxBwAEVncLe6T+x4BleQAAIgwzdwBAYDFz99dni3tMTIxiYmKcYi39xadOnWoax65du5xjLX20JVvPaGvfbdfnTpISEhJMuY8ePWqKHzNmjHNseXm5KXdVVZVzbGFhoSn37373O+fY6upqU+6BAwea4o8dO+Yca30OLf3Frdc4aGxsdI7Nyckx5bYct3v27DHlnjJliine8vo8fvy4Kfe0adOcY0+ePGnKPWjQIOdYy3Fl2e/dRXH3x7I8AAARxlTcly9frqlTpyopKUlJSUnKz8/X3/72t9Dvm5qatHDhQg0ZMkQJCQmaN2+eaWYFAIDFme+5d2eLRKbiPnz4cP30pz/Vli1btHnzZl1zzTW6/vrrtWPHDknSfffdp1deeUUvvviiNmzYoEOHDunGG2/skYEDAEBx92f6zP26667r8PNPfvITLV++XBs3btTw4cP129/+VqtXr9Y111wjSVq5cqUmTpyojRs36vLLLw/fqAEAwFl1+TP3trY2Pffcc2poaFB+fr62bNmi1tZWFRQUhGImTJigESNGqKSk5Kx5mpubVVtb22EDAMAFM3d/5uL+wQcfKCEhQbGxsbr77rv10ksvadKkSaqsrFRMTIxSUlI6xGdkZJzzbPaioiIlJyeHNusZswCA8xfF3Z+5uF9wwQXaunWr3n33Xd1zzz2aP3++du7c2eUBLFmyRDU1NaGtoqKiy7kAAOcXirs/8/fcY2JiNHbsWElSXl6e/vnPf+qpp57SzTffrJaWFlVXV3eYvVdVVZ3zu7GxsbGKjY21jxwAAPjq9vfc29vb1dzcrLy8PA0YMEDFxcWh35WWlqq8vFz5+fndvRsAAD6Dmbs/08x9yZIlmjNnjkaMGKG6ujqtXr1a69ev17p165ScnKw777xTixcvVmpqqpKSkvStb31L+fn5nCkPAOgRdKjzZyruR44c0W233abDhw8rOTlZU6dO1bp16/SVr3xFkvTkk0+qX79+mjdvnpqbm1VYWKhf//rXXRpYXV2dc6yl9ePevXtN45g1a5ZzrPXcA0t8VlaWKfeJEyecYy3tRyVpwIABpnhLS8y4uDhT7ry8POfYrVu3mnLfcMMNzrGWFsiSlJSUZIpPTU11jrW2E7a81nJzc025LW15refbWHLPnz/flLutrc0Uf9VVVznHWlvb7tu3zznW8nqQbO8T/fq5L/RaYtEzTMX9t7/97Tl/HxcXp2XLlmnZsmXdGhQAAC6YufvrsxeOAQCgMxR3f6ydAADQA37yk5/oiiuuUHx8/Gd6wJyN53l66KGHlJWVpYEDB6qgoEC7d+823zfFHQAQWH35bPmWlhbddNNNuueee5z/5vHHH9fTTz+tFStW6N1339WgQYNUWFiopqYm032zLA8ACKy+vCz/yCOPSJJWrVrlPJalS5fqgQce0PXXXy9JevbZZ5WRkaE1a9bolltucb5vZu4AgPPep69x0tzc/LmPoaysTJWVlR2u0ZKcnKzp06ef8xotfijuAIDACteyfE5OTofrnBQVFX3uj+XMV2ozMjI63N7ZNVr8sCwPAAiscC3LV1RUdOg/cba26Pfff79+9rOfnTPnrl27NGHChG6Nq7so7gCAwApXcU9KSnJqLvWd73xHt99++zljRo8e3aWxnLkOS1VVVYfGZVVVVbroootMuSjuAAA4Sk9PV3p6eo/kzs3NVWZmpoqLi0PFvLa2NnQVVos+V9zP/C+qvr7e+W8aGhqcYxsbG03jqa2tdY61jFmSTp061WO5Lc+Jpf1oT4/FehJLS0uLc6z1qySW3K2trabc1vj29nbnWMu4uzKWnsptbfnak8+JdSyW49b6HmQ5bi3vKdaxREVFmfN+Xg1i+mojmvLycp04cULl5eVqa2sLtcAeO3ZsqE30hAkTVFRUpBtuuEFRUVG699579eMf/1jjxo1Tbm6uHnzwQWVnZ2vu3Lm2O/f6mIqKCk8SGxsbG1vAt4qKih6rFY2NjV5mZmZYxpmZmek1NjaGfYzz58/3vb8333wzFCPJW7lyZejn9vZ278EHH/QyMjK82NhYb+bMmV5paan5vqP+L3mf0d7erkOHDikxMbHD/xRra2uVk5PzmZMeIg2PM3KcD49R4nFGmnA8Ts/zVFdXp+zs7B69iExTU5N5VcZPTEyM+aJVfV2fW5bv16+fhg8fftbfu570EHQ8zshxPjxGiccZabr7OJOTk8M4Gn9xcXERV5TDhe+5AwAQYSjuAABEmMAU99jYWD388MNnbSwQKXickeN8eIwSjzPSnC+PM9L1uRPqAABA9wRm5g4AANxQ3AEAiDAUdwAAIgzFHQCACBOY4r5s2TKNGjVKcXFxmj59ujZt2tTbQwqrH/3oR4qKiuqw9fYlA7vrrbfe0nXXXafs7GxFRUVpzZo1HX7veZ4eeughZWVlaeDAgSooKNDu3bt7Z7Dd0NnjvP322z+zb2fPnt07g+2ioqIiXXrppUpMTNTQoUM1d+5clZaWdohpamrSwoULNWTIECUkJGjevHmqqqrqpRF3jcvjnDFjxmf25913391LI+6a5cuXa+rUqaFGNfn5+frb3/4W+n0k7MvzXSCK+/PPP6/Fixfr4Ycf1nvvvadp06apsLBQR44c6e2hhdWFF16ow4cPh7a33367t4fULQ0NDZo2bZqWLVvm+/vHH39cTz/9tFasWKF3331XgwYNUmFhofkCL72ts8cpSbNnz+6wb//4xz9+jiPsvg0bNmjhwoXauHGjXnvtNbW2tmrWrFkdLgp033336ZVXXtGLL76oDRs26NChQ7rxxht7cdR2Lo9TkhYsWNBhfz7++OO9NOKuGT58uH76059qy5Yt2rx5s6655hpdf/312rFjh6TI2JfnvW51xf+cXHbZZd7ChQtDP7e1tXnZ2dleUVFRL44qvB5++GFv2rRpvT2MHiPJe+mll0I/t7e3e5mZmd7Pf/7z0G3V1dVebGys98c//rEXRhgen36cnvfxxSOuv/76XhlPTzly5IgnyduwYYPneR/vuwEDBngvvvhiKGbXrl2eJK+kpKS3htltn36cnud5X/7yl71vf/vbvTeoHjJ48GDvN7/5TcTuy/NNn5+5t7S0aMuWLSooKAjd1q9fPxUUFKikpKQXRxZ+u3fvVnZ2tkaPHq1bb71V5eXlvT2kHlNWVqbKysoO+zU5OVnTp0+PuP0qSevXr9fQoUN1wQUX6J577tHx48d7e0jdUlNTI0lKTU2VJG3ZskWtra0d9ueECRM0YsSIQO/PTz/OM/7whz8oLS1NkydP1pIlS8yXWu1L2tra9Nxzz6mhoUH5+fkRuy/PN33uwjGfduzYMbW1tSkjI6PD7RkZGfrwww97aVThN336dK1atUoXXHCBDh8+rEceeURf+tKXtH37diUmJvb28MKusrJSknz365nfRYrZs2frxhtvVG5urvbu3asf/vCHmjNnjkpKShQdHd3bwzNrb2/XvffeqyuvvFKTJ0+W9PH+jImJUUpKSofYIO9Pv8cpSV//+tc1cuRIZWdna9u2bfrBD36g0tJS/elPf+rF0dp98MEHys/PV1NTkxISEvTSSy9p0qRJ2rp1a8Tty/NRny/u54s5c+aE/j116lRNnz5dI0eO1AsvvKA777yzF0eG7rrllltC/54yZYqmTp2qMWPGaP369Zo5c2YvjqxrFi5cqO3btwf+nJDOnO1x3nXXXaF/T5kyRVlZWZo5c6b27t2rMWPGfN7D7LILLrhAW7duVU1Njf77v/9b8+fP14YNG3p7WAiTPr8sn5aWpujo6M+cqVlVVaXMzMxeGlXPS0lJ0fjx47Vnz57eHkqPOLPvzrf9KkmjR49WWlpaIPftokWL9Oc//1lvvvlmh0szZ2ZmqqWlRdXV1R3ig7o/z/Y4/UyfPl2SArc/Y2JiNHbsWOXl5amoqEjTpk3TU089FXH78nzV54t7TEyM8vLyVFxcHLqtvb1dxcXFys/P78WR9az6+nrt3btXWVlZvT2UHpGbm6vMzMwO+7W2tlbvvvtuRO9XSTpw4ICOHz8eqH3reZ4WLVqkl156SW+88YZyc3M7/D4vL08DBgzosD9LS0tVXl4eqP3Z2eP0s3XrVkkK1P70097erubm5ojZl+e93j6jz8Vzzz3nxcbGeqtWrfJ27tzp3XXXXV5KSopXWVnZ20MLm+985zve+vXrvbKyMu+dd97xCgoKvLS0NO/IkSO9PbQuq6ur895//33v/fff9yR5TzzxhPf+++97+/fv9zzP83760596KSkp3ssvv+xt27bNu/76673c3FyvsbGxl0duc67HWVdX5333u9/1SkpKvLKyMu/111/3Lr74Ym/cuHFeU1NTbw/d2T333OMlJyd769ev9w4fPhzaTp06FYq5++67vREjRnhvvPGGt3nzZi8/P9/Lz8/vxVHbdfY49+zZ4z366KPe5s2bvbKyMu/ll1/2Ro8e7V111VW9PHKb+++/39uwYYNXVlbmbdu2zbv//vu9qKgo79VXX/U8LzL25fkuEMXd8zzvl7/8pTdixAgvJibGu+yyy7yNGzf29pDC6uabb/aysrK8mJgYb9iwYd7NN9/s7dmzp7eH1S1vvvmmJ+kz2/z58z3P+/jrcA8++KCXkZHhxcbGejNnzvRKS0t7d9BdcK7HeerUKW/WrFleenq6N2DAAG/kyJHeggULAvcfU7/HJ8lbuXJlKKaxsdH7t3/7N2/w4MFefHy8d8MNN3iHDx/uvUF3QWePs7y83Lvqqqu81NRULzY21hs7dqz3ve99z6upqendgRt94xvf8EaOHOnFxMR46enp3syZM0OF3fMiY1+e77jkKwAAEabPf+YOAABsKO4AAEQYijsAABGG4g4AQIShuAMAEGEo7gAARBiKOwAAEYbiDgBAhKG4AwAQYSjuAABEGIo7AAARhuIOAECE+X9eu/BjisWv5wAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import torch\n", "from PIL import Image\n", "from diffusers import DDPMScheduler\n", "\n", "noise_scheduler = DDPMScheduler(num_train_timesteps = 1000)\n", "noise = torch.randn(sample_image.shape)\n", "timesteps = torch.LongTensor([50]) # example timestep\n", "noisy_image = noise_scheduler.add_noise(sample_image, noise, timesteps)\n", "\n", "# print(\"Noisy image shape:\", noisy_image.shape)\n", "# Image.fromarray(noisy_image.numpy().squeeze())\n", "plt.imshow(noisy_image.detach().cpu().numpy().squeeze(), cmap='gray')\n", "plt.colorbar()\n", "# plt.axis('off')\n", "plt.show()\n" ] }, { "cell_type": "code", "execution_count": 79, "id": "4e878652", "metadata": {}, "outputs": [], "source": [ "import torch.nn.functional as F\n", "\n", "noise_pred = model(noisy_image, timesteps).sample\n", "loss = F.mse_loss(noise_pred, noise)" ] }, { "cell_type": "markdown", "id": "7dffb56e", "metadata": {}, "source": [ "**Train the model**" ] }, { "cell_type": "code", "execution_count": 80, "id": "62915913", "metadata": {}, "outputs": [], "source": [ "from diffusers.optimization import get_cosine_schedule_with_warmup\n", "\n", "optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate)\n", "lr_scheduler = get_cosine_schedule_with_warmup(\n", " optimizer=optimizer,\n", " num_warmup_steps=config.lr_warmup_steps,\n", " num_training_steps=(len(train_dataloader) * config.num_epochs),\n", ")" ] }, { "cell_type": "code", "execution_count": 81, "id": "f2f3427a", "metadata": {}, "outputs": [], "source": [ "from diffusers import DDPMPipeline\n", "from diffusers.utils import make_image_grid\n", "import os\n", "\n", "def evaluate(config, epoch, pipeline):\n", " # Sample some images from random noise (this is the backward diffusion process).\n", " # The default pipeline output type is `List[PIL.Image]`\n", " images = pipeline(\n", " batch_size=config.eval_batch_size,\n", " generator=torch.Generator(device='cpu').manual_seed(config.seed), # Use a separate torch generator to avoid rewinding the random state of the main training loop\n", " ).images\n", "\n", " # Make a grid out of the images\n", " image_grid = make_image_grid(images, rows=4, cols=4)\n", "\n", " # Save the images\n", " test_dir = os.path.join(config.output_dir, \"samples\")\n", " os.makedirs(test_dir, exist_ok=True)\n", " image_grid.save(f\"{test_dir}/{epoch:04d}.png\")" ] }, { "cell_type": "code", "execution_count": 85, "id": "4fe935dd", "metadata": {}, "outputs": [], "source": [ "from accelerate import Accelerator\n", "from huggingface_hub import create_repo, upload_folder\n", "from tqdm.auto import tqdm\n", "from pathlib import Path\n", "import os" ] }, { "cell_type": "code", "execution_count": 86, "id": "07df8a4c", "metadata": {}, "outputs": [], "source": [ "def train_loop(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):\n", " # Initialize accelerator and tensorboard logging\n", " accelerator = Accelerator(\n", " mixed_precision=config.mixed_precision,\n", " gradient_accumulation_steps=config.gradient_accumulation_steps,\n", " log_with=\"tensorboard\",\n", " project_dir=os.path.join(config.output_dir, \"logs\"),\n", " )\n", " if accelerator.is_main_process:\n", " if config.output_dir is not None:\n", " os.makedirs(config.output_dir, exist_ok=True)\n", " if config.push_to_hub:\n", " repo_id = create_repo(\n", " repo_id=config.hub_model_id or Path(config.output_dir).name, exist_ok=True\n", " ).repo_id\n", " accelerator.init_trackers(\"train_example\")\n", "\n", " # Prepare everything\n", " # There is no specific order to remember, you just need to unpack the\n", " # objects in the same order you gave them to the prepare method.\n", " model, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(\n", " model, optimizer, train_dataloader, lr_scheduler\n", " )\n", "\n", " global_step = 0\n", "\n", " # Now you train the model\n", " for epoch in range(config.num_epochs):\n", " progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process)\n", " progress_bar.set_description(f\"Epoch {epoch}\")\n", "\n", " for step, batch in enumerate(train_dataloader):\n", " clean_images = batch[\"images\"]\n", " # Sample noise to add to the images\n", " noise = torch.randn(clean_images.shape, device=clean_images.device)\n", " bs = clean_images.shape[0]\n", "\n", " # Sample a random timestep for each image\n", " timesteps = torch.randint(\n", " 0, noise_scheduler.config.num_train_timesteps, (bs,), device=clean_images.device,\n", " dtype=torch.int64\n", " )\n", "\n", " # Add noise to the clean images according to the noise magnitude at each timestep\n", " # (this is the forward diffusion process)\n", " noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)\n", "\n", " with accelerator.accumulate(model):\n", " # Predict the noise residual\n", " noise_pred = model(noisy_images, timesteps, return_dict=False)[0]\n", " loss = F.mse_loss(noise_pred, noise)\n", " accelerator.backward(loss)\n", "\n", " if accelerator.sync_gradients:\n", " accelerator.clip_grad_norm_(model.parameters(), 1.0)\n", " optimizer.step()\n", " lr_scheduler.step()\n", " optimizer.zero_grad()\n", "\n", " progress_bar.update(1)\n", " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0], \"step\": global_step}\n", " progress_bar.set_postfix(**logs)\n", " accelerator.log(logs, step=global_step)\n", " global_step += 1\n", "\n", " # After each epoch you optionally sample some demo images with evaluate() and save the model\n", " if accelerator.is_main_process:\n", " pipeline = DDPMPipeline(unet=accelerator.unwrap_model(model), scheduler=noise_scheduler)\n", "\n", " if (epoch + 1) % config.save_image_epochs == 0 or epoch == config.num_epochs - 1:\n", " evaluate(config, epoch, pipeline)\n", "\n", " if (epoch + 1) % config.save_model_epochs == 0 or epoch == config.num_epochs - 1:\n", " if config.push_to_hub:\n", " upload_folder(\n", " repo_id=repo_id,\n", " folder_path=config.output_dir,\n", " commit_message=f\"Epoch {epoch}\",\n", " ignore_patterns=[\"step_*\", \"epoch_*\"],\n", " )\n", " else:\n", " pipeline.save_pretrained(config.output_dir)" ] }, { "cell_type": "code", "execution_count": 88, "id": "5d820f05", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Launching training on one GPU.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/accelerator.py:529: UserWarning: `log_with=tensorboard` was passed but no supported trackers are currently installed.\n", " warnings.warn(f\"`log_with={log_with}` was passed but no supported trackers are currently installed.\")\n", "Epoch 0: 0%| | 0/344 [00:00 \u001b[39m\u001b[32m3\u001b[39m \u001b[43mnotebook_launcher\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtrain_loop\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mnum_processes\u001b[49m\u001b[43m=\u001b[49m\u001b[32;43m1\u001b[39;49m\u001b[43m)\u001b[49m\n", "\u001b[36mFile \u001b[39m\u001b[32m/mnt/drive/adarsh/DC_cold3/my-env/lib/python3.12/site-packages/accelerate/launchers.py:270\u001b[39m, in \u001b[36mnotebook_launcher\u001b[39m\u001b[34m(function, args, num_processes, mixed_precision, use_port, master_addr, node_rank, num_nodes, rdzv_backend, rdzv_endpoint, rdzv_conf, rdzv_id, max_restarts, monitor_interval, log_line_prefix_template)\u001b[39m\n\u001b[32m 268\u001b[39m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[32m 269\u001b[39m \u001b[38;5;28mprint\u001b[39m(\u001b[33m\"\u001b[39m\u001b[33mLaunching training on CPU.\u001b[39m\u001b[33m\"\u001b[39m)\n\u001b[32m--> \u001b[39m\u001b[32m270\u001b[39m \u001b[43mfunction\u001b[49m\u001b[43m(\u001b[49m\u001b[43m*\u001b[49m\u001b[43margs\u001b[49m\u001b[43m)\u001b[49m\n", "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[86]\u001b[39m\u001b[32m, line 33\u001b[39m, in \u001b[36mtrain_loop\u001b[39m\u001b[34m(config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\u001b[39m\n\u001b[32m 30\u001b[39m progress_bar.set_description(\u001b[33mf\u001b[39m\u001b[33m\"\u001b[39m\u001b[33mEpoch \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mepoch\u001b[38;5;132;01m}\u001b[39;00m\u001b[33m\"\u001b[39m)\n\u001b[32m 32\u001b[39m \u001b[38;5;28;01mfor\u001b[39;00m step, batch \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28menumerate\u001b[39m(train_dataloader):\n\u001b[32m---> \u001b[39m\u001b[32m33\u001b[39m clean_images = \u001b[43mbatch\u001b[49m\u001b[43m[\u001b[49m\u001b[33;43m\"\u001b[39;49m\u001b[33;43mimages\u001b[39;49m\u001b[33;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\n\u001b[32m 34\u001b[39m \u001b[38;5;66;03m# Sample noise to add to the images\u001b[39;00m\n\u001b[32m 35\u001b[39m noise = torch.randn(clean_images.shape, device=clean_images.device)\n", "\u001b[31mIndexError\u001b[39m: too many indices for tensor of dimension 4" ] } ], "source": [ "from accelerate import notebook_launcher\n", "args = (config, model, noise_scheduler, optimizer, train_dataloader, lr_scheduler)\n", "notebook_launcher(train_loop, args, num_processes=1)" ] } ], "metadata": { "kernelspec": { "display_name": "my-env", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.3" } }, "nbformat": 4, "nbformat_minor": 5 }