| """ |
| Cross-Space Bridge: CLIP Image β CLAP Audio Alignment |
| |
| Problem: |
| CLIP image embeddings and CLAP audio embeddings live in DIFFERENT 512-d spaces. |
| Cosine similarity between them is meaningless. This is why si_a = None in |
| the coherence engine. |
| |
| Solution: |
| Train two lightweight projection heads that map: |
| CLIP image (512-d) β shared bridge space (256-d) |
| CLAP audio (512-d) β shared bridge space (256-d) |
| |
| After training, cosine similarity in the bridge space gives a meaningful |
| image-audio coherence score (si_a). |
| |
| Architecture: |
| βββββββββββββββ ββββββββββββββββ |
| β CLIP Image ββββΆ image_proj βββΆβ β |
| β Embedding β (512β256) β Shared ββββΆ cosine_sim(i, a) |
| β (512-d) β β Bridge β |
| βββββββββββββββ β Space β |
| β (256-d) β |
| βββββββββββββββ β β |
| β CLAP Audio ββββΆ audio_proj βββΆβ β |
| β Embedding β (512β256) ββββββββββββββββ |
| β (512-d) β |
| βββββββββββββββ |
| |
| Training: |
| Uses paired (image, audio) data where both depict the same scene. |
| InfoNCE contrastive loss pulls matched pairs together, pushes |
| mismatched pairs apart. Text is NOT involved β the bridge operates |
| purely between image and audio spaces. |
| |
| Critically, existing CLIP text-image and CLAP text-audio paths are |
| UNCHANGED. The bridge is additive β it enables si_a without |
| degrading st_i or st_a. |
| |
| Integration: |
| Once trained, load the bridge into CoherenceEngine via: |
| engine.load_bridge("models/bridge/bridge_final.pt") |
| This enables si_a computation and activates the full MSCI formula: |
| MSCI = 0.45 * st_i + 0.45 * st_a + 0.10 * si_a |
| |
| Status: ARCHITECTURE ONLY β not trained. Requires paired image-audio data. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| import logging |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple |
|
|
| import numpy as np |
|
|
| try: |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.utils.data import DataLoader, Dataset |
|
|
| TORCH_AVAILABLE = True |
| except ImportError: |
| TORCH_AVAILABLE = False |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| def _check_torch(): |
| if not TORCH_AVAILABLE: |
| raise ImportError( |
| "PyTorch is required for the cross-space bridge. " |
| "Install with: pip install torch" |
| ) |
|
|
|
|
| |
| |
| |
|
|
|
|
| class BridgeProjectionHead(nn.Module): |
| """ |
| Single projection head for one modality. |
| |
| Architecture: Linear(in, hidden) β GELU β Dropout β Linear(hidden, out) β L2 norm |
| |
| Uses GELU instead of ReLU following modern transformer conventions. |
| L2 normalization ensures cosine similarity operates on the unit hypersphere. |
| """ |
|
|
| def __init__( |
| self, |
| input_dim: int = 512, |
| hidden_dim: int = 384, |
| output_dim: int = 256, |
| dropout: float = 0.1, |
| ): |
| _check_torch() |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(input_dim, hidden_dim, bias=True), |
| nn.GELU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden_dim, output_dim, bias=False), |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| return F.normalize(self.net(x), p=2, dim=-1) |
|
|
|
|
| class CrossSpaceBridge(nn.Module): |
| """ |
| Learned bridge between CLIP image space and CLAP audio space. |
| |
| Maps both to a shared 256-d space where cosine similarity is meaningful. |
| Does NOT touch text embeddings β CLIP text-image and CLAP text-audio |
| paths remain identity (pre-trained alignment preserved). |
| """ |
|
|
| def __init__( |
| self, |
| clip_image_dim: int = 512, |
| clap_audio_dim: int = 512, |
| bridge_dim: int = 256, |
| hidden_dim: int = 384, |
| dropout: float = 0.1, |
| ): |
| _check_torch() |
| super().__init__() |
|
|
| self.image_proj = BridgeProjectionHead( |
| input_dim=clip_image_dim, |
| hidden_dim=hidden_dim, |
| output_dim=bridge_dim, |
| dropout=dropout, |
| ) |
| self.audio_proj = BridgeProjectionHead( |
| input_dim=clap_audio_dim, |
| hidden_dim=hidden_dim, |
| output_dim=bridge_dim, |
| dropout=dropout, |
| ) |
|
|
| self.bridge_dim = bridge_dim |
| self.config = { |
| "clip_image_dim": clip_image_dim, |
| "clap_audio_dim": clap_audio_dim, |
| "bridge_dim": bridge_dim, |
| "hidden_dim": hidden_dim, |
| "dropout": dropout, |
| } |
|
|
| def forward( |
| self, |
| image_emb: Optional[torch.Tensor] = None, |
| audio_emb: Optional[torch.Tensor] = None, |
| ) -> Dict[str, torch.Tensor]: |
| """ |
| Project image and/or audio embeddings into bridge space. |
| |
| Args: |
| image_emb: CLIP image embeddings [batch, 512] |
| audio_emb: CLAP audio embeddings [batch, 512] |
| |
| Returns: |
| Dict with 'image' and/or 'audio' keys, values are [batch, bridge_dim] |
| """ |
| result = {} |
| if image_emb is not None: |
| result["image"] = self.image_proj(image_emb) |
| if audio_emb is not None: |
| result["audio"] = self.audio_proj(audio_emb) |
| return result |
|
|
| def compute_similarity( |
| self, |
| image_emb: np.ndarray, |
| audio_emb: np.ndarray, |
| ) -> float: |
| """ |
| Compute image-audio similarity through the bridge. |
| |
| This is the main inference method. Takes raw CLIP image and CLAP audio |
| embeddings (numpy), projects both into bridge space, returns cosine sim. |
| |
| Args: |
| image_emb: CLIP image embedding, shape (512,) |
| audio_emb: CLAP audio embedding, shape (512,) |
| |
| Returns: |
| Cosine similarity in bridge space (float, range [-1, 1]) |
| """ |
| _check_torch() |
| self.eval() |
| with torch.no_grad(): |
| img_t = torch.tensor(image_emb, dtype=torch.float32).unsqueeze(0) |
| aud_t = torch.tensor(audio_emb, dtype=torch.float32).unsqueeze(0) |
| projected = self.forward(image_emb=img_t, audio_emb=aud_t) |
| sim = F.cosine_similarity(projected["image"], projected["audio"]) |
| return float(sim.item()) |
|
|
| def save(self, path: Path): |
| """Save bridge weights + config.""" |
| _check_torch() |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| torch.save(self.state_dict(), path) |
| config_path = path.with_suffix(".json") |
| with config_path.open("w") as f: |
| json.dump(self.config, f, indent=2) |
| logger.info("Saved bridge to %s", path) |
|
|
| @classmethod |
| def load(cls, path: Path) -> "CrossSpaceBridge": |
| """Load bridge from saved weights.""" |
| _check_torch() |
| path = Path(path) |
| config_path = path.with_suffix(".json") |
| with config_path.open("r") as f: |
| config = json.load(f) |
| model = cls(**config) |
| state_dict = torch.load(path, map_location="cpu", weights_only=True) |
| model.load_state_dict(state_dict) |
| model.eval() |
| logger.info("Loaded bridge from %s", path) |
| return model |
|
|
|
|
| |
| |
| |
|
|
|
|
| class ImageAudioPairDataset(Dataset): |
| """ |
| Dataset of paired (CLIP image, CLAP audio) embeddings. |
| |
| Each pair represents the same scene β e.g., a beach photo paired |
| with ocean wave audio. Text is not needed for bridge training. |
| |
| Data sources for future training: |
| - RQ1/RQ2 baseline runs: each has a matched (image, audio) pair |
| - Manual curation: pair images from data/wikimedia/images with |
| audio from data/freesound/audio by domain |
| - External datasets: AudioCaps + MSCOCO overlap, VGGSound, etc. |
| """ |
|
|
| def __init__( |
| self, |
| image_embeddings: np.ndarray, |
| audio_embeddings: np.ndarray, |
| ): |
| """ |
| Args: |
| image_embeddings: CLIP image embeddings [N, 512] |
| audio_embeddings: CLAP audio embeddings [N, 512] |
| """ |
| _check_torch() |
| assert len(image_embeddings) == len(audio_embeddings), ( |
| f"Mismatched pair count: {len(image_embeddings)} images, " |
| f"{len(audio_embeddings)} audio" |
| ) |
| self.images = torch.tensor(image_embeddings, dtype=torch.float32) |
| self.audio = torch.tensor(audio_embeddings, dtype=torch.float32) |
|
|
| def __len__(self) -> int: |
| return len(self.images) |
|
|
| def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: |
| return {"image": self.images[idx], "audio": self.audio[idx]} |
|
|
|
|
| class BridgeInfoNCELoss(nn.Module): |
| """ |
| Symmetric InfoNCE loss for image-audio bridge training. |
| |
| For a batch of N paired (image, audio) embeddings: |
| - Each image should be most similar to its paired audio (and vice versa) |
| - All other items in the batch are treated as negatives |
| |
| Loss = 0.5 * (imageβaudio NCE + audioβimage NCE) |
| |
| This is the same loss structure used by CLIP and CLAP themselves, |
| applied here to bridge their output spaces. |
| """ |
|
|
| def __init__(self, temperature: float = 0.07): |
| _check_torch() |
| super().__init__() |
| |
| self.log_temperature = nn.Parameter( |
| torch.tensor(np.log(1.0 / temperature)) |
| ) |
|
|
| @property |
| def temperature(self) -> torch.Tensor: |
| return torch.exp(-self.log_temperature) |
|
|
| def forward( |
| self, |
| image_emb: torch.Tensor, |
| audio_emb: torch.Tensor, |
| ) -> Tuple[torch.Tensor, Dict[str, float]]: |
| """ |
| Compute symmetric InfoNCE loss. |
| |
| Args: |
| image_emb: Projected image embeddings [batch, bridge_dim], L2-normalized |
| audio_emb: Projected audio embeddings [batch, bridge_dim], L2-normalized |
| |
| Returns: |
| (loss, metrics_dict) |
| """ |
| batch_size = image_emb.size(0) |
|
|
| |
| logits = torch.mm(image_emb, audio_emb.t()) / self.temperature |
|
|
| |
| labels = torch.arange(batch_size, device=logits.device) |
|
|
| |
| loss_i2a = F.cross_entropy(logits, labels) |
| loss_a2i = F.cross_entropy(logits.t(), labels) |
| loss = 0.5 * (loss_i2a + loss_a2i) |
|
|
| |
| with torch.no_grad(): |
| acc_i2a = (logits.argmax(dim=1) == labels).float().mean() |
| acc_a2i = (logits.t().argmax(dim=1) == labels).float().mean() |
|
|
| metrics = { |
| "loss": loss.item(), |
| "loss_i2a": loss_i2a.item(), |
| "loss_a2i": loss_a2i.item(), |
| "acc_i2a": acc_i2a.item(), |
| "acc_a2i": acc_a2i.item(), |
| "temperature": self.temperature.item(), |
| } |
|
|
| return loss, metrics |
|
|
|
|
| class BridgeTrainer: |
| """ |
| Training loop for the cross-space bridge. |
| |
| Minimal, focused trainer: |
| - AdamW optimizer with cosine LR schedule |
| - Symmetric InfoNCE loss with learnable temperature |
| - Early stopping on validation loss |
| - Checkpoint saving |
| |
| Usage (future, when paired data is available): |
| bridge = CrossSpaceBridge() |
| trainer = BridgeTrainer(bridge) |
| dataset = ImageAudioPairDataset(image_embs, audio_embs) |
| trainer.train(dataset) |
| """ |
|
|
| def __init__( |
| self, |
| model: CrossSpaceBridge, |
| lr: float = 3e-4, |
| weight_decay: float = 1e-4, |
| batch_size: int = 32, |
| n_epochs: int = 50, |
| patience: int = 10, |
| output_dir: str = "models/bridge", |
| ): |
| _check_torch() |
| self.device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
| self.model = model.to(self.device) |
| self.batch_size = batch_size |
| self.n_epochs = n_epochs |
| self.patience = patience |
| self.output_dir = Path(output_dir) |
| self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| self.loss_fn = BridgeInfoNCELoss().to(self.device) |
|
|
| |
| self.optimizer = torch.optim.AdamW( |
| list(model.parameters()) + list(self.loss_fn.parameters()), |
| lr=lr, |
| weight_decay=weight_decay, |
| ) |
|
|
| self.history = [] |
|
|
| def train( |
| self, |
| train_data: ImageAudioPairDataset, |
| val_data: Optional[ImageAudioPairDataset] = None, |
| val_split: float = 0.15, |
| ) -> CrossSpaceBridge: |
| """ |
| Train the bridge. |
| |
| Args: |
| train_data: Paired image-audio embeddings |
| val_data: Optional separate validation set. If None, splits from train_data. |
| val_split: Fraction to hold out for validation if val_data is None. |
| |
| Returns: |
| Trained CrossSpaceBridge model |
| """ |
| |
| if val_data is None and val_split > 0: |
| n_val = max(1, int(len(train_data) * val_split)) |
| n_train = len(train_data) - n_val |
| train_data, val_data = torch.utils.data.random_split( |
| train_data, [n_train, n_val], |
| generator=torch.Generator().manual_seed(42), |
| ) |
|
|
| train_loader = DataLoader( |
| train_data, batch_size=self.batch_size, shuffle=True, drop_last=True, |
| ) |
| val_loader = DataLoader( |
| val_data, batch_size=self.batch_size, shuffle=False, |
| ) if val_data is not None else None |
|
|
| |
| scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( |
| self.optimizer, T_max=self.n_epochs, |
| ) |
|
|
| best_val_loss = float("inf") |
| patience_counter = 0 |
|
|
| logger.info( |
| "Training bridge: %d train, %d val, %d epochs, batch=%d, device=%s", |
| len(train_data), len(val_data) if val_data else 0, |
| self.n_epochs, self.batch_size, self.device, |
| ) |
|
|
| for epoch in range(self.n_epochs): |
| |
| self.model.train() |
| self.loss_fn.train() |
| epoch_metrics = [] |
|
|
| for batch in train_loader: |
| img = batch["image"].to(self.device) |
| aud = batch["audio"].to(self.device) |
|
|
| self.optimizer.zero_grad() |
| projected = self.model(image_emb=img, audio_emb=aud) |
| loss, metrics = self.loss_fn(projected["image"], projected["audio"]) |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0) |
| self.optimizer.step() |
|
|
| epoch_metrics.append(metrics) |
|
|
| scheduler.step() |
|
|
| |
| avg = {k: np.mean([m[k] for m in epoch_metrics]) for k in epoch_metrics[0]} |
|
|
| |
| val_loss = None |
| if val_loader: |
| val_loss = self._validate(val_loader) |
| avg["val_loss"] = val_loss |
|
|
| avg["epoch"] = epoch + 1 |
| avg["lr"] = scheduler.get_last_lr()[0] |
| self.history.append(avg) |
|
|
| logger.info( |
| "Epoch %d/%d: loss=%.4f acc_i2a=%.3f acc_a2i=%.3f temp=%.3f%s", |
| epoch + 1, self.n_epochs, avg["loss"], |
| avg["acc_i2a"], avg["acc_a2i"], avg["temperature"], |
| f" val_loss={val_loss:.4f}" if val_loss is not None else "", |
| ) |
|
|
| |
| if val_loss is not None: |
| if val_loss < best_val_loss: |
| best_val_loss = val_loss |
| patience_counter = 0 |
| self.model.save(self.output_dir / "bridge_best.pt") |
| else: |
| patience_counter += 1 |
| if patience_counter >= self.patience: |
| logger.info("Early stopping at epoch %d", epoch + 1) |
| break |
|
|
| |
| self.model.save(self.output_dir / "bridge_final.pt") |
| self._save_history() |
|
|
| |
| if val_loader and (self.output_dir / "bridge_best.pt").exists(): |
| self.model = CrossSpaceBridge.load(self.output_dir / "bridge_best.pt") |
| logger.info("Loaded best checkpoint (val_loss=%.4f)", best_val_loss) |
|
|
| return self.model |
|
|
| def _validate(self, val_loader: DataLoader) -> float: |
| self.model.eval() |
| self.loss_fn.eval() |
| losses = [] |
| with torch.no_grad(): |
| for batch in val_loader: |
| img = batch["image"].to(self.device) |
| aud = batch["audio"].to(self.device) |
| projected = self.model(image_emb=img, audio_emb=aud) |
| loss, _ = self.loss_fn(projected["image"], projected["audio"]) |
| losses.append(loss.item()) |
| return float(np.mean(losses)) |
|
|
| def _save_history(self): |
| path = self.output_dir / "bridge_training_history.json" |
| with path.open("w") as f: |
| json.dump(self.history, f, indent=2) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def build_domain_matched_dataset( |
| image_index_path: str = "data/embeddings/image_index.npz", |
| audio_index_path: str = "data/embeddings/audio_index.npz", |
| ) -> ImageAudioPairDataset: |
| """ |
| Build paired training data via weak domain supervision. |
| |
| Pairs every image with every same-domain audio file: |
| nature images Γ nature audio β nature pairs |
| urban images Γ urban audio β urban pairs |
| water images Γ water audio β water pairs |
| |
| Expected yield (~1,000+ pairs): |
| 5 nature imgs Γ 22 nature audio = 110 |
| 23 urban imgs Γ 28 urban audio = 644 |
| 9 water imgs Γ 33 water audio = 297 |
| Total: ~1,051 pairs |
| |
| These are "weakly supervised" β same domain, not exact scene matches. |
| Sufficient for learning a rough bridge alignment. |
| |
| Args: |
| image_index_path: Path to image embedding index (.npz) |
| audio_index_path: Path to audio embedding index (.npz) |
| |
| Returns: |
| ImageAudioPairDataset ready for bridge training |
| """ |
| _check_torch() |
|
|
| img_index = np.load(image_index_path, allow_pickle=True) |
| aud_index = np.load(audio_index_path, allow_pickle=True) |
|
|
| img_paths = list(img_index["paths"]) |
| img_embeddings = img_index["embeddings"] |
| img_domains = list(img_index.get("domains", [])) |
|
|
| aud_paths = list(aud_index["paths"]) |
| aud_embeddings = aud_index["embeddings"] |
| aud_domains = list(aud_index.get("domains", [])) |
|
|
| |
| def _infer_domain(path_str: str) -> str: |
| p = str(path_str).lower() |
| for domain in ["nature", "urban", "water"]: |
| if f"/{domain}/" in p or f"_{domain}" in p: |
| return domain |
| return "other" |
|
|
| if not img_domains or len(img_domains) != len(img_paths): |
| img_domains = [_infer_domain(p) for p in img_paths] |
| if not aud_domains or len(aud_domains) != len(aud_paths): |
| aud_domains = [_infer_domain(p) for p in aud_paths] |
|
|
| |
| image_embs_list: List[np.ndarray] = [] |
| audio_embs_list: List[np.ndarray] = [] |
| pair_counts: Dict[str, int] = {} |
|
|
| for domain in ["nature", "urban", "water"]: |
| img_indices = [i for i, d in enumerate(img_domains) if d == domain] |
| aud_indices = [i for i, d in enumerate(aud_domains) if d == domain] |
|
|
| n_pairs = len(img_indices) * len(aud_indices) |
| pair_counts[domain] = n_pairs |
|
|
| for ii in img_indices: |
| for ai in aud_indices: |
| image_embs_list.append(img_embeddings[ii]) |
| audio_embs_list.append(aud_embeddings[ai]) |
|
|
| total = sum(pair_counts.values()) |
| logger.info( |
| "Domain-matched pairs: %s = %d total", |
| ", ".join(f"{d}={n}" for d, n in pair_counts.items()), |
| total, |
| ) |
|
|
| if total == 0: |
| raise ValueError("No domain-matched pairs found. Check embedding indexes and domain labels.") |
|
|
| return ImageAudioPairDataset( |
| image_embeddings=np.array(image_embs_list), |
| audio_embeddings=np.array(audio_embs_list), |
| ) |
|
|
|
|
| def build_paired_dataset_from_runs( |
| results_json: str, |
| embedder=None, |
| ) -> ImageAudioPairDataset: |
| """ |
| Build paired training data from existing experiment runs. |
| |
| Each baseline run in RQ1/RQ2 has a matched (image_path, audio_path) pair |
| for the same prompt. These pairs can be used to train the bridge. |
| |
| Args: |
| results_json: Path to rq1_results.json or rq2_results.json |
| embedder: AlignedEmbedder instance (for re-embedding if needed) |
| |
| Returns: |
| ImageAudioPairDataset ready for training |
| |
| Example: |
| embedder = AlignedEmbedder() |
| dataset = build_paired_dataset_from_runs( |
| "runs/rq1/rq1_results.json", |
| embedder=embedder, |
| ) |
| bridge = CrossSpaceBridge() |
| trainer = BridgeTrainer(bridge) |
| trainer.train(dataset) |
| """ |
| _check_torch() |
| import json as _json |
|
|
| with open(results_json) as f: |
| data = _json.load(f) |
|
|
| |
| pairs = {} |
| for r in data["results"]: |
| |
| condition = r.get("condition", r.get("mode", "")) |
| if condition not in ("baseline", "direct"): |
| continue |
|
|
| img = r.get("image_path") |
| aud = r.get("audio_path") |
| if img and aud: |
| key = f"{img}||{aud}" |
| if key not in pairs: |
| pairs[key] = (img, aud) |
|
|
| logger.info("Found %d unique image-audio pairs from %s", len(pairs), results_json) |
|
|
| if embedder is None: |
| from src.embeddings.aligned_embeddings import AlignedEmbedder |
| embedder = AlignedEmbedder() |
|
|
| image_embs = [] |
| audio_embs = [] |
| for img_path, aud_path in pairs.values(): |
| try: |
| img_emb = embedder.embed_image(img_path) |
| aud_emb = embedder.embed_audio(aud_path) |
| image_embs.append(img_emb) |
| audio_embs.append(aud_emb) |
| except Exception as e: |
| logger.warning("Skipping pair %s / %s: %s", img_path, aud_path, e) |
|
|
| logger.info("Successfully embedded %d pairs", len(image_embs)) |
|
|
| return ImageAudioPairDataset( |
| image_embeddings=np.array(image_embs), |
| audio_embeddings=np.array(audio_embs), |
| ) |
|
|