KAHABKALU commited on
Commit
9fc7ce7
·
verified ·
1 Parent(s): 9aed40e

Upload generate_images_direct.py

Browse files
Files changed (1) hide show
  1. generate_images_direct.py +361 -0
generate_images_direct.py ADDED
@@ -0,0 +1,361 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import random
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ import torch.optim as optim
8
+ from torch.utils.data import Dataset, DataLoader
9
+ import torchvision.transforms as T
10
+ from PIL import Image
11
+ import os
12
+ import json
13
+ from tqdm import tqdm
14
+ from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline
15
+ from transformers import CLIPTokenizer, CLIPTextModel
16
+ def seed_everything(seed=42):
17
+ torch.manual_seed(seed)
18
+ torch.cuda.manual_seed(seed)
19
+ torch.cuda.manual_seed_all(seed)
20
+ random.seed(seed)
21
+ np.random.seed(seed)
22
+ torch.backends.cudnn.deterministic = True
23
+ torch.backends.cudnn.benchmark = False
24
+
25
+ seed_everything(42)
26
+ # Sinusoidal timestep embedding for diffusion steps
27
+ def get_timestep_embedding(timesteps, embedding_dim):
28
+ half_dim = embedding_dim // 2
29
+ emb = torch.exp(
30
+ torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) *
31
+ -(torch.log(torch.tensor(10000.0)) / half_dim)
32
+ )
33
+ emb = timesteps.float()[:, None] * emb[None, :]
34
+ emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)
35
+ if embedding_dim % 2 == 1: # Handle odd embedding dimensions
36
+ emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)
37
+ return emb
38
+
39
+ # Residual block with time and context embeddings
40
+ class ResidualBlock(nn.Module):
41
+ def __init__(self, in_channels, out_channels, time_emb_dim, context_dim=None):
42
+ super().__init__()
43
+ self.norm1 = nn.GroupNorm(min(32, in_channels), in_channels)
44
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
45
+ self.norm2 = nn.GroupNorm(min(32, out_channels), out_channels)
46
+ self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
47
+ self.time_mlp = nn.Linear(time_emb_dim, out_channels)
48
+ self.context_proj = nn.Linear(context_dim, out_channels) if context_dim else None
49
+ self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()
50
+
51
+ def forward(self, x, t_emb, context=None):
52
+ h = self.norm1(x)
53
+ h = F.silu(h)
54
+ h = self.conv1(h)
55
+
56
+ # Add time embedding
57
+ t_proj = self.time_mlp(t_emb)[:, :, None, None]
58
+ h = h + t_proj
59
+
60
+ # Add context embedding if available
61
+ if self.context_proj is not None and context is not None:
62
+ context_pooled = context.mean(dim=1) # [batch, context_dim]
63
+ context_proj = self.context_proj(context_pooled)[:, :, None, None]
64
+ h = h + context_proj
65
+
66
+ h = self.norm2(h)
67
+ h = F.silu(h)
68
+ h = self.conv2(h)
69
+
70
+ return h + self.shortcut(x)
71
+
72
+ # Cross-attention to integrate text embeddings
73
+ class CrossAttention(nn.Module):
74
+ def __init__(self, channels, context_dim):
75
+ super().__init__()
76
+ self.channels = channels
77
+ self.query = nn.Linear(channels, channels)
78
+ self.key = nn.Linear(context_dim, channels)
79
+ self.value = nn.Linear(context_dim, channels)
80
+ self.out = nn.Linear(channels, channels)
81
+ self.norm = nn.LayerNorm(channels)
82
+
83
+ def forward(self, x, context):
84
+ if context is None:
85
+ return x
86
+
87
+ B, C, H, W = x.shape
88
+ x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
89
+ x_norm = self.norm(x_flat)
90
+
91
+ q = self.query(x_norm) # [B, H*W, C]
92
+ k = self.key(context) # [B, seq_len, C]
93
+ v = self.value(context) # [B, seq_len, C]
94
+
95
+ scale = (C ** -0.5)
96
+ attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale
97
+ attn_weights = F.softmax(attn_weights, dim=-1)
98
+ attn_out = torch.bmm(attn_weights, v)
99
+ attn_out = self.out(attn_out)
100
+
101
+ attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)
102
+ return x + attn_out
103
+
104
+ # Self-attention block for image features
105
+ class AttentionBlock(nn.Module):
106
+ def __init__(self, channels):
107
+ super().__init__()
108
+ self.norm = nn.GroupNorm(min(32, channels), channels)
109
+ self.qkv = nn.Conv2d(channels, channels * 3, 1)
110
+ self.proj = nn.Conv2d(channels, channels, 1)
111
+
112
+ def forward(self, x):
113
+ B, C, H, W = x.shape
114
+ h = self.norm(x)
115
+ qkv = self.qkv(h).reshape(B, 3, C, H * W)
116
+ q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]
117
+
118
+ scale = (C ** -0.5)
119
+ attn = torch.bmm(q.transpose(1, 2), k) * scale
120
+ attn = F.softmax(attn, dim=-1)
121
+
122
+ out = torch.bmm(v, attn.transpose(1, 2))
123
+ out = out.reshape(B, C, H, W)
124
+ return self.proj(out) + x
125
+
126
+ # U-Net model updated for 256x256 latents
127
+ class UNetConditional(nn.Module):
128
+ def __init__(self, in_channels=4, base_channels=128, context_dim=768):
129
+ super().__init__()
130
+ self.time_emb_dim = base_channels * 4
131
+ from types import SimpleNamespace
132
+ self.config = SimpleNamespace()
133
+ self.config._diffusers_version = "0.34.0"
134
+ self.config.in_channels = in_channels
135
+ self.config.out_channels = in_channels
136
+ self.config.sample_size = 256 # Updated for 256x256 latents
137
+ self.config.layers_per_block = 2
138
+ self.config.block_out_channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 8]
139
+ self.config.attention_head_dim = 8
140
+ self.config.cross_attention_dim = context_dim
141
+
142
+ # Time embedding MLP
143
+ self.time_mlp = nn.Sequential(
144
+ nn.Linear(base_channels, self.time_emb_dim),
145
+ nn.SiLU(),
146
+ nn.Linear(self.time_emb_dim, self.time_emb_dim),
147
+ )
148
+
149
+ # Input projection
150
+ self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)
151
+
152
+ # Encoder
153
+ self.down1 = ResidualBlock(base_channels, base_channels * 2, self.time_emb_dim, context_dim)
154
+ self.downsample1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)
155
+ self.cross1 = CrossAttention(base_channels * 2, context_dim)
156
+
157
+ self.down2 = ResidualBlock(base_channels * 2, base_channels * 4, self.time_emb_dim, context_dim)
158
+ self.downsample2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, stride=2, padding=1)
159
+ self.cross2 = CrossAttention(base_channels * 4, context_dim)
160
+
161
+ self.down3 = ResidualBlock(base_channels * 4, base_channels * 8, self.time_emb_dim, context_dim)
162
+ self.downsample3 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)
163
+ self.cross3 = CrossAttention(base_channels * 8, context_dim)
164
+
165
+ # Middle
166
+ self.middle1 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)
167
+ self.middle_attn = AttentionBlock(base_channels * 8)
168
+ self.middle2 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)
169
+
170
+ # Decoder
171
+ self.up3 = ResidualBlock(base_channels * 16, base_channels * 4, self.time_emb_dim, context_dim)
172
+ self.upsample3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)
173
+ self.cross_up3 = CrossAttention(base_channels * 4, context_dim)
174
+
175
+ self.up2 = ResidualBlock(base_channels * 8, base_channels * 2, self.time_emb_dim, context_dim)
176
+ self.upsample2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)
177
+ self.cross_up2 = CrossAttention(base_channels * 2, context_dim)
178
+
179
+ self.up1 = ResidualBlock(base_channels * 4, base_channels, self.time_emb_dim, context_dim)
180
+ self.upsample1 = nn.ConvTranspose2d(base_channels, base_channels, 4, stride=2, padding=1)
181
+
182
+ # Output
183
+ self.output_conv = nn.Sequential(
184
+ nn.GroupNorm(min(32, base_channels), base_channels),
185
+ nn.SiLU(),
186
+ nn.Conv2d(base_channels, in_channels, 3, padding=1)
187
+ )
188
+
189
+ def forward(self, x, t, context, cfg_scale=1.0):
190
+ t_emb = get_timestep_embedding(t, self.time_emb_dim // 4)
191
+ t_emb = self.time_mlp(t_emb)
192
+
193
+ def denoise(x, t_emb, context):
194
+ h = self.input_conv(x)
195
+
196
+ # Encoder
197
+ h1 = self.down1(h, t_emb, context)
198
+ h1_cross = self.cross1(h1, context)
199
+ h1_down = self.downsample1(h1_cross)
200
+
201
+ h2 = self.down2(h1_down, t_emb, context)
202
+ h2_cross = self.cross2(h2, context)
203
+ h2_down = self.downsample2(h2_cross)
204
+
205
+ h3 = self.down3(h2_down, t_emb, context)
206
+ h3_cross = self.cross3(h3, context)
207
+ h3_down = self.downsample3(h3_cross)
208
+
209
+ # Middle
210
+ h_mid = self.middle1(h3_down, t_emb, context)
211
+ h_mid = self.middle_attn(h_mid)
212
+ h_mid = self.middle2(h_mid, t_emb, context)
213
+
214
+ # Decoder
215
+ h3_cross_resized = F.interpolate(h3_cross, size=h_mid.shape[-2:], mode='nearest')
216
+ h = self.up3(torch.cat([h_mid, h3_cross_resized], dim=1), t_emb, context)
217
+ h = self.upsample3(h)
218
+ h = self.cross_up3(h, context)
219
+
220
+ h2_cross_resized = F.interpolate(h2_cross, size=h.shape[-2:], mode='nearest')
221
+ h = self.up2(torch.cat([h, h2_cross_resized], dim=1), t_emb, context)
222
+ h = self.upsample2(h)
223
+ h = self.cross_up2(h, context)
224
+
225
+ h1_cross_resized = F.interpolate(h1_cross, size=h.shape[-2:], mode='nearest')
226
+ h = self.up1(torch.cat([h, h1_cross_resized], dim=1), t_emb, context)
227
+ h = self.upsample1(h)
228
+
229
+ return self.output_conv(h)
230
+
231
+ if cfg_scale == 1.0 or context is None:
232
+ return denoise(x, t_emb, context)
233
+
234
+ uncond = denoise(x, t_emb, context=None)
235
+ cond = denoise(x, t_emb, context)
236
+ return uncond + cfg_scale * (cond - uncond)
237
+ import torch
238
+ from diffusers import AutoencoderKL, DDPMScheduler
239
+ from transformers import CLIPTextModel, CLIPTokenizer
240
+ from PIL import Image
241
+ import numpy as np
242
+ from tqdm import tqdm
243
+ import argparse
244
+ import sys
245
+
246
+
247
+
248
+ def seed_everything(seed):
249
+ torch.manual_seed(seed)
250
+ torch.cuda.manual_seed_all(seed)
251
+ np.random.seed(seed)
252
+
253
+ def generate_images_direct(unet_path="output/KahabMinGenT2Im-v1.pt", device="cuda", output_dir="output", prompt=None,num_inference_steps=50):
254
+ """Generate 256x256 images with a custom UNet and user-specified text prompt"""
255
+ seed_everything(42)
256
+ print(f"Using device: {device}")
257
+
258
+ # Load components
259
+ print("Loading VAE...")
260
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae").to(device).eval().requires_grad_(False)
261
+
262
+ print("Loading tokenizer and text encoder...")
263
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
264
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device).eval().requires_grad_(False)
265
+
266
+ print("Loading trained UNet...")
267
+ unet = UNetConditional(in_channels=4, base_channels=128, context_dim=768)
268
+ checkpoint = torch.load(unet_path, map_location=device, weights_only=True)
269
+ unet.load_state_dict(checkpoint['model_state_dict'])
270
+ unet = unet.to(device).eval()
271
+
272
+ # Create scheduler
273
+ scheduler = DDPMScheduler(num_inference_steps)
274
+
275
+ # Get prompt from user if not provided
276
+ if prompt is None:
277
+ # Check if running in Jupyter
278
+ if 'ipykernel' in sys.modules:
279
+ prompt = input("Enter your text prompt (e.g., 'A friendly dragon'): ").strip()
280
+ else:
281
+ prompt = "" # Will be handled by argparse default or user input
282
+ if not prompt:
283
+ prompt = "A friendly dragon" # Default prompt if empty
284
+
285
+ test_prompts = [prompt]
286
+
287
+ print("🎨 Generating 256x256 images...")
288
+ for i, prompt in enumerate(test_prompts):
289
+ print(f"Generating: {prompt}")
290
+ try:
291
+ with torch.no_grad():
292
+ # Encode prompt
293
+ inputs = tokenizer(
294
+ prompt,
295
+ padding="max_length",
296
+ truncation=True,
297
+ max_length=77,
298
+ return_tensors="pt"
299
+ )
300
+ inputs = {k: v.to(device) for k, v in inputs.items()}
301
+ text_embeddings = text_encoder(**inputs).last_hidden_state
302
+ print(f"Text embeddings shape: {text_embeddings.shape}, device: {text_embeddings.device}")
303
+
304
+ # Create random latents for 256x256 output (256/8 = 32 due to VAE scaling)
305
+ latents = torch.randn(1, 4, 32, 32, device=device, dtype=torch.float32)
306
+ print(f"Initial latents shape: {latents.shape}, device: {latents.device}")
307
+
308
+ # Set timesteps
309
+ scheduler.set_timesteps(num_inference_steps)
310
+
311
+ # Denoising loop
312
+ for t in tqdm(scheduler.timesteps, desc=f"Denoising {prompt}"):
313
+ t_tensor = torch.tensor([t], device=device, dtype=torch.long)
314
+ noise_pred = unet(latents, t_tensor, context=text_embeddings)
315
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
316
+
317
+ print(f"Final latents shape: {latents.shape}")
318
+
319
+ # Decode latents to image
320
+ latents = latents / 0.18215
321
+ images = vae.decode(latents).sample
322
+ images = (images / 2 + 0.5).clamp(0, 1) # Denormalize
323
+ images = images.cpu().permute(0, 2, 3, 1).numpy()
324
+ image = Image.fromarray((images[0] * 255).astype(np.uint8))
325
+
326
+ # Save
327
+ filename = f"{output_dir}/generated_256_{i+1}_{prompt.replace(' ', '_')}.png"
328
+ image.save(filename)
329
+ print(f"✅ Saved: {filename}")
330
+
331
+ except Exception as e:
332
+ print(f"❌ Error generating '{prompt}': {e}")
333
+ print(f"Error type: {type(e).__name__}")
334
+ continue
335
+
336
+ def main():
337
+ # Check if running in Jupyter
338
+ if 'ipykernel' in sys.modules:
339
+ generate_images_direct(
340
+ unet_path="output/KahabMinGenT2Im-v1.pt",
341
+ device="cuda" if torch.cuda.is_available() else "cpu",
342
+ output_dir="output",
343
+ prompt=None
344
+ )
345
+ else:
346
+ parser = argparse.ArgumentParser(description="Generate images with custom UNet and text prompt")
347
+ parser.add_argument("--unet_path", type=str, default="output/KahabMinGenT2Im-v1.pt", help="Path to UNet checkpoint")
348
+ parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device to use (cuda or cpu)")
349
+ parser.add_argument("--output_dir", type=str, default="output", help="Output directory for generated images")
350
+ parser.add_argument("--prompt", type=str, default=None, help="Text prompt for image generation")
351
+ args = parser.parse_args()
352
+
353
+ generate_images_direct(
354
+ unet_path=args.unet_path,
355
+ device=args.device,
356
+ output_dir=args.output_dir,
357
+ prompt=args.prompt
358
+ )
359
+
360
+ if __name__ == "__main__":
361
+ main()