--- license: mit language: - en base_model: - laion/CLIP-ViT-B-32-laion2B-s34B-b79K - MIT/ast-finetuned-audioset-10-10-0.4593 pipeline_tag: feature-extraction tags: - multimodal - cross-modal - embeddings - audio - vision - text - contrastive-learning - openclip - ast --- # ๐ŸŒฟ Sage Embed โ€” v1.1 **Sage Embed** is a cross-modal embedding model that encodes **text**, **images**, and **audio** into a unified shared latent space. By aligning all three modalities, it enables semantic similarity and retrieval across any combination โ€” query audio with text, compare images to sounds, or any mix in between. > ๐Ÿšง **This project is in early development.** Quantitative evaluation metrics will be added in a future release. See the [Evaluation](#evaluation) section for details. --- ## Model Summary | Property | Details | |---|---| | **Version** | v1.1 | | **Modalities** | Text ยท Image ยท Audio | | **Shared Embedding Dim** | 512 | | **Text & Image Backbone** | [`laion/CLIP-ViT-B-32-laion2B-s34B-b79K`](https://huggingface.co/laion/CLIP-ViT-B-32-laion2B-s34B-b79K) | | **Audio Backbone** | [`MIT/ast-finetuned-audioset-10-10-0.4593`](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) | | **Training Dataset** | [Clotho v2](https://github.com/audio-captioning/clotho-dataset) | | **Training Hardware** | Google Colab โ€” NVIDIA T4 (free tier) | | **Epochs** | 20 | | **License** | MIT | --- ## Architecture ### Text & Image: OpenCLIP ViT-B/32 - **Model:** `hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K` - Weights are **fully frozen** during training โ€” used purely as a fixed embedding target. - Produces **512-dimensional** L2-normalized embeddings for both text and images. ### Audio: Audio Spectrogram Transformer (AST) - **Model:** `MIT/ast-finetuned-audioset-10-10-0.4593` - AST hidden size: **768** - The model is **mostly frozen**; the **last 2 encoder layers** and the **final LayerNorm** are unfrozen for fine-tuning, allowing the audio encoder to adapt to the audio-caption alignment task without catastrophic forgetting. - The `[CLS]` token (`last_hidden_state[:, 0]`) is extracted and fed to the projection layer. ### Audio Projection Layer The projection layer maps AST's 768-dimensional output into CLIP's 512-dimensional shared space --- ## Evaluation > **Metrics Coming Soon** > > Quantitative benchmarks will be added in a future update as the project matures. This release focuses on establishing the cross-modal architecture --- ## Usage ### Installation ```bash pip install torch torchaudio transformers open_clip_torch huggingface_hub ``` ### Step 1 โ€” Define the Model & Load Weights Copy this class definition exactly. The weights file is pulled directly from the Hub. ```python import os import torch import torch.nn as nn import torchaudio import torchaudio.transforms as T from transformers import ASTModel, AutoProcessor from huggingface_hub import hf_hub_download import open_clip from PIL import Image import numpy as np MODEL_NAME = "hf-hub:laion/CLIP-ViT-B-32-laion2B-s34B-b79K" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" class OpenCLIP_AST_Model(nn.Module): def __init__(self, embedding_dim=512): super().__init__() self.ast = ASTModel.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") self.clip_model, _, self.image_preprocess = open_clip.create_model_and_transforms( MODEL_NAME, pretrained=None ) self.audio_projection = nn.Sequential( nn.Linear(self.ast.config.hidden_size, 1024), nn.LayerNorm(1024), nn.ReLU(), nn.Dropout(0.3), nn.Linear(1024, embedding_dim) ) self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) def forward_audio(self, input_values): outputs = self.ast(input_values) # Average of first two token states for richer audio representation audio_features = (outputs.last_hidden_state[:, 0] + outputs.last_hidden_state[:, 1]) / 2 return self.audio_projection(audio_features) # --- Load weights from HuggingFace Hub --- weights_path = hf_hub_download( repo_id="harryfrz/sage-embed", filename="sage_embed_v1.1.pt" ) model = OpenCLIP_AST_Model().to(DEVICE) model.load_state_dict(torch.load(weights_path, map_location=DEVICE)) model.eval() # Processors ast_processor = AutoProcessor.from_pretrained("MIT/ast-finetuned-audioset-10-10-0.4593") tokenizer = open_clip.get_tokenizer(MODEL_NAME) print("Sage Embed v1.1 loaded.") ``` --- ### Step 2 โ€” Encode Your Inputs **Encode a single audio file** (`.wav`, `.mp3`, `.flac`): ```python def encode_audio(filepath): waveform, sr = torchaudio.load(filepath) # Convert to mono if waveform.shape[0] > 1: waveform = torch.mean(waveform, dim=0) else: waveform = waveform.squeeze(0) # Resample to 16kHz if sr != 16000: waveform = T.Resample(sr, 16000)(waveform) # Pad or crop to 10 seconds (163840 samples @ 16kHz) target = 163840 if waveform.shape[0] < target: waveform = torch.nn.functional.pad(waveform, (0, target - waveform.shape[0])) else: waveform = waveform[:target] inputs = ast_processor(waveform, sampling_rate=16000, return_tensors="pt") input_tensor = inputs["input_values"].to(DEVICE) with torch.no_grad(): emb = model.forward_audio(input_tensor) emb = emb / emb.norm(dim=-1, keepdim=True) return emb # shape: [1, 512] audio_emb = encode_audio("your_audio.wav") ``` **Encode a single image file** (`.jpg`, `.png`, `.webp`): ```python def encode_image(filepath): image = Image.open(filepath).convert("RGB") image_input = model.image_preprocess(image).unsqueeze(0).to(DEVICE) with torch.no_grad(): emb = model.clip_model.encode_image(image_input) emb = emb / emb.norm(dim=-1, keepdim=True) return emb # shape: [1, 512] image_emb = encode_image("your_image.jpg") ``` **Encode text:** ```python def encode_text(texts): # texts: a string or list of strings if isinstance(texts, str): texts = [texts] tokens = tokenizer(texts).to(DEVICE) with torch.no_grad(): emb = model.clip_model.encode_text(tokens) emb = emb / emb.norm(dim=-1, keepdim=True) return emb # shape: [N, 512] text_emb = encode_text("a dog barking in the park") ``` --- ### Step 3 โ€” Cross-Modal Similarity Test All embeddings are **L2-normalized**, so dot product equals cosine similarity directly. ```python # --- Audio โ†” Text --- audio_emb = encode_audio("dog_bark.wav") text_embs = encode_text(["a dog barking", "rain on a rooftop", "busy city traffic"]) sim = (audio_emb @ text_embs.T)[0] # shape: [3] labels = ["a dog barking", "rain on a rooftop", "busy city traffic"] print("Audio โ†” Text similarity:") for label, score in zip(labels, sim.tolist()): print(f" '{label}': {score:.4f}") # Expected: highest score for "a dog barking" # --- Image โ†” Text --- image_emb = encode_image("sunset.jpg") text_embs = encode_text(["a sunset over the ocean", "a crowded street", "birds chirping"]) sim = (image_emb @ text_embs.T)[0] labels = ["a sunset over the ocean", "a crowded street", "birds chirping"] print("\nImage โ†” Text similarity:") for label, score in zip(labels, sim.tolist()): print(f" '{label}': {score:.4f}") # --- Audio โ†” Image (cross-modal) --- sim_cross = (audio_emb @ image_emb.T)[0] print(f"\n Audio โ†” Image similarity: {sim_cross.item():.4f}") ``` --- ## Limitations - **Early-stage model:** Sage Embed v1.1 is the first public release. The projection layer was trained on Clotho only; performance on out-of-distribution audio domains may vary. - **CLIP is frozen:** The text and image encoders are not jointly fine-tuned, so alignment quality is bounded by how well the audio projection adapts to CLIP's fixed representation space. - **Training scale:** Trained on a free-tier T4 GPU with a relatively small dataset. Larger-scale training is expected to improve retrieval quality significantly in future versions. - **Audio length:** Optimized for clips up to ~10 seconds. Longer clips are truncated from the start. --- ## Acknowledgements - [OpenCLIP](https://github.com/mlfoundations/open_clip) โ€” LAION CLIP ViT-B/32 backbone - [Audio Spectrogram Transformer](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593) โ€” MIT AST backbone - [Clotho Dataset](https://github.com/audio-captioning/clotho-dataset) โ€” Training data