| """ |
| Training Loop for ProbVLM-Style Probabilistic Adapters. |
| |
| Trains lightweight post-hoc adapters on top of frozen CLIP/CLAP encoders. |
| Each adapter learns to predict uncertainty (Generalized Gaussian parameters) |
| for a single embedding space. |
| |
| Two adapters to train: |
| 1. CLIP adapter: trained on (image_embedding, text_embedding) pairs |
| 2. CLAP adapter: trained on (audio_embedding, text_embedding) pairs |
| |
| Training data: |
| - Our 57 images paired with text descriptions (CLIP pairs) |
| - Our 104 audio files paired with text descriptions (CLAP pairs) |
| - All 30 RQ1 prompts × matched media as additional pairs |
| |
| Loss: |
| L = L1(mu, target) + GenGaussLoss(mu, alpha, beta, target) |
| |
| GenGaussLoss: |
| -log p(target | mu, alpha, beta) ∝ log(alpha) - log(beta) + (|target - mu| / alpha)^beta |
| """ |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
|
|
| logger = logging.getLogger(__name__) |
|
|
| try: |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset, random_split |
| TORCH_AVAILABLE = True |
| except ImportError: |
| TORCH_AVAILABLE = False |
|
|
| from src.embeddings.probabilistic_adapter import ProbabilisticAdapter |
|
|
|
|
| class EmbeddingPairDataset(Dataset): |
| """Dataset of (input_embedding, target_embedding) pairs.""" |
|
|
| def __init__(self, inputs: np.ndarray, targets: np.ndarray): |
| if not TORCH_AVAILABLE: |
| raise ImportError("PyTorch required") |
| assert len(inputs) == len(targets) |
| self.inputs = torch.tensor(inputs, dtype=torch.float32) |
| self.targets = torch.tensor(targets, dtype=torch.float32) |
|
|
| def __len__(self) -> int: |
| return len(self.inputs) |
|
|
| def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| return self.inputs[idx], self.targets[idx] |
|
|
|
|
| class GenGaussNLL(nn.Module): |
| """ |
| Negative log-likelihood loss for Generalized Gaussian distribution. |
| |
| -log p(x | mu, alpha, beta) = log(2*alpha) + log(Gamma(1/beta)/beta) + (|x - mu| / alpha)^beta |
| |
| Simplified (dropping constant terms): |
| L = log(alpha) + (|target - mu| / alpha)^beta |
| """ |
|
|
| def forward( |
| self, |
| mu: torch.Tensor, |
| alpha: torch.Tensor, |
| beta: torch.Tensor, |
| target: torch.Tensor, |
| ) -> torch.Tensor: |
| residual = torch.abs(target - mu) |
| |
| alpha_c = torch.clamp(alpha, min=1e-6) |
| nll = torch.log(alpha_c) + (residual / alpha_c).pow(beta) |
| return nll.mean() |
|
|
|
|
| def train_prob_adapter( |
| input_embeddings: np.ndarray, |
| target_embeddings: np.ndarray, |
| epochs: int = 100, |
| lr: float = 1e-4, |
| batch_size: int = 32, |
| val_split: float = 0.15, |
| patience: int = 15, |
| output_path: Optional[str] = None, |
| adapter_name: str = "adapter", |
| ) -> ProbabilisticAdapter: |
| """ |
| Train a ProbabilisticAdapter on paired embeddings. |
| |
| Args: |
| input_embeddings: Source embeddings [N, 512] (e.g. image CLIP or audio CLAP). |
| target_embeddings: Target embeddings [N, 512] (e.g. text CLIP or text CLAP). |
| epochs: Maximum training epochs. |
| lr: Learning rate. |
| batch_size: Batch size. |
| val_split: Fraction for validation. |
| patience: Early stopping patience. |
| output_path: If set, save best model here. |
| adapter_name: Name for logging. |
| |
| Returns: |
| Trained ProbabilisticAdapter. |
| """ |
| if not TORCH_AVAILABLE: |
| raise ImportError("PyTorch required for training") |
|
|
| device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
|
| |
| dataset = EmbeddingPairDataset(input_embeddings, target_embeddings) |
| n_val = max(1, int(len(dataset) * val_split)) |
| n_train = len(dataset) - n_val |
| train_ds, val_ds = random_split( |
| dataset, [n_train, n_val], |
| generator=torch.Generator().manual_seed(42), |
| ) |
|
|
| train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=len(train_ds) > batch_size) |
| val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False) |
|
|
| |
| input_dim = input_embeddings.shape[1] |
| adapter = ProbabilisticAdapter(input_dim=input_dim).to(device) |
| optimizer = torch.optim.AdamW(adapter.parameters(), lr=lr, weight_decay=1e-4) |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) |
|
|
| l1_loss = nn.L1Loss() |
| gg_loss = GenGaussNLL() |
|
|
| best_val_loss = float("inf") |
| patience_counter = 0 |
|
|
| logger.info( |
| "Training %s adapter: %d train, %d val, %d epochs, device=%s", |
| adapter_name, n_train, n_val, epochs, device, |
| ) |
|
|
| for epoch in range(epochs): |
| |
| adapter.train() |
| train_losses = [] |
| for inp, tgt in train_loader: |
| inp, tgt = inp.to(device), tgt.to(device) |
| optimizer.zero_grad() |
|
|
| mu, alpha, beta = adapter(inp) |
| loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(adapter.parameters(), max_norm=1.0) |
| optimizer.step() |
| train_losses.append(loss.item()) |
|
|
| scheduler.step() |
|
|
| |
| adapter.eval() |
| val_losses = [] |
| with torch.no_grad(): |
| for inp, tgt in val_loader: |
| inp, tgt = inp.to(device), tgt.to(device) |
| mu, alpha, beta = adapter(inp) |
| loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt) |
| val_losses.append(loss.item()) |
|
|
| avg_train = np.mean(train_losses) |
| avg_val = np.mean(val_losses) if val_losses else float("inf") |
|
|
| if (epoch + 1) % 10 == 0 or epoch == 0: |
| logger.info( |
| " [%s] Epoch %d/%d: train=%.4f, val=%.4f", |
| adapter_name, epoch + 1, epochs, avg_train, avg_val, |
| ) |
|
|
| |
| if avg_val < best_val_loss: |
| best_val_loss = avg_val |
| patience_counter = 0 |
| if output_path: |
| adapter.save(output_path) |
| else: |
| patience_counter += 1 |
| if patience_counter >= patience: |
| logger.info(" [%s] Early stopping at epoch %d", adapter_name, epoch + 1) |
| break |
|
|
| |
| if output_path and Path(output_path).exists(): |
| adapter = ProbabilisticAdapter.load(output_path) |
| adapter = adapter.to(device) |
| else: |
| adapter = adapter.cpu() |
|
|
| adapter.eval() |
| logger.info(" [%s] Training complete. Best val_loss=%.4f", adapter_name, best_val_loss) |
| return adapter |
|
|
|
|
| def build_training_pairs_from_index( |
| embedding_index_path: str, |
| text_embedder_fn, |
| modality: str = "image", |
| ) -> Tuple[np.ndarray, np.ndarray]: |
| """ |
| Build (media_embedding, text_embedding) pairs from an embedding index. |
| |
| For each media file in the index, generates a text description from |
| the filename/metadata and embeds it. |
| |
| Args: |
| embedding_index_path: Path to image_index.npz or audio_index.npz. |
| text_embedder_fn: Function that takes text -> np.ndarray embedding. |
| modality: "image" for CLIP text, "audio" for CLAP text. |
| |
| Returns: |
| (media_embeddings, text_embeddings) both shape [N, 512]. |
| """ |
| data = np.load(embedding_index_path, allow_pickle=True) |
| ids = data["ids"] if "ids" in data else data.get("paths", np.array([])) |
| embs = data["embs"] if "embs" in data else data.get("embeddings", np.array([])) |
| domains = data["domains"] if "domains" in data else np.array(["other"] * len(ids)) |
|
|
| media_embs = [] |
| text_embs = [] |
|
|
| for i, (file_id, domain) in enumerate(zip(ids, domains)): |
| |
| name = Path(str(file_id)).stem |
| |
| caption = name.replace("_", " ").replace("-", " ") |
| |
| for prefix in ["fs ", "wm ", "proc "]: |
| if caption.lower().startswith(prefix): |
| caption = caption[len(prefix):] |
| |
| if domain != "other": |
| caption = f"{domain}: {caption}" |
|
|
| try: |
| text_emb = text_embedder_fn(caption) |
| media_embs.append(embs[i]) |
| text_embs.append(text_emb) |
| except Exception as e: |
| logger.warning("Skipping %s: %s", file_id, e) |
|
|
| return np.array(media_embs), np.array(text_embs) |
|
|