Upload 2 files
Browse files
README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
tags:
|
| 4 |
+
- audio-to-image
|
| 5 |
+
- stable-diffusion
|
| 6 |
+
---
|
| 7 |
+
|
| 8 |
+
# Audio2Image Model
|
| 9 |
+
|
| 10 |
+
Generates images from audio using neural synthesis.
|
main2.py
ADDED
|
@@ -0,0 +1,1032 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Audio โ Image Generator (Multi-Task Loss Version)
|
| 3 |
+
Key features:
|
| 4 |
+
- Dual-head MLP: one for CLAP text space, one for SD embedding space
|
| 5 |
+
- Multi-task training: CLAP alignment loss + SD alignment loss
|
| 6 |
+
- Both heads are trained simultaneously
|
| 7 |
+
- to_sd head is properly trained and used during inference
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
# ========================
|
| 11 |
+
# Imports
|
| 12 |
+
# ========================
|
| 13 |
+
import os, math, csv, random, sys
|
| 14 |
+
from typing import List, Tuple
|
| 15 |
+
from dataclasses import dataclass
|
| 16 |
+
import zipfile
|
| 17 |
+
from io import BytesIO
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
import torch.nn as nn
|
| 21 |
+
import torch.nn.functional as F
|
| 22 |
+
import torchaudio
|
| 23 |
+
from torch.utils.data import Dataset, DataLoader
|
| 24 |
+
from tqdm import tqdm
|
| 25 |
+
|
| 26 |
+
from transformers import AutoProcessor, ClapModel, AutoTokenizer, CLIPProcessor, CLIPModel
|
| 27 |
+
from diffusers import StableDiffusionPipeline, DDPMScheduler, DDIMScheduler
|
| 28 |
+
from PIL import Image
|
| 29 |
+
from torchvision import transforms
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ========================
|
| 33 |
+
# Configuration
|
| 34 |
+
# ========================
|
| 35 |
+
@dataclass
|
| 36 |
+
class Config:
|
| 37 |
+
CLAP_ID: str = "laion/clap-htsat-fused"
|
| 38 |
+
SD_ID: str = "runwayml/stable-diffusion-v1-5"
|
| 39 |
+
CLIP_ID: str = "openai/clip-vit-base-patch32"
|
| 40 |
+
|
| 41 |
+
# Device configuration - automatically uses GPU if available
|
| 42 |
+
device: str = "mps" if torch.backends.mps.is_available() else ("cuda" if torch.cuda.is_available() else "cpu")
|
| 43 |
+
|
| 44 |
+
lr: float = 2e-4
|
| 45 |
+
weight_decay: float = 1e-4
|
| 46 |
+
temperature: float = 0.07
|
| 47 |
+
|
| 48 |
+
# Multi-task loss weights
|
| 49 |
+
clap_loss_weight: float = 0.5
|
| 50 |
+
sd_loss_weight: float = 1.0
|
| 51 |
+
diffusion_loss_weight: float = 1.0
|
| 52 |
+
|
| 53 |
+
batch_size: int = 2 # Reduced for Mac GPU memory
|
| 54 |
+
max_epochs: int = 20
|
| 55 |
+
base_prompt: str = "A photo of"
|
| 56 |
+
guidance: float = 7.5
|
| 57 |
+
steps: int = 30
|
| 58 |
+
|
| 59 |
+
# Dataset paths
|
| 60 |
+
train_csv: str = "/Users/rajvarun/Desktop/SIT/Trimester 4/AAI 3001 - Computer Vision & Deep Learning/Seeing Sound II/raj/main_dataV1.csv"
|
| 61 |
+
image_folder: str = "/Users/rajvarun/OneDrive - Singapore Institute Of Technology/ALEXI KIZHAKKEPURATHU GEORGE's files - VGGSound" # OneDrive folder with ZIP files
|
| 62 |
+
ckpt_path: str = "audio2image_mapper_dual_best.pt"
|
| 63 |
+
|
| 64 |
+
# ZIP file support (if data is in ZIP files instead of extracted)
|
| 65 |
+
use_zip_files: bool = True # Set to True to read from ZIP files directly
|
| 66 |
+
zip_files: dict = None # Will be populated automatically
|
| 67 |
+
|
| 68 |
+
# Fine-tuning control
|
| 69 |
+
finetune_sd: bool = False # Set to False to train without images
|
| 70 |
+
sd_lr: float = 1e-5
|
| 71 |
+
freeze_vae: bool = True
|
| 72 |
+
freeze_text_encoder: bool = True
|
| 73 |
+
|
| 74 |
+
# Evaluation settings
|
| 75 |
+
eval_every_n_epochs: int = 1 # Evaluate every N epochs
|
| 76 |
+
num_eval_samples: int = 4 # Number of samples to evaluate per batch
|
| 77 |
+
save_eval_images: bool = True # Save example generated images
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ========================
|
| 81 |
+
# Dataset
|
| 82 |
+
# ========================
|
| 83 |
+
class AudioCaptionDataset(Dataset):
|
| 84 |
+
"""
|
| 85 |
+
Reads a CSV file with audio-image-caption triplets.
|
| 86 |
+
Handles structure where data is in: base_folder/image/ and base_folder/audio/
|
| 87 |
+
|
| 88 |
+
Can read from extracted folders OR directly from ZIP files (no extraction needed!)
|
| 89 |
+
|
| 90 |
+
Example:
|
| 91 |
+
- CSV: vggsound_00,g-f_I2yQ_1.png,g-f_I2yQ_000001.wav,people marching
|
| 92 |
+
- Audio path: vggsound_00/audio/g-f_I2yQ_000001.wav
|
| 93 |
+
- Image path: vggsound_00/image/g-f_I2yQ_1.png
|
| 94 |
+
"""
|
| 95 |
+
def __init__(self, captions_path: str, image_folder: str = None, use_zip_files: bool = False):
|
| 96 |
+
self.items = []
|
| 97 |
+
base_dir = os.path.dirname(captions_path)
|
| 98 |
+
self.image_folder = image_folder or base_dir
|
| 99 |
+
self.use_zip_files = use_zip_files
|
| 100 |
+
self.zip_handles = {} # Cache opened ZIP files
|
| 101 |
+
|
| 102 |
+
# Image preprocessing for SD (512x512, normalized to [-1, 1])
|
| 103 |
+
self.img_transform = transforms.Compose([
|
| 104 |
+
transforms.Resize((512, 512)),
|
| 105 |
+
transforms.ToTensor(),
|
| 106 |
+
transforms.Normalize([0.5], [0.5])
|
| 107 |
+
])
|
| 108 |
+
|
| 109 |
+
print(f"Loading dataset from: {captions_path}")
|
| 110 |
+
print(f"Base folder: {self.image_folder}")
|
| 111 |
+
print(f"Use ZIP files: {use_zip_files}")
|
| 112 |
+
|
| 113 |
+
# If using ZIP files, find and open them
|
| 114 |
+
if use_zip_files:
|
| 115 |
+
self._find_zip_files()
|
| 116 |
+
|
| 117 |
+
# Read CSV file
|
| 118 |
+
import csv
|
| 119 |
+
with open(captions_path, "r", encoding="utf-8") as f:
|
| 120 |
+
reader = csv.DictReader(f)
|
| 121 |
+
|
| 122 |
+
for row_num, row in enumerate(reader, 1):
|
| 123 |
+
# CSV format: base_folder,image_file,audio_file,caption
|
| 124 |
+
if 'base_folder' in row and 'image_file' in row and 'audio_file' in row and 'caption' in row:
|
| 125 |
+
base_folder = row['base_folder'] # e.g., "vggsound_00"
|
| 126 |
+
img_filename = row['image_file'] # e.g., "g-f_I2yQ_1.png"
|
| 127 |
+
audio_filename = row['audio_file'] # e.g., "g-f_I2yQ_000001.wav"
|
| 128 |
+
caption = row['caption']
|
| 129 |
+
|
| 130 |
+
if use_zip_files:
|
| 131 |
+
# Use ZIP file paths
|
| 132 |
+
audio_path = f"{base_folder}/audio/{audio_filename}"
|
| 133 |
+
img_path = f"{base_folder}/image/{img_filename}"
|
| 134 |
+
|
| 135 |
+
# Check if files exist in ZIP
|
| 136 |
+
audio_exists = self._file_in_zip(base_folder, audio_path)
|
| 137 |
+
img_exists = self._file_in_zip(base_folder, img_path)
|
| 138 |
+
|
| 139 |
+
# Debug first few rows
|
| 140 |
+
if row_num <= 3:
|
| 141 |
+
print(f"Row {row_num}: base_folder='{base_folder}', audio='{audio_path}', exists={audio_exists}")
|
| 142 |
+
else:
|
| 143 |
+
# Use regular file paths
|
| 144 |
+
audio_path = os.path.join(self.image_folder, base_folder, "audio", audio_filename)
|
| 145 |
+
img_path = os.path.join(self.image_folder, base_folder, "image", img_filename)
|
| 146 |
+
|
| 147 |
+
audio_exists = os.path.exists(audio_path)
|
| 148 |
+
img_exists = os.path.exists(img_path)
|
| 149 |
+
|
| 150 |
+
if audio_exists:
|
| 151 |
+
if img_exists:
|
| 152 |
+
self.items.append((base_folder, audio_path, img_path, caption))
|
| 153 |
+
else:
|
| 154 |
+
# Audio exists but image doesn't
|
| 155 |
+
self.items.append((base_folder, audio_path, None, caption))
|
| 156 |
+
if row_num <= 3:
|
| 157 |
+
print(f"Warning: Image not found: {img_path}")
|
| 158 |
+
else:
|
| 159 |
+
if row_num <= 3:
|
| 160 |
+
print(f"Warning: Audio not found: {audio_path}")
|
| 161 |
+
else:
|
| 162 |
+
if row_num <= 3:
|
| 163 |
+
print(f"Warning: Row {row_num} missing required columns")
|
| 164 |
+
|
| 165 |
+
if not self.items:
|
| 166 |
+
raise ValueError("Empty dataset: no valid audio files found")
|
| 167 |
+
|
| 168 |
+
# Count how many have images
|
| 169 |
+
with_images = sum(1 for _, _, img_path, _ in self.items if img_path is not None)
|
| 170 |
+
print(f"โ Loaded {len(self.items)} audio files ({with_images} with matching images)")
|
| 171 |
+
|
| 172 |
+
def _find_zip_files(self):
|
| 173 |
+
"""Find and open ZIP files in the image_folder"""
|
| 174 |
+
print("Searching for ZIP files...")
|
| 175 |
+
for item in os.listdir(self.image_folder):
|
| 176 |
+
if item.endswith('.zip'):
|
| 177 |
+
zip_name = item.replace('.zip', '')
|
| 178 |
+
zip_path = os.path.join(self.image_folder, item)
|
| 179 |
+
try:
|
| 180 |
+
self.zip_handles[zip_name] = zipfile.ZipFile(zip_path, 'r')
|
| 181 |
+
# Get number of files in ZIP for debugging
|
| 182 |
+
file_count = len(self.zip_handles[zip_name].namelist())
|
| 183 |
+
print(f" โ Opened {item} (key: '{zip_name}', {file_count} files)")
|
| 184 |
+
except Exception as e:
|
| 185 |
+
print(f" โ Failed to open {item}: {e}")
|
| 186 |
+
|
| 187 |
+
def _file_in_zip(self, base_folder, file_path):
|
| 188 |
+
"""Check if a file exists in the corresponding ZIP"""
|
| 189 |
+
if base_folder not in self.zip_handles:
|
| 190 |
+
print(f" ! ZIP handle not found for base_folder='{base_folder}'. Available: {list(self.zip_handles.keys())}")
|
| 191 |
+
return False
|
| 192 |
+
try:
|
| 193 |
+
self.zip_handles[base_folder].getinfo(file_path)
|
| 194 |
+
return True
|
| 195 |
+
except KeyError:
|
| 196 |
+
return False
|
| 197 |
+
|
| 198 |
+
def _read_from_zip(self, base_folder, file_path):
|
| 199 |
+
"""Read a file from ZIP archive"""
|
| 200 |
+
if base_folder in self.zip_handles:
|
| 201 |
+
return self.zip_handles[base_folder].read(file_path)
|
| 202 |
+
return None
|
| 203 |
+
|
| 204 |
+
def __len__(self):
|
| 205 |
+
return len(self.items)
|
| 206 |
+
|
| 207 |
+
def __getitem__(self, idx: int):
|
| 208 |
+
base_folder, audio_path, img_path, cap = self.items[idx]
|
| 209 |
+
|
| 210 |
+
# Load audio
|
| 211 |
+
if self.use_zip_files:
|
| 212 |
+
# Read audio from ZIP
|
| 213 |
+
audio_bytes = self._read_from_zip(base_folder, audio_path)
|
| 214 |
+
if audio_bytes is None:
|
| 215 |
+
raise FileNotFoundError(f"Audio not found in ZIP: {audio_path}")
|
| 216 |
+
wav, sr = torchaudio.load(BytesIO(audio_bytes))
|
| 217 |
+
else:
|
| 218 |
+
# Read from file system
|
| 219 |
+
wav, sr = torchaudio.load(audio_path)
|
| 220 |
+
|
| 221 |
+
if wav.size(0) > 1:
|
| 222 |
+
wav = wav.mean(dim=0, keepdim=True)
|
| 223 |
+
wav = wav.squeeze(0).float()
|
| 224 |
+
# Resample to 48kHz for CLAP
|
| 225 |
+
if sr != 48000:
|
| 226 |
+
resampler = torchaudio.transforms.Resample(sr, 48000)
|
| 227 |
+
wav = resampler(wav)
|
| 228 |
+
|
| 229 |
+
# Load image if available
|
| 230 |
+
if img_path is not None:
|
| 231 |
+
if self.use_zip_files:
|
| 232 |
+
# Read image from ZIP
|
| 233 |
+
img_bytes = self._read_from_zip(base_folder, img_path)
|
| 234 |
+
if img_bytes:
|
| 235 |
+
img = Image.open(BytesIO(img_bytes)).convert('RGB')
|
| 236 |
+
img_tensor = self.img_transform(img)
|
| 237 |
+
else:
|
| 238 |
+
img_tensor = torch.zeros((3, 512, 512))
|
| 239 |
+
else:
|
| 240 |
+
# Read from file system
|
| 241 |
+
img = Image.open(img_path).convert('RGB')
|
| 242 |
+
img_tensor = self.img_transform(img)
|
| 243 |
+
else:
|
| 244 |
+
# Create dummy image if not available
|
| 245 |
+
img_tensor = torch.zeros((3, 512, 512))
|
| 246 |
+
|
| 247 |
+
return wav, 48000, cap, img_tensor, (img_path is not None)
|
| 248 |
+
|
| 249 |
+
def __del__(self):
|
| 250 |
+
"""Close ZIP files when done"""
|
| 251 |
+
for zip_handle in self.zip_handles.values():
|
| 252 |
+
try:
|
| 253 |
+
zip_handle.close()
|
| 254 |
+
except:
|
| 255 |
+
pass
|
| 256 |
+
|
| 257 |
+
def collate_audio(batch):
|
| 258 |
+
wavs, srs, caps, imgs, has_imgs = [], [], [], [], []
|
| 259 |
+
for w, sr, c, img, has_img in batch:
|
| 260 |
+
wavs.append(w)
|
| 261 |
+
srs.append(sr)
|
| 262 |
+
caps.append(c)
|
| 263 |
+
imgs.append(img)
|
| 264 |
+
has_imgs.append(has_img)
|
| 265 |
+
return wavs, srs[0], caps, torch.stack(imgs), torch.tensor(has_imgs)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
# ========================
|
| 269 |
+
# Model Components
|
| 270 |
+
# ========================
|
| 271 |
+
class AudioProjectionMLP(nn.Module):
|
| 272 |
+
"""
|
| 273 |
+
Dual-head MLP projection:
|
| 274 |
+
- to_text: CLAP audio โ CLAP text space (for CLAP alignment)
|
| 275 |
+
- to_sd: CLAP audio โ SD embedding space (for image generation)
|
| 276 |
+
Both heads are trained with multi-task loss.
|
| 277 |
+
"""
|
| 278 |
+
def __init__(self, in_dim, text_dim, sd_dim, hidden=1024):
|
| 279 |
+
super().__init__()
|
| 280 |
+
|
| 281 |
+
# Shared backbone
|
| 282 |
+
self.shared = nn.Sequential(
|
| 283 |
+
nn.Linear(in_dim, hidden),
|
| 284 |
+
nn.GELU(),
|
| 285 |
+
nn.Dropout(0.1),
|
| 286 |
+
nn.Linear(hidden, hidden),
|
| 287 |
+
nn.GELU(),
|
| 288 |
+
nn.Dropout(0.1)
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# Head 1: CLAP text space (for training alignment)
|
| 292 |
+
self.to_text = nn.Sequential(
|
| 293 |
+
nn.Linear(hidden, hidden),
|
| 294 |
+
nn.GELU(),
|
| 295 |
+
nn.Dropout(0.1),
|
| 296 |
+
nn.Linear(hidden, text_dim)
|
| 297 |
+
)
|
| 298 |
+
|
| 299 |
+
# Head 2: SD embedding space (for generation)
|
| 300 |
+
self.to_sd = nn.Sequential(
|
| 301 |
+
nn.Linear(hidden, hidden),
|
| 302 |
+
nn.GELU(),
|
| 303 |
+
nn.Dropout(0.1),
|
| 304 |
+
nn.Linear(hidden, sd_dim)
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
def forward(self, z):
|
| 308 |
+
shared_features = self.shared(z)
|
| 309 |
+
return self.to_text(shared_features), self.to_sd(shared_features)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# ========================
|
| 313 |
+
# Main Model
|
| 314 |
+
# ========================
|
| 315 |
+
class Audio2ImageModel(nn.Module):
|
| 316 |
+
def __init__(self, cfg: Config, load_sd: bool = False):
|
| 317 |
+
super().__init__()
|
| 318 |
+
self.cfg = cfg
|
| 319 |
+
device = cfg.device
|
| 320 |
+
|
| 321 |
+
# -------- Frozen CLAP --------
|
| 322 |
+
print("Loading CLAP model...")
|
| 323 |
+
self.clap = ClapModel.from_pretrained(cfg.CLAP_ID).eval().to(device)
|
| 324 |
+
for p in self.clap.parameters():
|
| 325 |
+
p.requires_grad = False
|
| 326 |
+
self.proc = AutoProcessor.from_pretrained(cfg.CLAP_ID)
|
| 327 |
+
|
| 328 |
+
# -------- CLIP for Evaluation (Frozen) --------
|
| 329 |
+
print("Loading CLIP for evaluation...")
|
| 330 |
+
self.clip_model = CLIPModel.from_pretrained(cfg.CLIP_ID).eval().to(device)
|
| 331 |
+
self.clip_processor = CLIPProcessor.from_pretrained(cfg.CLIP_ID)
|
| 332 |
+
for p in self.clip_model.parameters():
|
| 333 |
+
p.requires_grad = False
|
| 334 |
+
print(" โ CLIP loaded (frozen for evaluation only)")
|
| 335 |
+
|
| 336 |
+
# -------- Stable Diffusion (conditionally trainable) --------
|
| 337 |
+
self.sd_pipe = None
|
| 338 |
+
self.sd_tok = None
|
| 339 |
+
self.sd_text_encoder = None
|
| 340 |
+
self.sd_unet = None
|
| 341 |
+
self.sd_vae = None
|
| 342 |
+
self.sd_hidden = 768
|
| 343 |
+
|
| 344 |
+
# Always load full SD for training or inference
|
| 345 |
+
if True:
|
| 346 |
+
print("Loading Stable Diffusion...")
|
| 347 |
+
# Use float32 for training, float16 for inference only
|
| 348 |
+
dtype = torch.float32 if cfg.finetune_sd else (torch.float16 if device == "cuda" else torch.float32)
|
| 349 |
+
self.sd_pipe = StableDiffusionPipeline.from_pretrained(cfg.SD_ID, torch_dtype=dtype)
|
| 350 |
+
self.sd_pipe.to(device)
|
| 351 |
+
|
| 352 |
+
self.sd_tok = self.sd_pipe.tokenizer
|
| 353 |
+
self.sd_text_encoder = self.sd_pipe.text_encoder
|
| 354 |
+
self.sd_unet = self.sd_pipe.unet
|
| 355 |
+
self.sd_vae = self.sd_pipe.vae
|
| 356 |
+
self.sd_hidden = self.sd_pipe.text_encoder.config.hidden_size
|
| 357 |
+
|
| 358 |
+
# Configure trainability based on config
|
| 359 |
+
if cfg.finetune_sd:
|
| 360 |
+
print("๐ฅ End-to-End Training Mode:")
|
| 361 |
+
|
| 362 |
+
# UNet: TRAINABLE (this learns to generate!)
|
| 363 |
+
for p in self.sd_unet.parameters():
|
| 364 |
+
p.requires_grad = True
|
| 365 |
+
self.sd_unet.train()
|
| 366 |
+
print(" โ UNet: TRAINABLE")
|
| 367 |
+
|
| 368 |
+
# VAE: Usually frozen for stability
|
| 369 |
+
if cfg.freeze_vae:
|
| 370 |
+
for p in self.sd_vae.parameters():
|
| 371 |
+
p.requires_grad = False
|
| 372 |
+
self.sd_vae.eval()
|
| 373 |
+
print(" โ VAE: FROZEN")
|
| 374 |
+
else:
|
| 375 |
+
for p in self.sd_vae.parameters():
|
| 376 |
+
p.requires_grad = True
|
| 377 |
+
self.sd_vae.train()
|
| 378 |
+
print(" โ VAE: TRAINABLE")
|
| 379 |
+
|
| 380 |
+
# Text Encoder: Usually frozen
|
| 381 |
+
if cfg.freeze_text_encoder:
|
| 382 |
+
for p in self.sd_text_encoder.parameters():
|
| 383 |
+
p.requires_grad = False
|
| 384 |
+
self.sd_text_encoder.eval()
|
| 385 |
+
print(" โ Text Encoder: FROZEN")
|
| 386 |
+
else:
|
| 387 |
+
for p in self.sd_text_encoder.parameters():
|
| 388 |
+
p.requires_grad = True
|
| 389 |
+
self.sd_text_encoder.train()
|
| 390 |
+
print(" โ Text Encoder: TRAINABLE")
|
| 391 |
+
else:
|
| 392 |
+
print("Inference Mode: All SD components frozen")
|
| 393 |
+
for comp in (self.sd_unet, self.sd_vae, self.sd_text_encoder):
|
| 394 |
+
for p in comp.parameters():
|
| 395 |
+
p.requires_grad = False
|
| 396 |
+
comp.eval()
|
| 397 |
+
|
| 398 |
+
# -------- Get CLAP dims --------
|
| 399 |
+
dummy_text = ["test"]
|
| 400 |
+
dummy_audio = [torch.zeros(48000).numpy()]
|
| 401 |
+
|
| 402 |
+
with torch.no_grad():
|
| 403 |
+
text_proc = self.proc(text=dummy_text, return_tensors="pt")
|
| 404 |
+
text_proc = {k: v.to(device) for k,v in text_proc.items()}
|
| 405 |
+
t = self.clap.get_text_features(**text_proc)
|
| 406 |
+
clap_text_dim = t.shape[-1]
|
| 407 |
+
|
| 408 |
+
audio_proc = self.proc(audio=dummy_audio, sampling_rate=48000, return_tensors="pt")
|
| 409 |
+
audio_proc = {k: v.to(device) for k,v in audio_proc.items()}
|
| 410 |
+
a = self.clap.get_audio_features(**audio_proc)
|
| 411 |
+
clap_audio_dim = a.shape[-1]
|
| 412 |
+
|
| 413 |
+
# -------- Trainable Dual-Head MLP --------
|
| 414 |
+
print(f"Creating MLP: CLAP audio ({clap_audio_dim}) โ CLAP text ({clap_text_dim}) & SD ({self.sd_hidden})")
|
| 415 |
+
self.mapper = AudioProjectionMLP(clap_audio_dim, clap_text_dim, self.sd_hidden)
|
| 416 |
+
|
| 417 |
+
# --- Encoders ---
|
| 418 |
+
def encode_text_clap(self, caps):
|
| 419 |
+
"""Encode text using CLAP text encoder"""
|
| 420 |
+
proc = self.proc(text=caps, return_tensors="pt", padding=True)
|
| 421 |
+
proc = {k: v.to(self.cfg.device) for k,v in proc.items()}
|
| 422 |
+
|
| 423 |
+
# Ensure CLAP is in eval mode
|
| 424 |
+
was_training = self.clap.training
|
| 425 |
+
self.clap.eval()
|
| 426 |
+
|
| 427 |
+
with torch.no_grad():
|
| 428 |
+
e = self.clap.get_text_features(**proc)
|
| 429 |
+
|
| 430 |
+
# Restore training state if needed
|
| 431 |
+
if was_training:
|
| 432 |
+
self.clap.train()
|
| 433 |
+
|
| 434 |
+
return F.normalize(e, dim=-1)
|
| 435 |
+
|
| 436 |
+
def encode_text_sd(self, caps):
|
| 437 |
+
"""Encode text using SD text encoder (for target embeddings)"""
|
| 438 |
+
tokens = self.sd_tok(
|
| 439 |
+
caps,
|
| 440 |
+
padding="max_length",
|
| 441 |
+
max_length=self.sd_tok.model_max_length,
|
| 442 |
+
truncation=True,
|
| 443 |
+
return_tensors="pt"
|
| 444 |
+
).to(self.cfg.device)
|
| 445 |
+
|
| 446 |
+
with torch.no_grad():
|
| 447 |
+
# Get the pooled output (last hidden state mean)
|
| 448 |
+
outputs = self.sd_text_encoder(tokens["input_ids"])
|
| 449 |
+
# Use pooler_output if available, else mean pool
|
| 450 |
+
if hasattr(outputs, 'pooler_output') and outputs.pooler_output is not None:
|
| 451 |
+
embeddings = outputs.pooler_output
|
| 452 |
+
else:
|
| 453 |
+
embeddings = outputs.last_hidden_state.mean(dim=1)
|
| 454 |
+
|
| 455 |
+
return embeddings
|
| 456 |
+
|
| 457 |
+
def encode_audio(self, wavs, sr):
|
| 458 |
+
"""Returns raw CLAP audio embeddings - batched processing"""
|
| 459 |
+
# Convert all wavs to numpy for batch processing
|
| 460 |
+
audio_list = [w.cpu().numpy() for w in wavs]
|
| 461 |
+
|
| 462 |
+
# Process all audios in a single batch
|
| 463 |
+
proc = self.proc(audio=audio_list, sampling_rate=sr, return_tensors="pt")
|
| 464 |
+
proc = {k: v.to(self.cfg.device) for k, v in proc.items()}
|
| 465 |
+
|
| 466 |
+
# Ensure CLAP is in eval mode to avoid batch norm issues
|
| 467 |
+
was_training = self.clap.training
|
| 468 |
+
self.clap.eval()
|
| 469 |
+
|
| 470 |
+
with torch.no_grad():
|
| 471 |
+
embeddings = self.clap.get_audio_features(**proc)
|
| 472 |
+
|
| 473 |
+
# Restore training state if needed
|
| 474 |
+
if was_training:
|
| 475 |
+
self.clap.train()
|
| 476 |
+
|
| 477 |
+
return embeddings
|
| 478 |
+
|
| 479 |
+
# --- Loss ---
|
| 480 |
+
@staticmethod
|
| 481 |
+
def info_nce(a, b, temp):
|
| 482 |
+
"""InfoNCE contrastive loss"""
|
| 483 |
+
a, b = F.normalize(a, dim=-1), F.normalize(b, dim=-1)
|
| 484 |
+
logits = a @ b.t() / temp
|
| 485 |
+
tgt = torch.arange(a.size(0), device=a.device)
|
| 486 |
+
return 0.5 * (F.cross_entropy(logits, tgt) + F.cross_entropy(logits.t(), tgt))
|
| 487 |
+
|
| 488 |
+
def compute_diffusion_loss(self, images, audio_emb):
|
| 489 |
+
"""
|
| 490 |
+
Diffusion loss: Trains SD UNet to denoise images conditioned on audio.
|
| 491 |
+
This enables end-to-end learning of the generative model!
|
| 492 |
+
|
| 493 |
+
Args:
|
| 494 |
+
images: Ground truth images [B, 3, 512, 512] in range [-1, 1]
|
| 495 |
+
audio_emb: Audio embeddings from CLAP
|
| 496 |
+
|
| 497 |
+
Returns:
|
| 498 |
+
Denoising loss (MSE between predicted and actual noise)
|
| 499 |
+
"""
|
| 500 |
+
# 1. Encode images to latent space (no grad through VAE)
|
| 501 |
+
with torch.no_grad():
|
| 502 |
+
latents = self.sd_vae.encode(images).latent_dist.sample()
|
| 503 |
+
latents = latents * 0.18215 # SD's scaling factor
|
| 504 |
+
|
| 505 |
+
# 2. Sample random timesteps for diffusion training
|
| 506 |
+
noise = torch.randn_like(latents)
|
| 507 |
+
bsz = latents.shape[0]
|
| 508 |
+
timesteps = torch.randint(
|
| 509 |
+
0, 1000, (bsz,),
|
| 510 |
+
device=latents.device
|
| 511 |
+
).long()
|
| 512 |
+
|
| 513 |
+
# 3. Add noise to latents according to timestep
|
| 514 |
+
if not hasattr(self, 'noise_scheduler'):
|
| 515 |
+
self.noise_scheduler = DDPMScheduler.from_pretrained(
|
| 516 |
+
self.cfg.SD_ID,
|
| 517 |
+
subfolder="scheduler"
|
| 518 |
+
)
|
| 519 |
+
|
| 520 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
| 521 |
+
|
| 522 |
+
# 4. Get audio conditioning (gradients flow to mapper!)
|
| 523 |
+
_, audio_to_sd = self.mapper(audio_emb)
|
| 524 |
+
|
| 525 |
+
# Reshape for UNet: [batch, 1, hidden_dim]
|
| 526 |
+
encoder_hidden_states = audio_to_sd.unsqueeze(1)
|
| 527 |
+
|
| 528 |
+
# 5. UNet predicts noise (THIS IS WHERE SD LEARNS! โ
)
|
| 529 |
+
noise_pred = self.sd_unet(
|
| 530 |
+
noisy_latents, # Noisy input
|
| 531 |
+
timesteps, # Time conditioning
|
| 532 |
+
encoder_hidden_states # Audio conditioning
|
| 533 |
+
).sample
|
| 534 |
+
|
| 535 |
+
# 6. Compute denoising loss
|
| 536 |
+
# Gradients flow back to: UNet โ
and Mapper โ
|
| 537 |
+
loss = F.mse_loss(noise_pred, noise, reduction='mean')
|
| 538 |
+
|
| 539 |
+
return loss
|
| 540 |
+
|
| 541 |
+
@torch.inference_mode()
|
| 542 |
+
def evaluate_generation(self, wavs, sr, captions, num_samples=None):
|
| 543 |
+
"""
|
| 544 |
+
Evaluate quality of generated images using CLIP text-image similarity.
|
| 545 |
+
|
| 546 |
+
Args:
|
| 547 |
+
wavs: List of audio waveforms
|
| 548 |
+
sr: Sample rate
|
| 549 |
+
captions: List of text captions describing the audio
|
| 550 |
+
num_samples: Number of samples to evaluate (None = all)
|
| 551 |
+
|
| 552 |
+
Returns:
|
| 553 |
+
avg_clip_score: Average CLIP similarity score (0-100)
|
| 554 |
+
generated_images: List of PIL images
|
| 555 |
+
clip_scores: List of individual CLIP scores
|
| 556 |
+
"""
|
| 557 |
+
was_training = self.training
|
| 558 |
+
self.eval()
|
| 559 |
+
|
| 560 |
+
if num_samples is not None:
|
| 561 |
+
wavs = wavs[:num_samples]
|
| 562 |
+
captions = captions[:num_samples]
|
| 563 |
+
|
| 564 |
+
generated_images = []
|
| 565 |
+
clip_scores = []
|
| 566 |
+
|
| 567 |
+
for wav, caption in zip(wavs, captions):
|
| 568 |
+
# Generate image from audio
|
| 569 |
+
img = self.generate(wav, sr)
|
| 570 |
+
generated_images.append(img)
|
| 571 |
+
|
| 572 |
+
# Compute CLIP score (text-image similarity)
|
| 573 |
+
inputs = self.clip_processor(
|
| 574 |
+
text=[caption],
|
| 575 |
+
images=[img],
|
| 576 |
+
return_tensors="pt",
|
| 577 |
+
padding=True
|
| 578 |
+
).to(self.cfg.device)
|
| 579 |
+
|
| 580 |
+
outputs = self.clip_model(**inputs)
|
| 581 |
+
|
| 582 |
+
# Get similarity score (logits are already scaled by temperature)
|
| 583 |
+
# Higher score = better match between image and caption
|
| 584 |
+
logits_per_image = outputs.logits_per_image
|
| 585 |
+
clip_score = logits_per_image[0, 0].item()
|
| 586 |
+
clip_scores.append(clip_score)
|
| 587 |
+
|
| 588 |
+
avg_clip_score = sum(clip_scores) / len(clip_scores) if clip_scores else 0.0
|
| 589 |
+
|
| 590 |
+
if was_training:
|
| 591 |
+
self.train()
|
| 592 |
+
|
| 593 |
+
return avg_clip_score, generated_images, clip_scores
|
| 594 |
+
|
| 595 |
+
# --- Forward (Training with Multi-Task Loss) ---
|
| 596 |
+
def forward(self, wavs, sr, caps, images=None, has_images=None):
|
| 597 |
+
"""
|
| 598 |
+
Forward pass with three parallel losses:
|
| 599 |
+
1. CLAP alignment (semantic understanding)
|
| 600 |
+
2. SD embedding alignment (embedding compatibility)
|
| 601 |
+
3. Diffusion loss (pixel-level generation) - requires images
|
| 602 |
+
|
| 603 |
+
All losses train simultaneously in end-to-end fashion!
|
| 604 |
+
"""
|
| 605 |
+
# Get target embeddings (frozen encoders)
|
| 606 |
+
clap_text_emb = self.encode_text_clap(caps)
|
| 607 |
+
sd_text_emb = self.encode_text_sd(caps)
|
| 608 |
+
|
| 609 |
+
# Get audio embeddings
|
| 610 |
+
audio_emb = self.encode_audio(wavs, sr)
|
| 611 |
+
|
| 612 |
+
# Project audio to both spaces (gradients flow here!)
|
| 613 |
+
audio_to_clap, audio_to_sd = self.mapper(audio_emb)
|
| 614 |
+
|
| 615 |
+
# Loss 1: CLAP alignment (InfoNCE)
|
| 616 |
+
loss_clap = self.info_nce(audio_to_clap, clap_text_emb, self.cfg.temperature)
|
| 617 |
+
|
| 618 |
+
# Loss 2: SD embedding alignment (MSE)
|
| 619 |
+
loss_sd = F.mse_loss(audio_to_sd, sd_text_emb)
|
| 620 |
+
|
| 621 |
+
# Loss 3: Diffusion loss (pixel-level generation)
|
| 622 |
+
loss_diffusion = torch.tensor(0.0, device=self.cfg.device)
|
| 623 |
+
if self.cfg.finetune_sd and images is not None:
|
| 624 |
+
# Only compute on samples that have images
|
| 625 |
+
if has_images is not None:
|
| 626 |
+
valid_mask = has_images.to(self.cfg.device)
|
| 627 |
+
if valid_mask.sum() > 0:
|
| 628 |
+
valid_imgs = images[valid_mask]
|
| 629 |
+
valid_audio_emb = audio_emb[valid_mask]
|
| 630 |
+
loss_diffusion = self.compute_diffusion_loss(valid_imgs, valid_audio_emb)
|
| 631 |
+
else:
|
| 632 |
+
loss_diffusion = self.compute_diffusion_loss(images, audio_emb)
|
| 633 |
+
|
| 634 |
+
# Combined multi-task loss - all train together! ๐
|
| 635 |
+
total_loss = (
|
| 636 |
+
self.cfg.clap_loss_weight * loss_clap +
|
| 637 |
+
self.cfg.sd_loss_weight * loss_sd +
|
| 638 |
+
self.cfg.diffusion_loss_weight * loss_diffusion
|
| 639 |
+
)
|
| 640 |
+
|
| 641 |
+
# Compute similarities for monitoring
|
| 642 |
+
with torch.no_grad():
|
| 643 |
+
clap_sim = torch.diagonal(
|
| 644 |
+
F.normalize(audio_to_clap, dim=-1) @ F.normalize(clap_text_emb, dim=-1).t()
|
| 645 |
+
).mean()
|
| 646 |
+
|
| 647 |
+
sd_sim = F.cosine_similarity(audio_to_sd, sd_text_emb, dim=-1).mean()
|
| 648 |
+
|
| 649 |
+
return total_loss, {
|
| 650 |
+
"loss_clap": loss_clap.item(),
|
| 651 |
+
"loss_sd": loss_sd.item(),
|
| 652 |
+
"loss_diffusion": loss_diffusion.item(),
|
| 653 |
+
"clap_sim": clap_sim.item(),
|
| 654 |
+
"sd_sim": sd_sim.item()
|
| 655 |
+
}
|
| 656 |
+
|
| 657 |
+
# --- Inference ---
|
| 658 |
+
@torch.inference_mode()
|
| 659 |
+
def generate(self, wav, sr):
|
| 660 |
+
if self.sd_pipe is None:
|
| 661 |
+
raise RuntimeError("Stable Diffusion not loaded. Init with load_sd=True.")
|
| 662 |
+
|
| 663 |
+
# Get audio embedding and project to SD space
|
| 664 |
+
audio_emb = self.encode_audio([wav], sr)
|
| 665 |
+
_, soft_token = self.mapper(audio_emb) # Use to_sd head
|
| 666 |
+
|
| 667 |
+
# Tokenize base prompt
|
| 668 |
+
tok = self.sd_tok(
|
| 669 |
+
self.cfg.base_prompt,
|
| 670 |
+
padding="max_length",
|
| 671 |
+
max_length=self.sd_tok.model_max_length,
|
| 672 |
+
truncation=True,
|
| 673 |
+
return_tensors="pt"
|
| 674 |
+
).to(self.cfg.device)
|
| 675 |
+
|
| 676 |
+
# Get SD text embeddings
|
| 677 |
+
enc = self.sd_text_encoder(tok["input_ids"])[0]
|
| 678 |
+
|
| 679 |
+
# Find position to insert audio token (after last real token)
|
| 680 |
+
attention_mask = tok["attention_mask"][0]
|
| 681 |
+
last_token_pos = attention_mask.nonzero(as_tuple=False).max().item()
|
| 682 |
+
|
| 683 |
+
# Insert audio soft token AFTER the last token
|
| 684 |
+
if last_token_pos + 1 < enc.shape[1]:
|
| 685 |
+
enc[0, last_token_pos + 1:last_token_pos + 2, :] = soft_token
|
| 686 |
+
else:
|
| 687 |
+
# If no space, replace the last token
|
| 688 |
+
enc[0, last_token_pos:last_token_pos + 1, :] = soft_token
|
| 689 |
+
|
| 690 |
+
# Generate image
|
| 691 |
+
img = self.sd_pipe(
|
| 692 |
+
num_inference_steps=self.cfg.steps,
|
| 693 |
+
guidance_scale=self.cfg.guidance, # 7.5
|
| 694 |
+
prompt_embeds=enc
|
| 695 |
+
).images[0]
|
| 696 |
+
|
| 697 |
+
return img
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
# ========================
|
| 701 |
+
# Training
|
| 702 |
+
# ========================
|
| 703 |
+
def train(cfg: Config):
|
| 704 |
+
# Load dataset with images
|
| 705 |
+
full_ds = AudioCaptionDataset(cfg.train_csv, cfg.image_folder, use_zip_files=cfg.use_zip_files)
|
| 706 |
+
|
| 707 |
+
# Create train/validation split (90/10)
|
| 708 |
+
train_size = int(0.9 * len(full_ds))
|
| 709 |
+
val_size = len(full_ds) - train_size
|
| 710 |
+
train_ds, val_ds = torch.utils.data.random_split(
|
| 711 |
+
full_ds,
|
| 712 |
+
[train_size, val_size],
|
| 713 |
+
generator=torch.Generator().manual_seed(42) # For reproducibility
|
| 714 |
+
)
|
| 715 |
+
|
| 716 |
+
print(f"\nDataset split:")
|
| 717 |
+
print(f" Training: {len(train_ds)} samples")
|
| 718 |
+
print(f" Validation: {len(val_ds)} samples\n")
|
| 719 |
+
|
| 720 |
+
# Create dataloaders
|
| 721 |
+
train_loader = DataLoader(
|
| 722 |
+
train_ds,
|
| 723 |
+
batch_size=cfg.batch_size,
|
| 724 |
+
shuffle=True,
|
| 725 |
+
collate_fn=collate_audio,
|
| 726 |
+
num_workers=0,
|
| 727 |
+
drop_last=True
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
val_loader = DataLoader(
|
| 731 |
+
val_ds,
|
| 732 |
+
batch_size=cfg.batch_size,
|
| 733 |
+
shuffle=False,
|
| 734 |
+
collate_fn=collate_audio,
|
| 735 |
+
num_workers=0
|
| 736 |
+
)
|
| 737 |
+
|
| 738 |
+
# Initialize model
|
| 739 |
+
model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device)
|
| 740 |
+
|
| 741 |
+
# Separate optimizers with different learning rates
|
| 742 |
+
if cfg.finetune_sd:
|
| 743 |
+
print("\n๐ฅ Setting up END-TO-END training:")
|
| 744 |
+
|
| 745 |
+
# Optimizer 1: Mapper (higher LR)
|
| 746 |
+
opt_mapper = torch.optim.AdamW(
|
| 747 |
+
model.mapper.parameters(),
|
| 748 |
+
lr=cfg.lr,
|
| 749 |
+
weight_decay=cfg.weight_decay
|
| 750 |
+
)
|
| 751 |
+
print(f" Mapper optimizer: LR={cfg.lr}")
|
| 752 |
+
|
| 753 |
+
# Optimizer 2: SD UNet (lower LR for stability)
|
| 754 |
+
opt_sd = torch.optim.AdamW(
|
| 755 |
+
model.sd_unet.parameters(),
|
| 756 |
+
lr=cfg.sd_lr,
|
| 757 |
+
weight_decay=cfg.weight_decay
|
| 758 |
+
)
|
| 759 |
+
print(f" SD UNet optimizer: LR={cfg.sd_lr}")
|
| 760 |
+
|
| 761 |
+
opts = [opt_mapper, opt_sd]
|
| 762 |
+
else:
|
| 763 |
+
# Only train mapper
|
| 764 |
+
opt_mapper = torch.optim.AdamW(
|
| 765 |
+
model.parameters(),
|
| 766 |
+
lr=cfg.lr,
|
| 767 |
+
weight_decay=cfg.weight_decay
|
| 768 |
+
)
|
| 769 |
+
opts = [opt_mapper]
|
| 770 |
+
|
| 771 |
+
print(f"\n{'='*60}")
|
| 772 |
+
print(f"Starting {'End-to-End' if cfg.finetune_sd else 'Mapper-Only'} Training")
|
| 773 |
+
print(f"{'='*60}")
|
| 774 |
+
print(f"Dataset: {len(full_ds)} samples ({len(train_ds)} train, {len(val_ds)} val)")
|
| 775 |
+
print(f"Batch size: {cfg.batch_size}")
|
| 776 |
+
print(f"Epochs: {cfg.max_epochs}")
|
| 777 |
+
print(f"Evaluation: Every {cfg.eval_every_n_epochs} epoch(s)")
|
| 778 |
+
print(f"Loss weights:")
|
| 779 |
+
print(f" CLAP: {cfg.clap_loss_weight}")
|
| 780 |
+
print(f" SD Embedding: {cfg.sd_loss_weight}")
|
| 781 |
+
if cfg.finetune_sd:
|
| 782 |
+
print(f" Diffusion: {cfg.diffusion_loss_weight}")
|
| 783 |
+
print(f"{'='*60}\n")
|
| 784 |
+
|
| 785 |
+
# Track best model based on CLIP score
|
| 786 |
+
best_clip_score = -float('inf')
|
| 787 |
+
|
| 788 |
+
for ep in range(1, cfg.max_epochs + 1):
|
| 789 |
+
# ============================================
|
| 790 |
+
# TRAINING PHASE
|
| 791 |
+
# ============================================
|
| 792 |
+
model.train()
|
| 793 |
+
pbar = tqdm(train_loader, desc=f"Epoch {ep}/{cfg.max_epochs} [TRAIN]")
|
| 794 |
+
|
| 795 |
+
epoch_stats = {
|
| 796 |
+
"total": 0, "clap": 0, "sd": 0, "diff": 0,
|
| 797 |
+
"clap_sim": 0, "sd_sim": 0
|
| 798 |
+
}
|
| 799 |
+
|
| 800 |
+
for wavs, sr, caps, imgs, has_imgs in pbar:
|
| 801 |
+
wavs = [w.to(cfg.device) for w in wavs]
|
| 802 |
+
imgs = imgs.to(cfg.device)
|
| 803 |
+
|
| 804 |
+
# Forward pass - all losses computed!
|
| 805 |
+
loss, stats = model(wavs, sr, caps, imgs if cfg.finetune_sd else None, has_imgs)
|
| 806 |
+
|
| 807 |
+
# Zero gradients for all optimizers
|
| 808 |
+
for opt in opts:
|
| 809 |
+
opt.zero_grad()
|
| 810 |
+
|
| 811 |
+
# Backward pass - gradients flow to mapper AND UNet!
|
| 812 |
+
loss.backward()
|
| 813 |
+
|
| 814 |
+
# Clip gradients for stability
|
| 815 |
+
if cfg.finetune_sd:
|
| 816 |
+
nn.utils.clip_grad_norm_(model.mapper.parameters(), 1.0)
|
| 817 |
+
nn.utils.clip_grad_norm_(model.sd_unet.parameters(), 1.0)
|
| 818 |
+
else:
|
| 819 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 820 |
+
|
| 821 |
+
# Update all parameters simultaneously! ๐
|
| 822 |
+
for opt in opts:
|
| 823 |
+
opt.step()
|
| 824 |
+
|
| 825 |
+
# Accumulate stats
|
| 826 |
+
epoch_stats["total"] += loss.item()
|
| 827 |
+
epoch_stats["clap"] += stats['loss_clap']
|
| 828 |
+
epoch_stats["sd"] += stats['loss_sd']
|
| 829 |
+
epoch_stats["diff"] += stats['loss_diffusion']
|
| 830 |
+
epoch_stats["clap_sim"] += stats['clap_sim']
|
| 831 |
+
epoch_stats["sd_sim"] += stats['sd_sim']
|
| 832 |
+
|
| 833 |
+
pbar.set_postfix({
|
| 834 |
+
"total loss": f"{loss.item():.3f}",
|
| 835 |
+
"diff": f"{stats['loss_diffusion']:.3f}",
|
| 836 |
+
"c_sim": f"{stats['clap_sim']:.2f}",
|
| 837 |
+
"s_sim": f"{stats['sd_sim']:.2f}"
|
| 838 |
+
})
|
| 839 |
+
|
| 840 |
+
# Compute training epoch averages
|
| 841 |
+
n_train = len(train_loader)
|
| 842 |
+
for k in epoch_stats:
|
| 843 |
+
epoch_stats[k] /= n_train
|
| 844 |
+
|
| 845 |
+
# ============================================
|
| 846 |
+
# VALIDATION & EVALUATION PHASE
|
| 847 |
+
# ============================================
|
| 848 |
+
if ep % cfg.eval_every_n_epochs == 0:
|
| 849 |
+
print(f"\n{'='*60}")
|
| 850 |
+
print(f"๐ Evaluating Epoch {ep}...")
|
| 851 |
+
print(f"{'='*60}")
|
| 852 |
+
|
| 853 |
+
model.eval()
|
| 854 |
+
val_clip_scores = []
|
| 855 |
+
all_gen_images = []
|
| 856 |
+
all_captions = []
|
| 857 |
+
|
| 858 |
+
# Evaluate on validation set (limit to save time)
|
| 859 |
+
eval_batches = min(3, len(val_loader)) # Max 3 batches
|
| 860 |
+
|
| 861 |
+
for batch_idx, (wavs, sr, caps, imgs, has_imgs) in enumerate(val_loader):
|
| 862 |
+
if batch_idx >= eval_batches:
|
| 863 |
+
break
|
| 864 |
+
|
| 865 |
+
wavs = [w.to(cfg.device) for w in wavs]
|
| 866 |
+
|
| 867 |
+
# Generate images and compute CLIP scores
|
| 868 |
+
avg_score, gen_imgs, scores = model.evaluate_generation(
|
| 869 |
+
wavs, sr, caps,
|
| 870 |
+
num_samples=cfg.num_eval_samples
|
| 871 |
+
)
|
| 872 |
+
|
| 873 |
+
val_clip_scores.extend(scores)
|
| 874 |
+
all_gen_images.extend(gen_imgs)
|
| 875 |
+
all_captions.extend(caps[:cfg.num_eval_samples])
|
| 876 |
+
|
| 877 |
+
print(f" Batch {batch_idx + 1}/{eval_batches}: Avg CLIP = {avg_score:.3f}")
|
| 878 |
+
|
| 879 |
+
# Compute overall validation CLIP score
|
| 880 |
+
avg_val_clip = sum(val_clip_scores) / len(val_clip_scores) if val_clip_scores else 0.0
|
| 881 |
+
|
| 882 |
+
# Save example images from evaluation
|
| 883 |
+
if cfg.save_eval_images and all_gen_images:
|
| 884 |
+
os.makedirs("eval_samples", exist_ok=True)
|
| 885 |
+
for i, (img, cap, score) in enumerate(zip(all_gen_images[:4], all_captions[:4], val_clip_scores[:4])):
|
| 886 |
+
save_path = f"eval_samples/ep{ep}_sample{i}_score{score:.2f}.png"
|
| 887 |
+
img.save(save_path)
|
| 888 |
+
print(f" Sample {i}: '{cap[:50]}...' | CLIP: {score:.3f}")
|
| 889 |
+
print(f" Saved to: {save_path}")
|
| 890 |
+
|
| 891 |
+
# Clear MPS cache after evaluation
|
| 892 |
+
if cfg.device == "mps":
|
| 893 |
+
torch.mps.empty_cache()
|
| 894 |
+
|
| 895 |
+
print(f"\n{'='*60}")
|
| 896 |
+
print(f"๐ Epoch {ep} Summary:")
|
| 897 |
+
print(f"{'='*60}")
|
| 898 |
+
print(f"Training Metrics:")
|
| 899 |
+
print(f" Total Loss: {epoch_stats['total']:.4f}")
|
| 900 |
+
print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}")
|
| 901 |
+
print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}")
|
| 902 |
+
if cfg.finetune_sd:
|
| 903 |
+
print(f" Diffusion Loss: {epoch_stats['diff']:.4f}")
|
| 904 |
+
print(f"\nValidation Metrics:")
|
| 905 |
+
print(f" ๐ฏ CLIP Score: {avg_val_clip:.3f} (higher = better image-text match)")
|
| 906 |
+
print(f"{'='*60}\n")
|
| 907 |
+
|
| 908 |
+
else:
|
| 909 |
+
# Just print training stats if not evaluating
|
| 910 |
+
avg_val_clip = None
|
| 911 |
+
print(f"\n{'='*60}")
|
| 912 |
+
print(f"Epoch {ep} Summary:")
|
| 913 |
+
print(f" Total Loss: {epoch_stats['total']:.4f}")
|
| 914 |
+
print(f" CLAP Loss: {epoch_stats['clap']:.4f} | Sim: {epoch_stats['clap_sim']:.3f}")
|
| 915 |
+
print(f" SD Loss: {epoch_stats['sd']:.4f} | Sim: {epoch_stats['sd_sim']:.3f}")
|
| 916 |
+
if cfg.finetune_sd:
|
| 917 |
+
print(f" Diffusion Loss: {epoch_stats['diff']:.4f}")
|
| 918 |
+
print(f"{'='*60}\n")
|
| 919 |
+
|
| 920 |
+
# ============================================
|
| 921 |
+
# CHECKPOINT SAVING
|
| 922 |
+
# ============================================
|
| 923 |
+
checkpoint = {
|
| 924 |
+
"mapper": model.mapper.state_dict(),
|
| 925 |
+
"epoch": ep,
|
| 926 |
+
"val_clip_score": avg_val_clip if avg_val_clip is not None else -1,
|
| 927 |
+
**{k: v for k, v in epoch_stats.items()},
|
| 928 |
+
"config": {
|
| 929 |
+
"clap_loss_weight": cfg.clap_loss_weight,
|
| 930 |
+
"sd_loss_weight": cfg.sd_loss_weight,
|
| 931 |
+
"diffusion_loss_weight": cfg.diffusion_loss_weight,
|
| 932 |
+
"finetune_sd": cfg.finetune_sd
|
| 933 |
+
}
|
| 934 |
+
}
|
| 935 |
+
|
| 936 |
+
if cfg.finetune_sd:
|
| 937 |
+
checkpoint["unet"] = model.sd_unet.state_dict()
|
| 938 |
+
|
| 939 |
+
# Always save latest checkpoint
|
| 940 |
+
torch.save(checkpoint, cfg.ckpt_path)
|
| 941 |
+
print(f"๐พ Checkpoint saved: {cfg.ckpt_path}")
|
| 942 |
+
|
| 943 |
+
# Save best model based on CLIP score
|
| 944 |
+
if avg_val_clip is not None and avg_val_clip > best_clip_score:
|
| 945 |
+
best_clip_score = avg_val_clip
|
| 946 |
+
best_path = cfg.ckpt_path.replace('.pt', '_best.pt')
|
| 947 |
+
torch.save(checkpoint, best_path)
|
| 948 |
+
print(f"โ
New best model! CLIP: {avg_val_clip:.3f} -> Saved to {best_path}")
|
| 949 |
+
elif avg_val_clip is not None:
|
| 950 |
+
print(f" Current best CLIP: {best_clip_score:.3f}")
|
| 951 |
+
|
| 952 |
+
print()
|
| 953 |
+
|
| 954 |
+
print("๐ Training completed!")
|
| 955 |
+
if best_clip_score > -float('inf'):
|
| 956 |
+
print(f" Best CLIP score achieved: {best_clip_score:.3f}")
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
# ========================
|
| 960 |
+
# Inference
|
| 961 |
+
# ========================
|
| 962 |
+
def infer(cfg: Config, wav_path: str, out_path: str):
|
| 963 |
+
# Load audio
|
| 964 |
+
print(f"Loading audio from {wav_path}...")
|
| 965 |
+
wav, sr = torchaudio.load(wav_path)
|
| 966 |
+
if wav.size(0) > 1:
|
| 967 |
+
wav = wav.mean(0, keepdim=True)
|
| 968 |
+
wav = wav.squeeze(0).float()
|
| 969 |
+
|
| 970 |
+
# Resample to 48kHz for CLAP
|
| 971 |
+
if sr != 48000:
|
| 972 |
+
print(f"Resampling from {sr}Hz to 48000Hz...")
|
| 973 |
+
resampler = torchaudio.transforms.Resample(sr, 48000)
|
| 974 |
+
wav = resampler(wav)
|
| 975 |
+
sr = 48000
|
| 976 |
+
|
| 977 |
+
wav = wav.to(cfg.device)
|
| 978 |
+
|
| 979 |
+
# Load model with SD
|
| 980 |
+
model = Audio2ImageModel(cfg, load_sd=True).to(cfg.device)
|
| 981 |
+
|
| 982 |
+
# Load trained weights
|
| 983 |
+
print(f"Loading checkpoint from {cfg.ckpt_path}...")
|
| 984 |
+
ckpt = torch.load(cfg.ckpt_path, map_location=cfg.device)
|
| 985 |
+
model.mapper.load_state_dict(ckpt["mapper"])
|
| 986 |
+
|
| 987 |
+
# Load UNet weights if available (from fine-tuning)
|
| 988 |
+
if "unet" in ckpt:
|
| 989 |
+
print("Loading fine-tuned UNet weights...")
|
| 990 |
+
model.sd_unet.load_state_dict(ckpt["unet"])
|
| 991 |
+
|
| 992 |
+
print(f"Checkpoint info:")
|
| 993 |
+
print(f" Epoch: {ckpt.get('epoch', 'unknown')}")
|
| 994 |
+
print(f" CLAP Sim: {ckpt.get('clap_sim', 'N/A'):.3f}" if isinstance(ckpt.get('clap_sim'), (int, float)) else f" CLAP Sim: N/A")
|
| 995 |
+
print(f" SD Sim: {ckpt.get('sd_sim', 'N/A'):.3f}" if isinstance(ckpt.get('sd_sim'), (int, float)) else f" SD Sim: N/A")
|
| 996 |
+
if "unet" in ckpt:
|
| 997 |
+
print(" Fine-tuned UNet: โ")
|
| 998 |
+
|
| 999 |
+
# Generate image
|
| 1000 |
+
print("\nGenerating image...")
|
| 1001 |
+
img = model.generate(wav, sr)
|
| 1002 |
+
img.save(out_path)
|
| 1003 |
+
print(f"โ Generated image saved to {out_path}")
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
# ========================
|
| 1007 |
+
# Main
|
| 1008 |
+
# ========================
|
| 1009 |
+
if __name__ == "__main__":
|
| 1010 |
+
import argparse
|
| 1011 |
+
parser = argparse.ArgumentParser()
|
| 1012 |
+
parser.add_argument("--mode", choices=["train", "infer"], default="train")
|
| 1013 |
+
parser.add_argument("--wav", help="Audio file path for inference mode")
|
| 1014 |
+
parser.add_argument("--out", default="output.png", help="Output image path")
|
| 1015 |
+
args = parser.parse_args()
|
| 1016 |
+
|
| 1017 |
+
cfg = Config()
|
| 1018 |
+
print(f"Device: {cfg.device}")
|
| 1019 |
+
|
| 1020 |
+
if args.mode == "train":
|
| 1021 |
+
print(f"Dataset: {cfg.train_csv}")
|
| 1022 |
+
if not os.path.exists(cfg.train_csv):
|
| 1023 |
+
print(f"ERROR: Dataset not found at {cfg.train_csv}")
|
| 1024 |
+
print("Please ensure the captions.txt file exists")
|
| 1025 |
+
sys.exit(1)
|
| 1026 |
+
train(cfg)
|
| 1027 |
+
else:
|
| 1028 |
+
if not args.wav:
|
| 1029 |
+
raise ValueError("Need --wav for inference mode")
|
| 1030 |
+
if not os.path.exists(args.wav):
|
| 1031 |
+
raise ValueError(f"Audio file not found: {args.wav}")
|
| 1032 |
+
infer(cfg, args.wav, args.out)
|