JBlitzar commited on
Commit
897982a
·
1 Parent(s): ed96a20
Files changed (11) hide show
  1. app.py +19 -67
  2. bert_vectorize.py +27 -0
  3. factories.py +343 -0
  4. infer.py +43 -0
  5. logger.py +40 -0
  6. pipeline.py +69 -0
  7. predict.py +54 -0
  8. runner.py +80 -0
  9. runs/run_3_jxa/ckpt/latest.pt +3 -0
  10. runs/run_3_jxa/ckpt/latest_cpu.pt +3 -0
  11. wrapper.py +198 -0
app.py CHANGED
@@ -2,42 +2,30 @@ import gradio as gr
2
  import numpy as np
3
  import random
4
  #import spaces #[uncomment to use ZeroGPU]
5
- from diffusers import DiffusionPipeline
6
  import torch
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
- model_repo_id = "stabilityai/sdxl-turbo" #Replace to the model you would like to use
10
 
11
  if torch.cuda.is_available():
12
  torch_dtype = torch.float16
13
  else:
14
  torch_dtype = torch.float32
15
 
16
- pipe = DiffusionPipeline.from_pretrained(model_repo_id, torch_dtype=torch_dtype)
17
- pipe = pipe.to(device)
18
 
19
  MAX_SEED = np.iinfo(np.int32).max
20
  MAX_IMAGE_SIZE = 1024
21
 
22
  #@spaces.GPU #[uncomment to use ZeroGPU]
23
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps, progress=gr.Progress(track_tqdm=True)):
24
 
25
- if randomize_seed:
26
- seed = random.randint(0, MAX_SEED)
27
-
28
- generator = torch.Generator().manual_seed(seed)
29
 
30
  image = pipe(
31
- prompt = prompt,
32
- negative_prompt = negative_prompt,
33
- guidance_scale = guidance_scale,
34
- num_inference_steps = num_inference_steps,
35
- width = width,
36
- height = height,
37
- generator = generator
38
  ).images[0]
39
 
40
- return image, seed
41
 
42
  examples = [
43
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
@@ -75,58 +63,22 @@ with gr.Blocks(css=css) as demo:
75
 
76
  with gr.Accordion("Advanced Settings", open=False):
77
 
78
- negative_prompt = gr.Text(
79
- label="Negative prompt",
80
- max_lines=1,
81
- placeholder="Enter a negative prompt",
82
- visible=False,
83
- )
84
-
85
- seed = gr.Slider(
86
- label="Seed",
87
- minimum=0,
88
- maximum=MAX_SEED,
89
- step=1,
90
- value=0,
91
- )
92
-
93
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
94
-
95
- with gr.Row():
96
-
97
- width = gr.Slider(
98
- label="Width",
99
- minimum=256,
100
- maximum=MAX_IMAGE_SIZE,
101
- step=32,
102
- value=1024, #Replace with defaults that work for your model
103
- )
104
-
105
- height = gr.Slider(
106
- label="Height",
107
- minimum=256,
108
- maximum=MAX_IMAGE_SIZE,
109
- step=32,
110
- value=1024, #Replace with defaults that work for your model
111
  )
112
 
113
- with gr.Row():
114
-
115
- guidance_scale = gr.Slider(
116
- label="Guidance scale",
117
- minimum=0.0,
118
- maximum=10.0,
119
- step=0.1,
120
- value=0.0, #Replace with defaults that work for your model
121
- )
122
-
123
- num_inference_steps = gr.Slider(
124
- label="Number of inference steps",
125
- minimum=1,
126
- maximum=50,
127
  step=1,
128
- value=2, #Replace with defaults that work for your model
129
  )
 
130
 
131
  gr.Examples(
132
  examples = examples,
@@ -135,8 +87,8 @@ with gr.Blocks(css=css) as demo:
135
  gr.on(
136
  triggers=[run_button.click, prompt.submit],
137
  fn = infer,
138
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
139
- outputs = [result, seed]
140
  )
141
 
142
  demo.queue().launch()
 
2
  import numpy as np
3
  import random
4
  #import spaces #[uncomment to use ZeroGPU]
5
+ from pipeline import TextToImagePipeline
6
  import torch
7
 
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
9
 
10
  if torch.cuda.is_available():
11
  torch_dtype = torch.float16
12
  else:
13
  torch_dtype = torch.float32
14
 
15
+ pipe = TextToImagePipeline(device=device)
 
16
 
17
  MAX_SEED = np.iinfo(np.int32).max
18
  MAX_IMAGE_SIZE = 1024
19
 
20
  #@spaces.GPU #[uncomment to use ZeroGPU]
21
+ def infer(prompt, num_inference_steps, amt, progress=gr.Progress(track_tqdm=True)):
22
 
 
 
 
 
23
 
24
  image = pipe(
25
+ prompt, num_inference_steps, amt
 
 
 
 
 
 
26
  ).images[0]
27
 
28
+ return image
29
 
30
  examples = [
31
  "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
 
63
 
64
  with gr.Accordion("Advanced Settings", open=False):
65
 
66
+ amt = gr.Slider(
67
+ label="Amount",
68
+ minimum=1,
69
+ maximum=8,
70
+ step=1,
71
+ value=8,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
+ steps = gr.Slider(
75
+ label="Num inference steps",
76
+ minimum=100,
77
+ maximum=2000,
 
 
 
 
 
 
 
 
 
 
78
  step=1,
79
+ value=1000,
80
  )
81
+
82
 
83
  gr.Examples(
84
  examples = examples,
 
87
  gr.on(
88
  triggers=[run_button.click, prompt.submit],
89
  fn = infer,
90
+ inputs = [prompt, steps,amt],
91
+ outputs = [result]
92
  )
93
 
94
  demo.queue().launch()
bert_vectorize.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import BertTokenizer, BertModel, DistilBertTokenizer, DistilBertModel
2
+ import torch
3
+ tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
4
+ model = DistilBertModel.from_pretrained('distilbert-base-uncased', output_hidden_states=True)
5
+ model.eval()
6
+
7
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
8
+
9
+ model = model.to(device)
10
+ def vectorize_text_with_bert(text):# from hf docs
11
+
12
+ inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True).to(device)
13
+ with torch.no_grad():
14
+ outputs = model(**inputs)
15
+ hidden_states = outputs.hidden_states
16
+ last_layer_hidden_states = hidden_states[-1]
17
+ text_representation = torch.mean(last_layer_hidden_states, dim=1).squeeze(0)
18
+
19
+ return text_representation
20
+
21
+ if __name__ == "__main__":
22
+ text = "A man walking down the street with a dog holding a balloon in one hand."
23
+ text_representation = vectorize_text_with_bert(text)
24
+
25
+
26
+ print("Vectorized representation:", text_representation)
27
+ print(text_representation.shape)
factories.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+
7
+ class EMA:
8
+ def __init__(self, beta):
9
+ super().__init__()
10
+ self.beta = beta
11
+ self.step = 0
12
+
13
+ def update_model_average(self, ma_model, current_model):
14
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
15
+ old_weight, up_weight = ma_params.data, current_params.data
16
+ ma_params.data = self.update_average(old_weight, up_weight)
17
+
18
+ def update_average(self, old, new):
19
+ if old is None:
20
+ return new
21
+ return old * self.beta + (1 - self.beta) * new
22
+
23
+ def step_ema(self, ema_model, model, step_start_ema=2000):
24
+ if self.step < step_start_ema:
25
+ self.reset_parameters(ema_model, model)
26
+ self.step += 1
27
+ return
28
+ self.update_model_average(ema_model, model)
29
+ self.step += 1
30
+
31
+ def reset_parameters(self, ema_model, model):
32
+ ema_model.load_state_dict(model.state_dict())
33
+
34
+
35
+ class SelfAttention(nn.Module):
36
+ def __init__(self, channels, size):
37
+ super(SelfAttention, self).__init__()
38
+ self.channels = channels
39
+ self.size = size
40
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
41
+ self.ln = nn.LayerNorm([channels])
42
+ self.ff_self = nn.Sequential(
43
+ nn.LayerNorm([channels]),
44
+ nn.Linear(channels, channels),
45
+ nn.GELU(),
46
+ nn.Linear(channels, channels),
47
+ )
48
+
49
+ def forward(self, x):
50
+ x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
51
+ x_ln = self.ln(x)
52
+ attention_value, _ = self.mha(x_ln, x_ln, x_ln)
53
+ attention_value = attention_value + x
54
+ attention_value = self.ff_self(attention_value) + attention_value
55
+ return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
56
+
57
+
58
+ class CrossAttention(nn.Module):
59
+ def __init__(self, channels, size, context_dim):
60
+ super(CrossAttention, self).__init__()
61
+ self.channels = channels
62
+ self.size = size
63
+ self.context_dim = context_dim
64
+ self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
65
+ self.ln = nn.LayerNorm(channels)
66
+ self.context_ln = nn.LayerNorm(channels)
67
+ self.ff_self = nn.Sequential(
68
+ nn.LayerNorm(channels),
69
+ nn.Linear(channels, channels),
70
+ nn.GELU(),
71
+ nn.Linear(channels, channels),
72
+ )
73
+
74
+
75
+ self.context_proj = nn.Linear(context_dim, channels)
76
+
77
+ def forward(self, x, context):
78
+
79
+ # Reshape and permute x for multi-head attention
80
+ batch_size, channels, height, width = x.size()
81
+ x = x.view(-1, self.channels, self.size * self.size).swapaxes(1,2)
82
+ x_ln = self.ln(x)
83
+
84
+ # Expand context to match the sequence length of x
85
+ context = self.context_proj(context)
86
+
87
+ context = context.unsqueeze(1).expand(-1, x_ln.size(1), -1)
88
+
89
+ context_ln = self.context_ln(context)
90
+
91
+
92
+
93
+
94
+
95
+ # Apply cross-attention
96
+ attention_value, _ = self.mha(x_ln, context_ln, context_ln)
97
+ attention_value = attention_value + x
98
+ attention_value = self.ff_self(attention_value) + attention_value
99
+
100
+ # Reshape and permute back to the original format
101
+ return attention_value.permute(0, 2, 1).view(batch_size, channels, height, width)
102
+
103
+
104
+ class DoubleConv(nn.Module):
105
+ def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
106
+ super().__init__()
107
+ self.residual = residual
108
+ if not mid_channels:
109
+ mid_channels = out_channels
110
+ self.double_conv = nn.Sequential(
111
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
112
+ nn.GroupNorm(1, mid_channels),
113
+ nn.GELU(),
114
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
115
+ nn.GroupNorm(1, out_channels),
116
+ )
117
+
118
+ def forward(self, x):
119
+ if self.residual:
120
+ return F.gelu(x + self.double_conv(x))
121
+ else:
122
+ return self.double_conv(x)
123
+
124
+
125
+ class Down(nn.Module):
126
+ def __init__(self, in_channels, out_channels, emb_dim=256):
127
+ super().__init__()
128
+ self.maxpool_conv = nn.Sequential(
129
+ nn.MaxPool2d(2),
130
+ DoubleConv(in_channels, in_channels, residual=True),
131
+ DoubleConv(in_channels, out_channels),
132
+ )
133
+
134
+ self.emb_layer = nn.Sequential(
135
+ nn.SiLU(),
136
+ nn.Linear(
137
+ emb_dim,
138
+ out_channels
139
+ ),
140
+ )
141
+
142
+ def forward(self, x, t):
143
+ x = self.maxpool_conv(x)
144
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
145
+ return x + emb
146
+
147
+
148
+ class Up(nn.Module):
149
+ def __init__(self, in_channels, out_channels, emb_dim=256):
150
+ super().__init__()
151
+
152
+ self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
153
+ self.conv = nn.Sequential(
154
+ DoubleConv(in_channels, in_channels, residual=True),
155
+ DoubleConv(in_channels, out_channels, in_channels // 2),
156
+ )
157
+
158
+ self.emb_layer = nn.Sequential(
159
+ nn.SiLU(),
160
+ nn.Linear(
161
+ emb_dim,
162
+ out_channels
163
+ ),
164
+ )
165
+
166
+ def forward(self, x, skip_x, t):
167
+ x = self.up(x)
168
+ x = torch.cat([skip_x, x], dim=1)
169
+ x = self.conv(x)
170
+ emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
171
+ return x + emb
172
+
173
+
174
+ class Dome_UNet(nn.Module):
175
+ def __init__(self, c_in=3, c_out=3, time_dim=256, device="mps"):
176
+ super().__init__()
177
+ self.device = device
178
+ self.time_dim = time_dim
179
+ self.inc = DoubleConv(c_in, 64)
180
+ self.down1 = Down(64, 128)
181
+ self.sa1 = SelfAttention(128, 32)
182
+ self.down2 = Down(128, 256)
183
+ self.sa2 = SelfAttention(256, 16)
184
+ self.down3 = Down(256, 256)
185
+ self.sa3 = SelfAttention(256, 8)
186
+
187
+ self.bot1 = DoubleConv(256, 512)
188
+ self.bot2 = DoubleConv(512, 512)
189
+ self.bot3 = DoubleConv(512, 256)
190
+
191
+ self.up1 = Up(512, 128)
192
+ self.sa4 = SelfAttention(128, 16)
193
+ self.up2 = Up(256, 64)
194
+ self.sa5 = SelfAttention(64, 32)
195
+ self.up3 = Up(128, 64)
196
+ self.sa6 = SelfAttention(64, 64)
197
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
198
+
199
+ def pos_encoding(self, t, channels):
200
+ inv_freq = 1.0 / (
201
+ 10000
202
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
203
+ )
204
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
205
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
206
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
207
+ return pos_enc
208
+
209
+ def forward(self, x, t):
210
+ t = t.unsqueeze(-1).type(torch.float)
211
+ t = self.pos_encoding(t, self.time_dim)
212
+
213
+ x1 = self.inc(x)
214
+ x2 = self.down1(x1, t)
215
+ x2 = self.sa1(x2)
216
+ x3 = self.down2(x2, t)
217
+ x3 = self.sa2(x3)
218
+ x4 = self.down3(x3, t)
219
+ x4 = self.sa3(x4)
220
+
221
+ x4 = self.bot1(x4)
222
+ x4 = self.bot2(x4)
223
+ x4 = self.bot3(x4)
224
+
225
+ x = self.up1(x4, x3, t)
226
+ x = self.sa4(x)
227
+ x = self.up2(x, x2, t)
228
+ x = self.sa5(x)
229
+ x = self.up3(x, x1, t)
230
+ x = self.sa6(x)
231
+ output = self.outc(x)
232
+ return output
233
+
234
+
235
+ class UNet_conditional(nn.Module):
236
+ def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, context_dim=None, device="mps"):
237
+ super().__init__()
238
+
239
+ if context_dim is None:
240
+ context_dim = num_classes
241
+ self.device = device
242
+ self.time_dim = time_dim
243
+
244
+
245
+ self.inc = DoubleConv(c_in, 64)
246
+ self.down1 = Down(64, 128)
247
+ self.sa1 = SelfAttention(128, 32)
248
+ self.xa1 = CrossAttention(128, 32, context_dim)
249
+ self.down2 = Down(128, 256)
250
+ self.xa2 = CrossAttention(256, 16, context_dim)
251
+ self.sa2 = SelfAttention(256, 16)
252
+ self.down3 = Down(256, 256)
253
+ self.xa3 = CrossAttention(256, 8, context_dim)
254
+ self.sa3 = SelfAttention(256, 8)
255
+
256
+ self.bot1 = DoubleConv(256, 512)
257
+ self.bot2 = DoubleConv(512, 512)
258
+ self.bot3 = DoubleConv(512, 256)
259
+
260
+ self.up1 = Up(512, 128)
261
+ self.xa4 = CrossAttention(128, 16, context_dim)
262
+ self.sa4 = SelfAttention(128, 16)
263
+ self.up2 = Up(256, 64)
264
+ self.xa5 = CrossAttention(64, 32, context_dim)
265
+ self.sa5 = SelfAttention(64, 32)
266
+ self.up3 = Up(128, 64)
267
+ self.xa6 = CrossAttention(64, 64, context_dim)
268
+ self.sa6 = SelfAttention(64, 64)
269
+ self.outc = nn.Conv2d(64, c_out, kernel_size=1)
270
+
271
+ if num_classes is not None:
272
+ self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
273
+ self.num_classes = num_classes
274
+ if context_dim is None:
275
+ context_dim = num_classes
276
+
277
+ self.context_dim = context_dim
278
+
279
+ self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
280
+
281
+ def pos_encoding(self, t, channels):
282
+ inv_freq = 1.0 / (
283
+ 10000
284
+ ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
285
+ )
286
+ pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
287
+ pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
288
+ pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
289
+ return pos_enc
290
+
291
+ def forward(self, x, t, y):
292
+ t = t.unsqueeze(-1).type(torch.float)
293
+ t = self.pos_encoding(t, self.time_dim)
294
+
295
+ if y is not None:
296
+
297
+ attn_y = y[:,:self.num_classes]
298
+ attn_y = self.label_crossattn_emb(attn_y)
299
+
300
+ # y = y[:,:self.num_classes]
301
+
302
+ # y = self.label_emb(y)
303
+
304
+
305
+ # t += y
306
+
307
+ x1 = self.inc(x)
308
+
309
+ x2 = self.down1(x1, t)
310
+ x2 = self.xa1(x2, attn_y)
311
+ #x2 = self.sa1(x2)
312
+
313
+ x3 = self.down2(x2, t)
314
+ x3 = self.xa2(x3, attn_y)
315
+ #x3 = self.sa2(x3)
316
+
317
+ x4 = self.down3(x3, t)
318
+ x4 = self.xa3(x4, attn_y)
319
+ #x4 = self.sa3(x4)
320
+
321
+
322
+ x4 = self.bot1(x4)
323
+ x4 = self.bot2(x4)
324
+ x4 = self.bot3(x4)
325
+
326
+
327
+ x = self.up1(x4, x3, t)
328
+ x = self.xa4(x,attn_y)
329
+ #x = self.sa4(x)
330
+
331
+ x = self.up2(x, x2, t)
332
+ x = self.xa5(x, attn_y)
333
+ #x = self.sa5(x)
334
+
335
+ x = self.up3(x, x1, t)
336
+ x = self.xa6(x, attn_y)
337
+ #x = self.sa6(x)
338
+ output = self.outc(x)
339
+
340
+
341
+
342
+ #output = F.sigmoid(x)
343
+ return output
infer.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from factories import UNet_conditional
2
+ from wrapper import DiffusionManager, Schedule
3
+ import os
4
+ import re
5
+ import torch
6
+ from bert_vectorize import vectorize_text_with_bert
7
+ import time
8
+ import torchvision
9
+ from logger import save_grid_with_label
10
+
11
+
12
+
13
+ EXPERIMENT_DIRECTORY = "runs/run_3_jxa"
14
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
15
+
16
+ try:
17
+ os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred"))
18
+ except:
19
+ print("Skipping making directory, directory already exists")
20
+
21
+ net = UNet_conditional(num_classes=768)
22
+ net.to(device)
23
+ net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt"),weights_only=True))
24
+
25
+
26
+
27
+ wrapper = DiffusionManager(net, device=device, noise_steps=1000)
28
+ wrapper.set_schedule(Schedule.LINEAR)
29
+
30
+
31
+ def generate_sample_save_images(prompt, amt=1):
32
+
33
+ path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png")
34
+
35
+ vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
36
+
37
+ generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu()
38
+
39
+
40
+ save_grid_with_label(torchvision.utils.make_grid(generated),prompt, path)
41
+
42
+ if __name__ == "__main__":
43
+ generate_sample_save_images(input("Prompt? "), 8)
logger.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.tensorboard import SummaryWriter
3
+ import matplotlib.pyplot as plt
4
+
5
+ writer = None
6
+ def log_data(data, i):
7
+
8
+
9
+ for key in data.keys():
10
+ writer.add_scalar(key, data[key], i)
11
+
12
+ def log_img(img, name):
13
+ writer.add_image(name, img)
14
+
15
+
16
+ def save_grid_with_label(img_grid, label, out_file):
17
+ img_grid = img_grid.permute(1, 2, 0).numpy()
18
+
19
+ fig, ax = plt.subplots(figsize=(8, 8))
20
+ ax.imshow(img_grid)
21
+ ax.set_title(label, fontsize=20)
22
+ ax.axis('off')
23
+
24
+
25
+ plt.subplots_adjust(top=0.85)
26
+
27
+ plt.savefig(out_file, bbox_inches='tight', pad_inches=0.1)
28
+
29
+
30
+ plt.close(fig)
31
+ plt.close("all")
32
+
33
+
34
+
35
+
36
+ def init_logger(dir="runs"):
37
+
38
+ global writer
39
+ if not writer:
40
+ writer = SummaryWriter(dir)
pipeline.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pipeline.py
2
+ import os
3
+ import re
4
+ import time
5
+ import torch
6
+ import torchvision
7
+ from huggingface_hub import HfApi, HfFolder
8
+ from transformers import Pipeline
9
+ from factories import UNet_conditional
10
+ from wrapper import DiffusionManager, Schedule
11
+ from bert_vectorize import vectorize_text_with_bert
12
+ from logger import save_grid_with_label
13
+
14
+ class TextToImagePipeline(Pipeline):
15
+ def __init__(self, model_dir: str = "runs/run_3_jxa", device: str = "cpu"):
16
+ # Initialize model, diffusion manager, and set up environment
17
+ self.device = device
18
+ self.model_dir = model_dir
19
+
20
+ # Create directories if they do not exist
21
+ os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True)
22
+
23
+ # Load model
24
+ self.net = UNet_conditional(num_classes=768)
25
+ self.net.to(self.device)
26
+ self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest.pt"), weights_only=True))
27
+
28
+ # Set up DiffusionManager
29
+ self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
30
+ self.wrapper.set_schedule(Schedule.LINEAR)
31
+
32
+ def __call__(self, prompt,num_steps,amt):
33
+
34
+
35
+ self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=num_steps)
36
+ self.wrapper.set_schedule(Schedule.LINEAR)
37
+
38
+
39
+ return self.generate_sample_save_images(prompt, amt)
40
+
41
+ def generate_sample_save_images(self, prompt: str, amt: int = 1):
42
+ # Prepare the output path
43
+ output_path = os.path.join(self.model_dir, "inferred",
44
+ re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_") + str(int(time.time())) + ".png")
45
+
46
+ # Vectorize the prompt
47
+ vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
48
+
49
+ # Generate images
50
+ generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu()
51
+
52
+ # Save images using the provided save function
53
+ save_grid_with_label(torchvision.utils.make_grid(generated), prompt, output_path)
54
+
55
+ return output_path # Return the path to the saved image
56
+
57
+
58
+ # Usage example
59
+ if __name__ == "__main__":
60
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
61
+ model_dir = "runs/run_3_jxa" # Path to your model directory
62
+
63
+ # Create an instance of the pipeline
64
+ pipeline = TextToImagePipeline(model_dir=model_dir, device=device)
65
+
66
+ # Get user input and generate an image
67
+ prompt = input("Prompt? ")
68
+ image_path = pipeline(prompt, amt=8)
69
+ print(f"Generated image saved at: {image_path}")
predict.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Prediction interface for Cog ⚙️
2
+ # https://cog.run/python
3
+
4
+ from cog import BasePredictor, Input, Path
5
+ import os
6
+ from factories import UNet_conditional
7
+ from wrapper import DiffusionManager, Schedule
8
+ import torch
9
+ import re
10
+ from bert_vectorize import vectorize_text_with_bert
11
+ from logger import save_grid_with_label
12
+ import torchvision
13
+ import time
14
+
15
+
16
+ class Predictor(BasePredictor):
17
+ def setup(self) -> None:
18
+ """Load the model into memory to make running multiple predictions efficient"""
19
+ # self.model = torch.load("./weights.pth")
20
+ # Initialize model, diffusion manager, and set up environment
21
+ device = "cpu"
22
+ model_dir = "runs/run_3_jxa"
23
+ self.device = device
24
+ self.model_dir = model_dir
25
+
26
+ # Create directories if they do not exist
27
+ os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True)
28
+
29
+ # Load model
30
+ self.net = UNet_conditional(num_classes=768,device=device)
31
+ self.net.to(self.device)
32
+ self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest_cpu.pt"), weights_only=False))
33
+
34
+ # Set up DiffusionManager
35
+ self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
36
+ self.wrapper.set_schedule(Schedule.LINEAR)
37
+
38
+ def predict(
39
+ self,
40
+ prompt: str = Input(description="Text prompt"),
41
+ amt: int = Input(description="Amt", default=8)
42
+ ) -> Path:
43
+ """Run a single prediction on the model"""
44
+ # processed_input = preprocess(image)
45
+ # output = self.model(processed_image, scale)
46
+ # return postprocess(output)
47
+
48
+
49
+ # Vectorize the prompt
50
+ vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
51
+
52
+ generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu()
53
+
54
+ return torchvision.utils.make_grid(generated).cpu().numpy()
runner.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from factories import UNet_conditional
2
+ from wrapper import DiffusionManager, Schedule
3
+ import os
4
+ import re
5
+ import torch
6
+ from bert_vectorize import vectorize_text_with_bert, cleanup
7
+ import time
8
+ import torchvision
9
+ from logger import save_grid_with_label
10
+ from clip_score import select_top_n_images
11
+ from torchinfo import summary
12
+
13
+
14
+
15
+ EXPERIMENT_DIRECTORY = "runs/run_3_jxa_resumed"
16
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
17
+
18
+ try:
19
+ os.mkdir(os.path.join(EXPERIMENT_DIRECTORY, "inferred"))
20
+ except:
21
+ print("Skipping making directory, directory already exists")
22
+
23
+ net = UNet_conditional(num_classes=768)
24
+ net.to(device)
25
+ net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt")))
26
+
27
+
28
+ def count_parameters(model):
29
+ return torch.tensor([p.numel() for p in model.parameters() if p.requires_grad]).sum().item()
30
+ print(f"Parameters: {count_parameters(net)}")
31
+
32
+
33
+
34
+ wrapper = DiffusionManager(net, device=device, noise_steps=1000)
35
+ wrapper.set_schedule(Schedule.LINEAR)
36
+
37
+
38
+ def infer(prompt, amt=1, topn=8):
39
+
40
+ path = os.path.join(EXPERIMENT_DIRECTORY, "inferred", re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_")+str(int(time.time()))+".png")
41
+
42
+ vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
43
+
44
+ generated = wrapper.sample(64, vprompt, amt=amt).detach().cpu()
45
+
46
+ generated, _ = select_top_n_images(generated, prompt, n=topn)
47
+
48
+ save_grid_with_label(torchvision.utils.make_grid(generated),prompt + f"({topn} best of {amt})", path)
49
+
50
+
51
+ def run_jobs():
52
+ n=8
53
+ bestof=32
54
+ print(f"using best {bestof} of {n}")
55
+ processed_tasks = set()
56
+ def read_jobs():
57
+ try:
58
+ with open("inference_jobs.txt", 'r') as file:
59
+ tasks = file.readlines()
60
+ return [task.strip() for task in tasks]
61
+ except FileNotFoundError:
62
+ return []
63
+
64
+ tasks = read_jobs()
65
+ new_tasks = [task for task in tasks if task not in processed_tasks]
66
+ while new_tasks:
67
+
68
+
69
+ if new_tasks:
70
+ for task in new_tasks:
71
+ infer(task, n,bestof)
72
+ processed_tasks.add(task)
73
+ tasks = read_jobs()
74
+ new_tasks = [task for task in tasks if task not in processed_tasks]
75
+
76
+ cleanup()
77
+
78
+ if __name__ == "__main__":
79
+ #infer(input("Prompt? "), 8)
80
+ run_jobs()
runs/run_3_jxa/ckpt/latest.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0cd39e8429ea0ace24bb40d4bd404baebb8aae471385987b898a966eb79dcc5f
3
+ size 103503678
runs/run_3_jxa/ckpt/latest_cpu.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e6d31021fe6d0df8d0d8dee730a411648345f13c0d5ae10084efe536d5dc7a2
3
+ size 103505112
wrapper.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from enum import Enum
4
+ from tqdm import trange
5
+
6
+
7
+
8
+
9
+
10
+ Schedule = Enum('Schedule', ['LINEAR', 'COSINE'])
11
+
12
+ class DiffusionManager(nn.Module):
13
+ def __init__(self, model: nn.Module, noise_steps=1000, start=0.0001, end=0.02, device="cpu", **kwargs ) -> None:
14
+ super().__init__(**kwargs)
15
+
16
+ self.model = model
17
+
18
+ self.noise_steps = noise_steps
19
+
20
+ self.start = start
21
+ self.end = end
22
+ self.device = device
23
+
24
+ self.schedule = None
25
+
26
+ self.set_schedule()
27
+
28
+ #model.set_parent(self)
29
+
30
+
31
+ def _get_schedule(self, schedule_type: Schedule = Schedule.LINEAR):
32
+ if schedule_type == Schedule.LINEAR:
33
+ return torch.linspace(self.start, self.end, self.noise_steps)
34
+ elif schedule_type == Schedule.COSINE:
35
+ # https://arxiv.org/pdf/2102.09672 page 4
36
+ #https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
37
+ #line 18
38
+ def get_alphahat_at(t):
39
+ def f(t):
40
+ s=self.start
41
+ return torch.cos((t/self.noise_steps + s)/(1+s) * torch.pi/2) ** 2
42
+
43
+ return f(t)/f(torch.zeros_like(t))
44
+
45
+ t = torch.Tensor(range(self.noise_steps))
46
+
47
+ t = 1-(get_alphahat_at(t + 1)/get_alphahat_at(t))
48
+
49
+ t = torch.minimum(t, torch.ones_like(t) * 0.999) #"In practice, we clip β_t to be no larger than 0.999 to prevent singularities at the end of the diffusion process n"
50
+
51
+ return t
52
+
53
+ def set_schedule(self, schedule: Schedule = Schedule.LINEAR):
54
+ self.schedule = self._get_schedule(schedule).to(self.device)
55
+
56
+ def get_schedule_at(self, step):
57
+ beta = self.schedule
58
+ alpha = 1 - beta
59
+ alpha_hat = torch.cumprod(alpha, dim=0)
60
+
61
+ return self._unsqueezify(beta.data[step]), self._unsqueezify(alpha.data[step]), self._unsqueezify(alpha_hat.data[step])
62
+
63
+ @staticmethod
64
+ def _unsqueezify(value):
65
+ return value.view(-1, 1, 1, 1)#.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
66
+
67
+ def noise_image(self, image, step):
68
+
69
+
70
+ image = image.to(self.device)
71
+
72
+ beta, alpha, alpha_hat = self.get_schedule_at(step)
73
+
74
+ epsilon = torch.randn_like(image)
75
+
76
+ # print(alpha_hat)
77
+
78
+ # print(alpha_hat.size())
79
+ # print(image.size())
80
+
81
+ noised_img = torch.sqrt(alpha_hat) * image + torch.sqrt(1 - alpha_hat) * epsilon
82
+
83
+ return noised_img, epsilon
84
+
85
+ def random_timesteps(self, amt=1):
86
+
87
+ return torch.randint(low=1, high=self.noise_steps, size=(amt,))
88
+
89
+
90
+
91
+
92
+ def sample(self, img_size, condition, amt=5, use_tqdm=True):
93
+
94
+ if tuple(condition.shape)[0] < amt:
95
+ condition = condition.repeat(amt, 1)
96
+
97
+ self.model.eval()
98
+
99
+ condition = condition.to(self.device)
100
+
101
+ my_trange = lambda x, y, z: trange(x,y, z, leave=False,dynamic_ncols=True)
102
+ fn = my_trange if use_tqdm else range
103
+ with torch.no_grad():
104
+
105
+ cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
106
+ for i in fn(self.noise_steps-1, 0, -1):
107
+
108
+ timestep = torch.ones(amt) * (i)
109
+
110
+ timestep = timestep.to(self.device)
111
+
112
+
113
+
114
+ predicted_noise = self.model(cur_img, timestep, condition)
115
+
116
+ beta, alpha, alpha_hat = self.get_schedule_at(i)
117
+
118
+ cur_img = (1/torch.sqrt(alpha))*(cur_img - (beta/torch.sqrt(1-alpha_hat))*predicted_noise)
119
+ if i > 1:
120
+ cur_img = cur_img + torch.sqrt(beta)*torch.randn_like(cur_img)
121
+
122
+
123
+ self.model.train()
124
+
125
+
126
+
127
+
128
+
129
+ return cur_img
130
+ def sample_multicond(self, img_size, condition, use_tqdm=True):
131
+ num_conditions = condition.shape[0]
132
+
133
+
134
+
135
+ amt = num_conditions
136
+
137
+ self.model.eval()
138
+
139
+ condition = condition.to(self.device)
140
+
141
+ my_trange = lambda x, y, z: trange(x, y, z, leave=False, dynamic_ncols=True)
142
+ fn = my_trange if use_tqdm else range
143
+
144
+ with torch.no_grad():
145
+
146
+ cur_img = torch.randn((amt, 3, img_size, img_size)).to(self.device)
147
+
148
+ for i in fn(self.noise_steps-1, 0, -1):
149
+ timestep = torch.ones(amt) * i
150
+ timestep = timestep.to(self.device)
151
+
152
+
153
+ predicted_noise = self.model(cur_img, timestep, condition)
154
+
155
+ beta, alpha, alpha_hat = self.get_schedule_at(i)
156
+
157
+ cur_img = (1 / torch.sqrt(alpha)) * (cur_img - (beta / torch.sqrt(1 - alpha_hat)) * predicted_noise)
158
+ if i > 1:
159
+ cur_img = cur_img + torch.sqrt(beta) * torch.randn_like(cur_img)
160
+
161
+ self.model.train()
162
+
163
+ # Return images sampled for each condition
164
+ return cur_img
165
+
166
+ def training_loop_iteration(self, optimizer, batch, label, criterion):
167
+
168
+ def print_(string):
169
+ for i in range(10):
170
+ print(string)
171
+ batch = batch.to(self.device)
172
+
173
+ #label = label.long() # uncomment for nn.Embedding
174
+ label = label.to(self.device)
175
+
176
+ timesteps = self.random_timesteps(batch.shape[0]).to(self.device)
177
+
178
+ noisy_batch, real_noise = self.noise_image(batch, timesteps)
179
+
180
+ if torch.isnan(noisy_batch).any() or torch.isnan(real_noise).any():
181
+ print_("NaNs detected in the noisy batch or real noise")
182
+
183
+
184
+ pred_noise = self.model(noisy_batch, timesteps, label)
185
+
186
+ if torch.isnan(pred_noise).any():
187
+ print_("NaNs detected in the predicted noise")
188
+
189
+ loss = criterion(real_noise, pred_noise)
190
+
191
+ if torch.isnan(loss).any():
192
+ print_("NaNs detected in the loss")
193
+
194
+ loss.backward()
195
+ optimizer.step()
196
+
197
+ return loss.item()
198
+