MultiModal-Coherence-AI / src /embeddings /cross_space_bridge.py
pratik-250620's picture
Upload folder using huggingface_hub
6835659 verified
"""
Cross-Space Bridge: CLIP Image ↔ CLAP Audio Alignment
Problem:
CLIP image embeddings and CLAP audio embeddings live in DIFFERENT 512-d spaces.
Cosine similarity between them is meaningless. This is why si_a = None in
the coherence engine.
Solution:
Train two lightweight projection heads that map:
CLIP image (512-d) β†’ shared bridge space (256-d)
CLAP audio (512-d) β†’ shared bridge space (256-d)
After training, cosine similarity in the bridge space gives a meaningful
image-audio coherence score (si_a).
Architecture:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ CLIP Image │──▢ image_proj ──▢│ β”‚
β”‚ Embedding β”‚ (512β†’256) β”‚ Shared │──▢ cosine_sim(i, a)
β”‚ (512-d) β”‚ β”‚ Bridge β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ Space β”‚
β”‚ (256-d) β”‚
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ β”‚
β”‚ CLAP Audio │──▢ audio_proj ──▢│ β”‚
β”‚ Embedding β”‚ (512β†’256) β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”‚ (512-d) β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
Training:
Uses paired (image, audio) data where both depict the same scene.
InfoNCE contrastive loss pulls matched pairs together, pushes
mismatched pairs apart. Text is NOT involved β€” the bridge operates
purely between image and audio spaces.
Critically, existing CLIP text-image and CLAP text-audio paths are
UNCHANGED. The bridge is additive β€” it enables si_a without
degrading st_i or st_a.
Integration:
Once trained, load the bridge into CoherenceEngine via:
engine.load_bridge("models/bridge/bridge_final.pt")
This enables si_a computation and activates the full MSCI formula:
MSCI = 0.45 * st_i + 0.45 * st_a + 0.10 * si_a
Status: ARCHITECTURE ONLY β€” not trained. Requires paired image-audio data.
"""
from __future__ import annotations
import json
import logging
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
try:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
TORCH_AVAILABLE = True
except ImportError:
TORCH_AVAILABLE = False
logger = logging.getLogger(__name__)
def _check_torch():
if not TORCH_AVAILABLE:
raise ImportError(
"PyTorch is required for the cross-space bridge. "
"Install with: pip install torch"
)
# ═══════════════════════════════════════════════════════════════
# MODEL
# ═══════════════════════════════════════════════════════════════
class BridgeProjectionHead(nn.Module):
"""
Single projection head for one modality.
Architecture: Linear(in, hidden) β†’ GELU β†’ Dropout β†’ Linear(hidden, out) β†’ L2 norm
Uses GELU instead of ReLU following modern transformer conventions.
L2 normalization ensures cosine similarity operates on the unit hypersphere.
"""
def __init__(
self,
input_dim: int = 512,
hidden_dim: int = 384,
output_dim: int = 256,
dropout: float = 0.1,
):
_check_torch()
super().__init__()
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim, bias=True),
nn.GELU(),
nn.Dropout(dropout),
nn.Linear(hidden_dim, output_dim, bias=False),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.normalize(self.net(x), p=2, dim=-1)
class CrossSpaceBridge(nn.Module):
"""
Learned bridge between CLIP image space and CLAP audio space.
Maps both to a shared 256-d space where cosine similarity is meaningful.
Does NOT touch text embeddings β€” CLIP text-image and CLAP text-audio
paths remain identity (pre-trained alignment preserved).
"""
def __init__(
self,
clip_image_dim: int = 512,
clap_audio_dim: int = 512,
bridge_dim: int = 256,
hidden_dim: int = 384,
dropout: float = 0.1,
):
_check_torch()
super().__init__()
self.image_proj = BridgeProjectionHead(
input_dim=clip_image_dim,
hidden_dim=hidden_dim,
output_dim=bridge_dim,
dropout=dropout,
)
self.audio_proj = BridgeProjectionHead(
input_dim=clap_audio_dim,
hidden_dim=hidden_dim,
output_dim=bridge_dim,
dropout=dropout,
)
self.bridge_dim = bridge_dim
self.config = {
"clip_image_dim": clip_image_dim,
"clap_audio_dim": clap_audio_dim,
"bridge_dim": bridge_dim,
"hidden_dim": hidden_dim,
"dropout": dropout,
}
def forward(
self,
image_emb: Optional[torch.Tensor] = None,
audio_emb: Optional[torch.Tensor] = None,
) -> Dict[str, torch.Tensor]:
"""
Project image and/or audio embeddings into bridge space.
Args:
image_emb: CLIP image embeddings [batch, 512]
audio_emb: CLAP audio embeddings [batch, 512]
Returns:
Dict with 'image' and/or 'audio' keys, values are [batch, bridge_dim]
"""
result = {}
if image_emb is not None:
result["image"] = self.image_proj(image_emb)
if audio_emb is not None:
result["audio"] = self.audio_proj(audio_emb)
return result
def compute_similarity(
self,
image_emb: np.ndarray,
audio_emb: np.ndarray,
) -> float:
"""
Compute image-audio similarity through the bridge.
This is the main inference method. Takes raw CLIP image and CLAP audio
embeddings (numpy), projects both into bridge space, returns cosine sim.
Args:
image_emb: CLIP image embedding, shape (512,)
audio_emb: CLAP audio embedding, shape (512,)
Returns:
Cosine similarity in bridge space (float, range [-1, 1])
"""
_check_torch()
self.eval()
with torch.no_grad():
img_t = torch.tensor(image_emb, dtype=torch.float32).unsqueeze(0)
aud_t = torch.tensor(audio_emb, dtype=torch.float32).unsqueeze(0)
projected = self.forward(image_emb=img_t, audio_emb=aud_t)
sim = F.cosine_similarity(projected["image"], projected["audio"])
return float(sim.item())
def save(self, path: Path):
"""Save bridge weights + config."""
_check_torch()
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
torch.save(self.state_dict(), path)
config_path = path.with_suffix(".json")
with config_path.open("w") as f:
json.dump(self.config, f, indent=2)
logger.info("Saved bridge to %s", path)
@classmethod
def load(cls, path: Path) -> "CrossSpaceBridge":
"""Load bridge from saved weights."""
_check_torch()
path = Path(path)
config_path = path.with_suffix(".json")
with config_path.open("r") as f:
config = json.load(f)
model = cls(**config)
state_dict = torch.load(path, map_location="cpu", weights_only=True)
model.load_state_dict(state_dict)
model.eval()
logger.info("Loaded bridge from %s", path)
return model
# ═══════════════════════════════════════════════════════════════
# TRAINING COMPONENTS
# ═══════════════════════════════════════════════════════════════
class ImageAudioPairDataset(Dataset):
"""
Dataset of paired (CLIP image, CLAP audio) embeddings.
Each pair represents the same scene β€” e.g., a beach photo paired
with ocean wave audio. Text is not needed for bridge training.
Data sources for future training:
- RQ1/RQ2 baseline runs: each has a matched (image, audio) pair
- Manual curation: pair images from data/wikimedia/images with
audio from data/freesound/audio by domain
- External datasets: AudioCaps + MSCOCO overlap, VGGSound, etc.
"""
def __init__(
self,
image_embeddings: np.ndarray,
audio_embeddings: np.ndarray,
):
"""
Args:
image_embeddings: CLIP image embeddings [N, 512]
audio_embeddings: CLAP audio embeddings [N, 512]
"""
_check_torch()
assert len(image_embeddings) == len(audio_embeddings), (
f"Mismatched pair count: {len(image_embeddings)} images, "
f"{len(audio_embeddings)} audio"
)
self.images = torch.tensor(image_embeddings, dtype=torch.float32)
self.audio = torch.tensor(audio_embeddings, dtype=torch.float32)
def __len__(self) -> int:
return len(self.images)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
return {"image": self.images[idx], "audio": self.audio[idx]}
class BridgeInfoNCELoss(nn.Module):
"""
Symmetric InfoNCE loss for image-audio bridge training.
For a batch of N paired (image, audio) embeddings:
- Each image should be most similar to its paired audio (and vice versa)
- All other items in the batch are treated as negatives
Loss = 0.5 * (image→audio NCE + audio→image NCE)
This is the same loss structure used by CLIP and CLAP themselves,
applied here to bridge their output spaces.
"""
def __init__(self, temperature: float = 0.07):
_check_torch()
super().__init__()
# Learnable temperature (following CLIP)
self.log_temperature = nn.Parameter(
torch.tensor(np.log(1.0 / temperature))
)
@property
def temperature(self) -> torch.Tensor:
return torch.exp(-self.log_temperature)
def forward(
self,
image_emb: torch.Tensor,
audio_emb: torch.Tensor,
) -> Tuple[torch.Tensor, Dict[str, float]]:
"""
Compute symmetric InfoNCE loss.
Args:
image_emb: Projected image embeddings [batch, bridge_dim], L2-normalized
audio_emb: Projected audio embeddings [batch, bridge_dim], L2-normalized
Returns:
(loss, metrics_dict)
"""
batch_size = image_emb.size(0)
# Similarity matrix [batch, batch]
logits = torch.mm(image_emb, audio_emb.t()) / self.temperature
# Labels: diagonal (each image matches its own audio)
labels = torch.arange(batch_size, device=logits.device)
# Symmetric cross-entropy
loss_i2a = F.cross_entropy(logits, labels)
loss_a2i = F.cross_entropy(logits.t(), labels)
loss = 0.5 * (loss_i2a + loss_a2i)
# Metrics
with torch.no_grad():
acc_i2a = (logits.argmax(dim=1) == labels).float().mean()
acc_a2i = (logits.t().argmax(dim=1) == labels).float().mean()
metrics = {
"loss": loss.item(),
"loss_i2a": loss_i2a.item(),
"loss_a2i": loss_a2i.item(),
"acc_i2a": acc_i2a.item(),
"acc_a2i": acc_a2i.item(),
"temperature": self.temperature.item(),
}
return loss, metrics
class BridgeTrainer:
"""
Training loop for the cross-space bridge.
Minimal, focused trainer:
- AdamW optimizer with cosine LR schedule
- Symmetric InfoNCE loss with learnable temperature
- Early stopping on validation loss
- Checkpoint saving
Usage (future, when paired data is available):
bridge = CrossSpaceBridge()
trainer = BridgeTrainer(bridge)
dataset = ImageAudioPairDataset(image_embs, audio_embs)
trainer.train(dataset)
"""
def __init__(
self,
model: CrossSpaceBridge,
lr: float = 3e-4,
weight_decay: float = 1e-4,
batch_size: int = 32,
n_epochs: int = 50,
patience: int = 10,
output_dir: str = "models/bridge",
):
_check_torch()
self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
self.model = model.to(self.device)
self.batch_size = batch_size
self.n_epochs = n_epochs
self.patience = patience
self.output_dir = Path(output_dir)
self.output_dir.mkdir(parents=True, exist_ok=True)
self.loss_fn = BridgeInfoNCELoss().to(self.device)
# Optimize both projection heads + temperature
self.optimizer = torch.optim.AdamW(
list(model.parameters()) + list(self.loss_fn.parameters()),
lr=lr,
weight_decay=weight_decay,
)
self.history = []
def train(
self,
train_data: ImageAudioPairDataset,
val_data: Optional[ImageAudioPairDataset] = None,
val_split: float = 0.15,
) -> CrossSpaceBridge:
"""
Train the bridge.
Args:
train_data: Paired image-audio embeddings
val_data: Optional separate validation set. If None, splits from train_data.
val_split: Fraction to hold out for validation if val_data is None.
Returns:
Trained CrossSpaceBridge model
"""
# Split if no val_data provided
if val_data is None and val_split > 0:
n_val = max(1, int(len(train_data) * val_split))
n_train = len(train_data) - n_val
train_data, val_data = torch.utils.data.random_split(
train_data, [n_train, n_val],
generator=torch.Generator().manual_seed(42),
)
train_loader = DataLoader(
train_data, batch_size=self.batch_size, shuffle=True, drop_last=True,
)
val_loader = DataLoader(
val_data, batch_size=self.batch_size, shuffle=False,
) if val_data is not None else None
# Cosine LR schedule
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
self.optimizer, T_max=self.n_epochs,
)
best_val_loss = float("inf")
patience_counter = 0
logger.info(
"Training bridge: %d train, %d val, %d epochs, batch=%d, device=%s",
len(train_data), len(val_data) if val_data else 0,
self.n_epochs, self.batch_size, self.device,
)
for epoch in range(self.n_epochs):
# ── Train ────────────────────────────────
self.model.train()
self.loss_fn.train()
epoch_metrics = []
for batch in train_loader:
img = batch["image"].to(self.device)
aud = batch["audio"].to(self.device)
self.optimizer.zero_grad()
projected = self.model(image_emb=img, audio_emb=aud)
loss, metrics = self.loss_fn(projected["image"], projected["audio"])
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
self.optimizer.step()
epoch_metrics.append(metrics)
scheduler.step()
# Epoch averages
avg = {k: np.mean([m[k] for m in epoch_metrics]) for k in epoch_metrics[0]}
# ── Validate ─────────────────────────────
val_loss = None
if val_loader:
val_loss = self._validate(val_loader)
avg["val_loss"] = val_loss
avg["epoch"] = epoch + 1
avg["lr"] = scheduler.get_last_lr()[0]
self.history.append(avg)
logger.info(
"Epoch %d/%d: loss=%.4f acc_i2a=%.3f acc_a2i=%.3f temp=%.3f%s",
epoch + 1, self.n_epochs, avg["loss"],
avg["acc_i2a"], avg["acc_a2i"], avg["temperature"],
f" val_loss={val_loss:.4f}" if val_loss is not None else "",
)
# ── Early stopping ───────────────────────
if val_loss is not None:
if val_loss < best_val_loss:
best_val_loss = val_loss
patience_counter = 0
self.model.save(self.output_dir / "bridge_best.pt")
else:
patience_counter += 1
if patience_counter >= self.patience:
logger.info("Early stopping at epoch %d", epoch + 1)
break
# Save final
self.model.save(self.output_dir / "bridge_final.pt")
self._save_history()
# Load best if we had validation
if val_loader and (self.output_dir / "bridge_best.pt").exists():
self.model = CrossSpaceBridge.load(self.output_dir / "bridge_best.pt")
logger.info("Loaded best checkpoint (val_loss=%.4f)", best_val_loss)
return self.model
def _validate(self, val_loader: DataLoader) -> float:
self.model.eval()
self.loss_fn.eval()
losses = []
with torch.no_grad():
for batch in val_loader:
img = batch["image"].to(self.device)
aud = batch["audio"].to(self.device)
projected = self.model(image_emb=img, audio_emb=aud)
loss, _ = self.loss_fn(projected["image"], projected["audio"])
losses.append(loss.item())
return float(np.mean(losses))
def _save_history(self):
path = self.output_dir / "bridge_training_history.json"
with path.open("w") as f:
json.dump(self.history, f, indent=2)
# ═══════════════════════════════════════════════════════════════
# COHERENCE ENGINE INTEGRATION (future use)
# ═══════════════════════════════════════════════════════════════
def build_domain_matched_dataset(
image_index_path: str = "data/embeddings/image_index.npz",
audio_index_path: str = "data/embeddings/audio_index.npz",
) -> ImageAudioPairDataset:
"""
Build paired training data via weak domain supervision.
Pairs every image with every same-domain audio file:
nature images Γ— nature audio β†’ nature pairs
urban images Γ— urban audio β†’ urban pairs
water images Γ— water audio β†’ water pairs
Expected yield (~1,000+ pairs):
5 nature imgs Γ— 22 nature audio = 110
23 urban imgs Γ— 28 urban audio = 644
9 water imgs Γ— 33 water audio = 297
Total: ~1,051 pairs
These are "weakly supervised" β€” same domain, not exact scene matches.
Sufficient for learning a rough bridge alignment.
Args:
image_index_path: Path to image embedding index (.npz)
audio_index_path: Path to audio embedding index (.npz)
Returns:
ImageAudioPairDataset ready for bridge training
"""
_check_torch()
img_index = np.load(image_index_path, allow_pickle=True)
aud_index = np.load(audio_index_path, allow_pickle=True)
img_paths = list(img_index["paths"])
img_embeddings = img_index["embeddings"]
img_domains = list(img_index.get("domains", []))
aud_paths = list(aud_index["paths"])
aud_embeddings = aud_index["embeddings"]
aud_domains = list(aud_index.get("domains", []))
# Infer domains from paths if not stored in index
def _infer_domain(path_str: str) -> str:
p = str(path_str).lower()
for domain in ["nature", "urban", "water"]:
if f"/{domain}/" in p or f"_{domain}" in p:
return domain
return "other"
if not img_domains or len(img_domains) != len(img_paths):
img_domains = [_infer_domain(p) for p in img_paths]
if not aud_domains or len(aud_domains) != len(aud_paths):
aud_domains = [_infer_domain(p) for p in aud_paths]
# Build cross-product pairs within each domain
image_embs_list: List[np.ndarray] = []
audio_embs_list: List[np.ndarray] = []
pair_counts: Dict[str, int] = {}
for domain in ["nature", "urban", "water"]:
img_indices = [i for i, d in enumerate(img_domains) if d == domain]
aud_indices = [i for i, d in enumerate(aud_domains) if d == domain]
n_pairs = len(img_indices) * len(aud_indices)
pair_counts[domain] = n_pairs
for ii in img_indices:
for ai in aud_indices:
image_embs_list.append(img_embeddings[ii])
audio_embs_list.append(aud_embeddings[ai])
total = sum(pair_counts.values())
logger.info(
"Domain-matched pairs: %s = %d total",
", ".join(f"{d}={n}" for d, n in pair_counts.items()),
total,
)
if total == 0:
raise ValueError("No domain-matched pairs found. Check embedding indexes and domain labels.")
return ImageAudioPairDataset(
image_embeddings=np.array(image_embs_list),
audio_embeddings=np.array(audio_embs_list),
)
def build_paired_dataset_from_runs(
results_json: str,
embedder=None,
) -> ImageAudioPairDataset:
"""
Build paired training data from existing experiment runs.
Each baseline run in RQ1/RQ2 has a matched (image_path, audio_path) pair
for the same prompt. These pairs can be used to train the bridge.
Args:
results_json: Path to rq1_results.json or rq2_results.json
embedder: AlignedEmbedder instance (for re-embedding if needed)
Returns:
ImageAudioPairDataset ready for training
Example:
embedder = AlignedEmbedder()
dataset = build_paired_dataset_from_runs(
"runs/rq1/rq1_results.json",
embedder=embedder,
)
bridge = CrossSpaceBridge()
trainer = BridgeTrainer(bridge)
trainer.train(dataset)
"""
_check_torch()
import json as _json
with open(results_json) as f:
data = _json.load(f)
# Collect unique (image_path, audio_path) pairs from baseline runs
pairs = {}
for r in data["results"]:
# Only use baseline (matched) pairs
condition = r.get("condition", r.get("mode", ""))
if condition not in ("baseline", "direct"):
continue
img = r.get("image_path")
aud = r.get("audio_path")
if img and aud:
key = f"{img}||{aud}"
if key not in pairs:
pairs[key] = (img, aud)
logger.info("Found %d unique image-audio pairs from %s", len(pairs), results_json)
if embedder is None:
from src.embeddings.aligned_embeddings import AlignedEmbedder
embedder = AlignedEmbedder()
image_embs = []
audio_embs = []
for img_path, aud_path in pairs.values():
try:
img_emb = embedder.embed_image(img_path)
aud_emb = embedder.embed_audio(aud_path)
image_embs.append(img_emb)
audio_embs.append(aud_emb)
except Exception as e:
logger.warning("Skipping pair %s / %s: %s", img_path, aud_path, e)
logger.info("Successfully embedded %d pairs", len(image_embs))
return ImageAudioPairDataset(
image_embeddings=np.array(image_embs),
audio_embeddings=np.array(audio_embs),
)