MultiModal-Coherence-AI / src /embeddings /prob_adapter_trainer.py
pratik-250620's picture
Upload folder using huggingface_hub
358d3bc verified
"""
Training Loop for ProbVLM-Style Probabilistic Adapters.
Trains lightweight post-hoc adapters on top of frozen CLIP/CLAP encoders.
Each adapter learns to predict uncertainty (Generalized Gaussian parameters)
for a single embedding space.
Two adapters to train:
1. CLIP adapter: trained on (image_embedding, text_embedding) pairs
2. CLAP adapter: trained on (audio_embedding, text_embedding) pairs
Training data:
- Our 57 images paired with text descriptions (CLIP pairs)
- Our 104 audio files paired with text descriptions (CLAP pairs)
- All 30 RQ1 prompts × matched media as additional pairs
Loss:
L = L1(mu, target) + GenGaussLoss(mu, alpha, beta, target)
GenGaussLoss:
-log p(target | mu, alpha, beta) ∝ log(alpha) - log(beta) + (|target - mu| / alpha)^beta
"""
from __future__ import annotations
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, random_split
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
from src.embeddings.probabilistic_adapter import ProbabilisticAdapter
class EmbeddingPairDataset(Dataset):
"""Dataset of (input_embedding, target_embedding) pairs."""
def __init__(self, inputs: np.ndarray, targets: np.ndarray):
if not TORCH_AVAILABLE:
raise ImportError("PyTorch required")
assert len(inputs) == len(targets)
self.inputs = torch.tensor(inputs, dtype=torch.float32)
self.targets = torch.tensor(targets, dtype=torch.float32)
def __len__(self) -> int:
return len(self.inputs)
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
return self.inputs[idx], self.targets[idx]
class GenGaussNLL(nn.Module):
"""
Negative log-likelihood loss for Generalized Gaussian distribution.
-log p(x | mu, alpha, beta) = log(2*alpha) + log(Gamma(1/beta)/beta) + (|x - mu| / alpha)^beta
Simplified (dropping constant terms):
L = log(alpha) + (|target - mu| / alpha)^beta
"""
def forward(
self,
mu: torch.Tensor,
alpha: torch.Tensor,
beta: torch.Tensor,
target: torch.Tensor,
) -> torch.Tensor:
residual = torch.abs(target - mu)
# Clamp alpha to avoid division by zero
alpha_c = torch.clamp(alpha, min=1e-6)
nll = torch.log(alpha_c) + (residual / alpha_c).pow(beta)
return nll.mean()
def train_prob_adapter(
input_embeddings: np.ndarray,
target_embeddings: np.ndarray,
epochs: int = 100,
lr: float = 1e-4,
batch_size: int = 32,
val_split: float = 0.15,
patience: int = 15,
output_path: Optional[str] = None,
adapter_name: str = "adapter",
) -> ProbabilisticAdapter:
"""
Train a ProbabilisticAdapter on paired embeddings.
Args:
input_embeddings: Source embeddings [N, 512] (e.g. image CLIP or audio CLAP).
target_embeddings: Target embeddings [N, 512] (e.g. text CLIP or text CLAP).
epochs: Maximum training epochs.
lr: Learning rate.
batch_size: Batch size.
val_split: Fraction for validation.
patience: Early stopping patience.
output_path: If set, save best model here.
adapter_name: Name for logging.
Returns:
Trained ProbabilisticAdapter.
"""
if not TORCH_AVAILABLE:
raise ImportError("PyTorch required for training")
device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
# Build dataset
dataset = EmbeddingPairDataset(input_embeddings, target_embeddings)
n_val = max(1, int(len(dataset) * val_split))
n_train = len(dataset) - n_val
train_ds, val_ds = random_split(
dataset, [n_train, n_val],
generator=torch.Generator().manual_seed(42),
)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, drop_last=len(train_ds) > batch_size)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
# Build model
input_dim = input_embeddings.shape[1]
adapter = ProbabilisticAdapter(input_dim=input_dim).to(device)
optimizer = torch.optim.AdamW(adapter.parameters(), lr=lr, weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
l1_loss = nn.L1Loss()
gg_loss = GenGaussNLL()
best_val_loss = float("inf")
patience_counter = 0
logger.info(
"Training %s adapter: %d train, %d val, %d epochs, device=%s",
adapter_name, n_train, n_val, epochs, device,
)
for epoch in range(epochs):
# Train
adapter.train()
train_losses = []
for inp, tgt in train_loader:
inp, tgt = inp.to(device), tgt.to(device)
optimizer.zero_grad()
mu, alpha, beta = adapter(inp)
loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
loss.backward()
torch.nn.utils.clip_grad_norm_(adapter.parameters(), max_norm=1.0)
optimizer.step()
train_losses.append(loss.item())
scheduler.step()
# Validate
adapter.eval()
val_losses = []
with torch.no_grad():
for inp, tgt in val_loader:
inp, tgt = inp.to(device), tgt.to(device)
mu, alpha, beta = adapter(inp)
loss = l1_loss(mu, tgt) + gg_loss(mu, alpha, beta, tgt)
val_losses.append(loss.item())
avg_train = np.mean(train_losses)
avg_val = np.mean(val_losses) if val_losses else float("inf")
if (epoch + 1) % 10 == 0 or epoch == 0:
logger.info(
" [%s] Epoch %d/%d: train=%.4f, val=%.4f",
adapter_name, epoch + 1, epochs, avg_train, avg_val,
)
# Early stopping
if avg_val < best_val_loss:
best_val_loss = avg_val
patience_counter = 0
if output_path:
adapter.save(output_path)
else:
patience_counter += 1
if patience_counter >= patience:
logger.info(" [%s] Early stopping at epoch %d", adapter_name, epoch + 1)
break
# Load best if saved
if output_path and Path(output_path).exists():
adapter = ProbabilisticAdapter.load(output_path)
adapter = adapter.to(device)
else:
adapter = adapter.cpu()
adapter.eval()
logger.info(" [%s] Training complete. Best val_loss=%.4f", adapter_name, best_val_loss)
return adapter
def build_training_pairs_from_index(
embedding_index_path: str,
text_embedder_fn,
modality: str = "image",
) -> Tuple[np.ndarray, np.ndarray]:
"""
Build (media_embedding, text_embedding) pairs from an embedding index.
For each media file in the index, generates a text description from
the filename/metadata and embeds it.
Args:
embedding_index_path: Path to image_index.npz or audio_index.npz.
text_embedder_fn: Function that takes text -> np.ndarray embedding.
modality: "image" for CLIP text, "audio" for CLAP text.
Returns:
(media_embeddings, text_embeddings) both shape [N, 512].
"""
data = np.load(embedding_index_path, allow_pickle=True)
ids = data["ids"] if "ids" in data else data.get("paths", np.array([]))
embs = data["embs"] if "embs" in data else data.get("embeddings", np.array([]))
domains = data["domains"] if "domains" in data else np.array(["other"] * len(ids))
media_embs = []
text_embs = []
for i, (file_id, domain) in enumerate(zip(ids, domains)):
# Generate caption from filename
name = Path(str(file_id)).stem
# Clean up filename to make a caption
caption = name.replace("_", " ").replace("-", " ")
# Remove common prefixes
for prefix in ["fs ", "wm ", "proc "]:
if caption.lower().startswith(prefix):
caption = caption[len(prefix):]
# Add domain context
if domain != "other":
caption = f"{domain}: {caption}"
try:
text_emb = text_embedder_fn(caption)
media_embs.append(embs[i])
text_embs.append(text_emb)
except Exception as e:
logger.warning("Skipping %s: %s", file_id, e)
return np.array(media_embs), np.array(text_embs)