{ "nbformat": 4, "nbformat_minor": 0, "metadata": { "colab": { "private_outputs": true, "provenance": [], "gpuType": "T4" }, "kernelspec": { "name": "python3", "display_name": "Python 3" }, "language_info": { "name": "python" }, "accelerator": "GPU" }, "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "eUREgErj2o-Z" }, "outputs": [], "source": [ "!git clone https://github.com/Won-Seong/simple-latent-diffusion-model.git\n", "\n", "import os\n", "os.chdir('simple-latent-diffusion-model') # Replace with your repository name\n", "os.chdir('simple-latent-diffusion-model') # Replace with your repository name\n", "\n", "from google.colab import drive\n", "drive.mount('/content/drive')\n", "\n", "from diffusion_model.models.diffusion_model import DiffusionModel\n", "from diffusion_model.sampler.ddim import DDIM\n", "from diffusion_model.sampler.ddpm import DDPM\n", "from diffusion_model.network.unet import Unet\n", "from diffusion_model.network.unet_wrapper import UnetWrapper\n", "from helper.painter import Painter\n", "from helper.trainer import Trainer\n", "from helper.data_generator import DataGenerator\n", "from helper.loader import Loader\n", "from helper.cond_encoder import ConditionEncoder\n", "import torch" ] }, { "cell_type": "code", "source": [ "IMAGE_SHAPE = (3, 32, 32)\n", "#CONFIG_PATH = 'Config Path'\n", "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'); device" ], "metadata": { "id": "iN360_ddTmUr" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "sampler = DDIM(CONFIG_PATH)\n", "cond_encoder = ConditionEncoder(CONFIG_PATH)\n", "network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)\n", "dm = DiffusionModel(network, sampler, IMAGE_SHAPE)\n", "painter = Painter()\n", "data_generator = DataGenerator()\n", "loader = Loader()" ], "metadata": { "id": "RK7nLUEDTzit" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "dm = loader.model_load('cifar10_diffusion', dm, is_ema=True) # Modify the model path" ], "metadata": { "id": "lNCz3WUOWezr" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "# Inference\n", "dm.eval()\n", "sample = dm(9, y = 0)" ], "metadata": { "id": "3bAjCEDbWn3p" }, "execution_count": null, "outputs": [] }, { "cell_type": "code", "source": [ "painter.show_images(sample)" ], "metadata": { "id": "ixVSSEhwikrv" }, "execution_count": null, "outputs": [] } ] }