sage-embed / README.md
harryfrz's picture
Update README.md
7322449 verified
---
license: mit
language:
- en
base_model:
- laion/CLIP-ViT-B-32-laion2B-s34B-b79K
- MIT/ast-finetuned-audioset-10-10-0.4593
pipeline_tag: feature-extraction
tags:
- multimodal
- cross-modal
- embeddings
- audio
- vision
- text
- contrastive-learning
- openclip
- ast
---
# 🌿 Sage Embed β€” v1.1
**Sage Embed** is a cross-modal embedding model that encodes **text**, **images**, and **audio** into a unified shared latent space. By aligning all three modalities, it enables semantic similarity and retrieval across any combination β€” query audio with text, compare images to sounds, or any mix in between.
> 🚧 **This project is in early development.** Quantitative evaluation metrics will be added in a future release. See the [Evaluation](#evaluation) section for details.
---
## Model Summary
| Property | Details |
|---|---|
| **Version** | v1.1 |
| **Modalities** | Text Β· Image Β· Audio |
| **Shared Embedding Dim** | 512 |
| **Text & Image Backbone** | [`laion/CLIP-ViT-B-32-laion2B-s34B-b79K`](https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K) |
| **Audio Backbone** | [`MIT/ast-finetuned-audioset-10-10-0.4593`](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) |
| **Training Dataset** | [Clotho v2](https://github.com/audio-captioning/clotho-dataset) |
| **Training Hardware** | Google Colab β€” NVIDIA T4 (free tier) |
| **Epochs** | 20 |
| **License** | MIT |
---
## Architecture
### Text & Image: OpenCLIP ViT-B/32
- **Model:** `hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K`
- Weights are **fully frozen** during training β€” used purely as a fixed embedding target.
- Produces **512-dimensional** L2-normalized embeddings for both text and images.
### Audio: Audio Spectrogram Transformer (AST)
- **Model:** `MIT/ast-finetuned-audioset-10-10-0.4593`
- AST hidden size: **768**
- The model is **mostly frozen**; the **last 2 encoder layers** and the **final LayerNorm** are unfrozen for fine-tuning, allowing the audio encoder to adapt to the audio-caption alignment task without catastrophic forgetting.
- The `[CLS]` token (`last_hidden_state[:, 0]`) is extracted and fed to the projection layer.
### Audio Projection Layer
The projection layer maps AST's 768-dimensional output into CLIP's 512-dimensional shared space
---
## Evaluation
> **Metrics Coming Soon**
>
> Quantitative benchmarks will be added in a future update as the project matures. This release focuses on establishing the cross-modal architecture
---
## Usage
### Installation
```bash
pip install torch torchaudio transformers open_clip_torch huggingface_hub
```
### Step 1 β€” Define the Model & Load Weights
Copy this class definition exactly. The weights file is pulled directly from the Hub.
```python
import os
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from transformers import ASTModel, AutoProcessor
from huggingface_hub import hf_hub_download
import open_clip
from PIL import Image
import numpy as np
MODEL_NAME = "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class OpenCLIP_AST_Model(nn.Module):
def __init__(self, embedding_dim=512):
super().__init__()
self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
self.clip_model, _, self.image_preprocess = open_clip.create_model_and_transforms(
MODEL_NAME, pretrained=None
)
self.audio_projection = nn.Sequential(
nn.Linear(self.ast.config.hidden_size, 1024),
nn.LayerNorm(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, embedding_dim)
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward_audio(self, input_values):
outputs = self.ast(input_values)
# Average of first two token states for richer audio representation
audio_features = (outputs.last_hidden_state[:, 0] + outputs.last_hidden_state[:, 1]) / 2
return self.audio_projection(audio_features)
# --- Load weights from HuggingFace Hub ---
weights_path = hf_hub_download(
repo_id="harryfrz/sage-embed",
filename="sage_embed_v1.1.pt"
)
model = OpenCLIP_AST_Model().to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
# Processors
ast_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
print("Sage Embed v1.1 loaded.")
```
---
### Step 2 β€” Encode Your Inputs
**Encode a single audio file** (`.wav`, `.mp3`, `.flac`):
```python
def encode_audio(filepath):
waveform, sr = torchaudio.load(filepath)
# Convert to mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0)
else:
waveform = waveform.squeeze(0)
# Resample to 16kHz
if sr != 16000:
waveform = T.Resample(sr, 16000)(waveform)
# Pad or crop to 10 seconds (163840 samples @ 16kHz)
target = 163840
if waveform.shape[0] < target:
waveform = torch.nn.functional.pad(waveform, (0, target - waveform.shape[0]))
else:
waveform = waveform[:target]
inputs = ast_processor(waveform, sampling_rate=16000, return_tensors="pt")
input_tensor = inputs["input_values"].to(DEVICE)
with torch.no_grad():
emb = model.forward_audio(input_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [1, 512]
audio_emb = encode_audio("your_audio.wav")
```
**Encode a single image file** (`.jpg`, `.png`, `.webp`):
```python
def encode_image(filepath):
image = Image.open(filepath).convert("RGB")
image_input = model.image_preprocess(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
emb = model.clip_model.encode_image(image_input)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [1, 512]
image_emb = encode_image("your_image.jpg")
```
**Encode text:**
```python
def encode_text(texts):
# texts: a string or list of strings
if isinstance(texts, str):
texts = [texts]
tokens = tokenizer(texts).to(DEVICE)
with torch.no_grad():
emb = model.clip_model.encode_text(tokens)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [N, 512]
text_emb = encode_text("a dog barking in the park")
```
---
### Step 3 β€” Cross-Modal Similarity Test
All embeddings are **L2-normalized**, so dot product equals cosine similarity directly.
```python
# --- Audio ↔ Text ---
audio_emb = encode_audio("dog_bark.wav")
text_embs = encode_text(["a dog barking", "rain on a rooftop", "busy city traffic"])
sim = (audio_emb @ text_embs.T)[0] # shape: [3]
labels = ["a dog barking", "rain on a rooftop", "busy city traffic"]
print("Audio ↔ Text similarity:")
for label, score in zip(labels, sim.tolist()):
print(f" '{label}': {score:.4f}")
# Expected: highest score for "a dog barking"
# --- Image ↔ Text ---
image_emb = encode_image("sunset.jpg")
text_embs = encode_text(["a sunset over the ocean", "a crowded street", "birds chirping"])
sim = (image_emb @ text_embs.T)[0]
labels = ["a sunset over the ocean", "a crowded street", "birds chirping"]
print("\nImage ↔ Text similarity:")
for label, score in zip(labels, sim.tolist()):
print(f" '{label}': {score:.4f}")
# --- Audio ↔ Image (cross-modal) ---
sim_cross = (audio_emb @ image_emb.T)[0]
print(f"\n Audio ↔ Image similarity: {sim_cross.item():.4f}")
```
---
## Limitations
- **Early-stage model:** Sage Embed v1.1 is the first public release. The projection layer was trained on Clotho only; performance on out-of-distribution audio domains may vary.
- **CLIP is frozen:** The text and image encoders are not jointly fine-tuned, so alignment quality is bounded by how well the audio projection adapts to CLIP's fixed representation space.
- **Training scale:** Trained on a free-tier T4 GPU with a relatively small dataset. Larger-scale training is expected to improve retrieval quality significantly in future versions.
- **Audio length:** Optimized for clips up to ~10 seconds. Longer clips are truncated from the start.
---
## Acknowledgements
- [OpenCLIP](https://github.com/mlfoundations/open_clip) β€” LAION CLIP ViT-B/32 backbone
- [Audio Spectrogram Transformer](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) β€” MIT AST backbone
- [Clotho Dataset](https://github.com/audio-captioning/clotho-dataset) β€” Training data