🌿 Sage Embed β€” v1.1

Sage Embed is a cross-modal embedding model that encodes text, images, and audio into a unified shared latent space. By aligning all three modalities, it enables semantic similarity and retrieval across any combination β€” query audio with text, compare images to sounds, or any mix in between.

🚧 This project is in early development. Quantitative evaluation metrics will be added in a future release. See the Evaluation section for details.


Model Summary

Property Details
Version v1.1
Modalities Text Β· Image Β· Audio
Shared Embedding Dim 512
Text & Image Backbone laion/CLIP-ViT-B-32-laion2B-s34B-b79K
Audio Backbone MIT/ast-finetuned-audioset-10-10-0.4593
Training Dataset Clotho v2
Training Hardware Google Colab β€” NVIDIA T4 (free tier)
Epochs 20
License MIT

Architecture

Text & Image: OpenCLIP ViT-B/32

  • Model: hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K
  • Weights are fully frozen during training β€” used purely as a fixed embedding target.
  • Produces 512-dimensional L2-normalized embeddings for both text and images.

Audio: Audio Spectrogram Transformer (AST)

  • Model: MIT/ast-finetuned-audioset-10-10-0.4593
  • AST hidden size: 768
  • The model is mostly frozen; the last 2 encoder layers and the final LayerNorm are unfrozen for fine-tuning, allowing the audio encoder to adapt to the audio-caption alignment task without catastrophic forgetting.
  • The [CLS] token (last_hidden_state[:, 0]) is extracted and fed to the projection layer.

Audio Projection Layer

The projection layer maps AST's 768-dimensional output into CLIP's 512-dimensional shared space


Evaluation

Metrics Coming Soon

Quantitative benchmarks will be added in a future update as the project matures. This release focuses on establishing the cross-modal architecture


Usage

Installation

pip install torch torchaudio transformers open_clip_torch huggingface_hub

Step 1 β€” Define the Model & Load Weights

Copy this class definition exactly. The weights file is pulled directly from the Hub.

import os
import torch
import torch.nn as nn
import torchaudio
import torchaudio.transforms as T
from transformers import ASTModel, AutoProcessor
from huggingface_hub import hf_hub_download
import open_clip
from PIL import Image
import numpy as np

MODEL_NAME = "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class OpenCLIP_AST_Model(nn.Module):
    def __init__(self, embedding_dim=512):
        super().__init__()
        self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
        self.clip_model, _, self.image_preprocess = open_clip.create_model_and_transforms(
            MODEL_NAME, pretrained=None
        )
        self.audio_projection = nn.Sequential(
            nn.Linear(self.ast.config.hidden_size, 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(1024, embedding_dim)
        )
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward_audio(self, input_values):
        outputs = self.ast(input_values)
        # Average of first two token states for richer audio representation
        audio_features = (outputs.last_hidden_state[:, 0] + outputs.last_hidden_state[:, 1]) / 2
        return self.audio_projection(audio_features)

# --- Load weights from HuggingFace Hub ---
weights_path = hf_hub_download(
    repo_id="harryfrz/sage-embed",
    filename="sage_embed_v1.1.pt"
)

model = OpenCLIP_AST_Model().to(DEVICE)
model.load_state_dict(torch.load(weights_path, map_location=DEVICE))
model.eval()

# Processors
ast_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593")
tokenizer = open_clip.get_tokenizer(MODEL_NAME)

print("Sage Embed v1.1 loaded.")

Step 2 β€” Encode Your Inputs

Encode a single audio file (.wav, .mp3, .flac):

def encode_audio(filepath):
    waveform, sr = torchaudio.load(filepath)

    # Convert to mono
    if waveform.shape[0] > 1:
        waveform = torch.mean(waveform, dim=0)
    else:
        waveform = waveform.squeeze(0)

    # Resample to 16kHz
    if sr != 16000:
        waveform = T.Resample(sr, 16000)(waveform)

    # Pad or crop to 10 seconds (163840 samples @ 16kHz)
    target = 163840
    if waveform.shape[0] < target:
        waveform = torch.nn.functional.pad(waveform, (0, target - waveform.shape[0]))
    else:
        waveform = waveform[:target]

    inputs = ast_processor(waveform, sampling_rate=16000, return_tensors="pt")
    input_tensor = inputs["input_values"].to(DEVICE)

    with torch.no_grad():
        emb = model.forward_audio(input_tensor)
        emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb  # shape: [1, 512]


audio_emb = encode_audio("your_audio.wav")

Encode a single image file (.jpg, .png, .webp):

def encode_image(filepath):
    image = Image.open(filepath).convert("RGB")
    image_input = model.image_preprocess(image).unsqueeze(0).to(DEVICE)

    with torch.no_grad():
        emb = model.clip_model.encode_image(image_input)
        emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb  # shape: [1, 512]


image_emb = encode_image("your_image.jpg")

Encode text:

def encode_text(texts):
    # texts: a string or list of strings
    if isinstance(texts, str):
        texts = [texts]
    tokens = tokenizer(texts).to(DEVICE)

    with torch.no_grad():
        emb = model.clip_model.encode_text(tokens)
        emb = emb / emb.norm(dim=-1, keepdim=True)
    return emb  # shape: [N, 512]


text_emb = encode_text("a dog barking in the park")

Step 3 β€” Cross-Modal Similarity Test

All embeddings are L2-normalized, so dot product equals cosine similarity directly.

# --- Audio ↔ Text ---
audio_emb = encode_audio("dog_bark.wav")
text_embs = encode_text(["a dog barking", "rain on a rooftop", "busy city traffic"])

sim = (audio_emb @ text_embs.T)[0]  # shape: [3]
labels = ["a dog barking", "rain on a rooftop", "busy city traffic"]

print("Audio ↔ Text similarity:")
for label, score in zip(labels, sim.tolist()):
    print(f"  '{label}': {score:.4f}")
# Expected: highest score for "a dog barking"


# --- Image ↔ Text ---
image_emb = encode_image("sunset.jpg")
text_embs = encode_text(["a sunset over the ocean", "a crowded street", "birds chirping"])

sim = (image_emb @ text_embs.T)[0]
labels = ["a sunset over the ocean", "a crowded street", "birds chirping"]

print("\nImage ↔ Text similarity:")
for label, score in zip(labels, sim.tolist()):
    print(f"  '{label}': {score:.4f}")


# --- Audio ↔ Image (cross-modal) ---
sim_cross = (audio_emb @ image_emb.T)[0]
print(f"\n Audio ↔ Image similarity: {sim_cross.item():.4f}")

Limitations

  • Early-stage model: Sage Embed v1.1 is the first public release. The projection layer was trained on Clotho only; performance on out-of-distribution audio domains may vary.
  • CLIP is frozen: The text and image encoders are not jointly fine-tuned, so alignment quality is bounded by how well the audio projection adapts to CLIP's fixed representation space.
  • Training scale: Trained on a free-tier T4 GPU with a relatively small dataset. Larger-scale training is expected to improve retrieval quality significantly in future versions.
  • Audio length: Optimized for clips up to ~10 seconds. Longer clips are truncated from the start.

Acknowledgements

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for harryfrz/sage-embed

Finetuned
(162)
this model