File size: 8,580 Bytes
358d3bc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 | """
Training Loop for ProbVLM-Style Probabilistic Adapters.
Trains lightweight post-hoc adapters on top of frozen CLIP/CLAP encoders.
Each adapter learns to predict uncertainty (Generalized Gaussian parameters)
for a single embedding space.
Two adapters to train:
1. CLIP adapter: trained on (image_embedding, text_embedding) pairs
2. CLAP adapter: trained on (audio_embedding, text_embedding) pairs
Training data:
- Our 57 images paired with text descriptions (CLIP pairs)
- Our 104 audio files paired with text descriptions (CLAP pairs)
- All 30 RQ1 prompts × matched media as additional pairs
Loss:
L = L1(mu, target) + GenGaussLoss(mu, alpha, beta, target)
GenGaussLoss:
-log p(target | mu, alpha, beta) ∝ log(alpha) - log(beta) + (|target - mu| / alpha)^beta
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
class EmbeddingPairDataset(Dataset):
"""Dataset of (input_embedding, target_embedding) pairs."""
def __init__(self, inputs: np.ndarray, targets: np.ndarray):
if not TORCH_AVAILABLE:
raise ImportError("PyTorch required")
assert len(inputs) == len(targets)
self.inputs = torch.tensor(inputs, dtype=torch.float32)
self.targets = torch.tensor(targets, dtype=torch.float32)
def __len__(self) -> int:
return len(self.inputs)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.inputs[idx], self.targets[idx]
class GenGaussNLL(nn.Module):
"""
Negative log-likelihood loss for Generalized Gaussian distribution.
-log p(x | mu, alpha, beta) = log(2*alpha) + log(Gamma(1/beta)/beta) + (|x - mu| / alpha)^beta
Simplified (dropping constant terms):
L = log(alpha) + (|target - mu| / alpha)^beta
"""
def forward(
self,
mu: torch.Tensor,
alpha: torch.Tensor,
beta: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
residual = torch.abs(target - mu)
# Clamp alpha to avoid division by zero
alpha_c = torch.clamp(alpha, min=1e-6)
nll = torch.log(alpha_c) + (residual / alpha_c).pow(beta)
return nll.mean()
def train_prob_adapter(
input_embeddings: np.ndarray,
target_embeddings: np.ndarray,
epochs: int = 100,
lr: float = 1e-4,
batch_size: int = 32,
val_split: float = 0.15,
patience: int = 15,
output_path: Optional[str] = None,
adapter_name: str = "adapter",
) -> ProbabilisticAdapter:
"""
Train a ProbabilisticAdapter on paired embeddings.
Args:
input_embeddings: Source embeddings [N, 512] (e.g. image CLIP or audio CLAP).
target_embeddings: Target embeddings [N, 512] (e.g. text CLIP or text CLAP).
epochs: Maximum training epochs.
lr: Learning rate.
batch_size: Batch size.
val_split: Fraction for validation.
patience: Early stopping patience.
output_path: If set, save best model here.
adapter_name: Name for logging.
Returns:
Trained ProbabilisticAdapter.
"""
if not TORCH_AVAILABLE:
raise ImportError("PyTorch required for training")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Build dataset
dataset = EmbeddingPairDataset(input_embeddings, target_embeddings)
n_val = max(1, int(len(dataset) * val_split))
n_train = len(dataset) - n_val
train_ds, val_ds = random_split(
dataset, [n_train, n_val],
generator=torch.Generator().manual_seed(42),
)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=len(train_ds) > batch_size)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
# Build model
input_dim = input_embeddings.shape[1]
adapter = ProbabilisticAdapter(input_dim=input_dim).to(device)
optimizer = torch.optim.AdamW(adapter.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
l1_loss = nn.L1Loss()
gg_loss = GenGaussNLL()
best_val_loss = float("inf")
patience_counter = 0
logger.info(
"Training %s adapter: %d train, %d val, %d epochs, device=%s",
adapter_name, n_train, n_val, epochs, device,
)
for epoch in range(epochs):
# Train
adapter.train()
train_losses = []
for inp, tgt in train_loader:
inp, tgt = inp.to(device), tgt.to(device)
optimizer.zero_grad()
mu, alpha, beta = adapter(inp)
loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
loss.backward()
torch.nn.utils.clip_grad_norm_(adapter.parameters(), max_norm=1.0)
optimizer.step()
train_losses.append(loss.item())
scheduler.step()
# Validate
adapter.eval()
val_losses = []
with torch.no_grad():
for inp, tgt in val_loader:
inp, tgt = inp.to(device), tgt.to(device)
mu, alpha, beta = adapter(inp)
loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
val_losses.append(loss.item())
avg_train = np.mean(train_losses)
avg_val = np.mean(val_losses) if val_losses else float("inf")
if (epoch + 1) % 10 == 0 or epoch == 0:
logger.info(
" [%s] Epoch %d/%d: train=%.4f, val=%.4f",
adapter_name, epoch + 1, epochs, avg_train, avg_val,
)
# Early stopping
if avg_val < best_val_loss:
best_val_loss = avg_val
patience_counter = 0
if output_path:
adapter.save(output_path)
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(" [%s] Early stopping at epoch %d", adapter_name, epoch + 1)
break
# Load best if saved
if output_path and Path(output_path).exists():
adapter = ProbabilisticAdapter.load(output_path)
adapter = adapter.to(device)
else:
adapter = adapter.cpu()
adapter.eval()
logger.info(" [%s] Training complete. Best val_loss=%.4f", adapter_name, best_val_loss)
return adapter
def build_training_pairs_from_index(
embedding_index_path: str,
text_embedder_fn,
modality: str = "image",
) -> Tuple[np.ndarray, np.ndarray]:
"""
Build (media_embedding, text_embedding) pairs from an embedding index.
For each media file in the index, generates a text description from
the filename/metadata and embeds it.
Args:
embedding_index_path: Path to image_index.npz or audio_index.npz.
text_embedder_fn: Function that takes text -> np.ndarray embedding.
modality: "image" for CLIP text, "audio" for CLAP text.
Returns:
(media_embeddings, text_embeddings) both shape [N, 512].
"""
data = np.load(embedding_index_path, allow_pickle=True)
ids = data["ids"] if "ids" in data else data.get("paths", np.array([]))
embs = data["embs"] if "embs" in data else data.get("embeddings", np.array([]))
domains = data["domains"] if "domains" in data else np.array(["other"] * len(ids))
media_embs = []
text_embs = []
for i, (file_id, domain) in enumerate(zip(ids, domains)):
# Generate caption from filename
name = Path(str(file_id)).stem
# Clean up filename to make a caption
caption = name.replace("_", " ").replace("-", " ")
# Remove common prefixes
for prefix in ["fs ", "wm ", "proc "]:
if caption.lower().startswith(prefix):
caption = caption[len(prefix):]
# Add domain context
if domain != "other":
caption = f"{domain}: {caption}"
try:
text_emb = text_embedder_fn(caption)
media_embs.append(embs[i])
text_embs.append(text_emb)
except Exception as e:
logger.warning("Skipping %s: %s", file_id, e)
return np.array(media_embs), np.array(text_embs)
|