A compact multimodal embedding model that unifies text, image, and audio representations in a shared semantic space. Part of the MoM (Mixture of Models) family.
Model Description
multi-modal-embed-small is a lightweight multimodal encoder (~120M parameters) supporting:
Text encoding via MiniLM-L6-v2 (22M params)
Image encoding via SigLIP-base-patch16-512 (86M params)
Audio encoding via Whisper-tiny encoder (8M params)
Cross-modal fusion via 2-layer transformer attention
2DMSE: Two-Dimensional Matryoshka Sentence Embeddings for adaptive compute
MRL: Matryoshka Representation Learning for flexible embedding dimensions
Key Features
Feature
Description
Embedding Dimension
384 (supports MRL truncation to 32, 64, 128, 256)
Image Resolution
512Γ512
Audio Input
Up to 30s, 16kHz (Whisper Mel spectrogram)
Modalities
Text, Image, Audio, Multimodal fusion
2DMSE Support
Early exit at any encoder layer
Languages
English
Installation
pip install torch transformers pillow safetensors
Usage
Load Model
Two checkpoint formats are available:
model.pt (932 MB) - PyTorch format
model.safetensors (1.35 GB) - SafeTensors format
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer, SiglipModel, SiglipProcessor, WhisperModel, WhisperFeatureExtractor
from huggingface_hub import hf_hub_download
classMultiModalEmbedder(nn.Module):
"""Standalone multimodal embedder - no external dependencies."""def__init__(self):
super().__init__()
# Text encoder (384d, no projection needed)
self.text_tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
self.text_encoder = AutoModel.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")
# Image encoder (768d -> 384d projection)
self.image_processor = SiglipProcessor.from_pretrained("google/siglip-base-patch16-512")
self.image_encoder = SiglipModel.from_pretrained("google/siglip-base-patch16-512").vision_model
self.image_proj = nn.Linear(768, 384)
# Audio encoder (384d, no projection needed)
self.audio_processor = WhisperFeatureExtractor.from_pretrained("openai/whisper-tiny")
self.audio_encoder = WhisperModel.from_pretrained("openai/whisper-tiny").encoder
defencode_text(self, texts):
ifisinstance(texts, str):
texts = [texts]
inputs = self.text_tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
outputs = self.text_encoder(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1) # Mean poolingreturn F.normalize(embeddings, p=2, dim=-1)
defencode_image(self, images):
inputs = self.image_processor(images=images, return_tensors="pt")
inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
outputs = self.image_encoder(**inputs)
embeddings = outputs.pooler_output
embeddings = self.image_proj(embeddings) # 768 -> 384return F.normalize(embeddings, p=2, dim=-1)
defencode_audio(self, waveform):
# waveform: numpy array or tensor at 16kHzifisinstance(waveform, torch.Tensor):
waveform = waveform.squeeze().numpy()
inputs = self.audio_processor(waveform, sampling_rate=16000, return_tensors="pt")
inputs = {k: v.to(next(self.parameters()).device) for k, v in inputs.items()}
outputs = self.audio_encoder(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1) # Mean poolingreturn F.normalize(embeddings, p=2, dim=-1)
# Load model
model = MultiModalEmbedder()
# Download and load trained weights
checkpoint_path = hf_hub_download(
repo_id="llm-semantic-router/multi-modal-embed-small",
filename="model.pt"
)
state_dict = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
# Load text encoder weights
model.text_encoder.load_state_dict({
k.replace("text_encoder.encoder.", ""): v
for k, v in state_dict.items()
if k.startswith("text_encoder.encoder.")
})
# Load image encoder and projection weights
model.image_encoder.load_state_dict({
k.replace("image_encoder.vision_encoder.", ""): v
for k, v in state_dict.items()
if k.startswith("image_encoder.vision_encoder.")
})
model.image_proj.load_state_dict({
k.replace("image_encoder.projection.", ""): v
for k, v in state_dict.items()
if k.startswith("image_encoder.projection.")
})
# Load audio encoder weights
model.audio_encoder.load_state_dict({
k.replace("audio_encoder.encoder.", ""): v
for k, v in state_dict.items()
if k.startswith("audio_encoder.encoder.")
})
model.eval()
print("Model loaded successfully!")
Text Embedding
import torch.nn.functional as F
# Single text
text_embedding = model.encode_text("A photo of a cat") # Shape: [1, 384]# Batch of texts
texts = ["A fluffy orange cat", "A golden retriever dog", "A red sports car"]
text_embeddings = model.encode_text(texts) # Shape: [3, 384]# Compute similarity
similarities = F.cosine_similarity(text_embeddings[0:1], text_embeddings[1:], dim=-1)
print(f"Cat vs Dog: {similarities[0]:.3f}")
print(f"Cat vs Car: {similarities[1]:.3f}")
Image Embedding
from PIL import Image
import requests
from io import BytesIO
# Load image
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/3/3a/Cat03.jpg/1200px-Cat03.jpg"
image = Image.open(BytesIO(requests.get(url).content)).convert('RGB')
# Get embedding
image_embedding = model.encode_image(image) # Shape: [1, 384]
# Image-to-text retrieval
image = Image.open("cat.jpg").convert('RGB')
image_emb = model.encode_image(image)
captions = [
"A cat sleeping on a bed",
"A dog playing in the park",
"A car driving on the highway",
]
text_embs = model.encode_text(captions)
similarities = F.cosine_similarity(image_emb, text_embs)
best_idx = similarities.argmax().item()
print(f"Best match: {captions[best_idx]} ({similarities[best_idx]:.3f})")