JuyeopDang commited on
Commit
531f7bb
·
verified ·
1 Parent(s): 83061de

Upload cifar10_example.ipynb

Browse files
Files changed (1) hide show
  1. cifar10_example.ipynb +116 -0
cifar10_example.ipynb ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "private_outputs": true,
7
+ "provenance": [],
8
+ "gpuType": "T4"
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "accelerator": "GPU"
18
+ },
19
+ "cells": [
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "metadata": {
24
+ "id": "eUREgErj2o-Z"
25
+ },
26
+ "outputs": [],
27
+ "source": [
28
+ "!git clone https://github.com/Won-Seong/simple-latent-diffusion-model.git\n",
29
+ "\n",
30
+ "import os\n",
31
+ "os.chdir('simple-latent-diffusion-model') # Replace with your repository name\n",
32
+ "os.chdir('simple-latent-diffusion-model') # Replace with your repository name\n",
33
+ "\n",
34
+ "from google.colab import drive\n",
35
+ "drive.mount('/content/drive')\n",
36
+ "\n",
37
+ "from diffusion_model.models.diffusion_model import DiffusionModel\n",
38
+ "from diffusion_model.sampler.ddim import DDIM\n",
39
+ "from diffusion_model.sampler.ddpm import DDPM\n",
40
+ "from diffusion_model.network.unet import Unet\n",
41
+ "from diffusion_model.network.unet_wrapper import UnetWrapper\n",
42
+ "from helper.painter import Painter\n",
43
+ "from helper.trainer import Trainer\n",
44
+ "from helper.data_generator import DataGenerator\n",
45
+ "from helper.loader import Loader\n",
46
+ "from helper.cond_encoder import ConditionEncoder\n",
47
+ "import torch"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "source": [
53
+ "IMAGE_SHAPE = (3, 32, 32)\n",
54
+ "#CONFIG_PATH = 'Config Path'\n",
55
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu'); device"
56
+ ],
57
+ "metadata": {
58
+ "id": "iN360_ddTmUr"
59
+ },
60
+ "execution_count": null,
61
+ "outputs": []
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "source": [
66
+ "sampler = DDIM(CONFIG_PATH)\n",
67
+ "cond_encoder = ConditionEncoder(CONFIG_PATH)\n",
68
+ "network = UnetWrapper(Unet, CONFIG_PATH, cond_encoder)\n",
69
+ "dm = DiffusionModel(network, sampler, IMAGE_SHAPE)\n",
70
+ "painter = Painter()\n",
71
+ "data_generator = DataGenerator()\n",
72
+ "loader = Loader()"
73
+ ],
74
+ "metadata": {
75
+ "id": "RK7nLUEDTzit"
76
+ },
77
+ "execution_count": null,
78
+ "outputs": []
79
+ },
80
+ {
81
+ "cell_type": "code",
82
+ "source": [
83
+ "dm = loader.model_load('cifar10_diffusion', dm, is_ema=True) # Modify the model path"
84
+ ],
85
+ "metadata": {
86
+ "id": "lNCz3WUOWezr"
87
+ },
88
+ "execution_count": null,
89
+ "outputs": []
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "source": [
94
+ "# Inference\n",
95
+ "dm.eval()\n",
96
+ "sample = dm(9, y = 0)"
97
+ ],
98
+ "metadata": {
99
+ "id": "3bAjCEDbWn3p"
100
+ },
101
+ "execution_count": null,
102
+ "outputs": []
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "source": [
107
+ "painter.show_images(sample)"
108
+ ],
109
+ "metadata": {
110
+ "id": "ixVSSEhwikrv"
111
+ },
112
+ "execution_count": null,
113
+ "outputs": []
114
+ }
115
+ ]
116
+ }