audio2image-model / main2.py
Suyamprakasam's picture
Upload 2 files
7b93c28 verified
"""
Audio → Image Generator (Multi-Task Loss Version)
Key features:
- Dual-head MLP: one for CLAP text space, one for SD embedding space
- Multi-task training: CLAP alignment loss + SD alignment loss
- Both heads are trained simultaneously
- to_sd head is properly trained and used during inference
"""
# ========================
# Imports
# ========================
import os, math, csv, random, sys
from typing import List, Tuple
from dataclasses import dataclass
import zipfile
from io import BytesIO
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import AutoProcessor, ClapModel, AutoTokenizer, CLIPProcessor, CLIPModel
from diffusers import StableDiffusionPipeline, DDPMScheduler, DDIMScheduler
from PIL import Image
from torchvision import transforms
# ========================
# Configuration
# ========================
@dataclass
class Config:
CLAP_ID: str = "laion/clap-htsat-fused"
SD_ID: str = "runwayml/stable-diffusion-v1-5"
CLIP_ID: str = "openai/clip-vit-base-patch32"
# Device configuration - automatically uses GPU if available
device: str = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
lr: float = 2e-4
weight_decay: float = 1e-4
temperature: float = 0.07
# Multi-task loss weights
clap_loss_weight: float = 0.5
sd_loss_weight: float = 1.0
diffusion_loss_weight: float = 1.0
batch_size: int = 2 # Reduced for Mac GPU memory
max_epochs: int = 20
base_prompt: str = "A photo of"
guidance: float = 7.5
steps: int = 30
# Dataset paths
train_csv: str = "/Users/rajvarun/Desktop/SIT/Trimester 4/AAI 3001 - Computer Vision & Deep Learning/Seeing Sound II/raj/main_dataV1.csv"
image_folder: str = "/Users/rajvarun/OneDrive - Singapore Institute Of Technology/ALEXI KIZHAKKEPURATHU GEORGE's files - VGGSound" # OneDrive folder with ZIP files
ckpt_path: str = "audio2image_mapper_dual_best.pt"
# ZIP file support (if data is in ZIP files instead of extracted)
use_zip_files: bool = True # Set to True to read from ZIP files directly
zip_files: dict = None # Will be populated automatically
# Fine-tuning control
finetune_sd: bool = False # Set to False to train without images
sd_lr: float = 1e-5
freeze_vae: bool = True
freeze_text_encoder: bool = True
# Evaluation settings
eval_every_n_epochs: int = 1 # Evaluate every N epochs
num_eval_samples: int = 4 # Number of samples to evaluate per batch
save_eval_images: bool = True # Save example generated images
# ========================
# Dataset
# ========================
class AudioCaptionDataset(Dataset):
"""
Reads a CSV file with audio-image-caption triplets.
Handles structure where data is in: base_folder/image/ and base_folder/audio/
Can read from extracted folders OR directly from ZIP files (no extraction needed!)
Example:
- CSV: vggsound_00,g-f_I2yQ_1.png,g-f_I2yQ_000001.wav,people marching
- Audio path: vggsound_00/audio/g-f_I2yQ_000001.wav
- Image path: vggsound_00/image/g-f_I2yQ_1.png
"""
def __init__(self, captions_path: str, image_folder: str = None, use_zip_files: bool = False):
self.items = []
base_dir = os.path.dirname(captions_path)
self.image_folder = image_folder or base_dir
self.use_zip_files = use_zip_files
self.zip_handles = {} # Cache opened ZIP files
# Image preprocessing for SD (512x512, normalized to [-1, 1])
self.img_transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize([0.5], [0.5])
])
print(f"Loading dataset from: {captions_path}")
print(f"Base folder: {self.image_folder}")
print(f"Use ZIP files: {use_zip_files}")
# If using ZIP files, find and open them
if use_zip_files:
self._find_zip_files()
# Read CSV file
import csv
with open(captions_path, "r", encoding="utf-8") as f:
reader = csv.DictReader(f)
for row_num, row in enumerate(reader, 1):
# CSV format: base_folder,image_file,audio_file,caption
if 'base_folder' in row and 'image_file' in row and 'audio_file' in row and 'caption' in row:
base_folder = row['base_folder'] # e.g., "vggsound_00"
img_filename = row['image_file'] # e.g., "g-f_I2yQ_1.png"
audio_filename = row['audio_file'] # e.g., "g-f_I2yQ_000001.wav"
caption = row['caption']
if use_zip_files:
# Use ZIP file paths
audio_path = f"{base_folder}/audio/{audio_filename}"
img_path = f"{base_folder}/image/{img_filename}"
# Check if files exist in ZIP
audio_exists = self._file_in_zip(base_folder, audio_path)
img_exists = self._file_in_zip(base_folder, img_path)
# Debug first few rows
if row_num <= 3:
print(f"Row {row_num}: base_folder='{base_folder}', audio='{audio_path}', exists={audio_exists}")
else:
# Use regular file paths
audio_path = os.path.join(self.image_folder, base_folder, "audio", audio_filename)
img_path = os.path.join(self.image_folder, base_folder, "image", img_filename)
audio_exists = os.path.exists(audio_path)
img_exists = os.path.exists(img_path)
if audio_exists:
if img_exists:
self.items.append((base_folder, audio_path, img_path, caption))
else:
# Audio exists but image doesn't
self.items.append((base_folder, audio_path, None, caption))
if row_num <= 3:
print(f"Warning: Image not found: {img_path}")
else:
if row_num <= 3:
print(f"Warning: Audio not found: {audio_path}")
else:
if row_num <= 3:
print(f"Warning: Row {row_num} missing required columns")
if not self.items:
raise ValueError("Empty dataset: no valid audio files found")
# Count how many have images
with_images = sum(1 for _, _, img_path, _ in self.items if img_path is not None)
print(f"✓ Loaded {len(self.items)} audio files ({with_images} with matching images)")
def _find_zip_files(self):
"""Find and open ZIP files in the image_folder"""
print("Searching for ZIP files...")
for item in os.listdir(self.image_folder):
if item.endswith('.zip'):
zip_name = item.replace('.zip', '')
zip_path = os.path.join(self.image_folder, item)
try:
self.zip_handles[zip_name] = zipfile.ZipFile(zip_path, 'r')
# Get number of files in ZIP for debugging
file_count = len(self.zip_handles[zip_name].namelist())
print(f" ✓ Opened {item} (key: '{zip_name}', {file_count} files)")
except Exception as e:
print(f" ✗ Failed to open {item}: {e}")
def _file_in_zip(self, base_folder, file_path):
"""Check if a file exists in the corresponding ZIP"""
if base_folder not in self.zip_handles:
print(f" ! ZIP handle not found for base_folder='{base_folder}'. Available: {list(self.zip_handles.keys())}")
return False
try:
self.zip_handles[base_folder].getinfo(file_path)
return True
except KeyError:
return False
def _read_from_zip(self, base_folder, file_path):
"""Read a file from ZIP archive"""
if base_folder in self.zip_handles:
return self.zip_handles[base_folder].read(file_path)
return None
def __len__(self):
return len(self.items)
def __getitem__(self, idx: int):
base_folder, audio_path, img_path, cap = self.items[idx]
# Load audio
if self.use_zip_files:
# Read audio from ZIP
audio_bytes = self._read_from_zip(base_folder, audio_path)
if audio_bytes is None:
raise FileNotFoundError(f"Audio not found in ZIP: {audio_path}")
wav, sr = torchaudio.load(BytesIO(audio_bytes))
else:
# Read from file system
wav, sr = torchaudio.load(audio_path)
if wav.size(0) > 1:
wav = wav.mean(dim=0, keepdim=True)
wav = wav.squeeze(0).float()
# Resample to 48kHz for CLAP
if sr != 48000:
resampler = torchaudio.transforms.Resample(sr, 48000)
wav = resampler(wav)
# Load image if available
if img_path is not None:
if self.use_zip_files:
# Read image from ZIP
img_bytes = self._read_from_zip(base_folder, img_path)
if img_bytes:
img = Image.open(BytesIO(img_bytes)).convert('RGB')
img_tensor = self.img_transform(img)
else:
img_tensor = torch.zeros((3, 512, 512))
else:
# Read from file system
img = Image.open(img_path).convert('RGB')
img_tensor = self.img_transform(img)
else:
# Create dummy image if not available
img_tensor = torch.zeros((3, 512, 512))
return wav, 48000, cap, img_tensor, (img_path is not None)
def __del__(self):
"""Close ZIP files when done"""
for zip_handle in self.zip_handles.values():
try:
zip_handle.close()
except:
pass
def collate_audio(batch):
wavs, srs, caps, imgs, has_imgs = [], [], [], [], []
for w, sr, c, img, has_img in batch:
wavs.append(w)
srs.append(sr)
caps.append(c)
imgs.append(img)
has_imgs.append(has_img)
return wavs, srs[0], caps, torch.stack(imgs), torch.tensor(has_imgs)
# ========================
# Model Components
# ========================
class AudioProjectionMLP(nn.Module):
"""
Dual-head MLP projection:
- to_text: CLAP audio → CLAP text space (for CLAP alignment)
- to_sd: CLAP audio → SD embedding space (for image generation)
Both heads are trained with multi-task loss.
"""
def __init__(self, in_dim, text_dim, sd_dim, hidden=1024):
super().__init__()
# Shared backbone
self.shared = nn.Sequential(
nn.Linear(in_dim, hidden),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden, hidden),
nn.GELU(),
nn.Dropout(0.1)
)
# Head 1: CLAP text space (for training alignment)
self.to_text = nn.Sequential(
nn.Linear(hidden, hidden),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden, text_dim)
)
# Head 2: SD embedding space (for generation)
self.to_sd = nn.Sequential(
nn.Linear(hidden, hidden),
nn.GELU(),
nn.Dropout(0.1),
nn.Linear(hidden, sd_dim)
)
def forward(self, z):
shared_features = self.shared(z)
return self.to_text(shared_features), self.to_sd(shared_features)
# ========================
# Main Model
# ========================
class Audio2ImageModel(nn.Module):
def __init__(self, cfg: Config, load_sd: bool = False):
super().__init__()
self.cfg = cfg
device = cfg.device
# -------- Frozen CLAP --------
print("Loading CLAP model...")
self.clap = ClapModel.from_pretrained(cfg.CLAP_ID).eval().to(device)
for p in self.clap.parameters():
p.requires_grad = False
self.proc = AutoProcessor.from_pretrained(cfg.CLAP_ID)
# -------- CLIP for Evaluation (Frozen) --------
print("Loading CLIP for evaluation...")
self.clip_model = CLIPModel.from_pretrained(cfg.CLIP_ID).eval().to(device)
self.clip_processor = CLIPProcessor.from_pretrained(cfg.CLIP_ID)
for p in self.clip_model.parameters():
p.requires_grad = False
print(" ✓ CLIP loaded (frozen for evaluation only)")
# -------- Stable Diffusion (conditionally trainable) --------
self.sd_pipe = None
self.sd_tok = None
self.sd_text_encoder = None
self.sd_unet = None
self.sd_vae = None
self.sd_hidden = 768
# Always load full SD for training or inference
if True:
print("Loading Stable Diffusion...")
# Use float32 for training, float16 for inference only
dtype = torch.float32 if cfg.finetune_sd else (torch.float16 if device == "cuda" else torch.float32)
self.sd_pipe = StableDiffusionPipeline.from_pretrained(cfg.SD_ID, torch_dtype=dtype)
self.sd_pipe.to(device)
self.sd_tok = self.sd_pipe.tokenizer
self.sd_text_encoder = self.sd_pipe.text_encoder
self.sd_unet = self.sd_pipe.unet
self.sd_vae = self.sd_pipe.vae
self.sd_hidden = self.sd_pipe.text_encoder.config.hidden_size
# Configure trainability based on config
if cfg.finetune_sd:
print("🔥 End-to-End Training Mode:")
# UNet: TRAINABLE (this learns to generate!)
for p in self.sd_unet.parameters():
p.requires_grad = True
self.sd_unet.train()
print(" ✓ UNet: TRAINABLE")
# VAE: Usually frozen for stability
if cfg.freeze_vae:
for p in self.sd_vae.parameters():
p.requires_grad = False
self.sd_vae.eval()
print(" ✓ VAE: FROZEN")
else:
for p in self.sd_vae.parameters():
p.requires_grad = True
self.sd_vae.train()
print(" ✓ VAE: TRAINABLE")
# Text Encoder: Usually frozen
if cfg.freeze_text_encoder:
for p in self.sd_text_encoder.parameters():
p.requires_grad = False
self.sd_text_encoder.eval()
print(" ✓ Text Encoder: FROZEN")
else:
for p in self.sd_text_encoder.parameters():
p.requires_grad = True
self.sd_text_encoder.train()
print(" ✓ Text Encoder: TRAINABLE")
else:
print("Inference Mode: All SD components frozen")
for comp in (self.sd_unet, self.sd_vae, self.sd_text_encoder):
for p in comp.parameters():
p.requires_grad = False
comp.eval()
# -------- Get CLAP dims --------
dummy_text = ["test"]
dummy_audio = [torch.zeros(48000).numpy()]
with torch.no_grad():
text_proc = self.proc(text=dummy_text, return_tensors="pt")
text_proc = {k: v.to(device) for k,v in text_proc.items()}
t = self.clap.get_text_features(**text_proc)
clap_text_dim = t.shape[-1]
audio_proc = self.proc(audio=dummy_audio, sampling_rate=48000, return_tensors="pt")
audio_proc = {k: v.to(device) for k,v in audio_proc.items()}
a = self.clap.get_audio_features(**audio_proc)
clap_audio_dim = a.shape[-1]
# -------- Trainable Dual-Head MLP --------
print(f"Creating MLP: CLAP audio ({clap_audio_dim}) → CLAP text ({clap_text_dim}) & SD ({self.sd_hidden})")
self.mapper = AudioProjectionMLP(clap_audio_dim, clap_text_dim, self.sd_hidden)
# --- Encoders ---
def encode_text_clap(self, caps):
"""Encode text using CLAP text encoder"""
proc = self.proc(text=caps, return_tensors="pt", padding=True)
proc = {k: v.to(self.cfg.device) for k,v in proc.items()}
# Ensure CLAP is in eval mode
was_training = self.clap.training
self.clap.eval()
with torch.no_grad():
e = self.clap.get_text_features(**proc)
# Restore training state if needed
if was_training:
self.clap.train()
return F.normalize(e, dim=-1)
def encode_text_sd(self, caps):
"""Encode text using SD text encoder (for target embeddings)"""
tokens = self.sd_tok(
caps,
padding="max_length",
max_length=self.sd_tok.model_max_length,
truncation=True,
return_tensors="pt"
).to(self.cfg.device)
with torch.no_grad():
# Get the pooled output (last hidden state mean)
outputs = self.sd_text_encoder(tokens["input_ids"])
# Use pooler_output if available, else mean pool
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
embeddings = outputs.pooler_output
else:
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings
def encode_audio(self, wavs, sr):
"""Returns raw CLAP audio embeddings - batched processing"""
# Convert all wavs to numpy for batch processing
audio_list = [w.cpu().numpy() for w in wavs]
# Process all audios in a single batch
proc = self.proc(audio=audio_list, sampling_rate=sr, return_tensors="pt")
proc = {k: v.to(self.cfg.device) for k, v in proc.items()}
# Ensure CLAP is in eval mode to avoid batch norm issues
was_training = self.clap.training
self.clap.eval()
with torch.no_grad():
embeddings = self.clap.get_audio_features(**proc)
# Restore training state if needed
if was_training:
self.clap.train()
return embeddings
# --- Loss ---
@staticmethod
def info_nce(a, b, temp):
"""InfoNCE contrastive loss"""
a, b = F.normalize(a, dim=-1), F.normalize(b, dim=-1)
logits = a @ b.t() / temp
tgt = torch.arange(a.size(0), device=a.device)
return 0.5 * (F.cross_entropy(logits, tgt) + F.cross_entropy(logits.t(), tgt))
def compute_diffusion_loss(self, images, audio_emb):
"""
Diffusion loss: Trains SD UNet to denoise images conditioned on audio.
This enables end-to-end learning of the generative model!
Args:
images: Ground truth images [B, 3, 512, 512] in range [-1, 1]
audio_emb: Audio embeddings from CLAP
Returns:
Denoising loss (MSE between predicted and actual noise)
"""
# 1. Encode images to latent space (no grad through VAE)
with torch.no_grad():
latents = self.sd_vae.encode(images).latent_dist.sample()
latents = latents * 0.18215 # SD's scaling factor
# 2. Sample random timesteps for diffusion training
noise = torch.randn_like(latents)
bsz = latents.shape[0]
timesteps = torch.randint(
0, 1000, (bsz,),
device=latents.device
).long()
# 3. Add noise to latents according to timestep
if not hasattr(self, 'noise_scheduler'):
self.noise_scheduler = DDPMScheduler.from_pretrained(
self.cfg.SD_ID,
subfolder="scheduler"
)
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
# 4. Get audio conditioning (gradients flow to mapper!)
_, audio_to_sd = self.mapper(audio_emb)
# Reshape for UNet: [batch, 1, hidden_dim]
encoder_hidden_states = audio_to_sd.unsqueeze(1)
# 5. UNet predicts noise (THIS IS WHERE SD LEARNS! ✅)
noise_pred = self.sd_unet(
noisy_latents, # Noisy input
timesteps, # Time conditioning
encoder_hidden_states # Audio conditioning
).sample
# 6. Compute denoising loss
# Gradients flow back to: UNet ✅ and Mapper ✅
loss = F.mse_loss(noise_pred, noise, reduction='mean')
return loss
@torch.inference_mode()
def evaluate_generation(self, wavs, sr, captions, num_samples=None):
"""
Evaluate quality of generated images using CLIP text-image similarity.
Args:
wavs: List of audio waveforms
sr: Sample rate
captions: List of text captions describing the audio
num_samples: Number of samples to evaluate (None = all)
Returns:
avg_clip_score: Average CLIP similarity score (0-100)
generated_images: List of PIL images
clip_scores: List of individual CLIP scores
"""
was_training = self.training
self.eval()
if num_samples is not None:
wavs = wavs[:num_samples]
captions = captions[:num_samples]
generated_images = []
clip_scores = []
for wav, caption in zip(wavs, captions):
# Generate image from audio
img = self.generate(wav, sr)
generated_images.append(img)
# Compute CLIP score (text-image similarity)
inputs = self.clip_processor(
text=[caption],
images=[img],
return_tensors="pt",
padding=True
).to(self.cfg.device)
outputs = self.clip_model(**inputs)
# Get similarity score (logits are already scaled by temperature)
# Higher score = better match between image and caption
logits_per_image = outputs.logits_per_image
clip_score = logits_per_image[0, 0].item()
clip_scores.append(clip_score)
avg_clip_score = sum(clip_scores) / len(clip_scores) if clip_scores else 0.0
if was_training:
self.train()
return avg_clip_score, generated_images, clip_scores
# --- Forward (Training with Multi-Task Loss) ---
def forward(self, wavs, sr, caps, images=None, has_images=None):
"""
Forward pass with three parallel losses:
1. CLAP alignment (semantic understanding)
2. SD embedding alignment (embedding compatibility)
3. Diffusion loss (pixel-level generation) - requires images
All losses train simultaneously in end-to-end fashion!
"""
# Get target embeddings (frozen encoders)
clap_text_emb = self.encode_text_clap(caps)
sd_text_emb = self.encode_text_sd(caps)
# Get audio embeddings
audio_emb = self.encode_audio(wavs, sr)
# Project audio to both spaces (gradients flow here!)
audio_to_clap, audio_to_sd = self.mapper(audio_emb)
# Loss 1: CLAP alignment (InfoNCE)
loss_clap = self.info_nce(audio_to_clap, clap_text_emb, self.cfg.temperature)
# Loss 2: SD embedding alignment (MSE)
loss_sd = F.mse_loss(audio_to_sd, sd_text_emb)
# Loss 3: Diffusion loss (pixel-level generation)
loss_diffusion = torch.tensor(0.0, device=self.cfg.device)
if self.cfg.finetune_sd and images is not None:
# Only compute on samples that have images
if has_images is not None:
valid_mask = has_images.to(self.cfg.device)
if valid_mask.sum() > 0:
valid_imgs = images[valid_mask]
valid_audio_emb = audio_emb[valid_mask]
loss_diffusion = self.compute_diffusion_loss(valid_imgs, valid_audio_emb)
else:
loss_diffusion = self.compute_diffusion_loss(images, audio_emb)
# Combined multi-task loss - all train together! 🚀
total_loss = (
self.cfg.clap_loss_weight * loss_clap +
self.cfg.sd_loss_weight * loss_sd +
self.cfg.diffusion_loss_weight * loss_diffusion
)
# Compute similarities for monitoring
with torch.no_grad():
clap_sim = torch.diagonal(
F.normalize(audio_to_clap, dim=-1) @ F.normalize(clap_text_emb, dim=-1).t()
).mean()
sd_sim = F.cosine_similarity(audio_to_sd, sd_text_emb, dim=-1).mean()
return total_loss, {
"loss_clap": loss_clap.item(),
"loss_sd": loss_sd.item(),
"loss_diffusion": loss_diffusion.item(),
"clap_sim": clap_sim.item(),
"sd_sim": sd_sim.item()
}
# --- Inference ---
@torch.inference_mode()
def generate(self, wav, sr):
if self.sd_pipe is None:
raise RuntimeError("Stable Diffusion not loaded. Init with load_sd=True.")
# Get audio embedding and project to SD space
audio_emb = self.encode_audio([wav], sr)
_, soft_token = self.mapper(audio_emb) # Use to_sd head
# Tokenize base prompt
tok = self.sd_tok(
self.cfg.base_prompt,
padding="max_length",
max_length=self.sd_tok.model_max_length,
truncation=True,
return_tensors="pt"
).to(self.cfg.device)
# Get SD text embeddings
enc = self.sd_text_encoder(tok["input_ids"])[0]
# Find position to insert audio token (after last real token)
attention_mask = tok["attention_mask"][0]
last_token_pos = attention_mask.nonzero(as_tuple=False).max().item()
# Insert audio soft token AFTER the last token
if last_token_pos + 1 < enc.shape[1]:
enc[0, last_token_pos + 1:last_token_pos + 2, :] = soft_token
else:
# If no space, replace the last token
enc[0, last_token_pos:last_token_pos + 1, :] = soft_token
# Generate image
img = self.sd_pipe(
num_inference_steps=self.cfg.steps,
guidance_scale=self.cfg.guidance, # 7.5
prompt_embeds=enc
).images[0]
return img
# ========================
# Training
# ========================
def train(cfg: Config):
# Load dataset with images
full_ds = AudioCaptionDataset(cfg.train_csv, cfg.image_folder, use_zip_files=cfg.use_zip_files)
# Create train/validation split (90/10)
train_size = int(0.9 * len(full_ds))
val_size = len(full_ds) - train_size
train_ds, val_ds = torch.utils.data.random_split(
full_ds,
[train_size, val_size],
generator=torch.Generator().manual_seed(42) # For reproducibility
)
print(f"\nDataset split:")
print(f" Training: {len(train_ds)} samples")
print(f" Validation: {len(val_ds)} samples\n")
# Create dataloaders
train_loader = DataLoader(
train_ds,
batch_size=cfg.batch_size,
shuffle=True,
collate_fn=collate_audio,
num_workers=0,
drop_last=True
)
val_loader = DataLoader(
val_ds,
batch_size=cfg.batch_size,
shuffle=False,
collate_fn=collate_audio,
num_workers=0
)
# Initialize model
model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device)
# Separate optimizers with different learning rates
if cfg.finetune_sd:
print("\n🔥 Setting up END-TO-END training:")
# Optimizer 1: Mapper (higher LR)
opt_mapper = torch.optim.AdamW(
model.mapper.parameters(),
lr=cfg.lr,
weight_decay=cfg.weight_decay
)
print(f" Mapper optimizer: LR={cfg.lr}")
# Optimizer 2: SD UNet (lower LR for stability)
opt_sd = torch.optim.AdamW(
model.sd_unet.parameters(),
lr=cfg.sd_lr,
weight_decay=cfg.weight_decay
)
print(f" SD UNet optimizer: LR={cfg.sd_lr}")
opts = [opt_mapper, opt_sd]
else:
# Only train mapper
opt_mapper = torch.optim.AdamW(
model.parameters(),
lr=cfg.lr,
weight_decay=cfg.weight_decay
)
opts = [opt_mapper]
print(f"\n{'='*60}")
print(f"Starting {'End-to-End' if cfg.finetune_sd else 'Mapper-Only'} Training")
print(f"{'='*60}")
print(f"Dataset: {len(full_ds)} samples ({len(train_ds)} train, {len(val_ds)} val)")
print(f"Batch size: {cfg.batch_size}")
print(f"Epochs: {cfg.max_epochs}")
print(f"Evaluation: Every {cfg.eval_every_n_epochs} epoch(s)")
print(f"Loss weights:")
print(f" CLAP: {cfg.clap_loss_weight}")
print(f" SD Embedding: {cfg.sd_loss_weight}")
if cfg.finetune_sd:
print(f" Diffusion: {cfg.diffusion_loss_weight}")
print(f"{'='*60}\n")
# Track best model based on CLIP score
best_clip_score = -float('inf')
for ep in range(1, cfg.max_epochs + 1):
# ============================================
# TRAINING PHASE
# ============================================
model.train()
pbar = tqdm(train_loader, desc=f"Epoch {ep}/{cfg.max_epochs} [TRAIN]")
epoch_stats = {
"total": 0, "clap": 0, "sd": 0, "diff": 0,
"clap_sim": 0, "sd_sim": 0
}
for wavs, sr, caps, imgs, has_imgs in pbar:
wavs = [w.to(cfg.device) for w in wavs]
imgs = imgs.to(cfg.device)
# Forward pass - all losses computed!
loss, stats = model(wavs, sr, caps, imgs if cfg.finetune_sd else None, has_imgs)
# Zero gradients for all optimizers
for opt in opts:
opt.zero_grad()
# Backward pass - gradients flow to mapper AND UNet!
loss.backward()
# Clip gradients for stability
if cfg.finetune_sd:
nn.utils.clip_grad_norm_(model.mapper.parameters(), 1.0)
nn.utils.clip_grad_norm_(model.sd_unet.parameters(), 1.0)
else:
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
# Update all parameters simultaneously! 🚀
for opt in opts:
opt.step()
# Accumulate stats
epoch_stats["total"] += loss.item()
epoch_stats["clap"] += stats['loss_clap']
epoch_stats["sd"] += stats['loss_sd']
epoch_stats["diff"] += stats['loss_diffusion']
epoch_stats["clap_sim"] += stats['clap_sim']
epoch_stats["sd_sim"] += stats['sd_sim']
pbar.set_postfix({
"total loss": f"{loss.item():.3f}",
"diff": f"{stats['loss_diffusion']:.3f}",
"c_sim": f"{stats['clap_sim']:.2f}",
"s_sim": f"{stats['sd_sim']:.2f}"
})
# Compute training epoch averages
n_train = len(train_loader)
for k in epoch_stats:
epoch_stats[k] /= n_train
# ============================================
# VALIDATION & EVALUATION PHASE
# ============================================
if ep % cfg.eval_every_n_epochs == 0:
print(f"\n{'='*60}")
print(f"🔍 Evaluating Epoch {ep}...")
print(f"{'='*60}")
model.eval()
val_clip_scores = []
all_gen_images = []
all_captions = []
# Evaluate on validation set (limit to save time)
eval_batches = min(3, len(val_loader)) # Max 3 batches
for batch_idx, (wavs, sr, caps, imgs, has_imgs) in enumerate(val_loader):
if batch_idx >= eval_batches:
break
wavs = [w.to(cfg.device) for w in wavs]
# Generate images and compute CLIP scores
avg_score, gen_imgs, scores = model.evaluate_generation(
wavs, sr, caps,
num_samples=cfg.num_eval_samples
)
val_clip_scores.extend(scores)
all_gen_images.extend(gen_imgs)
all_captions.extend(caps[:cfg.num_eval_samples])
print(f" Batch {batch_idx + 1}/{eval_batches}: Avg CLIP = {avg_score:.3f}")
# Compute overall validation CLIP score
avg_val_clip = sum(val_clip_scores) / len(val_clip_scores) if val_clip_scores else 0.0
# Save example images from evaluation
if cfg.save_eval_images and all_gen_images:
os.makedirs("eval_samples", exist_ok=True)
for i, (img, cap, score) in enumerate(zip(all_gen_images[:4], all_captions[:4], val_clip_scores[:4])):
save_path = f"eval_samples/ep{ep}_sample{i}_score{score:.2f}.png"
img.save(save_path)
print(f" Sample {i}: '{cap[:50]}...' | CLIP: {score:.3f}")
print(f" Saved to: {save_path}")
# Clear MPS cache after evaluation
if cfg.device == "mps":
torch.mps.empty_cache()
print(f"\n{'='*60}")
print(f"📊 Epoch {ep} Summary:")
print(f"{'='*60}")
print(f"Training Metrics:")
print(f" Total Loss: {epoch_stats['total']:.4f}")
print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}")
print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}")
if cfg.finetune_sd:
print(f" Diffusion Loss: {epoch_stats['diff']:.4f}")
print(f"\nValidation Metrics:")
print(f" 🎯 CLIP Score: {avg_val_clip:.3f} (higher = better image-text match)")
print(f"{'='*60}\n")
else:
# Just print training stats if not evaluating
avg_val_clip = None
print(f"\n{'='*60}")
print(f"Epoch {ep} Summary:")
print(f" Total Loss: {epoch_stats['total']:.4f}")
print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}")
print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}")
if cfg.finetune_sd:
print(f" Diffusion Loss: {epoch_stats['diff']:.4f}")
print(f"{'='*60}\n")
# ============================================
# CHECKPOINT SAVING
# ============================================
checkpoint = {
"mapper": model.mapper.state_dict(),
"epoch": ep,
"val_clip_score": avg_val_clip if avg_val_clip is not None else -1,
**{k: v for k, v in epoch_stats.items()},
"config": {
"clap_loss_weight": cfg.clap_loss_weight,
"sd_loss_weight": cfg.sd_loss_weight,
"diffusion_loss_weight": cfg.diffusion_loss_weight,
"finetune_sd": cfg.finetune_sd
}
}
if cfg.finetune_sd:
checkpoint["unet"] = model.sd_unet.state_dict()
# Always save latest checkpoint
torch.save(checkpoint, cfg.ckpt_path)
print(f"💾 Checkpoint saved: {cfg.ckpt_path}")
# Save best model based on CLIP score
if avg_val_clip is not None and avg_val_clip > best_clip_score:
best_clip_score = avg_val_clip
best_path = cfg.ckpt_path.replace('.pt', '_best.pt')
torch.save(checkpoint, best_path)
print(f"✅ New best model! CLIP: {avg_val_clip:.3f} -> Saved to {best_path}")
elif avg_val_clip is not None:
print(f" Current best CLIP: {best_clip_score:.3f}")
print()
print("🎉 Training completed!")
if best_clip_score > -float('inf'):
print(f" Best CLIP score achieved: {best_clip_score:.3f}")
# ========================
# Inference
# ========================
def infer(cfg: Config, wav_path: str, out_path: str):
# Load audio
print(f"Loading audio from {wav_path}...")
wav, sr = torchaudio.load(wav_path)
if wav.size(0) > 1:
wav = wav.mean(0, keepdim=True)
wav = wav.squeeze(0).float()
# Resample to 48kHz for CLAP
if sr != 48000:
print(f"Resampling from {sr}Hz to 48000Hz...")
resampler = torchaudio.transforms.Resample(sr, 48000)
wav = resampler(wav)
sr = 48000
wav = wav.to(cfg.device)
# Load model with SD
model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device)
# Load trained weights
print(f"Loading checkpoint from {cfg.ckpt_path}...")
ckpt = torch.load(cfg.ckpt_path, map_location=cfg.device)
model.mapper.load_state_dict(ckpt["mapper"])
# Load UNet weights if available (from fine-tuning)
if "unet" in ckpt:
print("Loading fine-tuned UNet weights...")
model.sd_unet.load_state_dict(ckpt["unet"])
print(f"Checkpoint info:")
print(f" Epoch: {ckpt.get('epoch', 'unknown')}")
print(f" CLAP Sim: {ckpt.get('clap_sim', 'N/A'):.3f}" if isinstance(ckpt.get('clap_sim'), (int, float)) else f" CLAP Sim: N/A")
print(f" SD Sim: {ckpt.get('sd_sim', 'N/A'):.3f}" if isinstance(ckpt.get('sd_sim'), (int, float)) else f" SD Sim: N/A")
if "unet" in ckpt:
print(" Fine-tuned UNet: ✓")
# Generate image
print("\nGenerating image...")
img = model.generate(wav, sr)
img.save(out_path)
print(f"✓ Generated image saved to {out_path}")
# ========================
# Main
# ========================
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--mode", choices=["train", "infer"], default="train")
parser.add_argument("--wav", help="Audio file path for inference mode")
parser.add_argument("--out", default="output.png", help="Output image path")
args = parser.parse_args()
cfg = Config()
print(f"Device: {cfg.device}")
if args.mode == "train":
print(f"Dataset: {cfg.train_csv}")
if not os.path.exists(cfg.train_csv):
print(f"ERROR: Dataset not found at {cfg.train_csv}")
print("Please ensure the captions.txt file exists")
sys.exit(1)
train(cfg)
else:
if not args.wav:
raise ValueError("Need --wav for inference mode")
if not os.path.exists(args.wav):
raise ValueError(f"Audio file not found: {args.wav}")
infer(cfg, args.wav, args.out)