| | """ |
| | Ex-MCR Cross-Space Alignment: CLAP Audio β CLIP Space. |
| | |
| | Ex-MCR (Ex-Modal Contrastive Retrieval) projects CLAP audio embeddings |
| | INTO CLIP space while keeping CLIP embeddings unchanged. This lets us |
| | compute meaningful image-audio similarity and full 3-way Gramian volume. |
| | |
| | Architecture decision: Ex-MCR over C-MCR because: |
| | - Ex-MCR keeps CLIP embeddings frozen (no recomputation needed) |
| | - C-MCR projects BOTH spaces into a new space (breaks everything) |
| | |
| | The projector is a lightweight MLP: |
| | CLAP 512-d β Linear(512, 512) β ReLU β Linear(512, 512) β L2 norm |
| | |
| | If Ex-MCR weights are not available, falls back to an untrained identity |
| | projection (which is equivalent to not using the projector). |
| | |
| | CLAP compatibility note: |
| | Our project uses `laion/clap-htsat-unfused`. |
| | Ex-MCR uses `laion_clap_fullset_fusion` (different model). |
| | If projections are poor with our CLAP, switch to the fusion model. |
| | """ |
| |
|
| | from __future__ import annotations |
| |
|
| | import logging |
| | from pathlib import Path |
| | from typing import Optional |
| |
|
| | import numpy as np |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | try: |
| | import torch |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | TORCH_AVAILABLE = True |
| | except ImportError: |
| | TORCH_AVAILABLE = False |
| |
|
| |
|
| | class ExMCRProjector: |
| | """ |
| | Projects CLAP audio embeddings into CLIP space. |
| | |
| | Usage: |
| | proj = ExMCRProjector("models/exmcr/ex_clap.pt") |
| | audio_in_clip = proj.project_audio(clap_embedding) # now comparable to CLIP |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | weights_path: Optional[str] = None, |
| | device: str = "cpu", |
| | ): |
| | """ |
| | Args: |
| | weights_path: Path to Ex-MCR CLAPβCLIP projection weights (.pt). |
| | If None or file doesn't exist, uses identity (passthrough). |
| | device: Torch device for inference. |
| | """ |
| | self._model = None |
| | self._device = device |
| | self._identity_mode = True |
| |
|
| | if weights_path and Path(weights_path).exists() and TORCH_AVAILABLE: |
| | self._load_weights(weights_path) |
| | elif weights_path and not Path(weights_path).exists(): |
| | logger.warning( |
| | "Ex-MCR weights not found: %s β using identity projection", weights_path |
| | ) |
| |
|
| | def _load_weights(self, path: str) -> None: |
| | """Load Ex-MCR projection head from saved weights.""" |
| | state_dict = torch.load(path, map_location=self._device, weights_only=True) |
| |
|
| | |
| | |
| | |
| | keys = list(state_dict.keys()) |
| |
|
| | |
| | if any("layers" in k for k in keys): |
| | |
| | in_dim = state_dict["layers.0.weight"].shape[1] |
| | hidden_dim = state_dict["layers.0.weight"].shape[0] |
| | out_dim = state_dict["layers.2.weight"].shape[0] |
| | model = nn.Sequential( |
| | nn.Linear(in_dim, hidden_dim), |
| | nn.ReLU(), |
| | nn.Linear(hidden_dim, out_dim), |
| | ) |
| | |
| | new_state = {} |
| | for k, v in state_dict.items(): |
| | new_key = k.replace("layers.", "") |
| | new_state[new_key] = v |
| | model.load_state_dict(new_state) |
| | elif any(k.startswith("0.") for k in keys): |
| | |
| | in_dim = state_dict["0.weight"].shape[1] |
| | hidden_dim = state_dict["0.weight"].shape[0] |
| | out_dim = state_dict["2.weight"].shape[0] |
| | model = nn.Sequential( |
| | nn.Linear(in_dim, hidden_dim), |
| | nn.ReLU(), |
| | nn.Linear(hidden_dim, out_dim), |
| | ) |
| | model.load_state_dict(state_dict) |
| | else: |
| | |
| | weight_keys = [k for k in keys if "weight" in k] |
| | if len(weight_keys) >= 2: |
| | first_w = state_dict[weight_keys[0]] |
| | last_w = state_dict[weight_keys[-1]] |
| | in_dim = first_w.shape[1] |
| | hidden_dim = first_w.shape[0] |
| | out_dim = last_w.shape[0] |
| | model = nn.Sequential( |
| | nn.Linear(in_dim, hidden_dim), |
| | nn.ReLU(), |
| | nn.Linear(hidden_dim, out_dim), |
| | ) |
| | model.load_state_dict(state_dict) |
| | else: |
| | logger.warning("Unrecognized Ex-MCR weight format β using identity") |
| | return |
| |
|
| | model.to(self._device) |
| | model.eval() |
| | self._model = model |
| | self._identity_mode = False |
| | logger.info( |
| | "Ex-MCR projector loaded: %d β %d β %d (from %s)", |
| | in_dim, hidden_dim, out_dim, path, |
| | ) |
| |
|
| | @property |
| | def is_identity(self) -> bool: |
| | """True if projector is passthrough (no trained weights loaded).""" |
| | return self._identity_mode |
| |
|
| | def project_audio(self, clap_embedding: np.ndarray) -> np.ndarray: |
| | """ |
| | Project CLAP audio embedding into CLIP space. |
| | |
| | Args: |
| | clap_embedding: CLAP audio embedding, shape (512,) or (N, 512). |
| | |
| | Returns: |
| | Projected embedding in CLIP space, L2-normalized. |
| | """ |
| | if self._identity_mode: |
| | emb = clap_embedding.squeeze().astype(np.float32) |
| | norm = np.linalg.norm(emb) + 1e-12 |
| | return emb / norm |
| |
|
| | if not TORCH_AVAILABLE: |
| | return clap_embedding.squeeze().astype(np.float32) |
| |
|
| | was_1d = clap_embedding.ndim == 1 or ( |
| | clap_embedding.ndim == 2 and clap_embedding.shape[0] == 1 |
| | ) |
| | emb = clap_embedding.squeeze() |
| | if emb.ndim == 1: |
| | emb = emb[np.newaxis, :] |
| |
|
| | with torch.no_grad(): |
| | x = torch.tensor(emb, dtype=torch.float32, device=self._device) |
| | projected = self._model(x) |
| | projected = F.normalize(projected, p=2, dim=-1) |
| | result = projected.cpu().numpy() |
| |
|
| | if was_1d: |
| | return result.squeeze(0) |
| | return result |
| |
|
| | def project_audio_batch(self, clap_embeddings: np.ndarray) -> np.ndarray: |
| | """ |
| | Batch projection of CLAP audio embeddings into CLIP space. |
| | |
| | Args: |
| | clap_embeddings: Shape (N, 512). |
| | |
| | Returns: |
| | Projected embeddings in CLIP space, shape (N, 512), L2-normalized. |
| | """ |
| | if self._identity_mode: |
| | norms = np.linalg.norm(clap_embeddings, axis=1, keepdims=True) + 1e-12 |
| | return (clap_embeddings / norms).astype(np.float32) |
| |
|
| | if not TORCH_AVAILABLE: |
| | norms = np.linalg.norm(clap_embeddings, axis=1, keepdims=True) + 1e-12 |
| | return (clap_embeddings / norms).astype(np.float32) |
| |
|
| | with torch.no_grad(): |
| | x = torch.tensor(clap_embeddings, dtype=torch.float32, device=self._device) |
| | projected = self._model(x) |
| | projected = F.normalize(projected, p=2, dim=-1) |
| | return projected.cpu().numpy() |
| |
|