colqwen3-8b-vetcoders-mlx / scripts /colqwen3_embedder.py
div0-space's picture
Upload folder using huggingface_hub
c6c3a3b verified
"""
ColQwen3 MLX Embedder
Production-ready multimodal document embedding using Tomoro-ColQwen3 on MLX.
Provides ColPali-style multi-vector embeddings for visual document retrieval.
Key insight: For proper image embeddings, <|image_pad|> tokens must be expanded
to match the number of vision patches, and only image token embeddings should
be used for MaxSim scoring.
Created by M&K (c)2025 The LibraxisAI Team
Co-Authored-By: Maciej (void@div0.space) & Klaudiusz (the1st@whoai.am)
"""
import os
from dataclasses import dataclass
from pathlib import Path
from typing import List, Optional, Tuple, Union
import mlx.core as mx
import numpy as np
from PIL import Image
# Special token ID for image patches
IMAGE_PAD_TOKEN = 151655
@dataclass
class EmbeddingResult:
"""Result of embedding operation."""
embeddings: mx.array # [num_tokens, 320]
num_tokens: int
source_type: str # "text" or "image"
class ColQwen3Embedder:
"""
ColQwen3 document embedder using MLX.
Provides multi-vector embeddings optimized for document retrieval
using Late Interaction (MaxSim) scoring.
Environment Variables:
COLQWEN3_MODEL_PATH: Path to Tomoro-ColQwen3 MLX model directory.
Default: /Volumes/Maciejowe/mlx_lm/models/tomoro-colqwen3-8b-mlx
COLQWEN3_PROJECTION_PATH: Path to embedding projection weights (.safetensors).
Default: /Volumes/Maciejowe/mlx_lm/models/colqwen3_projection.safetensors
Usage:
# Option 1: Set environment variables
export COLQWEN3_MODEL_PATH="/path/to/tomoro-colqwen3-8b-mlx"
export COLQWEN3_PROJECTION_PATH="/path/to/colqwen3_projection.safetensors"
embedder = ColQwen3Embedder()
embedder.load()
# Option 2: Pass paths directly (overrides env vars)
embedder = ColQwen3Embedder(
model_path="/path/to/model",
projection_path="/path/to/projection.safetensors"
)
embedder.load()
# Embed a document image
doc_emb = embedder.embed_image("document.png")
# Embed a text query
query_emb = embedder.embed_text("search query")
# Score relevance
score = embedder.maxsim_score(query_emb, doc_emb)
Created by M&K (c)2025 The LibraxisAI Team
"""
# Environment variable names for configuration
ENV_MODEL_PATH = "COLQWEN3_MODEL_PATH"
ENV_PROJECTION_PATH = "COLQWEN3_PROJECTION_PATH"
# Default paths (backward compatibility with existing setup)
DEFAULT_MODEL_PATH = "/Volumes/Maciejowe/mlx_lm/models/tomoro-colqwen3-8b-mlx"
DEFAULT_PROJ_PATH = "/Volumes/Maciejowe/mlx_lm/models/colqwen3_projection.safetensors"
def __init__(
self,
model_path: Optional[str] = None,
projection_path: Optional[str] = None,
embedding_dim: int = 320,
):
"""
Initialize the embedder.
Args:
model_path: Path to Tomoro-ColQwen3 MLX model (overrides env var)
projection_path: Path to embedding projection weights (overrides env var)
embedding_dim: Output embedding dimension (default 320)
Path resolution order:
1. Explicitly passed argument
2. Environment variable (COLQWEN3_MODEL_PATH / COLQWEN3_PROJECTION_PATH)
3. Default fallback path
"""
self.model_path = model_path or os.environ.get(self.ENV_MODEL_PATH) or self.DEFAULT_MODEL_PATH
self.projection_path = projection_path or os.environ.get(self.ENV_PROJECTION_PATH) or self.DEFAULT_PROJ_PATH
self.embedding_dim = embedding_dim
self.model = None
self.mlx_processor = None
self.tomoro_processor = None
self.proj_weight = None
self.proj_bias = None
self._loaded = False
def load(self) -> None:
"""Load model, processor, and projection weights."""
if self._loaded:
return
from mlx_vlm import load
from safetensors.torch import load_file
from transformers import AutoProcessor
print(f"Loading ColQwen3 from {self.model_path}...")
self.model, self.mlx_processor = load(self.model_path)
# Load Tomoro processor for proper image token expansion
print("Loading Tomoro processor for image token expansion...")
self.tomoro_processor = AutoProcessor.from_pretrained(
"TomoroAI/tomoro-colqwen3-embed-8b", trust_remote_code=True
)
print(f"Loading projection from {self.projection_path}...")
proj_weights = load_file(self.projection_path)
self.proj_weight = mx.array(proj_weights["embedding_proj_layer.weight"].float().numpy())
self.proj_bias = mx.array(proj_weights["embedding_proj_layer.bias"].float().numpy())
self._loaded = True
print("ColQwen3 Embedder ready!")
def _ensure_loaded(self) -> None:
"""Ensure model is loaded."""
if not self._loaded:
self.load()
def _project_and_normalize(self, hidden_states: mx.array) -> mx.array:
"""Apply projection layer and L2 normalize."""
# Project to embedding dimension
embeddings = hidden_states @ self.proj_weight.T + self.proj_bias
# L2 normalize
norm = mx.sqrt(mx.sum(embeddings**2, axis=-1, keepdims=True) + 1e-12)
embeddings = embeddings / norm
return embeddings
def embed_text(self, text: str) -> EmbeddingResult:
"""
Embed text query.
Args:
text: Query string
Returns:
EmbeddingResult with shape [num_tokens, 320]
"""
self._ensure_loaded()
# Get inner language model (skips lm_head)
inner_model = self.model["language_model"]["model"]
# Tokenize using Tomoro processor for consistency
inputs = self.tomoro_processor.tokenizer(text, return_tensors="np")
input_ids = mx.array(inputs["input_ids"])
batch_size, seq_len = input_ids.shape
# Create position IDs for M-ROPE
position_ids = mx.arange(seq_len).reshape(1, -1)
position_ids = mx.broadcast_to(position_ids, (batch_size, seq_len))
position_ids = mx.broadcast_to(position_ids[None, ...], (3, batch_size, seq_len))
# Get hidden states
hidden_states = inner_model(input_ids, position_ids=position_ids)
# Project and normalize
embeddings = self._project_and_normalize(hidden_states)
embeddings = embeddings.squeeze(0) # Remove batch dim
mx.eval(embeddings)
return EmbeddingResult(
embeddings=embeddings,
num_tokens=seq_len,
source_type="text",
)
def embed_image(
self,
image: Union[str, Path, Image.Image],
) -> EmbeddingResult:
"""
Embed document image with proper token expansion.
Uses Tomoro's ColQwen3Processor to correctly expand <|image_pad|>
tokens to match the number of vision patches. Only the image token
embeddings are returned for MaxSim scoring.
Args:
image: Image path or PIL Image object
Returns:
EmbeddingResult with shape [num_patches, 320]
"""
self._ensure_loaded()
# Load image if path
if isinstance(image, (str, Path)):
image = Image.open(image).convert("RGB")
# Process with Tomoro processor (properly expands <|image_pad|>)
inputs = self.tomoro_processor(
text="", # No text prompt - only image
images=[image],
return_tensors="pt",
)
input_ids = inputs["input_ids"]
pixel_values = inputs["pixel_values"]
image_grid_thw = inputs["image_grid_thw"]
# Create mask for image tokens
image_mask = (input_ids == IMAGE_PAD_TOKEN).numpy()[0]
image_positions = np.where(image_mask)[0].tolist()
# Get vision embeddings from vision tower
pixel_values_mx = mx.array(pixel_values.numpy())
image_grid_thw_mx = mx.array(image_grid_thw.numpy())
hidden_states_vision, _ = self.model["vision_tower"](pixel_values_mx, image_grid_thw_mx)
# Get text embeddings and inject vision embeddings at image positions
input_ids_mx = mx.array(input_ids.numpy())
embed_tokens = self.model["language_model"]["model"]["embed_tokens"]
text_emb_np = np.array(embed_tokens(input_ids_mx)[0])
vision_np = np.array(hidden_states_vision)
for i, pos in enumerate(image_positions):
if i < vision_np.shape[0]:
text_emb_np[pos] = vision_np[i]
batch_size, seq_len = input_ids_mx.shape
combined_embeddings = mx.array(text_emb_np).reshape(1, seq_len, -1)
# Create position IDs for M-ROPE
position_ids = mx.arange(seq_len).reshape(1, -1)
position_ids = mx.broadcast_to(position_ids, (batch_size, seq_len))
position_ids = mx.broadcast_to(position_ids[None, ...], (3, batch_size, seq_len))
# Forward through language model layers
inner_model = self.model["language_model"]["model"]
h = combined_embeddings
for layer in inner_model["layers"]:
h = layer(h, position_ids=position_ids)
h = inner_model["norm"](h)
# Extract ONLY image token embeddings for MaxSim
h_np = np.array(h[0])
image_hidden_states = mx.array(h_np[image_mask])
# Project and normalize
embeddings = self._project_and_normalize(image_hidden_states)
mx.eval(embeddings)
return EmbeddingResult(
embeddings=embeddings,
num_tokens=embeddings.shape[0],
source_type="image",
)
def embed_pdf_page(
self,
pdf_path: Union[str, Path],
page_num: int = 0,
dpi: int = 150,
) -> EmbeddingResult:
"""
Embed a page from a PDF document.
Args:
pdf_path: Path to PDF file
page_num: Page number (0-indexed)
dpi: Resolution for rendering
Returns:
EmbeddingResult with shape [num_patches, 320]
"""
try:
import fitz # PyMuPDF
except ImportError:
raise ImportError("PyMuPDF required: pip install pymupdf")
doc = fitz.open(pdf_path)
page = doc.load_page(page_num)
pix = page.get_pixmap(dpi=dpi)
image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
doc.close()
return self.embed_image(image)
def embed_pdf(
self,
pdf_path: Union[str, Path],
dpi: int = 150,
max_pages: Optional[int] = None,
) -> List[EmbeddingResult]:
"""
Embed all pages from a PDF document.
Args:
pdf_path: Path to PDF file
dpi: Resolution for rendering
max_pages: Maximum pages to process (None for all)
Returns:
List of EmbeddingResult, one per page
"""
try:
import fitz
except ImportError:
raise ImportError("PyMuPDF required: pip install pymupdf")
doc = fitz.open(pdf_path)
num_pages = min(len(doc), max_pages) if max_pages else len(doc)
results = []
for i in range(num_pages):
result = self.embed_pdf_page(pdf_path, page_num=i, dpi=dpi)
results.append(result)
doc.close()
return results
@staticmethod
def maxsim_score(
query_emb: Union[mx.array, EmbeddingResult],
doc_emb: Union[mx.array, EmbeddingResult],
) -> float:
"""
Compute MaxSim score between query and document embeddings.
MaxSim (Late Interaction): For each query token, find maximum
similarity across all document tokens, then sum.
Args:
query_emb: Query embeddings [q_len, dim]
doc_emb: Document embeddings [d_len, dim]
Returns:
Similarity score (higher = more relevant)
"""
if isinstance(query_emb, EmbeddingResult):
query_emb = query_emb.embeddings
if isinstance(doc_emb, EmbeddingResult):
doc_emb = doc_emb.embeddings
# Compute all pairwise similarities: [q_len, d_len]
similarities = query_emb @ doc_emb.T
# For each query token, take max over document tokens
max_sims = mx.max(similarities, axis=1)
# Sum across query tokens
score = mx.sum(max_sims)
mx.eval(score)
return float(score)
@staticmethod
def cosine_similarity(
emb1: Union[mx.array, EmbeddingResult],
emb2: Union[mx.array, EmbeddingResult],
) -> float:
"""
Compute mean-pooled cosine similarity.
Args:
emb1: First embeddings [n, dim]
emb2: Second embeddings [m, dim]
Returns:
Cosine similarity in [-1, 1]
"""
if isinstance(emb1, EmbeddingResult):
emb1 = emb1.embeddings
if isinstance(emb2, EmbeddingResult):
emb2 = emb2.embeddings
# Mean pool
v1 = mx.mean(emb1, axis=0)
v2 = mx.mean(emb2, axis=0)
# Cosine similarity
sim = mx.sum(v1 * v2) / (mx.sqrt(mx.sum(v1**2)) * mx.sqrt(mx.sum(v2**2)))
mx.eval(sim)
return float(sim)
def rank_documents(
self,
query: str,
documents: List[EmbeddingResult],
top_k: Optional[int] = None,
) -> List[Tuple[int, float]]:
"""
Rank documents by relevance to query.
Args:
query: Query string
documents: List of document embeddings
top_k: Return top K results (None for all)
Returns:
List of (doc_index, score) sorted by descending score
"""
query_emb = self.embed_text(query)
scores = []
for i, doc_emb in enumerate(documents):
score = self.maxsim_score(query_emb, doc_emb)
scores.append((i, score))
# Sort by score descending
scores.sort(key=lambda x: x[1], reverse=True)
if top_k:
scores = scores[:top_k]
return scores
def to_numpy(self, emb: Union[mx.array, EmbeddingResult]) -> np.ndarray:
"""Convert embeddings to numpy array (for storage/indexing)."""
if isinstance(emb, EmbeddingResult):
emb = emb.embeddings
return np.array(emb)
# Convenience functions
def load_embedder(
model_path: Optional[str] = None,
projection_path: Optional[str] = None,
) -> ColQwen3Embedder:
"""Load and return a ready-to-use embedder."""
embedder = ColQwen3Embedder(
model_path=model_path,
projection_path=projection_path,
)
embedder.load()
return embedder
if __name__ == "__main__":
# Quick test
print("Testing ColQwen3 Embedder...")
embedder = load_embedder()
# Test text embedding
text = "dawkowanie meloksykamu dla psa"
result = embedder.embed_text(text)
print(f"\nText: '{text}'")
print(f" Tokens: {result.num_tokens}")
print(f" Embedding shape: {result.embeddings.shape}")
# Test text similarity
text2 = "metacam dose for dogs"
result2 = embedder.embed_text(text2)
sim = embedder.cosine_similarity(result, result2)
print(f"\nSimilarity to '{text2}': {sim:.4f}")
print("\nColQwen3 Embedder test complete!")