KAHABKALU commited on
Commit
d47670f
·
verified ·
1 Parent(s): d4339b9

Upload generate_images_direct.py

Browse files
Files changed (1) hide show
  1. generate_images_direct.py +446 -0
generate_images_direct.py ADDED
@@ -0,0 +1,446 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 6,
6
+ "id": "6b7a883f-d686-4cd8-b625-7633d078f373",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stdout",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "Using device: cuda\n",
14
+ "Loading VAE...\n",
15
+ "Loading tokenizer and text encoder...\n",
16
+ "Loading trained UNet...\n"
17
+ ]
18
+ },
19
+ {
20
+ "name": "stdin",
21
+ "output_type": "stream",
22
+ "text": [
23
+ "Enter your text prompt (e.g., 'A friendly dragon'): A childran in Cyberpunk\n"
24
+ ]
25
+ },
26
+ {
27
+ "name": "stdout",
28
+ "output_type": "stream",
29
+ "text": [
30
+ "🎨 Generating 256x256 images...\n",
31
+ "Generating: A childran in Cyberpunk\n",
32
+ "Text embeddings shape: torch.Size([1, 77, 768]), device: cuda:0\n",
33
+ "Initial latents shape: torch.Size([1, 4, 32, 32]), device: cuda:0\n"
34
+ ]
35
+ },
36
+ {
37
+ "name": "stderr",
38
+ "output_type": "stream",
39
+ "text": [
40
+ "Denoising A childran in Cyberpunk: 100%|███████████████████████████████████████████████| 50/50 [00:06<00:00, 7.63it/s]\n"
41
+ ]
42
+ },
43
+ {
44
+ "name": "stdout",
45
+ "output_type": "stream",
46
+ "text": [
47
+ "Final latents shape: torch.Size([1, 4, 32, 32])\n",
48
+ "✅ Saved: output/generated_256_1_A_childran_in_Cyberpunk.png\n"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "import torch\n",
54
+ "import random\n",
55
+ "import numpy as np\n",
56
+ "import torch\n",
57
+ "import torch.nn as nn\n",
58
+ "import torch.nn.functional as F\n",
59
+ "import torch.optim as optim\n",
60
+ "from torch.utils.data import Dataset, DataLoader\n",
61
+ "import torchvision.transforms as T\n",
62
+ "from PIL import Image\n",
63
+ "import os\n",
64
+ "import json\n",
65
+ "from tqdm import tqdm\n",
66
+ "from diffusers import AutoencoderKL, DDPMScheduler, StableDiffusionPipeline\n",
67
+ "from transformers import CLIPTokenizer, CLIPTextModel\n",
68
+ "def seed_everything(seed=42):\n",
69
+ " torch.manual_seed(seed)\n",
70
+ " torch.cuda.manual_seed(seed)\n",
71
+ " torch.cuda.manual_seed_all(seed)\n",
72
+ " random.seed(seed)\n",
73
+ " np.random.seed(seed)\n",
74
+ " torch.backends.cudnn.deterministic = True\n",
75
+ " torch.backends.cudnn.benchmark = False\n",
76
+ "\n",
77
+ "seed_everything(42)\n",
78
+ "# Sinusoidal timestep embedding for diffusion steps\n",
79
+ "def get_timestep_embedding(timesteps, embedding_dim):\n",
80
+ " half_dim = embedding_dim // 2\n",
81
+ " emb = torch.exp(\n",
82
+ " torch.arange(half_dim, dtype=torch.float32, device=timesteps.device) *\n",
83
+ " -(torch.log(torch.tensor(10000.0)) / half_dim)\n",
84
+ " )\n",
85
+ " emb = timesteps.float()[:, None] * emb[None, :]\n",
86
+ " emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1)\n",
87
+ " if embedding_dim % 2 == 1: # Handle odd embedding dimensions\n",
88
+ " emb = torch.cat([emb, torch.zeros_like(emb[:, :1])], dim=1)\n",
89
+ " return emb\n",
90
+ "\n",
91
+ "# Residual block with time and context embeddings\n",
92
+ "class ResidualBlock(nn.Module):\n",
93
+ " def __init__(self, in_channels, out_channels, time_emb_dim, context_dim=None):\n",
94
+ " super().__init__()\n",
95
+ " self.norm1 = nn.GroupNorm(min(32, in_channels), in_channels)\n",
96
+ " self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)\n",
97
+ " self.norm2 = nn.GroupNorm(min(32, out_channels), out_channels)\n",
98
+ " self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)\n",
99
+ " self.time_mlp = nn.Linear(time_emb_dim, out_channels)\n",
100
+ " self.context_proj = nn.Linear(context_dim, out_channels) if context_dim else None\n",
101
+ " self.shortcut = nn.Conv2d(in_channels, out_channels, 1) if in_channels != out_channels else nn.Identity()\n",
102
+ "\n",
103
+ " def forward(self, x, t_emb, context=None):\n",
104
+ " h = self.norm1(x)\n",
105
+ " h = F.silu(h)\n",
106
+ " h = self.conv1(h)\n",
107
+ "\n",
108
+ " # Add time embedding\n",
109
+ " t_proj = self.time_mlp(t_emb)[:, :, None, None]\n",
110
+ " h = h + t_proj\n",
111
+ "\n",
112
+ " # Add context embedding if available\n",
113
+ " if self.context_proj is not None and context is not None:\n",
114
+ " context_pooled = context.mean(dim=1) # [batch, context_dim]\n",
115
+ " context_proj = self.context_proj(context_pooled)[:, :, None, None]\n",
116
+ " h = h + context_proj\n",
117
+ "\n",
118
+ " h = self.norm2(h)\n",
119
+ " h = F.silu(h)\n",
120
+ " h = self.conv2(h)\n",
121
+ "\n",
122
+ " return h + self.shortcut(x)\n",
123
+ "\n",
124
+ "# Cross-attention to integrate text embeddings\n",
125
+ "class CrossAttention(nn.Module):\n",
126
+ " def __init__(self, channels, context_dim):\n",
127
+ " super().__init__()\n",
128
+ " self.channels = channels\n",
129
+ " self.query = nn.Linear(channels, channels)\n",
130
+ " self.key = nn.Linear(context_dim, channels)\n",
131
+ " self.value = nn.Linear(context_dim, channels)\n",
132
+ " self.out = nn.Linear(channels, channels)\n",
133
+ " self.norm = nn.LayerNorm(channels)\n",
134
+ "\n",
135
+ " def forward(self, x, context):\n",
136
+ " if context is None:\n",
137
+ " return x\n",
138
+ "\n",
139
+ " B, C, H, W = x.shape\n",
140
+ " x_flat = x.permute(0, 2, 3, 1).reshape(B, H * W, C)\n",
141
+ " x_norm = self.norm(x_flat)\n",
142
+ "\n",
143
+ " q = self.query(x_norm) # [B, H*W, C]\n",
144
+ " k = self.key(context) # [B, seq_len, C]\n",
145
+ " v = self.value(context) # [B, seq_len, C]\n",
146
+ "\n",
147
+ " scale = (C ** -0.5)\n",
148
+ " attn_weights = torch.bmm(q, k.transpose(1, 2)) * scale\n",
149
+ " attn_weights = F.softmax(attn_weights, dim=-1)\n",
150
+ " attn_out = torch.bmm(attn_weights, v)\n",
151
+ " attn_out = self.out(attn_out)\n",
152
+ "\n",
153
+ " attn_out = attn_out.reshape(B, H, W, C).permute(0, 3, 1, 2)\n",
154
+ " return x + attn_out\n",
155
+ "\n",
156
+ "# Self-attention block for image features\n",
157
+ "class AttentionBlock(nn.Module):\n",
158
+ " def __init__(self, channels):\n",
159
+ " super().__init__()\n",
160
+ " self.norm = nn.GroupNorm(min(32, channels), channels)\n",
161
+ " self.qkv = nn.Conv2d(channels, channels * 3, 1)\n",
162
+ " self.proj = nn.Conv2d(channels, channels, 1)\n",
163
+ "\n",
164
+ " def forward(self, x):\n",
165
+ " B, C, H, W = x.shape\n",
166
+ " h = self.norm(x)\n",
167
+ " qkv = self.qkv(h).reshape(B, 3, C, H * W)\n",
168
+ " q, k, v = qkv[:, 0], qkv[:, 1], qkv[:, 2]\n",
169
+ "\n",
170
+ " scale = (C ** -0.5)\n",
171
+ " attn = torch.bmm(q.transpose(1, 2), k) * scale\n",
172
+ " attn = F.softmax(attn, dim=-1)\n",
173
+ "\n",
174
+ " out = torch.bmm(v, attn.transpose(1, 2))\n",
175
+ " out = out.reshape(B, C, H, W)\n",
176
+ " return self.proj(out) + x\n",
177
+ "\n",
178
+ "# U-Net model updated for 256x256 latents\n",
179
+ "class UNetConditional(nn.Module):\n",
180
+ " def __init__(self, in_channels=4, base_channels=128, context_dim=768):\n",
181
+ " super().__init__()\n",
182
+ " self.time_emb_dim = base_channels * 4\n",
183
+ " from types import SimpleNamespace\n",
184
+ " self.config = SimpleNamespace()\n",
185
+ " self.config._diffusers_version = \"0.34.0\"\n",
186
+ " self.config.in_channels = in_channels\n",
187
+ " self.config.out_channels = in_channels\n",
188
+ " self.config.sample_size = 256 # Updated for 256x256 latents\n",
189
+ " self.config.layers_per_block = 2\n",
190
+ " self.config.block_out_channels = [base_channels, base_channels * 2, base_channels * 4, base_channels * 8]\n",
191
+ " self.config.attention_head_dim = 8\n",
192
+ " self.config.cross_attention_dim = context_dim\n",
193
+ "\n",
194
+ " # Time embedding MLP\n",
195
+ " self.time_mlp = nn.Sequential(\n",
196
+ " nn.Linear(base_channels, self.time_emb_dim),\n",
197
+ " nn.SiLU(),\n",
198
+ " nn.Linear(self.time_emb_dim, self.time_emb_dim),\n",
199
+ " )\n",
200
+ "\n",
201
+ " # Input projection\n",
202
+ " self.input_conv = nn.Conv2d(in_channels, base_channels, 3, padding=1)\n",
203
+ "\n",
204
+ " # Encoder\n",
205
+ " self.down1 = ResidualBlock(base_channels, base_channels * 2, self.time_emb_dim, context_dim)\n",
206
+ " self.downsample1 = nn.Conv2d(base_channels * 2, base_channels * 2, 3, stride=2, padding=1)\n",
207
+ " self.cross1 = CrossAttention(base_channels * 2, context_dim)\n",
208
+ "\n",
209
+ " self.down2 = ResidualBlock(base_channels * 2, base_channels * 4, self.time_emb_dim, context_dim)\n",
210
+ " self.downsample2 = nn.Conv2d(base_channels * 4, base_channels * 4, 3, stride=2, padding=1)\n",
211
+ " self.cross2 = CrossAttention(base_channels * 4, context_dim)\n",
212
+ "\n",
213
+ " self.down3 = ResidualBlock(base_channels * 4, base_channels * 8, self.time_emb_dim, context_dim)\n",
214
+ " self.downsample3 = nn.Conv2d(base_channels * 8, base_channels * 8, 3, stride=2, padding=1)\n",
215
+ " self.cross3 = CrossAttention(base_channels * 8, context_dim)\n",
216
+ "\n",
217
+ " # Middle\n",
218
+ " self.middle1 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)\n",
219
+ " self.middle_attn = AttentionBlock(base_channels * 8)\n",
220
+ " self.middle2 = ResidualBlock(base_channels * 8, base_channels * 8, self.time_emb_dim, context_dim)\n",
221
+ "\n",
222
+ " # Decoder\n",
223
+ " self.up3 = ResidualBlock(base_channels * 16, base_channels * 4, self.time_emb_dim, context_dim)\n",
224
+ " self.upsample3 = nn.ConvTranspose2d(base_channels * 4, base_channels * 4, 4, stride=2, padding=1)\n",
225
+ " self.cross_up3 = CrossAttention(base_channels * 4, context_dim)\n",
226
+ "\n",
227
+ " self.up2 = ResidualBlock(base_channels * 8, base_channels * 2, self.time_emb_dim, context_dim)\n",
228
+ " self.upsample2 = nn.ConvTranspose2d(base_channels * 2, base_channels * 2, 4, stride=2, padding=1)\n",
229
+ " self.cross_up2 = CrossAttention(base_channels * 2, context_dim)\n",
230
+ "\n",
231
+ " self.up1 = ResidualBlock(base_channels * 4, base_channels, self.time_emb_dim, context_dim)\n",
232
+ " self.upsample1 = nn.ConvTranspose2d(base_channels, base_channels, 4, stride=2, padding=1)\n",
233
+ "\n",
234
+ " # Output\n",
235
+ " self.output_conv = nn.Sequential(\n",
236
+ " nn.GroupNorm(min(32, base_channels), base_channels),\n",
237
+ " nn.SiLU(),\n",
238
+ " nn.Conv2d(base_channels, in_channels, 3, padding=1)\n",
239
+ " )\n",
240
+ "\n",
241
+ " def forward(self, x, t, context, cfg_scale=1.0):\n",
242
+ " t_emb = get_timestep_embedding(t, self.time_emb_dim // 4)\n",
243
+ " t_emb = self.time_mlp(t_emb)\n",
244
+ "\n",
245
+ " def denoise(x, t_emb, context):\n",
246
+ " h = self.input_conv(x)\n",
247
+ "\n",
248
+ " # Encoder\n",
249
+ " h1 = self.down1(h, t_emb, context)\n",
250
+ " h1_cross = self.cross1(h1, context)\n",
251
+ " h1_down = self.downsample1(h1_cross)\n",
252
+ "\n",
253
+ " h2 = self.down2(h1_down, t_emb, context)\n",
254
+ " h2_cross = self.cross2(h2, context)\n",
255
+ " h2_down = self.downsample2(h2_cross)\n",
256
+ "\n",
257
+ " h3 = self.down3(h2_down, t_emb, context)\n",
258
+ " h3_cross = self.cross3(h3, context)\n",
259
+ " h3_down = self.downsample3(h3_cross)\n",
260
+ "\n",
261
+ " # Middle\n",
262
+ " h_mid = self.middle1(h3_down, t_emb, context)\n",
263
+ " h_mid = self.middle_attn(h_mid)\n",
264
+ " h_mid = self.middle2(h_mid, t_emb, context)\n",
265
+ "\n",
266
+ " # Decoder\n",
267
+ " h3_cross_resized = F.interpolate(h3_cross, size=h_mid.shape[-2:], mode='nearest')\n",
268
+ " h = self.up3(torch.cat([h_mid, h3_cross_resized], dim=1), t_emb, context)\n",
269
+ " h = self.upsample3(h)\n",
270
+ " h = self.cross_up3(h, context)\n",
271
+ "\n",
272
+ " h2_cross_resized = F.interpolate(h2_cross, size=h.shape[-2:], mode='nearest')\n",
273
+ " h = self.up2(torch.cat([h, h2_cross_resized], dim=1), t_emb, context)\n",
274
+ " h = self.upsample2(h)\n",
275
+ " h = self.cross_up2(h, context)\n",
276
+ "\n",
277
+ " h1_cross_resized = F.interpolate(h1_cross, size=h.shape[-2:], mode='nearest')\n",
278
+ " h = self.up1(torch.cat([h, h1_cross_resized], dim=1), t_emb, context)\n",
279
+ " h = self.upsample1(h)\n",
280
+ "\n",
281
+ " return self.output_conv(h)\n",
282
+ "\n",
283
+ " if cfg_scale == 1.0 or context is None:\n",
284
+ " return denoise(x, t_emb, context)\n",
285
+ "\n",
286
+ " uncond = denoise(x, t_emb, context=None)\n",
287
+ " cond = denoise(x, t_emb, context)\n",
288
+ " return uncond + cfg_scale * (cond - uncond)\n",
289
+ "import torch\n",
290
+ "from diffusers import AutoencoderKL, DDPMScheduler\n",
291
+ "from transformers import CLIPTextModel, CLIPTokenizer\n",
292
+ "from PIL import Image\n",
293
+ "import numpy as np\n",
294
+ "from tqdm import tqdm\n",
295
+ "import argparse\n",
296
+ "import sys\n",
297
+ "\n",
298
+ "\n",
299
+ "\n",
300
+ "def seed_everything(seed):\n",
301
+ " torch.manual_seed(seed)\n",
302
+ " torch.cuda.manual_seed_all(seed)\n",
303
+ " np.random.seed(seed)\n",
304
+ "\n",
305
+ "def generate_images_direct(unet_path=\"KahabMinGenT2Im-v1.pt\", device=\"cuda\", output_dir=\"output\", prompt=None,timesteps=50):\n",
306
+ " \"\"\"Generate 256x256 images with a custom UNet and user-specified text prompt\"\"\"\n",
307
+ " seed_everything(42)\n",
308
+ " print(f\"Using device: {device}\")\n",
309
+ "\n",
310
+ " # Load components\n",
311
+ " print(\"Loading VAE...\")\n",
312
+ " vae = AutoencoderKL.from_pretrained(\"CompVis/stable-diffusion-v1-4\", subfolder=\"vae\").to(device).eval().requires_grad_(False)\n",
313
+ "\n",
314
+ " print(\"Loading tokenizer and text encoder...\")\n",
315
+ " tokenizer = CLIPTokenizer.from_pretrained(\"openai/clip-vit-large-patch14\")\n",
316
+ " text_encoder = CLIPTextModel.from_pretrained(\"openai/clip-vit-large-patch14\").to(device).eval().requires_grad_(False)\n",
317
+ "\n",
318
+ " print(\"Loading trained UNet...\")\n",
319
+ " unet = UNetConditional(in_channels=4, base_channels=128, context_dim=768)\n",
320
+ " checkpoint = torch.load(unet_path, map_location=device, weights_only=True)\n",
321
+ " unet.load_state_dict(checkpoint['model_state_dict'])\n",
322
+ " unet = unet.to(device).eval()\n",
323
+ "\n",
324
+ " # Create scheduler\n",
325
+ " scheduler = DDPMScheduler(num_train_timesteps=1000)\n",
326
+ "\n",
327
+ " # Get prompt from user if not provided\n",
328
+ " if prompt is None:\n",
329
+ " # Check if running in Jupyter\n",
330
+ " if 'ipykernel' in sys.modules:\n",
331
+ " prompt = input(\"Enter your text prompt (e.g., 'A friendly dragon'): \").strip()\n",
332
+ " else:\n",
333
+ " prompt = \"\" # Will be handled by argparse default or user input\n",
334
+ " if not prompt:\n",
335
+ " prompt = \"A friendly dragon\" # Default prompt if empty\n",
336
+ "\n",
337
+ " test_prompts = [prompt]\n",
338
+ "\n",
339
+ " print(\"🎨 Generating 256x256 images...\")\n",
340
+ " for i, prompt in enumerate(test_prompts):\n",
341
+ " print(f\"Generating: {prompt}\")\n",
342
+ " try:\n",
343
+ " with torch.no_grad():\n",
344
+ " # Encode prompt\n",
345
+ " inputs = tokenizer(\n",
346
+ " prompt,\n",
347
+ " padding=\"max_length\",\n",
348
+ " truncation=True,\n",
349
+ " max_length=77,\n",
350
+ " return_tensors=\"pt\"\n",
351
+ " )\n",
352
+ " inputs = {k: v.to(device) for k, v in inputs.items()}\n",
353
+ " text_embeddings = text_encoder(**inputs).last_hidden_state\n",
354
+ " print(f\"Text embeddings shape: {text_embeddings.shape}, device: {text_embeddings.device}\")\n",
355
+ "\n",
356
+ " # Create random latents for 256x256 output (256/8 = 32 due to VAE scaling)\n",
357
+ " latents = torch.randn(1, 4, 32, 32, device=device, dtype=torch.float32)\n",
358
+ " print(f\"Initial latents shape: {latents.shape}, device: {latents.device}\")\n",
359
+ "\n",
360
+ " # Set timesteps\n",
361
+ " scheduler.set_timesteps(timesteps)\n",
362
+ "\n",
363
+ " # Denoising loop\n",
364
+ " for t in tqdm(scheduler.timesteps, desc=f\"Denoising {prompt}\"):\n",
365
+ " t_tensor = torch.tensor([t], device=device, dtype=torch.long)\n",
366
+ " noise_pred = unet(latents, t_tensor, context=text_embeddings)\n",
367
+ " latents = scheduler.step(noise_pred, t, latents).prev_sample\n",
368
+ "\n",
369
+ " print(f\"Final latents shape: {latents.shape}\")\n",
370
+ "\n",
371
+ " # Decode latents to image\n",
372
+ " latents = latents / 0.18215\n",
373
+ " images = vae.decode(latents).sample\n",
374
+ " images = (images / 2 + 0.5).clamp(0, 1) # Denormalize\n",
375
+ " images = images.cpu().permute(0, 2, 3, 1).numpy()\n",
376
+ " image = Image.fromarray((images[0] * 255).astype(np.uint8))\n",
377
+ "\n",
378
+ " # Save\n",
379
+ " filename = f\"{output_dir}/generated_256_{i+1}_{prompt.replace(' ', '_')}.png\"\n",
380
+ " image.save(filename)\n",
381
+ " print(f\"✅ Saved: {filename}\")\n",
382
+ "\n",
383
+ " except Exception as e:\n",
384
+ " print(f\"❌ Error generating '{prompt}': {e}\")\n",
385
+ " print(f\"Error type: {type(e).__name__}\")\n",
386
+ " continue\n",
387
+ "\n",
388
+ "def main():\n",
389
+ " # Check if running in Jupyter\n",
390
+ " if 'ipykernel' in sys.modules:\n",
391
+ " generate_images_direct(\n",
392
+ " unet_path=\"KahabMinGenT2Im-v1.pt\",\n",
393
+ " device=\"cuda\" if torch.cuda.is_available() else \"cpu\",\n",
394
+ " output_dir=\"output\",\n",
395
+ " prompt=None\n",
396
+ " )\n",
397
+ " else:\n",
398
+ " parser = argparse.ArgumentParser(description=\"Generate images with custom UNet and text prompt\")\n",
399
+ " parser.add_argument(\"--unet_path\", type=str, default=\"KahabMinGenT2Im-v1.pt\", help=\"Path to UNet checkpoint\")\n",
400
+ " parser.add_argument(\"--device\", type=str, default=\"cuda\" if torch.cuda.is_available() else \"cpu\", help=\"Device to use (cuda or cpu)\")\n",
401
+ " parser.add_argument(\"--output_dir\", type=str, default=\"output\", help=\"Output directory for generated images\")\n",
402
+ " parser.add_argument(\"--prompt\", type=str, default=None, help=\"Text prompt for image generation\")\n",
403
+ " args = parser.parse_args()\n",
404
+ "\n",
405
+ " generate_images_direct(\n",
406
+ " unet_path=args.unet_path,\n",
407
+ " device=args.device,\n",
408
+ " output_dir=args.output_dir,\n",
409
+ " prompt=args.prompt\n",
410
+ " )\n",
411
+ "\n",
412
+ "if __name__ == \"__main__\":\n",
413
+ " main()"
414
+ ]
415
+ },
416
+ {
417
+ "cell_type": "code",
418
+ "execution_count": null,
419
+ "id": "7a86f43b-1e8e-4ead-bcf5-c8ff9f065782",
420
+ "metadata": {},
421
+ "outputs": [],
422
+ "source": []
423
+ }
424
+ ],
425
+ "metadata": {
426
+ "kernelspec": {
427
+ "display_name": "Python 3 (ipykernel)",
428
+ "language": "python",
429
+ "name": "python3"
430
+ },
431
+ "language_info": {
432
+ "codemirror_mode": {
433
+ "name": "ipython",
434
+ "version": 3
435
+ },
436
+ "file_extension": ".py",
437
+ "mimetype": "text/x-python",
438
+ "name": "python",
439
+ "nbconvert_exporter": "python",
440
+ "pygments_lexer": "ipython3",
441
+ "version": "3.10.18"
442
+ }
443
+ },
444
+ "nbformat": 4,
445
+ "nbformat_minor": 5
446
+ }