πΏ Sage Embed β v1.1
Sage Embed is a cross-modal embedding model that encodes text, images, and audio into a unified shared latent space. By aligning all three modalities, it enables semantic similarity and retrieval across any combination β query audio with text, compare images to sounds, or any mix in between.
π§ This project is in early development. Quantitative evaluation metrics will be added in a future release. See the Evaluation section for details.
Model Summary
| Property | Details |
|---|---|
| Version | v1.1 |
| Modalities | Text Β· Image Β· Audio |
| Shared Embedding Dim | 512 |
| Text & Image Backbone | laion/CLIP-ViT-B-32-laion2B-s34B-b79K |
| Audio Backbone | MIT/ast-finetuned-audioset-10-10-0.4593 |
| Training Dataset | Clotho v2 |
| Training Hardware | Google Colab β NVIDIA T4 (free tier) |
| Epochs | 20 |
| License | MIT |
Architecture
Text & Image: OpenCLIP ViT-B/32
- Model:
hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K - Weights are fully frozen during training β used purely as a fixed embedding target.
- Produces 512-dimensional L2-normalized embeddings for both text and images.
Audio: Audio Spectrogram Transformer (AST)
- Model:
MIT/ast-finetuned-audioset-10-10-0.4593 - AST hidden size: 768
- The model is mostly frozen; the last 2 encoder layers and the final LayerNorm are unfrozen for fine-tuning, allowing the audio encoder to adapt to the audio-caption alignment task without catastrophic forgetting.
- The
[CLS]token (last_hidden_state[:, 0]) is extracted and fed to the projection layer.
Audio Projection Layer
The projection layer maps AST's 768-dimensional output into CLIP's 512-dimensional shared space
Evaluation
Metrics Coming Soon
Quantitative benchmarks will be added in a future update as the project matures. This release focuses on establishing the cross-modal architecture
Usage
Installation
pip install torch torchaudio transformers open_clip_torch huggingface_hub
Step 1 β Define the Model & Load Weights
Copy this class definition exactly. The weights file is pulled directly from the Hub.
import os
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from transformers import ASTModel, AutoProcessor
from huggingface_hub import hf_hub_download
import open_clip
from PIL import Image
import numpy as np
MODEL_NAME = "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
class OpenCLIP_AST_Model(nn.Module):
def __init__(self, embedding_dim=512):
super().__init__()
self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
self.clip_model, _, self.image_preprocess = open_clip.create_model_and_transforms(
MODEL_NAME, pretrained=None
)
self.audio_projection = nn.Sequential(
nn.Linear(self.ast.config.hidden_size, 1024),
nn.LayerNorm(1024),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(1024, embedding_dim)
)
self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))
def forward_audio(self, input_values):
outputs = self.ast(input_values)
# Average of first two token states for richer audio representation
audio_features = (outputs.last_hidden_state[:, 0] + outputs.last_hidden_state[:, 1]) / 2
return self.audio_projection(audio_features)
# --- Load weights from HuggingFace Hub ---
weights_path = hf_hub_download(
repo_id="harryfrz/sage-embed",
filename="sage_embed_v1.1.pt"
)
model = OpenCLIP_AST_Model().to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()
# Processors
ast_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
tokenizer = open_clip.get_tokenizer(MODEL_NAME)
print("Sage Embed v1.1 loaded.")
Step 2 β Encode Your Inputs
Encode a single audio file (.wav, .mp3, .flac):
def encode_audio(filepath):
waveform, sr = torchaudio.load(filepath)
# Convert to mono
if waveform.shape[0] > 1:
waveform = torch.mean(waveform, dim=0)
else:
waveform = waveform.squeeze(0)
# Resample to 16kHz
if sr != 16000:
waveform = T.Resample(sr, 16000)(waveform)
# Pad or crop to 10 seconds (163840 samples @ 16kHz)
target = 163840
if waveform.shape[0] < target:
waveform = torch.nn.functional.pad(waveform, (0, target - waveform.shape[0]))
else:
waveform = waveform[:target]
inputs = ast_processor(waveform, sampling_rate=16000, return_tensors="pt")
input_tensor = inputs["input_values"].to(DEVICE)
with torch.no_grad():
emb = model.forward_audio(input_tensor)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [1, 512]
audio_emb = encode_audio("your_audio.wav")
Encode a single image file (.jpg, .png, .webp):
def encode_image(filepath):
image = Image.open(filepath).convert("RGB")
image_input = model.image_preprocess(image).unsqueeze(0).to(DEVICE)
with torch.no_grad():
emb = model.clip_model.encode_image(image_input)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [1, 512]
image_emb = encode_image("your_image.jpg")
Encode text:
def encode_text(texts):
# texts: a string or list of strings
if isinstance(texts, str):
texts = [texts]
tokens = tokenizer(texts).to(DEVICE)
with torch.no_grad():
emb = model.clip_model.encode_text(tokens)
emb = emb / emb.norm(dim=-1, keepdim=True)
return emb # shape: [N, 512]
text_emb = encode_text("a dog barking in the park")
Step 3 β Cross-Modal Similarity Test
All embeddings are L2-normalized, so dot product equals cosine similarity directly.
# --- Audio β Text ---
audio_emb = encode_audio("dog_bark.wav")
text_embs = encode_text(["a dog barking", "rain on a rooftop", "busy city traffic"])
sim = (audio_emb @ text_embs.T)[0] # shape: [3]
labels = ["a dog barking", "rain on a rooftop", "busy city traffic"]
print("Audio β Text similarity:")
for label, score in zip(labels, sim.tolist()):
print(f" '{label}': {score:.4f}")
# Expected: highest score for "a dog barking"
# --- Image β Text ---
image_emb = encode_image("sunset.jpg")
text_embs = encode_text(["a sunset over the ocean", "a crowded street", "birds chirping"])
sim = (image_emb @ text_embs.T)[0]
labels = ["a sunset over the ocean", "a crowded street", "birds chirping"]
print("\nImage β Text similarity:")
for label, score in zip(labels, sim.tolist()):
print(f" '{label}': {score:.4f}")
# --- Audio β Image (cross-modal) ---
sim_cross = (audio_emb @ image_emb.T)[0]
print(f"\n Audio β Image similarity: {sim_cross.item():.4f}")
Limitations
- Early-stage model: Sage Embed v1.1 is the first public release. The projection layer was trained on Clotho only; performance on out-of-distribution audio domains may vary.
- CLIP is frozen: The text and image encoders are not jointly fine-tuned, so alignment quality is bounded by how well the audio projection adapts to CLIP's fixed representation space.
- Training scale: Trained on a free-tier T4 GPU with a relatively small dataset. Larger-scale training is expected to improve retrieval quality significantly in future versions.
- Audio length: Optimized for clips up to ~10 seconds. Longer clips are truncated from the start.
Acknowledgements
- OpenCLIP β LAION CLIP ViT-B/32 backbone
- Audio Spectrogram Transformer β MIT AST backbone
- Clotho Dataset β Training data
Model tree for harryfrz/sage-embed
Base model
MIT/ast-finetuned-audioset-10-10-0.4593