Suyamprakasam commited on
Commit
7b93c28
ยท
verified ยท
1 Parent(s): 959d97a

Upload 2 files

Browse files
Files changed (2) hide show
  1. README.md +10 -0
  2. main2.py +1032 -0
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - audio-to-image
5
+ - stable-diffusion
6
+ ---
7
+
8
+ # Audio2Image Model
9
+
10
+ Generates images from audio using neural synthesis.
main2.py ADDED
@@ -0,0 +1,1032 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Audio โ†’ Image Generator (Multi-Task Loss Version)
3
+ Key features:
4
+ - Dual-head MLP: one for CLAP text space, one for SD embedding space
5
+ - Multi-task training: CLAP alignment loss + SD alignment loss
6
+ - Both heads are trained simultaneously
7
+ - to_sd head is properly trained and used during inference
8
+ """
9
+
10
+ # ========================
11
+ # Imports
12
+ # ========================
13
+ import os, math, csv, random, sys
14
+ from typing import List, Tuple
15
+ from dataclasses import dataclass
16
+ import zipfile
17
+ from io import BytesIO
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+ import torch.nn.functional as F
22
+ import torchaudio
23
+ from torch.utils.data import Dataset, DataLoader
24
+ from tqdm import tqdm
25
+
26
+ from transformers import AutoProcessor, ClapModel, AutoTokenizer, CLIPProcessor, CLIPModel
27
+ from diffusers import StableDiffusionPipeline, DDPMScheduler, DDIMScheduler
28
+ from PIL import Image
29
+ from torchvision import transforms
30
+
31
+
32
+ # ========================
33
+ # Configuration
34
+ # ========================
35
+ @dataclass
36
+ class Config:
37
+ CLAP_ID: str = "laion/clap-htsat-fused"
38
+ SD_ID: str = "runwayml/stable-diffusion-v1-5"
39
+ CLIP_ID: str = "openai/clip-vit-base-patch32"
40
+
41
+ # Device configuration - automatically uses GPU if available
42
+ device: str = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
43
+
44
+ lr: float = 2e-4
45
+ weight_decay: float = 1e-4
46
+ temperature: float = 0.07
47
+
48
+ # Multi-task loss weights
49
+ clap_loss_weight: float = 0.5
50
+ sd_loss_weight: float = 1.0
51
+ diffusion_loss_weight: float = 1.0
52
+
53
+ batch_size: int = 2 # Reduced for Mac GPU memory
54
+ max_epochs: int = 20
55
+ base_prompt: str = "A photo of"
56
+ guidance: float = 7.5
57
+ steps: int = 30
58
+
59
+ # Dataset paths
60
+ train_csv: str = "/Users/rajvarun/Desktop/SIT/Trimester 4/AAI 3001 - Computer Vision & Deep Learning/Seeing Sound II/raj/main_dataV1.csv"
61
+ image_folder: str = "/Users/rajvarun/OneDrive - Singapore Institute Of Technology/ALEXI KIZHAKKEPURATHU GEORGE's files - VGGSound" # OneDrive folder with ZIP files
62
+ ckpt_path: str = "audio2image_mapper_dual_best.pt"
63
+
64
+ # ZIP file support (if data is in ZIP files instead of extracted)
65
+ use_zip_files: bool = True # Set to True to read from ZIP files directly
66
+ zip_files: dict = None # Will be populated automatically
67
+
68
+ # Fine-tuning control
69
+ finetune_sd: bool = False # Set to False to train without images
70
+ sd_lr: float = 1e-5
71
+ freeze_vae: bool = True
72
+ freeze_text_encoder: bool = True
73
+
74
+ # Evaluation settings
75
+ eval_every_n_epochs: int = 1 # Evaluate every N epochs
76
+ num_eval_samples: int = 4 # Number of samples to evaluate per batch
77
+ save_eval_images: bool = True # Save example generated images
78
+
79
+
80
+ # ========================
81
+ # Dataset
82
+ # ========================
83
+ class AudioCaptionDataset(Dataset):
84
+ """
85
+ Reads a CSV file with audio-image-caption triplets.
86
+ Handles structure where data is in: base_folder/image/ and base_folder/audio/
87
+
88
+ Can read from extracted folders OR directly from ZIP files (no extraction needed!)
89
+
90
+ Example:
91
+ - CSV: vggsound_00,g-f_I2yQ_1.png,g-f_I2yQ_000001.wav,people marching
92
+ - Audio path: vggsound_00/audio/g-f_I2yQ_000001.wav
93
+ - Image path: vggsound_00/image/g-f_I2yQ_1.png
94
+ """
95
+ def __init__(self, captions_path: str, image_folder: str = None, use_zip_files: bool = False):
96
+ self.items = []
97
+ base_dir = os.path.dirname(captions_path)
98
+ self.image_folder = image_folder or base_dir
99
+ self.use_zip_files = use_zip_files
100
+ self.zip_handles = {} # Cache opened ZIP files
101
+
102
+ # Image preprocessing for SD (512x512, normalized to [-1, 1])
103
+ self.img_transform = transforms.Compose([
104
+ transforms.Resize((512, 512)),
105
+ transforms.ToTensor(),
106
+ transforms.Normalize([0.5], [0.5])
107
+ ])
108
+
109
+ print(f"Loading dataset from: {captions_path}")
110
+ print(f"Base folder: {self.image_folder}")
111
+ print(f"Use ZIP files: {use_zip_files}")
112
+
113
+ # If using ZIP files, find and open them
114
+ if use_zip_files:
115
+ self._find_zip_files()
116
+
117
+ # Read CSV file
118
+ import csv
119
+ with open(captions_path, "r", encoding="utf-8") as f:
120
+ reader = csv.DictReader(f)
121
+
122
+ for row_num, row in enumerate(reader, 1):
123
+ # CSV format: base_folder,image_file,audio_file,caption
124
+ if 'base_folder' in row and 'image_file' in row and 'audio_file' in row and 'caption' in row:
125
+ base_folder = row['base_folder'] # e.g., "vggsound_00"
126
+ img_filename = row['image_file'] # e.g., "g-f_I2yQ_1.png"
127
+ audio_filename = row['audio_file'] # e.g., "g-f_I2yQ_000001.wav"
128
+ caption = row['caption']
129
+
130
+ if use_zip_files:
131
+ # Use ZIP file paths
132
+ audio_path = f"{base_folder}/audio/{audio_filename}"
133
+ img_path = f"{base_folder}/image/{img_filename}"
134
+
135
+ # Check if files exist in ZIP
136
+ audio_exists = self._file_in_zip(base_folder, audio_path)
137
+ img_exists = self._file_in_zip(base_folder, img_path)
138
+
139
+ # Debug first few rows
140
+ if row_num <= 3:
141
+ print(f"Row {row_num}: base_folder='{base_folder}', audio='{audio_path}', exists={audio_exists}")
142
+ else:
143
+ # Use regular file paths
144
+ audio_path = os.path.join(self.image_folder, base_folder, "audio", audio_filename)
145
+ img_path = os.path.join(self.image_folder, base_folder, "image", img_filename)
146
+
147
+ audio_exists = os.path.exists(audio_path)
148
+ img_exists = os.path.exists(img_path)
149
+
150
+ if audio_exists:
151
+ if img_exists:
152
+ self.items.append((base_folder, audio_path, img_path, caption))
153
+ else:
154
+ # Audio exists but image doesn't
155
+ self.items.append((base_folder, audio_path, None, caption))
156
+ if row_num <= 3:
157
+ print(f"Warning: Image not found: {img_path}")
158
+ else:
159
+ if row_num <= 3:
160
+ print(f"Warning: Audio not found: {audio_path}")
161
+ else:
162
+ if row_num <= 3:
163
+ print(f"Warning: Row {row_num} missing required columns")
164
+
165
+ if not self.items:
166
+ raise ValueError("Empty dataset: no valid audio files found")
167
+
168
+ # Count how many have images
169
+ with_images = sum(1 for _, _, img_path, _ in self.items if img_path is not None)
170
+ print(f"โœ“ Loaded {len(self.items)} audio files ({with_images} with matching images)")
171
+
172
+ def _find_zip_files(self):
173
+ """Find and open ZIP files in the image_folder"""
174
+ print("Searching for ZIP files...")
175
+ for item in os.listdir(self.image_folder):
176
+ if item.endswith('.zip'):
177
+ zip_name = item.replace('.zip', '')
178
+ zip_path = os.path.join(self.image_folder, item)
179
+ try:
180
+ self.zip_handles[zip_name] = zipfile.ZipFile(zip_path, 'r')
181
+ # Get number of files in ZIP for debugging
182
+ file_count = len(self.zip_handles[zip_name].namelist())
183
+ print(f" โœ“ Opened {item} (key: '{zip_name}', {file_count} files)")
184
+ except Exception as e:
185
+ print(f" โœ— Failed to open {item}: {e}")
186
+
187
+ def _file_in_zip(self, base_folder, file_path):
188
+ """Check if a file exists in the corresponding ZIP"""
189
+ if base_folder not in self.zip_handles:
190
+ print(f" ! ZIP handle not found for base_folder='{base_folder}'. Available: {list(self.zip_handles.keys())}")
191
+ return False
192
+ try:
193
+ self.zip_handles[base_folder].getinfo(file_path)
194
+ return True
195
+ except KeyError:
196
+ return False
197
+
198
+ def _read_from_zip(self, base_folder, file_path):
199
+ """Read a file from ZIP archive"""
200
+ if base_folder in self.zip_handles:
201
+ return self.zip_handles[base_folder].read(file_path)
202
+ return None
203
+
204
+ def __len__(self):
205
+ return len(self.items)
206
+
207
+ def __getitem__(self, idx: int):
208
+ base_folder, audio_path, img_path, cap = self.items[idx]
209
+
210
+ # Load audio
211
+ if self.use_zip_files:
212
+ # Read audio from ZIP
213
+ audio_bytes = self._read_from_zip(base_folder, audio_path)
214
+ if audio_bytes is None:
215
+ raise FileNotFoundError(f"Audio not found in ZIP: {audio_path}")
216
+ wav, sr = torchaudio.load(BytesIO(audio_bytes))
217
+ else:
218
+ # Read from file system
219
+ wav, sr = torchaudio.load(audio_path)
220
+
221
+ if wav.size(0) > 1:
222
+ wav = wav.mean(dim=0, keepdim=True)
223
+ wav = wav.squeeze(0).float()
224
+ # Resample to 48kHz for CLAP
225
+ if sr != 48000:
226
+ resampler = torchaudio.transforms.Resample(sr, 48000)
227
+ wav = resampler(wav)
228
+
229
+ # Load image if available
230
+ if img_path is not None:
231
+ if self.use_zip_files:
232
+ # Read image from ZIP
233
+ img_bytes = self._read_from_zip(base_folder, img_path)
234
+ if img_bytes:
235
+ img = Image.open(BytesIO(img_bytes)).convert('RGB')
236
+ img_tensor = self.img_transform(img)
237
+ else:
238
+ img_tensor = torch.zeros((3, 512, 512))
239
+ else:
240
+ # Read from file system
241
+ img = Image.open(img_path).convert('RGB')
242
+ img_tensor = self.img_transform(img)
243
+ else:
244
+ # Create dummy image if not available
245
+ img_tensor = torch.zeros((3, 512, 512))
246
+
247
+ return wav, 48000, cap, img_tensor, (img_path is not None)
248
+
249
+ def __del__(self):
250
+ """Close ZIP files when done"""
251
+ for zip_handle in self.zip_handles.values():
252
+ try:
253
+ zip_handle.close()
254
+ except:
255
+ pass
256
+
257
+ def collate_audio(batch):
258
+ wavs, srs, caps, imgs, has_imgs = [], [], [], [], []
259
+ for w, sr, c, img, has_img in batch:
260
+ wavs.append(w)
261
+ srs.append(sr)
262
+ caps.append(c)
263
+ imgs.append(img)
264
+ has_imgs.append(has_img)
265
+ return wavs, srs[0], caps, torch.stack(imgs), torch.tensor(has_imgs)
266
+
267
+
268
+ # ========================
269
+ # Model Components
270
+ # ========================
271
+ class AudioProjectionMLP(nn.Module):
272
+ """
273
+ Dual-head MLP projection:
274
+ - to_text: CLAP audio โ†’ CLAP text space (for CLAP alignment)
275
+ - to_sd: CLAP audio โ†’ SD embedding space (for image generation)
276
+ Both heads are trained with multi-task loss.
277
+ """
278
+ def __init__(self, in_dim, text_dim, sd_dim, hidden=1024):
279
+ super().__init__()
280
+
281
+ # Shared backbone
282
+ self.shared = nn.Sequential(
283
+ nn.Linear(in_dim, hidden),
284
+ nn.GELU(),
285
+ nn.Dropout(0.1),
286
+ nn.Linear(hidden, hidden),
287
+ nn.GELU(),
288
+ nn.Dropout(0.1)
289
+ )
290
+
291
+ # Head 1: CLAP text space (for training alignment)
292
+ self.to_text = nn.Sequential(
293
+ nn.Linear(hidden, hidden),
294
+ nn.GELU(),
295
+ nn.Dropout(0.1),
296
+ nn.Linear(hidden, text_dim)
297
+ )
298
+
299
+ # Head 2: SD embedding space (for generation)
300
+ self.to_sd = nn.Sequential(
301
+ nn.Linear(hidden, hidden),
302
+ nn.GELU(),
303
+ nn.Dropout(0.1),
304
+ nn.Linear(hidden, sd_dim)
305
+ )
306
+
307
+ def forward(self, z):
308
+ shared_features = self.shared(z)
309
+ return self.to_text(shared_features), self.to_sd(shared_features)
310
+
311
+
312
+ # ========================
313
+ # Main Model
314
+ # ========================
315
+ class Audio2ImageModel(nn.Module):
316
+ def __init__(self, cfg: Config, load_sd: bool = False):
317
+ super().__init__()
318
+ self.cfg = cfg
319
+ device = cfg.device
320
+
321
+ # -------- Frozen CLAP --------
322
+ print("Loading CLAP model...")
323
+ self.clap = ClapModel.from_pretrained(cfg.CLAP_ID).eval().to(device)
324
+ for p in self.clap.parameters():
325
+ p.requires_grad = False
326
+ self.proc = AutoProcessor.from_pretrained(cfg.CLAP_ID)
327
+
328
+ # -------- CLIP for Evaluation (Frozen) --------
329
+ print("Loading CLIP for evaluation...")
330
+ self.clip_model = CLIPModel.from_pretrained(cfg.CLIP_ID).eval().to(device)
331
+ self.clip_processor = CLIPProcessor.from_pretrained(cfg.CLIP_ID)
332
+ for p in self.clip_model.parameters():
333
+ p.requires_grad = False
334
+ print(" โœ“ CLIP loaded (frozen for evaluation only)")
335
+
336
+ # -------- Stable Diffusion (conditionally trainable) --------
337
+ self.sd_pipe = None
338
+ self.sd_tok = None
339
+ self.sd_text_encoder = None
340
+ self.sd_unet = None
341
+ self.sd_vae = None
342
+ self.sd_hidden = 768
343
+
344
+ # Always load full SD for training or inference
345
+ if True:
346
+ print("Loading Stable Diffusion...")
347
+ # Use float32 for training, float16 for inference only
348
+ dtype = torch.float32 if cfg.finetune_sd else (torch.float16 if device == "cuda" else torch.float32)
349
+ self.sd_pipe = StableDiffusionPipeline.from_pretrained(cfg.SD_ID, torch_dtype=dtype)
350
+ self.sd_pipe.to(device)
351
+
352
+ self.sd_tok = self.sd_pipe.tokenizer
353
+ self.sd_text_encoder = self.sd_pipe.text_encoder
354
+ self.sd_unet = self.sd_pipe.unet
355
+ self.sd_vae = self.sd_pipe.vae
356
+ self.sd_hidden = self.sd_pipe.text_encoder.config.hidden_size
357
+
358
+ # Configure trainability based on config
359
+ if cfg.finetune_sd:
360
+ print("๐Ÿ”ฅ End-to-End Training Mode:")
361
+
362
+ # UNet: TRAINABLE (this learns to generate!)
363
+ for p in self.sd_unet.parameters():
364
+ p.requires_grad = True
365
+ self.sd_unet.train()
366
+ print(" โœ“ UNet: TRAINABLE")
367
+
368
+ # VAE: Usually frozen for stability
369
+ if cfg.freeze_vae:
370
+ for p in self.sd_vae.parameters():
371
+ p.requires_grad = False
372
+ self.sd_vae.eval()
373
+ print(" โœ“ VAE: FROZEN")
374
+ else:
375
+ for p in self.sd_vae.parameters():
376
+ p.requires_grad = True
377
+ self.sd_vae.train()
378
+ print(" โœ“ VAE: TRAINABLE")
379
+
380
+ # Text Encoder: Usually frozen
381
+ if cfg.freeze_text_encoder:
382
+ for p in self.sd_text_encoder.parameters():
383
+ p.requires_grad = False
384
+ self.sd_text_encoder.eval()
385
+ print(" โœ“ Text Encoder: FROZEN")
386
+ else:
387
+ for p in self.sd_text_encoder.parameters():
388
+ p.requires_grad = True
389
+ self.sd_text_encoder.train()
390
+ print(" โœ“ Text Encoder: TRAINABLE")
391
+ else:
392
+ print("Inference Mode: All SD components frozen")
393
+ for comp in (self.sd_unet, self.sd_vae, self.sd_text_encoder):
394
+ for p in comp.parameters():
395
+ p.requires_grad = False
396
+ comp.eval()
397
+
398
+ # -------- Get CLAP dims --------
399
+ dummy_text = ["test"]
400
+ dummy_audio = [torch.zeros(48000).numpy()]
401
+
402
+ with torch.no_grad():
403
+ text_proc = self.proc(text=dummy_text, return_tensors="pt")
404
+ text_proc = {k: v.to(device) for k,v in text_proc.items()}
405
+ t = self.clap.get_text_features(**text_proc)
406
+ clap_text_dim = t.shape[-1]
407
+
408
+ audio_proc = self.proc(audio=dummy_audio, sampling_rate=48000, return_tensors="pt")
409
+ audio_proc = {k: v.to(device) for k,v in audio_proc.items()}
410
+ a = self.clap.get_audio_features(**audio_proc)
411
+ clap_audio_dim = a.shape[-1]
412
+
413
+ # -------- Trainable Dual-Head MLP --------
414
+ print(f"Creating MLP: CLAP audio ({clap_audio_dim}) โ†’ CLAP text ({clap_text_dim}) & SD ({self.sd_hidden})")
415
+ self.mapper = AudioProjectionMLP(clap_audio_dim, clap_text_dim, self.sd_hidden)
416
+
417
+ # --- Encoders ---
418
+ def encode_text_clap(self, caps):
419
+ """Encode text using CLAP text encoder"""
420
+ proc = self.proc(text=caps, return_tensors="pt", padding=True)
421
+ proc = {k: v.to(self.cfg.device) for k,v in proc.items()}
422
+
423
+ # Ensure CLAP is in eval mode
424
+ was_training = self.clap.training
425
+ self.clap.eval()
426
+
427
+ with torch.no_grad():
428
+ e = self.clap.get_text_features(**proc)
429
+
430
+ # Restore training state if needed
431
+ if was_training:
432
+ self.clap.train()
433
+
434
+ return F.normalize(e, dim=-1)
435
+
436
+ def encode_text_sd(self, caps):
437
+ """Encode text using SD text encoder (for target embeddings)"""
438
+ tokens = self.sd_tok(
439
+ caps,
440
+ padding="max_length",
441
+ max_length=self.sd_tok.model_max_length,
442
+ truncation=True,
443
+ return_tensors="pt"
444
+ ).to(self.cfg.device)
445
+
446
+ with torch.no_grad():
447
+ # Get the pooled output (last hidden state mean)
448
+ outputs = self.sd_text_encoder(tokens["input_ids"])
449
+ # Use pooler_output if available, else mean pool
450
+ if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
451
+ embeddings = outputs.pooler_output
452
+ else:
453
+ embeddings = outputs.last_hidden_state.mean(dim=1)
454
+
455
+ return embeddings
456
+
457
+ def encode_audio(self, wavs, sr):
458
+ """Returns raw CLAP audio embeddings - batched processing"""
459
+ # Convert all wavs to numpy for batch processing
460
+ audio_list = [w.cpu().numpy() for w in wavs]
461
+
462
+ # Process all audios in a single batch
463
+ proc = self.proc(audio=audio_list, sampling_rate=sr, return_tensors="pt")
464
+ proc = {k: v.to(self.cfg.device) for k, v in proc.items()}
465
+
466
+ # Ensure CLAP is in eval mode to avoid batch norm issues
467
+ was_training = self.clap.training
468
+ self.clap.eval()
469
+
470
+ with torch.no_grad():
471
+ embeddings = self.clap.get_audio_features(**proc)
472
+
473
+ # Restore training state if needed
474
+ if was_training:
475
+ self.clap.train()
476
+
477
+ return embeddings
478
+
479
+ # --- Loss ---
480
+ @staticmethod
481
+ def info_nce(a, b, temp):
482
+ """InfoNCE contrastive loss"""
483
+ a, b = F.normalize(a, dim=-1), F.normalize(b, dim=-1)
484
+ logits = a @ b.t() / temp
485
+ tgt = torch.arange(a.size(0), device=a.device)
486
+ return 0.5 * (F.cross_entropy(logits, tgt) + F.cross_entropy(logits.t(), tgt))
487
+
488
+ def compute_diffusion_loss(self, images, audio_emb):
489
+ """
490
+ Diffusion loss: Trains SD UNet to denoise images conditioned on audio.
491
+ This enables end-to-end learning of the generative model!
492
+
493
+ Args:
494
+ images: Ground truth images [B, 3, 512, 512] in range [-1, 1]
495
+ audio_emb: Audio embeddings from CLAP
496
+
497
+ Returns:
498
+ Denoising loss (MSE between predicted and actual noise)
499
+ """
500
+ # 1. Encode images to latent space (no grad through VAE)
501
+ with torch.no_grad():
502
+ latents = self.sd_vae.encode(images).latent_dist.sample()
503
+ latents = latents * 0.18215 # SD's scaling factor
504
+
505
+ # 2. Sample random timesteps for diffusion training
506
+ noise = torch.randn_like(latents)
507
+ bsz = latents.shape[0]
508
+ timesteps = torch.randint(
509
+ 0, 1000, (bsz,),
510
+ device=latents.device
511
+ ).long()
512
+
513
+ # 3. Add noise to latents according to timestep
514
+ if not hasattr(self, 'noise_scheduler'):
515
+ self.noise_scheduler = DDPMScheduler.from_pretrained(
516
+ self.cfg.SD_ID,
517
+ subfolder="scheduler"
518
+ )
519
+
520
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
521
+
522
+ # 4. Get audio conditioning (gradients flow to mapper!)
523
+ _, audio_to_sd = self.mapper(audio_emb)
524
+
525
+ # Reshape for UNet: [batch, 1, hidden_dim]
526
+ encoder_hidden_states = audio_to_sd.unsqueeze(1)
527
+
528
+ # 5. UNet predicts noise (THIS IS WHERE SD LEARNS! โœ…)
529
+ noise_pred = self.sd_unet(
530
+ noisy_latents, # Noisy input
531
+ timesteps, # Time conditioning
532
+ encoder_hidden_states # Audio conditioning
533
+ ).sample
534
+
535
+ # 6. Compute denoising loss
536
+ # Gradients flow back to: UNet โœ… and Mapper โœ…
537
+ loss = F.mse_loss(noise_pred, noise, reduction='mean')
538
+
539
+ return loss
540
+
541
+ @torch.inference_mode()
542
+ def evaluate_generation(self, wavs, sr, captions, num_samples=None):
543
+ """
544
+ Evaluate quality of generated images using CLIP text-image similarity.
545
+
546
+ Args:
547
+ wavs: List of audio waveforms
548
+ sr: Sample rate
549
+ captions: List of text captions describing the audio
550
+ num_samples: Number of samples to evaluate (None = all)
551
+
552
+ Returns:
553
+ avg_clip_score: Average CLIP similarity score (0-100)
554
+ generated_images: List of PIL images
555
+ clip_scores: List of individual CLIP scores
556
+ """
557
+ was_training = self.training
558
+ self.eval()
559
+
560
+ if num_samples is not None:
561
+ wavs = wavs[:num_samples]
562
+ captions = captions[:num_samples]
563
+
564
+ generated_images = []
565
+ clip_scores = []
566
+
567
+ for wav, caption in zip(wavs, captions):
568
+ # Generate image from audio
569
+ img = self.generate(wav, sr)
570
+ generated_images.append(img)
571
+
572
+ # Compute CLIP score (text-image similarity)
573
+ inputs = self.clip_processor(
574
+ text=[caption],
575
+ images=[img],
576
+ return_tensors="pt",
577
+ padding=True
578
+ ).to(self.cfg.device)
579
+
580
+ outputs = self.clip_model(**inputs)
581
+
582
+ # Get similarity score (logits are already scaled by temperature)
583
+ # Higher score = better match between image and caption
584
+ logits_per_image = outputs.logits_per_image
585
+ clip_score = logits_per_image[0, 0].item()
586
+ clip_scores.append(clip_score)
587
+
588
+ avg_clip_score = sum(clip_scores) / len(clip_scores) if clip_scores else 0.0
589
+
590
+ if was_training:
591
+ self.train()
592
+
593
+ return avg_clip_score, generated_images, clip_scores
594
+
595
+ # --- Forward (Training with Multi-Task Loss) ---
596
+ def forward(self, wavs, sr, caps, images=None, has_images=None):
597
+ """
598
+ Forward pass with three parallel losses:
599
+ 1. CLAP alignment (semantic understanding)
600
+ 2. SD embedding alignment (embedding compatibility)
601
+ 3. Diffusion loss (pixel-level generation) - requires images
602
+
603
+ All losses train simultaneously in end-to-end fashion!
604
+ """
605
+ # Get target embeddings (frozen encoders)
606
+ clap_text_emb = self.encode_text_clap(caps)
607
+ sd_text_emb = self.encode_text_sd(caps)
608
+
609
+ # Get audio embeddings
610
+ audio_emb = self.encode_audio(wavs, sr)
611
+
612
+ # Project audio to both spaces (gradients flow here!)
613
+ audio_to_clap, audio_to_sd = self.mapper(audio_emb)
614
+
615
+ # Loss 1: CLAP alignment (InfoNCE)
616
+ loss_clap = self.info_nce(audio_to_clap, clap_text_emb, self.cfg.temperature)
617
+
618
+ # Loss 2: SD embedding alignment (MSE)
619
+ loss_sd = F.mse_loss(audio_to_sd, sd_text_emb)
620
+
621
+ # Loss 3: Diffusion loss (pixel-level generation)
622
+ loss_diffusion = torch.tensor(0.0, device=self.cfg.device)
623
+ if self.cfg.finetune_sd and images is not None:
624
+ # Only compute on samples that have images
625
+ if has_images is not None:
626
+ valid_mask = has_images.to(self.cfg.device)
627
+ if valid_mask.sum() > 0:
628
+ valid_imgs = images[valid_mask]
629
+ valid_audio_emb = audio_emb[valid_mask]
630
+ loss_diffusion = self.compute_diffusion_loss(valid_imgs, valid_audio_emb)
631
+ else:
632
+ loss_diffusion = self.compute_diffusion_loss(images, audio_emb)
633
+
634
+ # Combined multi-task loss - all train together! ๐Ÿš€
635
+ total_loss = (
636
+ self.cfg.clap_loss_weight * loss_clap +
637
+ self.cfg.sd_loss_weight * loss_sd +
638
+ self.cfg.diffusion_loss_weight * loss_diffusion
639
+ )
640
+
641
+ # Compute similarities for monitoring
642
+ with torch.no_grad():
643
+ clap_sim = torch.diagonal(
644
+ F.normalize(audio_to_clap, dim=-1) @ F.normalize(clap_text_emb, dim=-1).t()
645
+ ).mean()
646
+
647
+ sd_sim = F.cosine_similarity(audio_to_sd, sd_text_emb, dim=-1).mean()
648
+
649
+ return total_loss, {
650
+ "loss_clap": loss_clap.item(),
651
+ "loss_sd": loss_sd.item(),
652
+ "loss_diffusion": loss_diffusion.item(),
653
+ "clap_sim": clap_sim.item(),
654
+ "sd_sim": sd_sim.item()
655
+ }
656
+
657
+ # --- Inference ---
658
+ @torch.inference_mode()
659
+ def generate(self, wav, sr):
660
+ if self.sd_pipe is None:
661
+ raise RuntimeError("Stable Diffusion not loaded. Init with load_sd=True.")
662
+
663
+ # Get audio embedding and project to SD space
664
+ audio_emb = self.encode_audio([wav], sr)
665
+ _, soft_token = self.mapper(audio_emb) # Use to_sd head
666
+
667
+ # Tokenize base prompt
668
+ tok = self.sd_tok(
669
+ self.cfg.base_prompt,
670
+ padding="max_length",
671
+ max_length=self.sd_tok.model_max_length,
672
+ truncation=True,
673
+ return_tensors="pt"
674
+ ).to(self.cfg.device)
675
+
676
+ # Get SD text embeddings
677
+ enc = self.sd_text_encoder(tok["input_ids"])[0]
678
+
679
+ # Find position to insert audio token (after last real token)
680
+ attention_mask = tok["attention_mask"][0]
681
+ last_token_pos = attention_mask.nonzero(as_tuple=False).max().item()
682
+
683
+ # Insert audio soft token AFTER the last token
684
+ if last_token_pos + 1 < enc.shape[1]:
685
+ enc[0, last_token_pos + 1:last_token_pos + 2, :] = soft_token
686
+ else:
687
+ # If no space, replace the last token
688
+ enc[0, last_token_pos:last_token_pos + 1, :] = soft_token
689
+
690
+ # Generate image
691
+ img = self.sd_pipe(
692
+ num_inference_steps=self.cfg.steps,
693
+ guidance_scale=self.cfg.guidance, # 7.5
694
+ prompt_embeds=enc
695
+ ).images[0]
696
+
697
+ return img
698
+
699
+
700
+ # ========================
701
+ # Training
702
+ # ========================
703
+ def train(cfg: Config):
704
+ # Load dataset with images
705
+ full_ds = AudioCaptionDataset(cfg.train_csv, cfg.image_folder, use_zip_files=cfg.use_zip_files)
706
+
707
+ # Create train/validation split (90/10)
708
+ train_size = int(0.9 * len(full_ds))
709
+ val_size = len(full_ds) - train_size
710
+ train_ds, val_ds = torch.utils.data.random_split(
711
+ full_ds,
712
+ [train_size, val_size],
713
+ generator=torch.Generator().manual_seed(42) # For reproducibility
714
+ )
715
+
716
+ print(f"\nDataset split:")
717
+ print(f" Training: {len(train_ds)} samples")
718
+ print(f" Validation: {len(val_ds)} samples\n")
719
+
720
+ # Create dataloaders
721
+ train_loader = DataLoader(
722
+ train_ds,
723
+ batch_size=cfg.batch_size,
724
+ shuffle=True,
725
+ collate_fn=collate_audio,
726
+ num_workers=0,
727
+ drop_last=True
728
+ )
729
+
730
+ val_loader = DataLoader(
731
+ val_ds,
732
+ batch_size=cfg.batch_size,
733
+ shuffle=False,
734
+ collate_fn=collate_audio,
735
+ num_workers=0
736
+ )
737
+
738
+ # Initialize model
739
+ model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device)
740
+
741
+ # Separate optimizers with different learning rates
742
+ if cfg.finetune_sd:
743
+ print("\n๐Ÿ”ฅ Setting up END-TO-END training:")
744
+
745
+ # Optimizer 1: Mapper (higher LR)
746
+ opt_mapper = torch.optim.AdamW(
747
+ model.mapper.parameters(),
748
+ lr=cfg.lr,
749
+ weight_decay=cfg.weight_decay
750
+ )
751
+ print(f" Mapper optimizer: LR={cfg.lr}")
752
+
753
+ # Optimizer 2: SD UNet (lower LR for stability)
754
+ opt_sd = torch.optim.AdamW(
755
+ model.sd_unet.parameters(),
756
+ lr=cfg.sd_lr,
757
+ weight_decay=cfg.weight_decay
758
+ )
759
+ print(f" SD UNet optimizer: LR={cfg.sd_lr}")
760
+
761
+ opts = [opt_mapper, opt_sd]
762
+ else:
763
+ # Only train mapper
764
+ opt_mapper = torch.optim.AdamW(
765
+ model.parameters(),
766
+ lr=cfg.lr,
767
+ weight_decay=cfg.weight_decay
768
+ )
769
+ opts = [opt_mapper]
770
+
771
+ print(f"\n{'='*60}")
772
+ print(f"Starting {'End-to-End' if cfg.finetune_sd else 'Mapper-Only'} Training")
773
+ print(f"{'='*60}")
774
+ print(f"Dataset: {len(full_ds)} samples ({len(train_ds)} train, {len(val_ds)} val)")
775
+ print(f"Batch size: {cfg.batch_size}")
776
+ print(f"Epochs: {cfg.max_epochs}")
777
+ print(f"Evaluation: Every {cfg.eval_every_n_epochs} epoch(s)")
778
+ print(f"Loss weights:")
779
+ print(f" CLAP: {cfg.clap_loss_weight}")
780
+ print(f" SD Embedding: {cfg.sd_loss_weight}")
781
+ if cfg.finetune_sd:
782
+ print(f" Diffusion: {cfg.diffusion_loss_weight}")
783
+ print(f"{'='*60}\n")
784
+
785
+ # Track best model based on CLIP score
786
+ best_clip_score = -float('inf')
787
+
788
+ for ep in range(1, cfg.max_epochs + 1):
789
+ # ============================================
790
+ # TRAINING PHASE
791
+ # ============================================
792
+ model.train()
793
+ pbar = tqdm(train_loader, desc=f"Epoch {ep}/{cfg.max_epochs} [TRAIN]")
794
+
795
+ epoch_stats = {
796
+ "total": 0, "clap": 0, "sd": 0, "diff": 0,
797
+ "clap_sim": 0, "sd_sim": 0
798
+ }
799
+
800
+ for wavs, sr, caps, imgs, has_imgs in pbar:
801
+ wavs = [w.to(cfg.device) for w in wavs]
802
+ imgs = imgs.to(cfg.device)
803
+
804
+ # Forward pass - all losses computed!
805
+ loss, stats = model(wavs, sr, caps, imgs if cfg.finetune_sd else None, has_imgs)
806
+
807
+ # Zero gradients for all optimizers
808
+ for opt in opts:
809
+ opt.zero_grad()
810
+
811
+ # Backward pass - gradients flow to mapper AND UNet!
812
+ loss.backward()
813
+
814
+ # Clip gradients for stability
815
+ if cfg.finetune_sd:
816
+ nn.utils.clip_grad_norm_(model.mapper.parameters(), 1.0)
817
+ nn.utils.clip_grad_norm_(model.sd_unet.parameters(), 1.0)
818
+ else:
819
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0)
820
+
821
+ # Update all parameters simultaneously! ๐Ÿš€
822
+ for opt in opts:
823
+ opt.step()
824
+
825
+ # Accumulate stats
826
+ epoch_stats["total"] += loss.item()
827
+ epoch_stats["clap"] += stats['loss_clap']
828
+ epoch_stats["sd"] += stats['loss_sd']
829
+ epoch_stats["diff"] += stats['loss_diffusion']
830
+ epoch_stats["clap_sim"] += stats['clap_sim']
831
+ epoch_stats["sd_sim"] += stats['sd_sim']
832
+
833
+ pbar.set_postfix({
834
+ "total loss": f"{loss.item():.3f}",
835
+ "diff": f"{stats['loss_diffusion']:.3f}",
836
+ "c_sim": f"{stats['clap_sim']:.2f}",
837
+ "s_sim": f"{stats['sd_sim']:.2f}"
838
+ })
839
+
840
+ # Compute training epoch averages
841
+ n_train = len(train_loader)
842
+ for k in epoch_stats:
843
+ epoch_stats[k] /= n_train
844
+
845
+ # ============================================
846
+ # VALIDATION & EVALUATION PHASE
847
+ # ============================================
848
+ if ep % cfg.eval_every_n_epochs == 0:
849
+ print(f"\n{'='*60}")
850
+ print(f"๐Ÿ” Evaluating Epoch {ep}...")
851
+ print(f"{'='*60}")
852
+
853
+ model.eval()
854
+ val_clip_scores = []
855
+ all_gen_images = []
856
+ all_captions = []
857
+
858
+ # Evaluate on validation set (limit to save time)
859
+ eval_batches = min(3, len(val_loader)) # Max 3 batches
860
+
861
+ for batch_idx, (wavs, sr, caps, imgs, has_imgs) in enumerate(val_loader):
862
+ if batch_idx >= eval_batches:
863
+ break
864
+
865
+ wavs = [w.to(cfg.device) for w in wavs]
866
+
867
+ # Generate images and compute CLIP scores
868
+ avg_score, gen_imgs, scores = model.evaluate_generation(
869
+ wavs, sr, caps,
870
+ num_samples=cfg.num_eval_samples
871
+ )
872
+
873
+ val_clip_scores.extend(scores)
874
+ all_gen_images.extend(gen_imgs)
875
+ all_captions.extend(caps[:cfg.num_eval_samples])
876
+
877
+ print(f" Batch {batch_idx + 1}/{eval_batches}: Avg CLIP = {avg_score:.3f}")
878
+
879
+ # Compute overall validation CLIP score
880
+ avg_val_clip = sum(val_clip_scores) / len(val_clip_scores) if val_clip_scores else 0.0
881
+
882
+ # Save example images from evaluation
883
+ if cfg.save_eval_images and all_gen_images:
884
+ os.makedirs("eval_samples", exist_ok=True)
885
+ for i, (img, cap, score) in enumerate(zip(all_gen_images[:4], all_captions[:4], val_clip_scores[:4])):
886
+ save_path = f"eval_samples/ep{ep}_sample{i}_score{score:.2f}.png"
887
+ img.save(save_path)
888
+ print(f" Sample {i}: '{cap[:50]}...' | CLIP: {score:.3f}")
889
+ print(f" Saved to: {save_path}")
890
+
891
+ # Clear MPS cache after evaluation
892
+ if cfg.device == "mps":
893
+ torch.mps.empty_cache()
894
+
895
+ print(f"\n{'='*60}")
896
+ print(f"๐Ÿ“Š Epoch {ep} Summary:")
897
+ print(f"{'='*60}")
898
+ print(f"Training Metrics:")
899
+ print(f" Total Loss: {epoch_stats['total']:.4f}")
900
+ print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}")
901
+ print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}")
902
+ if cfg.finetune_sd:
903
+ print(f" Diffusion Loss: {epoch_stats['diff']:.4f}")
904
+ print(f"\nValidation Metrics:")
905
+ print(f" ๐ŸŽฏ CLIP Score: {avg_val_clip:.3f} (higher = better image-text match)")
906
+ print(f"{'='*60}\n")
907
+
908
+ else:
909
+ # Just print training stats if not evaluating
910
+ avg_val_clip = None
911
+ print(f"\n{'='*60}")
912
+ print(f"Epoch {ep} Summary:")
913
+ print(f" Total Loss: {epoch_stats['total']:.4f}")
914
+ print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}")
915
+ print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}")
916
+ if cfg.finetune_sd:
917
+ print(f" Diffusion Loss: {epoch_stats['diff']:.4f}")
918
+ print(f"{'='*60}\n")
919
+
920
+ # ============================================
921
+ # CHECKPOINT SAVING
922
+ # ============================================
923
+ checkpoint = {
924
+ "mapper": model.mapper.state_dict(),
925
+ "epoch": ep,
926
+ "val_clip_score": avg_val_clip if avg_val_clip is not None else -1,
927
+ **{k: v for k, v in epoch_stats.items()},
928
+ "config": {
929
+ "clap_loss_weight": cfg.clap_loss_weight,
930
+ "sd_loss_weight": cfg.sd_loss_weight,
931
+ "diffusion_loss_weight": cfg.diffusion_loss_weight,
932
+ "finetune_sd": cfg.finetune_sd
933
+ }
934
+ }
935
+
936
+ if cfg.finetune_sd:
937
+ checkpoint["unet"] = model.sd_unet.state_dict()
938
+
939
+ # Always save latest checkpoint
940
+ torch.save(checkpoint, cfg.ckpt_path)
941
+ print(f"๐Ÿ’พ Checkpoint saved: {cfg.ckpt_path}")
942
+
943
+ # Save best model based on CLIP score
944
+ if avg_val_clip is not None and avg_val_clip > best_clip_score:
945
+ best_clip_score = avg_val_clip
946
+ best_path = cfg.ckpt_path.replace('.pt', '_best.pt')
947
+ torch.save(checkpoint, best_path)
948
+ print(f"โœ… New best model! CLIP: {avg_val_clip:.3f} -> Saved to {best_path}")
949
+ elif avg_val_clip is not None:
950
+ print(f" Current best CLIP: {best_clip_score:.3f}")
951
+
952
+ print()
953
+
954
+ print("๐ŸŽ‰ Training completed!")
955
+ if best_clip_score > -float('inf'):
956
+ print(f" Best CLIP score achieved: {best_clip_score:.3f}")
957
+
958
+
959
+ # ========================
960
+ # Inference
961
+ # ========================
962
+ def infer(cfg: Config, wav_path: str, out_path: str):
963
+ # Load audio
964
+ print(f"Loading audio from {wav_path}...")
965
+ wav, sr = torchaudio.load(wav_path)
966
+ if wav.size(0) > 1:
967
+ wav = wav.mean(0, keepdim=True)
968
+ wav = wav.squeeze(0).float()
969
+
970
+ # Resample to 48kHz for CLAP
971
+ if sr != 48000:
972
+ print(f"Resampling from {sr}Hz to 48000Hz...")
973
+ resampler = torchaudio.transforms.Resample(sr, 48000)
974
+ wav = resampler(wav)
975
+ sr = 48000
976
+
977
+ wav = wav.to(cfg.device)
978
+
979
+ # Load model with SD
980
+ model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device)
981
+
982
+ # Load trained weights
983
+ print(f"Loading checkpoint from {cfg.ckpt_path}...")
984
+ ckpt = torch.load(cfg.ckpt_path, map_location=cfg.device)
985
+ model.mapper.load_state_dict(ckpt["mapper"])
986
+
987
+ # Load UNet weights if available (from fine-tuning)
988
+ if "unet" in ckpt:
989
+ print("Loading fine-tuned UNet weights...")
990
+ model.sd_unet.load_state_dict(ckpt["unet"])
991
+
992
+ print(f"Checkpoint info:")
993
+ print(f" Epoch: {ckpt.get('epoch', 'unknown')}")
994
+ print(f" CLAP Sim: {ckpt.get('clap_sim', 'N/A'):.3f}" if isinstance(ckpt.get('clap_sim'), (int, float)) else f" CLAP Sim: N/A")
995
+ print(f" SD Sim: {ckpt.get('sd_sim', 'N/A'):.3f}" if isinstance(ckpt.get('sd_sim'), (int, float)) else f" SD Sim: N/A")
996
+ if "unet" in ckpt:
997
+ print(" Fine-tuned UNet: โœ“")
998
+
999
+ # Generate image
1000
+ print("\nGenerating image...")
1001
+ img = model.generate(wav, sr)
1002
+ img.save(out_path)
1003
+ print(f"โœ“ Generated image saved to {out_path}")
1004
+
1005
+
1006
+ # ========================
1007
+ # Main
1008
+ # ========================
1009
+ if __name__ == "__main__":
1010
+ import argparse
1011
+ parser = argparse.ArgumentParser()
1012
+ parser.add_argument("--mode", choices=["train", "infer"], default="train")
1013
+ parser.add_argument("--wav", help="Audio file path for inference mode")
1014
+ parser.add_argument("--out", default="output.png", help="Output image path")
1015
+ args = parser.parse_args()
1016
+
1017
+ cfg = Config()
1018
+ print(f"Device: {cfg.device}")
1019
+
1020
+ if args.mode == "train":
1021
+ print(f"Dataset: {cfg.train_csv}")
1022
+ if not os.path.exists(cfg.train_csv):
1023
+ print(f"ERROR: Dataset not found at {cfg.train_csv}")
1024
+ print("Please ensure the captions.txt file exists")
1025
+ sys.exit(1)
1026
+ train(cfg)
1027
+ else:
1028
+ if not args.wav:
1029
+ raise ValueError("Need --wav for inference mode")
1030
+ if not os.path.exists(args.wav):
1031
+ raise ValueError(f"Audio file not found: {args.wav}")
1032
+ infer(cfg, args.wav, args.out)