File size: 8,527 Bytes
7322449 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 | ---
license: mit
language:
- en
base_model:
- laion/CLIP-ViT-B-32-laion2B-s34B-b79K
- MIT/ast-finetuned-audioset-10-10-0.4593
pipeline_tag: feature-extraction
tags:
- multimodal
- cross-modal
- embeddings
- audio
- vision
- text
- contrastive-learning
- openclip
- ast
---
# πΏ Sage Embed β v1.1
**Sage Embed** is a cross-modal embedding model that encodes **text**, **images**, and **audio** into a unified shared latent space. By aligning all three modalities, it enables semantic similarity and retrieval across any combination β query audio with text, compare images to sounds, or any mix in between.
> π§ **This project is in early development.** Quantitative evaluation metrics will be added in a future release. See the [Evaluation](#evaluation) section for details.
---
## Model Summary
| Property | Details |
|---|---|
| **Version** | v1.1 |
| **Modalities** | Text Β· Image Β· Audio |
| **Shared Embedding Dim** | 512 |
| **Text & Image Backbone** | [`laion/CLIP-ViT-B-32-laion2B-s34B-b79K`](https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K) |
| **Audio Backbone** | [`MIT/ast-finetuned-audioset-10-10-0.4593`](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) |
| **Training Dataset** | [Clotho v2](https://github.com/audio-captioning/clotho-dataset) |
| **Training Hardware** | Google Colab β NVIDIA T4 (free tier) |
| **Epochs** | 20 |
| **License** | MIT |
---
## Architecture
### Text & Image: OpenCLIP ViT-B/32
- **Model:** `hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
- Weights are **fully frozen** during training β used purely as a fixed embedding target.
- Produces **512-dimensional** L2-normalized embeddings for both text and images.
### Audio: Audio Spectrogram Transformer (AST)
- **Model:** `MIT/ast-finetuned-audioset-10-10-0.4593`
- AST hidden size: **768**
- The model is **mostly frozen**; the **last 2 encoder layers** and the **final LayerNorm** are unfrozen for fine-tuning, allowing the audio encoder to adapt to the audio-caption alignment task without catastrophic forgetting.
- The `[CLS]` token (`last_hidden_state[:, 0]`) is extracted and fed to the projection layer.
### Audio Projection Layer
The projection layer maps AST's 768-dimensional output into CLIP's 512-dimensional shared space
---
## Evaluation
> **Metrics Coming Soon**
>
> Quantitative benchmarks will be added in a future update as the project matures. This release focuses on establishing the cross-modal architecture
---
## Usage
### Installation
```bash
pip install torch torchaudio transformers open_clip_torch huggingface_hub
```
### Step 1 β Define the Model & Load Weights
Copy this class definition exactly. The weights file is pulled directly from the Hub.
```python
import os
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from transformers import ASTModel, AutoProcessor
from huggingface_hub import hf_hub_download
import open_clip
from PIL import Image
import numpy as np
MODEL_NAME = "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class OpenCLIP_AST_Model(nn.Module):
def __init__(self, embedding_dim=512):
super().__init__()
self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
self.clip_model, _, self.image_preprocess = open_clip.create_model_and_transforms(
MODEL_NAME, pretrained=None
)
self.audio_projection = nn.Sequential(
nn.Linear(self.ast.config.hidden_size, 1024),
nn.LayerNorm(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, embedding_dim)
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward_audio(self, input_values):
outputs = self.ast(input_values)
# Average of first two token states for richer audio representation
audio_features = (outputs.last_hidden_state[:, 0] + outputs.last_hidden_state[:, 1]) / 2
return self.audio_projection(audio_features)
# --- Load weights from HuggingFace Hub ---
weights_path = hf_hub_download(
repo_id="harryfrz/sage-embed",
filename="sage_embed_v1.1.pt"
)
model = OpenCLIP_AST_Model().to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
# Processors
ast_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
print("Sage Embed v1.1 loaded.")
```
---
### Step 2 β Encode Your Inputs
**Encode a single audio file** (`.wav`, `.mp3`, `.flac`):
```python
def encode_audio(filepath):
waveform, sr = torchaudio.load(filepath)
# Convert to mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0)
else:
waveform = waveform.squeeze(0)
# Resample to 16kHz
if sr != 16000:
waveform = T.Resample(sr, 16000)(waveform)
# Pad or crop to 10 seconds (163840 samples @ 16kHz)
target = 163840
if waveform.shape[0] < target:
waveform = torch.nn.functional.pad(waveform, (0, target - waveform.shape[0]))
else:
waveform = waveform[:target]
inputs = ast_processor(waveform, sampling_rate=16000, return_tensors="pt")
input_tensor = inputs["input_values"].to(DEVICE)
with torch.no_grad():
emb = model.forward_audio(input_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [1, 512]
audio_emb = encode_audio("your_audio.wav")
```
**Encode a single image file** (`.jpg`, `.png`, `.webp`):
```python
def encode_image(filepath):
image = Image.open(filepath).convert("RGB")
image_input = model.image_preprocess(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
emb = model.clip_model.encode_image(image_input)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [1, 512]
image_emb = encode_image("your_image.jpg")
```
**Encode text:**
```python
def encode_text(texts):
# texts: a string or list of strings
if isinstance(texts, str):
texts = [texts]
tokens = tokenizer(texts).to(DEVICE)
with torch.no_grad():
emb = model.clip_model.encode_text(tokens)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [N, 512]
text_emb = encode_text("a dog barking in the park")
```
---
### Step 3 β Cross-Modal Similarity Test
All embeddings are **L2-normalized**, so dot product equals cosine similarity directly.
```python
# --- Audio β Text ---
audio_emb = encode_audio("dog_bark.wav")
text_embs = encode_text(["a dog barking", "rain on a rooftop", "busy city traffic"])
sim = (audio_emb @ text_embs.T)[0] # shape: [3]
labels = ["a dog barking", "rain on a rooftop", "busy city traffic"]
print("Audio β Text similarity:")
for label, score in zip(labels, sim.tolist()):
print(f" '{label}': {score:.4f}")
# Expected: highest score for "a dog barking"
# --- Image β Text ---
image_emb = encode_image("sunset.jpg")
text_embs = encode_text(["a sunset over the ocean", "a crowded street", "birds chirping"])
sim = (image_emb @ text_embs.T)[0]
labels = ["a sunset over the ocean", "a crowded street", "birds chirping"]
print("\nImage β Text similarity:")
for label, score in zip(labels, sim.tolist()):
print(f" '{label}': {score:.4f}")
# --- Audio β Image (cross-modal) ---
sim_cross = (audio_emb @ image_emb.T)[0]
print(f"\n Audio β Image similarity: {sim_cross.item():.4f}")
```
---
## Limitations
- **Early-stage model:** Sage Embed v1.1 is the first public release. The projection layer was trained on Clotho only; performance on out-of-distribution audio domains may vary.
- **CLIP is frozen:** The text and image encoders are not jointly fine-tuned, so alignment quality is bounded by how well the audio projection adapts to CLIP's fixed representation space.
- **Training scale:** Trained on a free-tier T4 GPU with a relatively small dataset. Larger-scale training is expected to improve retrieval quality significantly in future versions.
- **Audio length:** Optimized for clips up to ~10 seconds. Longer clips are truncated from the start.
---
## Acknowledgements
- [OpenCLIP](https://github.com/mlfoundations/open_clip) β LAION CLIP ViT-B/32 backbone
- [Audio Spectrogram Transformer](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) β MIT AST backbone
- [Clotho Dataset](https://github.com/audio-captioning/clotho-dataset) β Training data
|