|
|
""" |
|
|
Audio → Image Generator (Multi-Task Loss Version) |
|
|
Key features: |
|
|
- Dual-head MLP: one for CLAP text space, one for SD embedding space |
|
|
- Multi-task training: CLAP alignment loss + SD alignment loss |
|
|
- Both heads are trained simultaneously |
|
|
- to_sd head is properly trained and used during inference |
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os, math, csv, random, sys |
|
|
from typing import List, Tuple |
|
|
from dataclasses import dataclass |
|
|
import zipfile |
|
|
from io import BytesIO |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torchaudio |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
from tqdm import tqdm |
|
|
|
|
|
from transformers import AutoProcessor, ClapModel, AutoTokenizer, CLIPProcessor, CLIPModel |
|
|
from diffusers import StableDiffusionPipeline, DDPMScheduler, DDIMScheduler |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Config: |
|
|
CLAP_ID: str = "laion/clap-htsat-fused" |
|
|
SD_ID: str = "runwayml/stable-diffusion-v1-5" |
|
|
CLIP_ID: str = "openai/clip-vit-base-patch32" |
|
|
|
|
|
|
|
|
device: str = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
lr: float = 2e-4 |
|
|
weight_decay: float = 1e-4 |
|
|
temperature: float = 0.07 |
|
|
|
|
|
|
|
|
clap_loss_weight: float = 0.5 |
|
|
sd_loss_weight: float = 1.0 |
|
|
diffusion_loss_weight: float = 1.0 |
|
|
|
|
|
batch_size: int = 2 |
|
|
max_epochs: int = 20 |
|
|
base_prompt: str = "A photo of" |
|
|
guidance: float = 7.5 |
|
|
steps: int = 30 |
|
|
|
|
|
|
|
|
train_csv: str = "/Users/rajvarun/Desktop/SIT/Trimester 4/AAI 3001 - Computer Vision & Deep Learning/Seeing Sound II/raj/main_dataV1.csv" |
|
|
image_folder: str = "/Users/rajvarun/OneDrive - Singapore Institute Of Technology/ALEXI KIZHAKKEPURATHU GEORGE's files - VGGSound" |
|
|
ckpt_path: str = "audio2image_mapper_dual_best.pt" |
|
|
|
|
|
|
|
|
use_zip_files: bool = True |
|
|
zip_files: dict = None |
|
|
|
|
|
|
|
|
finetune_sd: bool = False |
|
|
sd_lr: float = 1e-5 |
|
|
freeze_vae: bool = True |
|
|
freeze_text_encoder: bool = True |
|
|
|
|
|
|
|
|
eval_every_n_epochs: int = 1 |
|
|
num_eval_samples: int = 4 |
|
|
save_eval_images: bool = True |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioCaptionDataset(Dataset): |
|
|
""" |
|
|
Reads a CSV file with audio-image-caption triplets. |
|
|
Handles structure where data is in: base_folder/image/ and base_folder/audio/ |
|
|
|
|
|
Can read from extracted folders OR directly from ZIP files (no extraction needed!) |
|
|
|
|
|
Example: |
|
|
- CSV: vggsound_00,g-f_I2yQ_1.png,g-f_I2yQ_000001.wav,people marching |
|
|
- Audio path: vggsound_00/audio/g-f_I2yQ_000001.wav |
|
|
- Image path: vggsound_00/image/g-f_I2yQ_1.png |
|
|
""" |
|
|
def __init__(self, captions_path: str, image_folder: str = None, use_zip_files: bool = False): |
|
|
self.items = [] |
|
|
base_dir = os.path.dirname(captions_path) |
|
|
self.image_folder = image_folder or base_dir |
|
|
self.use_zip_files = use_zip_files |
|
|
self.zip_handles = {} |
|
|
|
|
|
|
|
|
self.img_transform = transforms.Compose([ |
|
|
transforms.Resize((512, 512)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.5], [0.5]) |
|
|
]) |
|
|
|
|
|
print(f"Loading dataset from: {captions_path}") |
|
|
print(f"Base folder: {self.image_folder}") |
|
|
print(f"Use ZIP files: {use_zip_files}") |
|
|
|
|
|
|
|
|
if use_zip_files: |
|
|
self._find_zip_files() |
|
|
|
|
|
|
|
|
import csv |
|
|
with open(captions_path, "r", encoding="utf-8") as f: |
|
|
reader = csv.DictReader(f) |
|
|
|
|
|
for row_num, row in enumerate(reader, 1): |
|
|
|
|
|
if 'base_folder' in row and 'image_file' in row and 'audio_file' in row and 'caption' in row: |
|
|
base_folder = row['base_folder'] |
|
|
img_filename = row['image_file'] |
|
|
audio_filename = row['audio_file'] |
|
|
caption = row['caption'] |
|
|
|
|
|
if use_zip_files: |
|
|
|
|
|
audio_path = f"{base_folder}/audio/{audio_filename}" |
|
|
img_path = f"{base_folder}/image/{img_filename}" |
|
|
|
|
|
|
|
|
audio_exists = self._file_in_zip(base_folder, audio_path) |
|
|
img_exists = self._file_in_zip(base_folder, img_path) |
|
|
|
|
|
|
|
|
if row_num <= 3: |
|
|
print(f"Row {row_num}: base_folder='{base_folder}', audio='{audio_path}', exists={audio_exists}") |
|
|
else: |
|
|
|
|
|
audio_path = os.path.join(self.image_folder, base_folder, "audio", audio_filename) |
|
|
img_path = os.path.join(self.image_folder, base_folder, "image", img_filename) |
|
|
|
|
|
audio_exists = os.path.exists(audio_path) |
|
|
img_exists = os.path.exists(img_path) |
|
|
|
|
|
if audio_exists: |
|
|
if img_exists: |
|
|
self.items.append((base_folder, audio_path, img_path, caption)) |
|
|
else: |
|
|
|
|
|
self.items.append((base_folder, audio_path, None, caption)) |
|
|
if row_num <= 3: |
|
|
print(f"Warning: Image not found: {img_path}") |
|
|
else: |
|
|
if row_num <= 3: |
|
|
print(f"Warning: Audio not found: {audio_path}") |
|
|
else: |
|
|
if row_num <= 3: |
|
|
print(f"Warning: Row {row_num} missing required columns") |
|
|
|
|
|
if not self.items: |
|
|
raise ValueError("Empty dataset: no valid audio files found") |
|
|
|
|
|
|
|
|
with_images = sum(1 for _, _, img_path, _ in self.items if img_path is not None) |
|
|
print(f"✓ Loaded {len(self.items)} audio files ({with_images} with matching images)") |
|
|
|
|
|
def _find_zip_files(self): |
|
|
"""Find and open ZIP files in the image_folder""" |
|
|
print("Searching for ZIP files...") |
|
|
for item in os.listdir(self.image_folder): |
|
|
if item.endswith('.zip'): |
|
|
zip_name = item.replace('.zip', '') |
|
|
zip_path = os.path.join(self.image_folder, item) |
|
|
try: |
|
|
self.zip_handles[zip_name] = zipfile.ZipFile(zip_path, 'r') |
|
|
|
|
|
file_count = len(self.zip_handles[zip_name].namelist()) |
|
|
print(f" ✓ Opened {item} (key: '{zip_name}', {file_count} files)") |
|
|
except Exception as e: |
|
|
print(f" ✗ Failed to open {item}: {e}") |
|
|
|
|
|
def _file_in_zip(self, base_folder, file_path): |
|
|
"""Check if a file exists in the corresponding ZIP""" |
|
|
if base_folder not in self.zip_handles: |
|
|
print(f" ! ZIP handle not found for base_folder='{base_folder}'. Available: {list(self.zip_handles.keys())}") |
|
|
return False |
|
|
try: |
|
|
self.zip_handles[base_folder].getinfo(file_path) |
|
|
return True |
|
|
except KeyError: |
|
|
return False |
|
|
|
|
|
def _read_from_zip(self, base_folder, file_path): |
|
|
"""Read a file from ZIP archive""" |
|
|
if base_folder in self.zip_handles: |
|
|
return self.zip_handles[base_folder].read(file_path) |
|
|
return None |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.items) |
|
|
|
|
|
def __getitem__(self, idx: int): |
|
|
base_folder, audio_path, img_path, cap = self.items[idx] |
|
|
|
|
|
|
|
|
if self.use_zip_files: |
|
|
|
|
|
audio_bytes = self._read_from_zip(base_folder, audio_path) |
|
|
if audio_bytes is None: |
|
|
raise FileNotFoundError(f"Audio not found in ZIP: {audio_path}") |
|
|
wav, sr = torchaudio.load(BytesIO(audio_bytes)) |
|
|
else: |
|
|
|
|
|
wav, sr = torchaudio.load(audio_path) |
|
|
|
|
|
if wav.size(0) > 1: |
|
|
wav = wav.mean(dim=0, keepdim=True) |
|
|
wav = wav.squeeze(0).float() |
|
|
|
|
|
if sr != 48000: |
|
|
resampler = torchaudio.transforms.Resample(sr, 48000) |
|
|
wav = resampler(wav) |
|
|
|
|
|
|
|
|
if img_path is not None: |
|
|
if self.use_zip_files: |
|
|
|
|
|
img_bytes = self._read_from_zip(base_folder, img_path) |
|
|
if img_bytes: |
|
|
img = Image.open(BytesIO(img_bytes)).convert('RGB') |
|
|
img_tensor = self.img_transform(img) |
|
|
else: |
|
|
img_tensor = torch.zeros((3, 512, 512)) |
|
|
else: |
|
|
|
|
|
img = Image.open(img_path).convert('RGB') |
|
|
img_tensor = self.img_transform(img) |
|
|
else: |
|
|
|
|
|
img_tensor = torch.zeros((3, 512, 512)) |
|
|
|
|
|
return wav, 48000, cap, img_tensor, (img_path is not None) |
|
|
|
|
|
def __del__(self): |
|
|
"""Close ZIP files when done""" |
|
|
for zip_handle in self.zip_handles.values(): |
|
|
try: |
|
|
zip_handle.close() |
|
|
except: |
|
|
pass |
|
|
|
|
|
def collate_audio(batch): |
|
|
wavs, srs, caps, imgs, has_imgs = [], [], [], [], [] |
|
|
for w, sr, c, img, has_img in batch: |
|
|
wavs.append(w) |
|
|
srs.append(sr) |
|
|
caps.append(c) |
|
|
imgs.append(img) |
|
|
has_imgs.append(has_img) |
|
|
return wavs, srs[0], caps, torch.stack(imgs), torch.tensor(has_imgs) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AudioProjectionMLP(nn.Module): |
|
|
""" |
|
|
Dual-head MLP projection: |
|
|
- to_text: CLAP audio → CLAP text space (for CLAP alignment) |
|
|
- to_sd: CLAP audio → SD embedding space (for image generation) |
|
|
Both heads are trained with multi-task loss. |
|
|
""" |
|
|
def __init__(self, in_dim, text_dim, sd_dim, hidden=1024): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.shared = nn.Sequential( |
|
|
nn.Linear(in_dim, hidden), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden, hidden), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1) |
|
|
) |
|
|
|
|
|
|
|
|
self.to_text = nn.Sequential( |
|
|
nn.Linear(hidden, hidden), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden, text_dim) |
|
|
) |
|
|
|
|
|
|
|
|
self.to_sd = nn.Sequential( |
|
|
nn.Linear(hidden, hidden), |
|
|
nn.GELU(), |
|
|
nn.Dropout(0.1), |
|
|
nn.Linear(hidden, sd_dim) |
|
|
) |
|
|
|
|
|
def forward(self, z): |
|
|
shared_features = self.shared(z) |
|
|
return self.to_text(shared_features), self.to_sd(shared_features) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Audio2ImageModel(nn.Module): |
|
|
def __init__(self, cfg: Config, load_sd: bool = False): |
|
|
super().__init__() |
|
|
self.cfg = cfg |
|
|
device = cfg.device |
|
|
|
|
|
|
|
|
print("Loading CLAP model...") |
|
|
self.clap = ClapModel.from_pretrained(cfg.CLAP_ID).eval().to(device) |
|
|
for p in self.clap.parameters(): |
|
|
p.requires_grad = False |
|
|
self.proc = AutoProcessor.from_pretrained(cfg.CLAP_ID) |
|
|
|
|
|
|
|
|
print("Loading CLIP for evaluation...") |
|
|
self.clip_model = CLIPModel.from_pretrained(cfg.CLIP_ID).eval().to(device) |
|
|
self.clip_processor = CLIPProcessor.from_pretrained(cfg.CLIP_ID) |
|
|
for p in self.clip_model.parameters(): |
|
|
p.requires_grad = False |
|
|
print(" ✓ CLIP loaded (frozen for evaluation only)") |
|
|
|
|
|
|
|
|
self.sd_pipe = None |
|
|
self.sd_tok = None |
|
|
self.sd_text_encoder = None |
|
|
self.sd_unet = None |
|
|
self.sd_vae = None |
|
|
self.sd_hidden = 768 |
|
|
|
|
|
|
|
|
if True: |
|
|
print("Loading Stable Diffusion...") |
|
|
|
|
|
dtype = torch.float32 if cfg.finetune_sd else (torch.float16 if device == "cuda" else torch.float32) |
|
|
self.sd_pipe = StableDiffusionPipeline.from_pretrained(cfg.SD_ID, torch_dtype=dtype) |
|
|
self.sd_pipe.to(device) |
|
|
|
|
|
self.sd_tok = self.sd_pipe.tokenizer |
|
|
self.sd_text_encoder = self.sd_pipe.text_encoder |
|
|
self.sd_unet = self.sd_pipe.unet |
|
|
self.sd_vae = self.sd_pipe.vae |
|
|
self.sd_hidden = self.sd_pipe.text_encoder.config.hidden_size |
|
|
|
|
|
|
|
|
if cfg.finetune_sd: |
|
|
print("🔥 End-to-End Training Mode:") |
|
|
|
|
|
|
|
|
for p in self.sd_unet.parameters(): |
|
|
p.requires_grad = True |
|
|
self.sd_unet.train() |
|
|
print(" ✓ UNet: TRAINABLE") |
|
|
|
|
|
|
|
|
if cfg.freeze_vae: |
|
|
for p in self.sd_vae.parameters(): |
|
|
p.requires_grad = False |
|
|
self.sd_vae.eval() |
|
|
print(" ✓ VAE: FROZEN") |
|
|
else: |
|
|
for p in self.sd_vae.parameters(): |
|
|
p.requires_grad = True |
|
|
self.sd_vae.train() |
|
|
print(" ✓ VAE: TRAINABLE") |
|
|
|
|
|
|
|
|
if cfg.freeze_text_encoder: |
|
|
for p in self.sd_text_encoder.parameters(): |
|
|
p.requires_grad = False |
|
|
self.sd_text_encoder.eval() |
|
|
print(" ✓ Text Encoder: FROZEN") |
|
|
else: |
|
|
for p in self.sd_text_encoder.parameters(): |
|
|
p.requires_grad = True |
|
|
self.sd_text_encoder.train() |
|
|
print(" ✓ Text Encoder: TRAINABLE") |
|
|
else: |
|
|
print("Inference Mode: All SD components frozen") |
|
|
for comp in (self.sd_unet, self.sd_vae, self.sd_text_encoder): |
|
|
for p in comp.parameters(): |
|
|
p.requires_grad = False |
|
|
comp.eval() |
|
|
|
|
|
|
|
|
dummy_text = ["test"] |
|
|
dummy_audio = [torch.zeros(48000).numpy()] |
|
|
|
|
|
with torch.no_grad(): |
|
|
text_proc = self.proc(text=dummy_text, return_tensors="pt") |
|
|
text_proc = {k: v.to(device) for k,v in text_proc.items()} |
|
|
t = self.clap.get_text_features(**text_proc) |
|
|
clap_text_dim = t.shape[-1] |
|
|
|
|
|
audio_proc = self.proc(audio=dummy_audio, sampling_rate=48000, return_tensors="pt") |
|
|
audio_proc = {k: v.to(device) for k,v in audio_proc.items()} |
|
|
a = self.clap.get_audio_features(**audio_proc) |
|
|
clap_audio_dim = a.shape[-1] |
|
|
|
|
|
|
|
|
print(f"Creating MLP: CLAP audio ({clap_audio_dim}) → CLAP text ({clap_text_dim}) & SD ({self.sd_hidden})") |
|
|
self.mapper = AudioProjectionMLP(clap_audio_dim, clap_text_dim, self.sd_hidden) |
|
|
|
|
|
|
|
|
def encode_text_clap(self, caps): |
|
|
"""Encode text using CLAP text encoder""" |
|
|
proc = self.proc(text=caps, return_tensors="pt", padding=True) |
|
|
proc = {k: v.to(self.cfg.device) for k,v in proc.items()} |
|
|
|
|
|
|
|
|
was_training = self.clap.training |
|
|
self.clap.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
e = self.clap.get_text_features(**proc) |
|
|
|
|
|
|
|
|
if was_training: |
|
|
self.clap.train() |
|
|
|
|
|
return F.normalize(e, dim=-1) |
|
|
|
|
|
def encode_text_sd(self, caps): |
|
|
"""Encode text using SD text encoder (for target embeddings)""" |
|
|
tokens = self.sd_tok( |
|
|
caps, |
|
|
padding="max_length", |
|
|
max_length=self.sd_tok.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
).to(self.cfg.device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
|
outputs = self.sd_text_encoder(tokens["input_ids"]) |
|
|
|
|
|
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None: |
|
|
embeddings = outputs.pooler_output |
|
|
else: |
|
|
embeddings = outputs.last_hidden_state.mean(dim=1) |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def encode_audio(self, wavs, sr): |
|
|
"""Returns raw CLAP audio embeddings - batched processing""" |
|
|
|
|
|
audio_list = [w.cpu().numpy() for w in wavs] |
|
|
|
|
|
|
|
|
proc = self.proc(audio=audio_list, sampling_rate=sr, return_tensors="pt") |
|
|
proc = {k: v.to(self.cfg.device) for k, v in proc.items()} |
|
|
|
|
|
|
|
|
was_training = self.clap.training |
|
|
self.clap.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
embeddings = self.clap.get_audio_features(**proc) |
|
|
|
|
|
|
|
|
if was_training: |
|
|
self.clap.train() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
|
|
|
@staticmethod |
|
|
def info_nce(a, b, temp): |
|
|
"""InfoNCE contrastive loss""" |
|
|
a, b = F.normalize(a, dim=-1), F.normalize(b, dim=-1) |
|
|
logits = a @ b.t() / temp |
|
|
tgt = torch.arange(a.size(0), device=a.device) |
|
|
return 0.5 * (F.cross_entropy(logits, tgt) + F.cross_entropy(logits.t(), tgt)) |
|
|
|
|
|
def compute_diffusion_loss(self, images, audio_emb): |
|
|
""" |
|
|
Diffusion loss: Trains SD UNet to denoise images conditioned on audio. |
|
|
This enables end-to-end learning of the generative model! |
|
|
|
|
|
Args: |
|
|
images: Ground truth images [B, 3, 512, 512] in range [-1, 1] |
|
|
audio_emb: Audio embeddings from CLAP |
|
|
|
|
|
Returns: |
|
|
Denoising loss (MSE between predicted and actual noise) |
|
|
""" |
|
|
|
|
|
with torch.no_grad(): |
|
|
latents = self.sd_vae.encode(images).latent_dist.sample() |
|
|
latents = latents * 0.18215 |
|
|
|
|
|
|
|
|
noise = torch.randn_like(latents) |
|
|
bsz = latents.shape[0] |
|
|
timesteps = torch.randint( |
|
|
0, 1000, (bsz,), |
|
|
device=latents.device |
|
|
).long() |
|
|
|
|
|
|
|
|
if not hasattr(self, 'noise_scheduler'): |
|
|
self.noise_scheduler = DDPMScheduler.from_pretrained( |
|
|
self.cfg.SD_ID, |
|
|
subfolder="scheduler" |
|
|
) |
|
|
|
|
|
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) |
|
|
|
|
|
|
|
|
_, audio_to_sd = self.mapper(audio_emb) |
|
|
|
|
|
|
|
|
encoder_hidden_states = audio_to_sd.unsqueeze(1) |
|
|
|
|
|
|
|
|
noise_pred = self.sd_unet( |
|
|
noisy_latents, |
|
|
timesteps, |
|
|
encoder_hidden_states |
|
|
).sample |
|
|
|
|
|
|
|
|
|
|
|
loss = F.mse_loss(noise_pred, noise, reduction='mean') |
|
|
|
|
|
return loss |
|
|
|
|
|
@torch.inference_mode() |
|
|
def evaluate_generation(self, wavs, sr, captions, num_samples=None): |
|
|
""" |
|
|
Evaluate quality of generated images using CLIP text-image similarity. |
|
|
|
|
|
Args: |
|
|
wavs: List of audio waveforms |
|
|
sr: Sample rate |
|
|
captions: List of text captions describing the audio |
|
|
num_samples: Number of samples to evaluate (None = all) |
|
|
|
|
|
Returns: |
|
|
avg_clip_score: Average CLIP similarity score (0-100) |
|
|
generated_images: List of PIL images |
|
|
clip_scores: List of individual CLIP scores |
|
|
""" |
|
|
was_training = self.training |
|
|
self.eval() |
|
|
|
|
|
if num_samples is not None: |
|
|
wavs = wavs[:num_samples] |
|
|
captions = captions[:num_samples] |
|
|
|
|
|
generated_images = [] |
|
|
clip_scores = [] |
|
|
|
|
|
for wav, caption in zip(wavs, captions): |
|
|
|
|
|
img = self.generate(wav, sr) |
|
|
generated_images.append(img) |
|
|
|
|
|
|
|
|
inputs = self.clip_processor( |
|
|
text=[caption], |
|
|
images=[img], |
|
|
return_tensors="pt", |
|
|
padding=True |
|
|
).to(self.cfg.device) |
|
|
|
|
|
outputs = self.clip_model(**inputs) |
|
|
|
|
|
|
|
|
|
|
|
logits_per_image = outputs.logits_per_image |
|
|
clip_score = logits_per_image[0, 0].item() |
|
|
clip_scores.append(clip_score) |
|
|
|
|
|
avg_clip_score = sum(clip_scores) / len(clip_scores) if clip_scores else 0.0 |
|
|
|
|
|
if was_training: |
|
|
self.train() |
|
|
|
|
|
return avg_clip_score, generated_images, clip_scores |
|
|
|
|
|
|
|
|
def forward(self, wavs, sr, caps, images=None, has_images=None): |
|
|
""" |
|
|
Forward pass with three parallel losses: |
|
|
1. CLAP alignment (semantic understanding) |
|
|
2. SD embedding alignment (embedding compatibility) |
|
|
3. Diffusion loss (pixel-level generation) - requires images |
|
|
|
|
|
All losses train simultaneously in end-to-end fashion! |
|
|
""" |
|
|
|
|
|
clap_text_emb = self.encode_text_clap(caps) |
|
|
sd_text_emb = self.encode_text_sd(caps) |
|
|
|
|
|
|
|
|
audio_emb = self.encode_audio(wavs, sr) |
|
|
|
|
|
|
|
|
audio_to_clap, audio_to_sd = self.mapper(audio_emb) |
|
|
|
|
|
|
|
|
loss_clap = self.info_nce(audio_to_clap, clap_text_emb, self.cfg.temperature) |
|
|
|
|
|
|
|
|
loss_sd = F.mse_loss(audio_to_sd, sd_text_emb) |
|
|
|
|
|
|
|
|
loss_diffusion = torch.tensor(0.0, device=self.cfg.device) |
|
|
if self.cfg.finetune_sd and images is not None: |
|
|
|
|
|
if has_images is not None: |
|
|
valid_mask = has_images.to(self.cfg.device) |
|
|
if valid_mask.sum() > 0: |
|
|
valid_imgs = images[valid_mask] |
|
|
valid_audio_emb = audio_emb[valid_mask] |
|
|
loss_diffusion = self.compute_diffusion_loss(valid_imgs, valid_audio_emb) |
|
|
else: |
|
|
loss_diffusion = self.compute_diffusion_loss(images, audio_emb) |
|
|
|
|
|
|
|
|
total_loss = ( |
|
|
self.cfg.clap_loss_weight * loss_clap + |
|
|
self.cfg.sd_loss_weight * loss_sd + |
|
|
self.cfg.diffusion_loss_weight * loss_diffusion |
|
|
) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
clap_sim = torch.diagonal( |
|
|
F.normalize(audio_to_clap, dim=-1) @ F.normalize(clap_text_emb, dim=-1).t() |
|
|
).mean() |
|
|
|
|
|
sd_sim = F.cosine_similarity(audio_to_sd, sd_text_emb, dim=-1).mean() |
|
|
|
|
|
return total_loss, { |
|
|
"loss_clap": loss_clap.item(), |
|
|
"loss_sd": loss_sd.item(), |
|
|
"loss_diffusion": loss_diffusion.item(), |
|
|
"clap_sim": clap_sim.item(), |
|
|
"sd_sim": sd_sim.item() |
|
|
} |
|
|
|
|
|
|
|
|
@torch.inference_mode() |
|
|
def generate(self, wav, sr): |
|
|
if self.sd_pipe is None: |
|
|
raise RuntimeError("Stable Diffusion not loaded. Init with load_sd=True.") |
|
|
|
|
|
|
|
|
audio_emb = self.encode_audio([wav], sr) |
|
|
_, soft_token = self.mapper(audio_emb) |
|
|
|
|
|
|
|
|
tok = self.sd_tok( |
|
|
self.cfg.base_prompt, |
|
|
padding="max_length", |
|
|
max_length=self.sd_tok.model_max_length, |
|
|
truncation=True, |
|
|
return_tensors="pt" |
|
|
).to(self.cfg.device) |
|
|
|
|
|
|
|
|
enc = self.sd_text_encoder(tok["input_ids"])[0] |
|
|
|
|
|
|
|
|
attention_mask = tok["attention_mask"][0] |
|
|
last_token_pos = attention_mask.nonzero(as_tuple=False).max().item() |
|
|
|
|
|
|
|
|
if last_token_pos + 1 < enc.shape[1]: |
|
|
enc[0, last_token_pos + 1:last_token_pos + 2, :] = soft_token |
|
|
else: |
|
|
|
|
|
enc[0, last_token_pos:last_token_pos + 1, :] = soft_token |
|
|
|
|
|
|
|
|
img = self.sd_pipe( |
|
|
num_inference_steps=self.cfg.steps, |
|
|
guidance_scale=self.cfg.guidance, |
|
|
prompt_embeds=enc |
|
|
).images[0] |
|
|
|
|
|
return img |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train(cfg: Config): |
|
|
|
|
|
full_ds = AudioCaptionDataset(cfg.train_csv, cfg.image_folder, use_zip_files=cfg.use_zip_files) |
|
|
|
|
|
|
|
|
train_size = int(0.9 * len(full_ds)) |
|
|
val_size = len(full_ds) - train_size |
|
|
train_ds, val_ds = torch.utils.data.random_split( |
|
|
full_ds, |
|
|
[train_size, val_size], |
|
|
generator=torch.Generator().manual_seed(42) |
|
|
) |
|
|
|
|
|
print(f"\nDataset split:") |
|
|
print(f" Training: {len(train_ds)} samples") |
|
|
print(f" Validation: {len(val_ds)} samples\n") |
|
|
|
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=cfg.batch_size, |
|
|
shuffle=True, |
|
|
collate_fn=collate_audio, |
|
|
num_workers=0, |
|
|
drop_last=True |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_ds, |
|
|
batch_size=cfg.batch_size, |
|
|
shuffle=False, |
|
|
collate_fn=collate_audio, |
|
|
num_workers=0 |
|
|
) |
|
|
|
|
|
|
|
|
model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device) |
|
|
|
|
|
|
|
|
if cfg.finetune_sd: |
|
|
print("\n🔥 Setting up END-TO-END training:") |
|
|
|
|
|
|
|
|
opt_mapper = torch.optim.AdamW( |
|
|
model.mapper.parameters(), |
|
|
lr=cfg.lr, |
|
|
weight_decay=cfg.weight_decay |
|
|
) |
|
|
print(f" Mapper optimizer: LR={cfg.lr}") |
|
|
|
|
|
|
|
|
opt_sd = torch.optim.AdamW( |
|
|
model.sd_unet.parameters(), |
|
|
lr=cfg.sd_lr, |
|
|
weight_decay=cfg.weight_decay |
|
|
) |
|
|
print(f" SD UNet optimizer: LR={cfg.sd_lr}") |
|
|
|
|
|
opts = [opt_mapper, opt_sd] |
|
|
else: |
|
|
|
|
|
opt_mapper = torch.optim.AdamW( |
|
|
model.parameters(), |
|
|
lr=cfg.lr, |
|
|
weight_decay=cfg.weight_decay |
|
|
) |
|
|
opts = [opt_mapper] |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"Starting {'End-to-End' if cfg.finetune_sd else 'Mapper-Only'} Training") |
|
|
print(f"{'='*60}") |
|
|
print(f"Dataset: {len(full_ds)} samples ({len(train_ds)} train, {len(val_ds)} val)") |
|
|
print(f"Batch size: {cfg.batch_size}") |
|
|
print(f"Epochs: {cfg.max_epochs}") |
|
|
print(f"Evaluation: Every {cfg.eval_every_n_epochs} epoch(s)") |
|
|
print(f"Loss weights:") |
|
|
print(f" CLAP: {cfg.clap_loss_weight}") |
|
|
print(f" SD Embedding: {cfg.sd_loss_weight}") |
|
|
if cfg.finetune_sd: |
|
|
print(f" Diffusion: {cfg.diffusion_loss_weight}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
|
|
|
best_clip_score = -float('inf') |
|
|
|
|
|
for ep in range(1, cfg.max_epochs + 1): |
|
|
|
|
|
|
|
|
|
|
|
model.train() |
|
|
pbar = tqdm(train_loader, desc=f"Epoch {ep}/{cfg.max_epochs} [TRAIN]") |
|
|
|
|
|
epoch_stats = { |
|
|
"total": 0, "clap": 0, "sd": 0, "diff": 0, |
|
|
"clap_sim": 0, "sd_sim": 0 |
|
|
} |
|
|
|
|
|
for wavs, sr, caps, imgs, has_imgs in pbar: |
|
|
wavs = [w.to(cfg.device) for w in wavs] |
|
|
imgs = imgs.to(cfg.device) |
|
|
|
|
|
|
|
|
loss, stats = model(wavs, sr, caps, imgs if cfg.finetune_sd else None, has_imgs) |
|
|
|
|
|
|
|
|
for opt in opts: |
|
|
opt.zero_grad() |
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
|
|
|
|
if cfg.finetune_sd: |
|
|
nn.utils.clip_grad_norm_(model.mapper.parameters(), 1.0) |
|
|
nn.utils.clip_grad_norm_(model.sd_unet.parameters(), 1.0) |
|
|
else: |
|
|
nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
|
|
|
|
|
|
|
|
for opt in opts: |
|
|
opt.step() |
|
|
|
|
|
|
|
|
epoch_stats["total"] += loss.item() |
|
|
epoch_stats["clap"] += stats['loss_clap'] |
|
|
epoch_stats["sd"] += stats['loss_sd'] |
|
|
epoch_stats["diff"] += stats['loss_diffusion'] |
|
|
epoch_stats["clap_sim"] += stats['clap_sim'] |
|
|
epoch_stats["sd_sim"] += stats['sd_sim'] |
|
|
|
|
|
pbar.set_postfix({ |
|
|
"total loss": f"{loss.item():.3f}", |
|
|
"diff": f"{stats['loss_diffusion']:.3f}", |
|
|
"c_sim": f"{stats['clap_sim']:.2f}", |
|
|
"s_sim": f"{stats['sd_sim']:.2f}" |
|
|
}) |
|
|
|
|
|
|
|
|
n_train = len(train_loader) |
|
|
for k in epoch_stats: |
|
|
epoch_stats[k] /= n_train |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if ep % cfg.eval_every_n_epochs == 0: |
|
|
print(f"\n{'='*60}") |
|
|
print(f"🔍 Evaluating Epoch {ep}...") |
|
|
print(f"{'='*60}") |
|
|
|
|
|
model.eval() |
|
|
val_clip_scores = [] |
|
|
all_gen_images = [] |
|
|
all_captions = [] |
|
|
|
|
|
|
|
|
eval_batches = min(3, len(val_loader)) |
|
|
|
|
|
for batch_idx, (wavs, sr, caps, imgs, has_imgs) in enumerate(val_loader): |
|
|
if batch_idx >= eval_batches: |
|
|
break |
|
|
|
|
|
wavs = [w.to(cfg.device) for w in wavs] |
|
|
|
|
|
|
|
|
avg_score, gen_imgs, scores = model.evaluate_generation( |
|
|
wavs, sr, caps, |
|
|
num_samples=cfg.num_eval_samples |
|
|
) |
|
|
|
|
|
val_clip_scores.extend(scores) |
|
|
all_gen_images.extend(gen_imgs) |
|
|
all_captions.extend(caps[:cfg.num_eval_samples]) |
|
|
|
|
|
print(f" Batch {batch_idx + 1}/{eval_batches}: Avg CLIP = {avg_score:.3f}") |
|
|
|
|
|
|
|
|
avg_val_clip = sum(val_clip_scores) / len(val_clip_scores) if val_clip_scores else 0.0 |
|
|
|
|
|
|
|
|
if cfg.save_eval_images and all_gen_images: |
|
|
os.makedirs("eval_samples", exist_ok=True) |
|
|
for i, (img, cap, score) in enumerate(zip(all_gen_images[:4], all_captions[:4], val_clip_scores[:4])): |
|
|
save_path = f"eval_samples/ep{ep}_sample{i}_score{score:.2f}.png" |
|
|
img.save(save_path) |
|
|
print(f" Sample {i}: '{cap[:50]}...' | CLIP: {score:.3f}") |
|
|
print(f" Saved to: {save_path}") |
|
|
|
|
|
|
|
|
if cfg.device == "mps": |
|
|
torch.mps.empty_cache() |
|
|
|
|
|
print(f"\n{'='*60}") |
|
|
print(f"📊 Epoch {ep} Summary:") |
|
|
print(f"{'='*60}") |
|
|
print(f"Training Metrics:") |
|
|
print(f" Total Loss: {epoch_stats['total']:.4f}") |
|
|
print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}") |
|
|
print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}") |
|
|
if cfg.finetune_sd: |
|
|
print(f" Diffusion Loss: {epoch_stats['diff']:.4f}") |
|
|
print(f"\nValidation Metrics:") |
|
|
print(f" 🎯 CLIP Score: {avg_val_clip:.3f} (higher = better image-text match)") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
else: |
|
|
|
|
|
avg_val_clip = None |
|
|
print(f"\n{'='*60}") |
|
|
print(f"Epoch {ep} Summary:") |
|
|
print(f" Total Loss: {epoch_stats['total']:.4f}") |
|
|
print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}") |
|
|
print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}") |
|
|
if cfg.finetune_sd: |
|
|
print(f" Diffusion Loss: {epoch_stats['diff']:.4f}") |
|
|
print(f"{'='*60}\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
checkpoint = { |
|
|
"mapper": model.mapper.state_dict(), |
|
|
"epoch": ep, |
|
|
"val_clip_score": avg_val_clip if avg_val_clip is not None else -1, |
|
|
**{k: v for k, v in epoch_stats.items()}, |
|
|
"config": { |
|
|
"clap_loss_weight": cfg.clap_loss_weight, |
|
|
"sd_loss_weight": cfg.sd_loss_weight, |
|
|
"diffusion_loss_weight": cfg.diffusion_loss_weight, |
|
|
"finetune_sd": cfg.finetune_sd |
|
|
} |
|
|
} |
|
|
|
|
|
if cfg.finetune_sd: |
|
|
checkpoint["unet"] = model.sd_unet.state_dict() |
|
|
|
|
|
|
|
|
torch.save(checkpoint, cfg.ckpt_path) |
|
|
print(f"💾 Checkpoint saved: {cfg.ckpt_path}") |
|
|
|
|
|
|
|
|
if avg_val_clip is not None and avg_val_clip > best_clip_score: |
|
|
best_clip_score = avg_val_clip |
|
|
best_path = cfg.ckpt_path.replace('.pt', '_best.pt') |
|
|
torch.save(checkpoint, best_path) |
|
|
print(f"✅ New best model! CLIP: {avg_val_clip:.3f} -> Saved to {best_path}") |
|
|
elif avg_val_clip is not None: |
|
|
print(f" Current best CLIP: {best_clip_score:.3f}") |
|
|
|
|
|
print() |
|
|
|
|
|
print("🎉 Training completed!") |
|
|
if best_clip_score > -float('inf'): |
|
|
print(f" Best CLIP score achieved: {best_clip_score:.3f}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def infer(cfg: Config, wav_path: str, out_path: str): |
|
|
|
|
|
print(f"Loading audio from {wav_path}...") |
|
|
wav, sr = torchaudio.load(wav_path) |
|
|
if wav.size(0) > 1: |
|
|
wav = wav.mean(0, keepdim=True) |
|
|
wav = wav.squeeze(0).float() |
|
|
|
|
|
|
|
|
if sr != 48000: |
|
|
print(f"Resampling from {sr}Hz to 48000Hz...") |
|
|
resampler = torchaudio.transforms.Resample(sr, 48000) |
|
|
wav = resampler(wav) |
|
|
sr = 48000 |
|
|
|
|
|
wav = wav.to(cfg.device) |
|
|
|
|
|
|
|
|
model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device) |
|
|
|
|
|
|
|
|
print(f"Loading checkpoint from {cfg.ckpt_path}...") |
|
|
ckpt = torch.load(cfg.ckpt_path, map_location=cfg.device) |
|
|
model.mapper.load_state_dict(ckpt["mapper"]) |
|
|
|
|
|
|
|
|
if "unet" in ckpt: |
|
|
print("Loading fine-tuned UNet weights...") |
|
|
model.sd_unet.load_state_dict(ckpt["unet"]) |
|
|
|
|
|
print(f"Checkpoint info:") |
|
|
print(f" Epoch: {ckpt.get('epoch', 'unknown')}") |
|
|
print(f" CLAP Sim: {ckpt.get('clap_sim', 'N/A'):.3f}" if isinstance(ckpt.get('clap_sim'), (int, float)) else f" CLAP Sim: N/A") |
|
|
print(f" SD Sim: {ckpt.get('sd_sim', 'N/A'):.3f}" if isinstance(ckpt.get('sd_sim'), (int, float)) else f" SD Sim: N/A") |
|
|
if "unet" in ckpt: |
|
|
print(" Fine-tuned UNet: ✓") |
|
|
|
|
|
|
|
|
print("\nGenerating image...") |
|
|
img = model.generate(wav, sr) |
|
|
img.save(out_path) |
|
|
print(f"✓ Generated image saved to {out_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--mode", choices=["train", "infer"], default="train") |
|
|
parser.add_argument("--wav", help="Audio file path for inference mode") |
|
|
parser.add_argument("--out", default="output.png", help="Output image path") |
|
|
args = parser.parse_args() |
|
|
|
|
|
cfg = Config() |
|
|
print(f"Device: {cfg.device}") |
|
|
|
|
|
if args.mode == "train": |
|
|
print(f"Dataset: {cfg.train_csv}") |
|
|
if not os.path.exists(cfg.train_csv): |
|
|
print(f"ERROR: Dataset not found at {cfg.train_csv}") |
|
|
print("Please ensure the captions.txt file exists") |
|
|
sys.exit(1) |
|
|
train(cfg) |
|
|
else: |
|
|
if not args.wav: |
|
|
raise ValueError("Need --wav for inference mode") |
|
|
if not os.path.exists(args.wav): |
|
|
raise ValueError(f"Audio file not found: {args.wav}") |
|
|
infer(cfg, args.wav, args.out) |