leowajda commited on
Commit
7578496
·
1 Parent(s): 7014ab1

initial commit

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. README.md +4 -4
  3. app.py +130 -0
  4. diffusion_sampler.py +152 -0
  5. requirements.txt +180 -0
.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ .idea
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Diffusion Model
3
  emoji: 📈
4
- colorFrom: pink
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.12.0
8
  app_file: app.py
9
  pinned: false
10
  license: agpl-3.0
 
1
  ---
2
+ title: Temp Diffusion
3
  emoji: 📈
4
+ colorFrom: indigo
5
+ colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 4.10.0
8
  app_file: app.py
9
  pinned: false
10
  license: agpl-3.0
app.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from huggingface_hub import from_pretrained_keras
3
+ from diffusion_sampler import DiffusionSampler
4
+
5
+ scheduler_button = gr.Radio(
6
+ choices=["Linear", "Cosine"],
7
+ label="Noise Scheduler",
8
+ value="Linear",
9
+ info="""
10
+ Decides whether to employ a model trained with a linear scheduler,
11
+ as proposed by Jonathan Ho et al. in 'Denoising Diffusion Probabilistic Models',
12
+ or the cosine variant introduced by Alex Nichol et al. in 'Improved Denoising Diffusion Probabilistic Models'.
13
+ """,
14
+ )
15
+
16
+ sampling_button = gr.Radio(
17
+ choices=["DDPM", "DDIM"],
18
+ label="Sampling Procedure",
19
+ value="DDPM",
20
+ info="""
21
+ Selects either the stocasthic sampling procedure described by Jonathan Ho et al. in 'Denoising Diffusion Probabilistic Models',
22
+ or the implicit variant proposed by Jiaming Song et al. in 'Denoising Diffusion Implicit Models'.
23
+ For the latter, it is also necessary to specify the sub-sequence strategy and the number of sampling steps.
24
+ """,
25
+ )
26
+
27
+ subsequence_button = gr.Radio(
28
+ choices=["Linear", "Quadratic"],
29
+ label="Sub-Sequence",
30
+ value="Linear",
31
+ info="""
32
+ Specific to DDIM sampling, this parameter chooses the procedure
33
+ for forming the sub-sequence employed during the sampling process.
34
+ """,
35
+ )
36
+
37
+ ema_button = gr.Checkbox(
38
+ value=True,
39
+ label="Exponential Moving Average",
40
+ info="""
41
+ Whether to invoke the network with the applied exponential moving average on the model parameters.
42
+ Recommended for better results.
43
+ """
44
+ )
45
+
46
+ images_button = gr.Number(
47
+ label="Number of images to generate",
48
+ value=5,
49
+ precision=0,
50
+ minimum=1,
51
+ maximum=64,
52
+ info="""
53
+ The number of images to be generated.
54
+ Larger batch sizes result in longer inference times.
55
+ """
56
+ )
57
+
58
+ step_button = gr.Slider(
59
+ minimum=500,
60
+ value=1_000,
61
+ maximum=1_000,
62
+ randomize=True,
63
+ label="Number of sampling steps",
64
+ info="""
65
+ Relevant exclusively to DDIM sampling, this parameter determines the number of steps to be utilized during sampling.
66
+ The default value is set to 1000 in the case of DDPM sampling.
67
+ """
68
+ )
69
+
70
+ gallery = gr.Gallery(
71
+ label="""
72
+ Generated Flowers
73
+ """
74
+ )
75
+
76
+
77
+ linear_diffusion_model = DiffusionSampler(
78
+ model=from_pretrained_keras("leowajda/linear_diffusion"),
79
+ ema_model=from_pretrained_keras("leowajda/linear_diffusion_ema"),
80
+ noise_scheduler="cosine",
81
+ )
82
+
83
+ cosine_diffusion_model = DiffusionSampler(
84
+ model=from_pretrained_keras("leowajda/cosine_diffusion"),
85
+ ema_model=from_pretrained_keras("leowajda/cosine_diffusion_ema"),
86
+ noise_scheduler="cosine",
87
+ )
88
+
89
+
90
+ def call_model(
91
+ model_to_call: str,
92
+ sample_strategy: str = "ddim",
93
+ step_strategy: str = "uniform",
94
+ ema: bool = True,
95
+ steps: int = 1_000,
96
+ num_images: int = 0,
97
+ ):
98
+ diffusion_model = linear_diffusion_model if model_to_call.lower() == "linear" else cosine_diffusion_model
99
+ return diffusion_model.generate_images(
100
+ num_images=int(num_images),
101
+ steps=int(steps),
102
+ sample_strategy=sample_strategy.lower(),
103
+ step_strategy=step_strategy.lower(),
104
+ ema=ema,
105
+ )
106
+
107
+
108
+ demo = gr.Interface(
109
+ fn=call_model,
110
+ inputs=[scheduler_button, sampling_button, subsequence_button, ema_button, step_button, images_button],
111
+ outputs=gallery,
112
+ cache_examples=False,
113
+ title="""Unconditional Image Generation Through Denoising Diffusion Implicit Models""",
114
+ examples=[
115
+ ["Linear", "DDPM", "Linear", True, 1_000, 10],
116
+ ["Cosine", "DDIM", "Linear", True, 750, 20],
117
+ ["Linear", "DDIM", "Quadratic", True, 750, 20]
118
+ ],
119
+ description="""
120
+ <p align="center">
121
+ Supervisor: <strong>Wojciech Oronowicz – Jaśkowiak, PhD</strong>
122
+ &emsp;
123
+ Author: <strong>Leonardo Wajda</strong>
124
+ &emsp;
125
+ Specialization: <strong>Intelligent Data Processing Systems</strong>
126
+ </p>
127
+ """,
128
+ )
129
+
130
+ demo.queue().launch()
diffusion_sampler.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tqdm as tqdm
3
+ import tensorflow as tf
4
+ import math
5
+ from tensorflow import keras
6
+ from keras.models import load_model
7
+
8
+
9
+ def as_float32(t: tf.Tensor) -> tf.Tensor:
10
+ return tf.cast(t, dtype=tf.float32)
11
+
12
+
13
+ def batch_reshape(t: tf.Tensor, x: tf.Tensor) -> tf.Tensor:
14
+ def inner_function(coeff: tf.Tensor) -> tf.Tensor:
15
+ batch_dim = tf.shape(x)[0]
16
+ return tf.reshape(tf.gather(coeff, t), [batch_dim, 1, 1, 1])
17
+
18
+ return inner_function
19
+
20
+
21
+ class DiffusionSampler(keras.Model):
22
+ def __init__(
23
+ self,
24
+ model: keras.Model | str,
25
+ ema_model: keras.Model | str,
26
+ timesteps: int | None = 1_000,
27
+ beta_start: float | None = 1e-4,
28
+ beta_end: float | None = 0.02,
29
+ noise_scheduler: str = "linear",
30
+ ema: float = 0.999,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(**kwargs)
34
+ self.noise_predictor = load_model(filepath=model, safe_mode=False) if isinstance(model, str) else model
35
+ self.ema_noise_predictor = load_model(filepath=ema_model, safe_mode=False) if isinstance(model,
36
+ str) else ema_model
37
+ self.ema = ema
38
+ self.beta_start = beta_start
39
+ self.beta_end = beta_end
40
+ self.timesteps = timesteps
41
+
42
+ betas = self.noise_scheduler(noise_scheduler)
43
+ alphas = 1.0 - betas
44
+ alphas_cum_prod = tf.math.cumprod(alphas, axis=0)
45
+ alphas_cum_prod_prev = tf.concat([tf.constant([1.0], dtype=tf.float64), alphas_cum_prod[:-1]], axis=0)
46
+ posterior_variances = betas * (1.0 - alphas_cum_prod_prev) / (1.0 - alphas_cum_prod)
47
+
48
+ self.betas = as_float32(betas)
49
+ self.posterior_variances = as_float32(posterior_variances)
50
+ self.alphas_cum_prod_prev = as_float32(alphas_cum_prod_prev)
51
+ self.one_minus_alphas_cum_prod = as_float32(1.0 - alphas_cum_prod)
52
+ self.one_minus_alphas_cum_prod_prev = as_float32(1.0 - alphas_cum_prod_prev)
53
+
54
+ self.sqrt_one_minus_alphas_cum_prod = as_float32(tf.sqrt(1.0 - alphas_cum_prod))
55
+ self.sqrt_alphas_cum_prod_prev = as_float32(tf.sqrt(alphas_cum_prod_prev))
56
+ self.sqrt_alphas_cum_prod = as_float32(tf.sqrt(alphas_cum_prod))
57
+
58
+ self.rev_sqrt_alphas_cum_prod = as_float32(1.0 / tf.sqrt(alphas_cum_prod))
59
+ self.rev_sqrt_alphas = as_float32(tf.sqrt(1.0 / alphas))
60
+
61
+ def ddpm_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor) -> tf.Tensor:
62
+ batch_dim = tf.shape(x_t)[0]
63
+ at_timestep = batch_reshape(t, x_t)
64
+
65
+ beta = at_timestep(self.betas)
66
+ rev_sqrt_alpha = at_timestep(self.rev_sqrt_alphas)
67
+ sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
68
+ posterior_variance = at_timestep(self.posterior_variances)
69
+
70
+ mean = rev_sqrt_alpha * (
71
+ x_t - (beta / sqrt_one_minus_alpha_cum_prod) * pred_noise
72
+ )
73
+
74
+ nonzero_mask = tf.reshape(
75
+ 1 - tf.cast(tf.equal(t, 0), dtype=tf.float32), [batch_dim, 1, 1, 1]
76
+ )
77
+
78
+ random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
79
+ return mean + nonzero_mask * tf.sqrt(posterior_variance) * random_noise
80
+
81
+ def ddim_sample(self, pred_noise: tf.Tensor, x_t: tf.Tensor, t: tf.Tensor, eta: float = 0.0) -> tf.Tensor:
82
+ at_timestep = batch_reshape(t, x_t)
83
+
84
+ sqrt_alpha_cum_prod_prev = at_timestep(self.sqrt_alphas_cum_prod_prev)
85
+ rev_sqrt_alpha_cum_prod = at_timestep(self.rev_sqrt_alphas_cum_prod)
86
+ sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
87
+ alpha_cum_prod_prev = at_timestep(self.alphas_cum_prod_prev)
88
+ one_minus_alpha_cum_prod = at_timestep(self.one_minus_alphas_cum_prod)
89
+ one_minus_alpha_cum_prod_prev = at_timestep(self.one_minus_alphas_cum_prod_prev)
90
+
91
+ x0_t = (
92
+ (x_t - (sqrt_one_minus_alpha_cum_prod * pred_noise)) * rev_sqrt_alpha_cum_prod
93
+ )
94
+ c1 = eta * tf.sqrt(
95
+ (one_minus_alpha_cum_prod_prev / one_minus_alpha_cum_prod) * (
96
+ one_minus_alpha_cum_prod / alpha_cum_prod_prev)
97
+ )
98
+
99
+ x_t_dir = tf.sqrt(one_minus_alpha_cum_prod_prev - tf.square(c1))
100
+ random_noise = tf.random.normal(shape=x_t.shape, dtype=x_t.dtype)
101
+ return sqrt_alpha_cum_prod_prev * x0_t + x_t_dir * pred_noise + c1 * random_noise
102
+
103
+ def noise_scheduler(self, scheduler: str, max_beta: int = 0.02) -> tf.Tensor:
104
+ alpha_bar = lambda t: tf.math.cos((t + 0.008) / 1.008 * tf.constant(math.pi, dtype=tf.float64) / 2) ** 2
105
+ cosine_scheduler = lambda i: tf.minimum(
106
+ 1 - alpha_bar((i + 1) / tf.cast(self.timesteps, dtype=tf.float64)) / alpha_bar(
107
+ i / tf.cast(self.timesteps, dtype=tf.float64)), max_beta)
108
+
109
+ if scheduler == "linear":
110
+ x = tf.linspace(start=self.beta_start, stop=self.beta_end, num=self.timesteps)
111
+ return tf.cast(x, dtype=tf.float64)
112
+
113
+ elif scheduler == "cosine":
114
+ x = tf.vectorized_map(fn=cosine_scheduler, elems=tf.range(self.timesteps, dtype=tf.float64))
115
+ return tf.cast(x, dtype=tf.float64)
116
+
117
+ def x_t(self, x_start: tf.Tensor, t: tf.Tensor, noise: tf.Tensor) -> tf.Tensor:
118
+ at_timestep = batch_reshape(t, x_start)
119
+
120
+ sqrt_alpha_cum_prod = at_timestep(self.sqrt_alphas_cum_prod)
121
+ sqrt_one_minus_alpha_cum_prod = at_timestep(self.sqrt_one_minus_alphas_cum_prod)
122
+ return sqrt_alpha_cum_prod * x_start + sqrt_one_minus_alpha_cum_prod * noise
123
+
124
+ def generate_images(
125
+ self,
126
+ num_images: int,
127
+ steps: int,
128
+ sample_strategy: str = "ddim",
129
+ step_strategy: str = "uniform",
130
+ ema: bool = True,
131
+ ):
132
+ sampling_stategies = {
133
+ ("ddpm", "linear"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
134
+ ("ddpm", "quadratic"): (self.ddpm_sample, tf.range(self.timesteps, dtype=tf.float64)),
135
+ ("ddim", "linear"): (self.ddim_sample, tf.range(steps, dtype=tf.float64)),
136
+ ("ddim", "quadratic"): (self.ddim_sample, tf.cast(tf.linspace(start=0.0, stop=tf.sqrt(self.timesteps * 0.8), num=steps) ** 2, dtype=tf.float64))
137
+ }
138
+
139
+ noise_predictor = self.ema_noise_predictor if ema else self.noise_predictor
140
+ sampler, seq = sampling_stategies[(sample_strategy, step_strategy)]
141
+ samples = tf.random.normal(shape=(num_images, 64, 64, 3), dtype=tf.float32)
142
+
143
+ for t in tqdm.tqdm(tf.reverse(seq, axis=[0])):
144
+ tt = tf.cast(tf.fill(dims=(num_images,), value=t), dtype=tf.int64)
145
+ pred_noise = noise_predictor.predict([samples, tt], verbose=0, batch_size=num_images)
146
+ samples = sampler(pred_noise, samples, tt, )
147
+
148
+ return (
149
+ tf.clip_by_value(samples * 127.5 + 127.5, 0.0, 255.0)
150
+ .numpy()
151
+ .astype(np.uint8)
152
+ )
requirements.txt ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==1.4.0
2
+ aiofiles==23.2.1
3
+ altair==5.2.0
4
+ annotated-types==0.6.0
5
+ anyio==3.7.1
6
+ argon2-cffi==23.1.0
7
+ argon2-cffi-bindings==21.2.0
8
+ array-record==0.5.0
9
+ arrow==1.3.0
10
+ asttokens==2.4.1
11
+ astunparse==1.6.3
12
+ async-lru==2.0.4
13
+ attrs==23.1.0
14
+ Babel==2.13.1
15
+ beautifulsoup4==4.12.2
16
+ bleach==6.1.0
17
+ cachetools==5.3.2
18
+ certifi==2023.11.17
19
+ cffi==1.16.0
20
+ charset-normalizer==3.3.2
21
+ click==8.1.7
22
+ colorama==0.4.6
23
+ comm==0.2.0
24
+ contourpy==1.2.0
25
+ cycler==0.12.1
26
+ debugpy==1.8.0
27
+ decorator==5.1.1
28
+ defusedxml==0.7.1
29
+ dm-tree==0.1.8
30
+ etils==1.5.2
31
+ exceptiongroup==1.2.0
32
+ executing==2.0.1
33
+ fastapi==0.104.1
34
+ fastjsonschema==2.19.0
35
+ ffmpy==0.3.1
36
+ filelock==3.13.1
37
+ flatbuffers==23.5.26
38
+ fonttools==4.46.0
39
+ fqdn==1.5.1
40
+ fsspec==2023.12.0
41
+ gast==0.5.4
42
+ google-auth==2.24.0
43
+ google-auth-oauthlib==1.0.0
44
+ google-pasta==0.2.0
45
+ googleapis-common-protos==1.61.0
46
+ gradio==4.8.0
47
+ gradio_client==0.7.1
48
+ graphviz==0.20.1
49
+ grpcio==1.59.3
50
+ h11==0.14.0
51
+ h5py==3.10.0
52
+ httpcore==1.0.2
53
+ httpx==0.25.2
54
+ huggingface-hub==0.19.4
55
+ idna==3.6
56
+ importlib-resources==6.1.1
57
+ ipykernel==6.27.1
58
+ ipython==8.18.1
59
+ ipywidgets==8.1.1
60
+ isoduration==20.11.0
61
+ jedi==0.19.1
62
+ Jinja2==3.1.2
63
+ json5==0.9.14
64
+ jsonpointer==2.4
65
+ jsonschema==4.20.0
66
+ jsonschema-specifications==2023.11.2
67
+ jupyter==1.0.0
68
+ jupyter-console==6.6.3
69
+ jupyter-events==0.9.0
70
+ jupyter-lsp==2.2.1
71
+ jupyter_client==8.6.0
72
+ jupyter_core==5.5.0
73
+ jupyter_server==2.11.2
74
+ jupyter_server_terminals==0.4.4
75
+ jupyterlab==4.0.9
76
+ jupyterlab-widgets==3.0.9
77
+ jupyterlab_pygments==0.3.0
78
+ jupyterlab_server==2.25.2
79
+ keras==2.15.0
80
+ kiwisolver==1.4.5
81
+ libclang==16.0.6
82
+ Markdown==3.5.1
83
+ markdown-it-py==3.0.0
84
+ MarkupSafe==2.1.3
85
+ matplotlib==3.8.2
86
+ matplotlib-inline==0.1.6
87
+ mdurl==0.1.2
88
+ mistune==3.0.2
89
+ ml-dtypes==0.2.0
90
+ nbclient==0.9.0
91
+ nbconvert==7.12.0
92
+ nbformat==5.9.2
93
+ nest-asyncio==1.5.8
94
+ notebook==7.0.6
95
+ notebook_shim==0.2.3
96
+ numpy==1.26.2
97
+ oauthlib==3.2.2
98
+ opt-einsum==3.3.0
99
+ orjson==3.9.10
100
+ overrides==7.4.0
101
+ packaging==23.2
102
+ pandas==2.1.4
103
+ pandocfilters==1.5.0
104
+ parso==0.8.3
105
+ pexpect==4.9.0
106
+ Pillow==10.1.0
107
+ platformdirs==4.1.0
108
+ prometheus-client==0.19.0
109
+ promise==2.3
110
+ prompt-toolkit==3.0.41
111
+ protobuf==3.20.3
112
+ psutil==5.9.6
113
+ ptyprocess==0.7.0
114
+ pure-eval==0.2.2
115
+ pyasn1==0.5.1
116
+ pyasn1-modules==0.3.0
117
+ pycparser==2.21
118
+ pydantic==2.5.2
119
+ pydantic_core==2.14.5
120
+ pydot==1.4.2
121
+ pydub==0.25.1
122
+ Pygments==2.17.2
123
+ pyparsing==3.1.1
124
+ python-dateutil==2.8.2
125
+ python-json-logger==2.0.7
126
+ python-multipart==0.0.6
127
+ pytz==2023.3.post1
128
+ PyYAML==6.0.1
129
+ pyzmq==25.1.1
130
+ qtconsole==5.5.1
131
+ QtPy==2.4.1
132
+ referencing==0.31.1
133
+ requests==2.31.0
134
+ requests-oauthlib==1.3.1
135
+ rfc3339-validator==0.1.4
136
+ rfc3986-validator==0.1.1
137
+ rich==13.7.0
138
+ rpds-py==0.13.2
139
+ rsa==4.9
140
+ semantic-version==2.10.0
141
+ Send2Trash==1.8.2
142
+ shellingham==1.5.4
143
+ six==1.16.0
144
+ sniffio==1.3.0
145
+ soupsieve==2.5
146
+ stack-data==0.6.3
147
+ starlette==0.27.0
148
+ tensorboard==2.15.1
149
+ tensorboard-data-server==0.7.2
150
+ tensorflow==2.15.0
151
+ tensorflow-datasets==4.9.3
152
+ tensorflow-estimator==2.15.0
153
+ tensorflow-io-gcs-filesystem==0.34.0
154
+ tensorflow-metadata==1.14.0
155
+ termcolor==2.4.0
156
+ terminado==0.18.0
157
+ tinycss2==1.2.1
158
+ toml==0.10.2
159
+ tomli==2.0.1
160
+ tomlkit==0.12.0
161
+ toolz==0.12.0
162
+ tornado==6.4
163
+ tqdm==4.66.1
164
+ traitlets==5.14.0
165
+ typer==0.9.0
166
+ types-python-dateutil==2.8.19.14
167
+ typing_extensions==4.8.0
168
+ tzdata==2023.3
169
+ uri-template==1.3.0
170
+ urllib3==2.1.0
171
+ uvicorn==0.24.0.post1
172
+ wcwidth==0.2.12
173
+ webcolors==1.13
174
+ webencodings==0.5.1
175
+ websocket-client==1.7.0
176
+ websockets==11.0.3
177
+ Werkzeug==3.0.1
178
+ widgetsnbextension==4.0.9
179
+ wrapt==1.14.1
180
+ zipp==3.17.0