JBlitzar commited on
Commit
3802079
·
1 Parent(s): 9f5a022
Files changed (3) hide show
  1. infer.py +1 -1
  2. pipeline.py +46 -346
  3. uploadify.py +0 -0
infer.py CHANGED
@@ -20,7 +20,7 @@ except:
20
 
21
  net = UNet_conditional(num_classes=768)
22
  net.to(device)
23
- net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt")))
24
 
25
 
26
 
 
20
 
21
  net = UNet_conditional(num_classes=768)
22
  net.to(device)
23
+ net.load_state_dict(torch.load(os.path.join(EXPERIMENT_DIRECTORY, "ckpt/latest.pt"),weights_only=True))
24
 
25
 
26
 
pipeline.py CHANGED
@@ -1,364 +1,64 @@
1
  # pipeline.py
 
 
 
2
  import torch
 
 
3
  from transformers import Pipeline
4
-
5
-
 
 
6
 
7
  class TextToImagePipeline(Pipeline):
8
- def __init__(self, model, tokenizer):
9
- super().__init__(model=model, tokenizer=tokenizer)
10
-
11
- def __call__(self, inputs):
12
- text_inputs = self.tokenizer(inputs, return_tensors="pt")
13
-
14
-
15
- with torch.no_grad():
16
- image = self.model(text_inputs['input_ids'])
17
-
18
-
19
- image = image.cpu().numpy()
20
-
21
- return image
22
-
23
-
24
-
25
- import torch
26
- import torch.nn as nn
27
- import torch.nn.functional as F
28
-
29
-
30
- class EMA:
31
- def __init__(self, beta):
32
- super().__init__()
33
- self.beta = beta
34
- self.step = 0
35
-
36
- def update_model_average(self, ma_model, current_model):
37
- for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
38
- old_weight, up_weight = ma_params.data, current_params.data
39
- ma_params.data = self.update_average(old_weight, up_weight)
40
-
41
- def update_average(self, old, new):
42
- if old is None:
43
- return new
44
- return old * self.beta + (1 - self.beta) * new
45
-
46
- def step_ema(self, ema_model, model, step_start_ema=2000):
47
- if self.step < step_start_ema:
48
- self.reset_parameters(ema_model, model)
49
- self.step += 1
50
- return
51
- self.update_model_average(ema_model, model)
52
- self.step += 1
53
-
54
- def reset_parameters(self, ema_model, model):
55
- ema_model.load_state_dict(model.state_dict())
56
-
57
-
58
- class SelfAttention(nn.Module):
59
- def __init__(self, channels, size):
60
- super(SelfAttention, self).__init__()
61
- self.channels = channels
62
- self.size = size
63
- self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
64
- self.ln = nn.LayerNorm([channels])
65
- self.ff_self = nn.Sequential(
66
- nn.LayerNorm([channels]),
67
- nn.Linear(channels, channels),
68
- nn.GELU(),
69
- nn.Linear(channels, channels),
70
- )
71
-
72
- def forward(self, x):
73
- x = x.view(-1, self.channels, self.size * self.size).swapaxes(1, 2)
74
- x_ln = self.ln(x)
75
- attention_value, _ = self.mha(x_ln, x_ln, x_ln)
76
- attention_value = attention_value + x
77
- attention_value = self.ff_self(attention_value) + attention_value
78
- return attention_value.swapaxes(2, 1).view(-1, self.channels, self.size, self.size)
79
-
80
-
81
- class CrossAttention(nn.Module):
82
- def __init__(self, channels, size, context_dim):
83
- super(CrossAttention, self).__init__()
84
- self.channels = channels
85
- self.size = size
86
- self.context_dim = context_dim
87
- self.mha = nn.MultiheadAttention(channels, 4, batch_first=True)
88
- self.ln = nn.LayerNorm(channels)
89
- self.context_ln = nn.LayerNorm(channels)
90
- self.ff_self = nn.Sequential(
91
- nn.LayerNorm(channels),
92
- nn.Linear(channels, channels),
93
- nn.GELU(),
94
- nn.Linear(channels, channels),
95
- )
96
-
97
-
98
- self.context_proj = nn.Linear(context_dim, channels)
99
-
100
- def forward(self, x, context):
101
-
102
- # Reshape and permute x for multi-head attention
103
- batch_size, channels, height, width = x.size()
104
- x = x.view(-1, self.channels, self.size * self.size).swapaxes(1,2)
105
- x_ln = self.ln(x)
106
-
107
- # Expand context to match the sequence length of x
108
- context = self.context_proj(context)
109
-
110
- context = context.unsqueeze(1).expand(-1, x_ln.size(1), -1)
111
-
112
- context_ln = self.context_ln(context)
113
-
114
-
115
-
116
-
117
-
118
- # Apply cross-attention
119
- attention_value, _ = self.mha(x_ln, context_ln, context_ln)
120
- attention_value = attention_value + x
121
- attention_value = self.ff_self(attention_value) + attention_value
122
-
123
- # Reshape and permute back to the original format
124
- return attention_value.permute(0, 2, 1).view(batch_size, channels, height, width)
125
-
126
-
127
- class DoubleConv(nn.Module):
128
- def __init__(self, in_channels, out_channels, mid_channels=None, residual=False):
129
- super().__init__()
130
- self.residual = residual
131
- if not mid_channels:
132
- mid_channels = out_channels
133
- self.double_conv = nn.Sequential(
134
- nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
135
- nn.GroupNorm(1, mid_channels),
136
- nn.GELU(),
137
- nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
138
- nn.GroupNorm(1, out_channels),
139
- )
140
-
141
- def forward(self, x):
142
- if self.residual:
143
- return F.gelu(x + self.double_conv(x))
144
- else:
145
- return self.double_conv(x)
146
-
147
-
148
- class Down(nn.Module):
149
- def __init__(self, in_channels, out_channels, emb_dim=256):
150
- super().__init__()
151
- self.maxpool_conv = nn.Sequential(
152
- nn.MaxPool2d(2),
153
- DoubleConv(in_channels, in_channels, residual=True),
154
- DoubleConv(in_channels, out_channels),
155
- )
156
-
157
- self.emb_layer = nn.Sequential(
158
- nn.SiLU(),
159
- nn.Linear(
160
- emb_dim,
161
- out_channels
162
- ),
163
- )
164
-
165
- def forward(self, x, t):
166
- x = self.maxpool_conv(x)
167
- emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
168
- return x + emb
169
-
170
-
171
- class Up(nn.Module):
172
- def __init__(self, in_channels, out_channels, emb_dim=256):
173
- super().__init__()
174
-
175
- self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
176
- self.conv = nn.Sequential(
177
- DoubleConv(in_channels, in_channels, residual=True),
178
- DoubleConv(in_channels, out_channels, in_channels // 2),
179
- )
180
-
181
- self.emb_layer = nn.Sequential(
182
- nn.SiLU(),
183
- nn.Linear(
184
- emb_dim,
185
- out_channels
186
- ),
187
- )
188
-
189
- def forward(self, x, skip_x, t):
190
- x = self.up(x)
191
- x = torch.cat([skip_x, x], dim=1)
192
- x = self.conv(x)
193
- emb = self.emb_layer(t)[:, :, None, None].repeat(1, 1, x.shape[-2], x.shape[-1])
194
- return x + emb
195
-
196
-
197
- class Dome_UNet(nn.Module):
198
- def __init__(self, c_in=3, c_out=3, time_dim=256, device="mps"):
199
- super().__init__()
200
- self.device = device
201
- self.time_dim = time_dim
202
- self.inc = DoubleConv(c_in, 64)
203
- self.down1 = Down(64, 128)
204
- self.sa1 = SelfAttention(128, 32)
205
- self.down2 = Down(128, 256)
206
- self.sa2 = SelfAttention(256, 16)
207
- self.down3 = Down(256, 256)
208
- self.sa3 = SelfAttention(256, 8)
209
-
210
- self.bot1 = DoubleConv(256, 512)
211
- self.bot2 = DoubleConv(512, 512)
212
- self.bot3 = DoubleConv(512, 256)
213
-
214
- self.up1 = Up(512, 128)
215
- self.sa4 = SelfAttention(128, 16)
216
- self.up2 = Up(256, 64)
217
- self.sa5 = SelfAttention(64, 32)
218
- self.up3 = Up(128, 64)
219
- self.sa6 = SelfAttention(64, 64)
220
- self.outc = nn.Conv2d(64, c_out, kernel_size=1)
221
-
222
- def pos_encoding(self, t, channels):
223
- inv_freq = 1.0 / (
224
- 10000
225
- ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
226
- )
227
- pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
228
- pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
229
- pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
230
- return pos_enc
231
-
232
- def forward(self, x, t):
233
- t = t.unsqueeze(-1).type(torch.float)
234
- t = self.pos_encoding(t, self.time_dim)
235
-
236
- x1 = self.inc(x)
237
- x2 = self.down1(x1, t)
238
- x2 = self.sa1(x2)
239
- x3 = self.down2(x2, t)
240
- x3 = self.sa2(x3)
241
- x4 = self.down3(x3, t)
242
- x4 = self.sa3(x4)
243
-
244
- x4 = self.bot1(x4)
245
- x4 = self.bot2(x4)
246
- x4 = self.bot3(x4)
247
-
248
- x = self.up1(x4, x3, t)
249
- x = self.sa4(x)
250
- x = self.up2(x, x2, t)
251
- x = self.sa5(x)
252
- x = self.up3(x, x1, t)
253
- x = self.sa6(x)
254
- output = self.outc(x)
255
- return output
256
-
257
-
258
- class UNet_conditional(nn.Module):
259
- def __init__(self, c_in=3, c_out=3, time_dim=256, num_classes=None, context_dim=None, device="mps"):
260
- super().__init__()
261
-
262
- if context_dim is None:
263
- context_dim = num_classes
264
  self.device = device
265
- self.time_dim = time_dim
266
-
267
-
268
- self.inc = DoubleConv(c_in, 64)
269
- self.down1 = Down(64, 128)
270
- self.sa1 = SelfAttention(128, 32)
271
- self.xa1 = CrossAttention(128, 32, context_dim)
272
- self.down2 = Down(128, 256)
273
- self.xa2 = CrossAttention(256, 16, context_dim)
274
- self.sa2 = SelfAttention(256, 16)
275
- self.down3 = Down(256, 256)
276
- self.xa3 = CrossAttention(256, 8, context_dim)
277
- self.sa3 = SelfAttention(256, 8)
278
-
279
- self.bot1 = DoubleConv(256, 512)
280
- self.bot2 = DoubleConv(512, 512)
281
- self.bot3 = DoubleConv(512, 256)
282
-
283
- self.up1 = Up(512, 128)
284
- self.xa4 = CrossAttention(128, 16, context_dim)
285
- self.sa4 = SelfAttention(128, 16)
286
- self.up2 = Up(256, 64)
287
- self.xa5 = CrossAttention(64, 32, context_dim)
288
- self.sa5 = SelfAttention(64, 32)
289
- self.up3 = Up(128, 64)
290
- self.xa6 = CrossAttention(64, 64, context_dim)
291
- self.sa6 = SelfAttention(64, 64)
292
- self.outc = nn.Conv2d(64, c_out, kernel_size=1)
293
-
294
- if num_classes is not None:
295
- self.label_emb = nn.Linear(num_classes, time_dim)#Embedding(num_classes, time_dim)
296
- self.num_classes = num_classes
297
- if context_dim is None:
298
- context_dim = num_classes
299
-
300
- self.context_dim = context_dim
301
-
302
- self.label_crossattn_emb = nn.Linear(num_classes, context_dim)
303
-
304
- def pos_encoding(self, t, channels):
305
- inv_freq = 1.0 / (
306
- 10000
307
- ** (torch.arange(0, channels, 2, device=self.device).float() / channels)
308
- )
309
- pos_enc_a = torch.sin(t.repeat(1, channels // 2) * inv_freq)
310
- pos_enc_b = torch.cos(t.repeat(1, channels // 2) * inv_freq)
311
- pos_enc = torch.cat([pos_enc_a, pos_enc_b], dim=-1)
312
- return pos_enc
313
-
314
- def forward(self, x, t, y):
315
- t = t.unsqueeze(-1).type(torch.float)
316
- t = self.pos_encoding(t, self.time_dim)
317
-
318
- if y is not None:
319
-
320
- attn_y = y[:,:self.num_classes]
321
- attn_y = self.label_crossattn_emb(attn_y)
322
-
323
- # y = y[:,:self.num_classes]
324
-
325
- # y = self.label_emb(y)
326
-
327
-
328
- # t += y
329
-
330
- x1 = self.inc(x)
331
-
332
- x2 = self.down1(x1, t)
333
- x2 = self.xa1(x2, attn_y)
334
-
335
-
336
- x3 = self.down2(x2, t)
337
- x3 = self.xa2(x3, attn_y)
338
-
339
-
340
- x4 = self.down3(x3, t)
341
- x4 = self.xa3(x4, attn_y)
342
 
 
 
 
 
343
 
 
 
 
344
 
345
- x4 = self.bot1(x4)
 
 
346
 
 
 
 
 
347
 
348
- x = self.up1(x4, x3, t)
349
- x = self.xa4(x,attn_y)
350
 
351
- x = self.up2(x, x2, t)
352
- x = self.xa5(x, attn_y)
353
 
 
 
354
 
355
- x = self.up3(x, x1, t)
356
- x = self.xa6(x, attn_y)
357
- x = self.sa6(x)
358
 
359
- output = self.outc(x)
360
 
 
 
 
 
361
 
362
- #output = F.sigmoid(x)
363
- return output
364
 
 
 
 
 
 
1
  # pipeline.py
2
+ import os
3
+ import re
4
+ import time
5
  import torch
6
+ import torchvision
7
+ from huggingface_hub import HfApi, HfFolder
8
  from transformers import Pipeline
9
+ from factories import UNet_conditional
10
+ from wrapper import DiffusionManager, Schedule
11
+ from bert_vectorize import vectorize_text_with_bert
12
+ from logger import save_grid_with_label
13
 
14
  class TextToImagePipeline(Pipeline):
15
+ def __init__(self, model_dir: str, device: str = "cpu"):
16
+ # Initialize model, diffusion manager, and set up environment
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  self.device = device
18
+ self.model_dir = model_dir
19
+
20
+ # Create directories if they do not exist
21
+ os.makedirs(os.path.join(model_dir, "inferred"), exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # Load model
24
+ self.net = UNet_conditional(num_classes=768)
25
+ self.net.to(self.device)
26
+ self.net.load_state_dict(torch.load(os.path.join(model_dir, "ckpt/latest.pt"), weights_only=True))
27
 
28
+ # Set up DiffusionManager
29
+ self.wrapper = DiffusionManager(self.net, device=self.device, noise_steps=1000)
30
+ self.wrapper.set_schedule(Schedule.LINEAR)
31
 
32
+ def __call__(self, prompt: str, amt: int = 1):
33
+ # Generate images based on the prompt
34
+ return self.generate_sample_save_images(prompt, amt)
35
 
36
+ def generate_sample_save_images(self, prompt: str, amt: int = 1):
37
+ # Prepare the output path
38
+ output_path = os.path.join(self.model_dir, "inferred",
39
+ re.sub(r'[^a-zA-Z\s]', '', prompt).replace(" ", "_") + str(int(time.time())) + ".png")
40
 
41
+ # Vectorize the prompt
42
+ vprompt = vectorize_text_with_bert(prompt).unsqueeze(0)
43
 
44
+ # Generate images
45
+ generated = self.wrapper.sample(64, vprompt, amt=amt).detach().cpu()
46
 
47
+ # Save images using the provided save function
48
+ save_grid_with_label(torchvision.utils.make_grid(generated), prompt, output_path)
49
 
50
+ return output_path # Return the path to the saved image
 
 
51
 
 
52
 
53
+ # Usage example
54
+ if __name__ == "__main__":
55
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
56
+ model_dir = "runs/run_3_jxa" # Path to your model directory
57
 
58
+ # Create an instance of the pipeline
59
+ pipeline = TextToImagePipeline(model_dir=model_dir, device=device)
60
 
61
+ # Get user input and generate an image
62
+ prompt = input("Prompt? ")
63
+ image_path = pipeline(prompt, amt=8)
64
+ print(f"Generated image saved at: {image_path}")
uploadify.py ADDED
File without changes