| """ |
| ColQwen3 MLX Embedder |
| |
| Production-ready multimodal document embedding using Tomoro-ColQwen3 on MLX. |
| Provides ColPali-style multi-vector embeddings for visual document retrieval. |
| |
| Key insight: For proper image embeddings, <|image_pad|> tokens must be expanded |
| to match the number of vision patches, and only image token embeddings should |
| be used for MaxSim scoring. |
| |
| Created by M&K (c)2025 The LibraxisAI Team |
| Co-Authored-By: Maciej (void@div0.space) & Klaudiusz (the1st@whoai.am) |
| """ |
|
|
| import os |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import List, Optional, Tuple, Union |
|
|
| import mlx.core as mx |
| import numpy as np |
| from PIL import Image |
|
|
| |
| IMAGE_PAD_TOKEN = 151655 |
|
|
|
|
| @dataclass |
| class EmbeddingResult: |
| """Result of embedding operation.""" |
|
|
| embeddings: mx.array |
| num_tokens: int |
| source_type: str |
|
|
|
|
| class ColQwen3Embedder: |
| """ |
| ColQwen3 document embedder using MLX. |
| |
| Provides multi-vector embeddings optimized for document retrieval |
| using Late Interaction (MaxSim) scoring. |
| |
| Environment Variables: |
| COLQWEN3_MODEL_PATH: Path to Tomoro-ColQwen3 MLX model directory. |
| Default: /Volumes/Maciejowe/mlx_lm/models/tomoro-colqwen3-8b-mlx |
| COLQWEN3_PROJECTION_PATH: Path to embedding projection weights (.safetensors). |
| Default: /Volumes/Maciejowe/mlx_lm/models/colqwen3_projection.safetensors |
| |
| Usage: |
| # Option 1: Set environment variables |
| export COLQWEN3_MODEL_PATH="/path/to/tomoro-colqwen3-8b-mlx" |
| export COLQWEN3_PROJECTION_PATH="/path/to/colqwen3_projection.safetensors" |
| |
| embedder = ColQwen3Embedder() |
| embedder.load() |
| |
| # Option 2: Pass paths directly (overrides env vars) |
| embedder = ColQwen3Embedder( |
| model_path="/path/to/model", |
| projection_path="/path/to/projection.safetensors" |
| ) |
| embedder.load() |
| |
| # Embed a document image |
| doc_emb = embedder.embed_image("document.png") |
| |
| # Embed a text query |
| query_emb = embedder.embed_text("search query") |
| |
| # Score relevance |
| score = embedder.maxsim_score(query_emb, doc_emb) |
| |
| Created by M&K (c)2025 The LibraxisAI Team |
| """ |
|
|
| |
| ENV_MODEL_PATH = "COLQWEN3_MODEL_PATH" |
| ENV_PROJECTION_PATH = "COLQWEN3_PROJECTION_PATH" |
|
|
| |
| DEFAULT_MODEL_PATH = "/Volumes/Maciejowe/mlx_lm/models/tomoro-colqwen3-8b-mlx" |
| DEFAULT_PROJ_PATH = "/Volumes/Maciejowe/mlx_lm/models/colqwen3_projection.safetensors" |
|
|
| def __init__( |
| self, |
| model_path: Optional[str] = None, |
| projection_path: Optional[str] = None, |
| embedding_dim: int = 320, |
| ): |
| """ |
| Initialize the embedder. |
| |
| Args: |
| model_path: Path to Tomoro-ColQwen3 MLX model (overrides env var) |
| projection_path: Path to embedding projection weights (overrides env var) |
| embedding_dim: Output embedding dimension (default 320) |
| |
| Path resolution order: |
| 1. Explicitly passed argument |
| 2. Environment variable (COLQWEN3_MODEL_PATH / COLQWEN3_PROJECTION_PATH) |
| 3. Default fallback path |
| """ |
| self.model_path = model_path or os.environ.get(self.ENV_MODEL_PATH) or self.DEFAULT_MODEL_PATH |
| self.projection_path = projection_path or os.environ.get(self.ENV_PROJECTION_PATH) or self.DEFAULT_PROJ_PATH |
| self.embedding_dim = embedding_dim |
|
|
| self.model = None |
| self.mlx_processor = None |
| self.tomoro_processor = None |
| self.proj_weight = None |
| self.proj_bias = None |
| self._loaded = False |
|
|
| def load(self) -> None: |
| """Load model, processor, and projection weights.""" |
| if self._loaded: |
| return |
|
|
| from mlx_vlm import load |
| from safetensors.torch import load_file |
| from transformers import AutoProcessor |
|
|
| print(f"Loading ColQwen3 from {self.model_path}...") |
| self.model, self.mlx_processor = load(self.model_path) |
|
|
| |
| print("Loading Tomoro processor for image token expansion...") |
| self.tomoro_processor = AutoProcessor.from_pretrained( |
| "TomoroAI/tomoro-colqwen3-embed-8b", trust_remote_code=True |
| ) |
|
|
| print(f"Loading projection from {self.projection_path}...") |
| proj_weights = load_file(self.projection_path) |
| self.proj_weight = mx.array(proj_weights["embedding_proj_layer.weight"].float().numpy()) |
| self.proj_bias = mx.array(proj_weights["embedding_proj_layer.bias"].float().numpy()) |
|
|
| self._loaded = True |
| print("ColQwen3 Embedder ready!") |
|
|
| def _ensure_loaded(self) -> None: |
| """Ensure model is loaded.""" |
| if not self._loaded: |
| self.load() |
|
|
| def _project_and_normalize(self, hidden_states: mx.array) -> mx.array: |
| """Apply projection layer and L2 normalize.""" |
| |
| embeddings = hidden_states @ self.proj_weight.T + self.proj_bias |
|
|
| |
| norm = mx.sqrt(mx.sum(embeddings**2, axis=-1, keepdims=True) + 1e-12) |
| embeddings = embeddings / norm |
|
|
| return embeddings |
|
|
| def embed_text(self, text: str) -> EmbeddingResult: |
| """ |
| Embed text query. |
| |
| Args: |
| text: Query string |
| |
| Returns: |
| EmbeddingResult with shape [num_tokens, 320] |
| """ |
| self._ensure_loaded() |
|
|
| |
| inner_model = self.model["language_model"]["model"] |
|
|
| |
| inputs = self.tomoro_processor.tokenizer(text, return_tensors="np") |
| input_ids = mx.array(inputs["input_ids"]) |
| batch_size, seq_len = input_ids.shape |
|
|
| |
| position_ids = mx.arange(seq_len).reshape(1, -1) |
| position_ids = mx.broadcast_to(position_ids, (batch_size, seq_len)) |
| position_ids = mx.broadcast_to(position_ids[None, ...], (3, batch_size, seq_len)) |
|
|
| |
| hidden_states = inner_model(input_ids, position_ids=position_ids) |
|
|
| |
| embeddings = self._project_and_normalize(hidden_states) |
| embeddings = embeddings.squeeze(0) |
| mx.eval(embeddings) |
|
|
| return EmbeddingResult( |
| embeddings=embeddings, |
| num_tokens=seq_len, |
| source_type="text", |
| ) |
|
|
| def embed_image( |
| self, |
| image: Union[str, Path, Image.Image], |
| ) -> EmbeddingResult: |
| """ |
| Embed document image with proper token expansion. |
| |
| Uses Tomoro's ColQwen3Processor to correctly expand <|image_pad|> |
| tokens to match the number of vision patches. Only the image token |
| embeddings are returned for MaxSim scoring. |
| |
| Args: |
| image: Image path or PIL Image object |
| |
| Returns: |
| EmbeddingResult with shape [num_patches, 320] |
| """ |
| self._ensure_loaded() |
|
|
| |
| if isinstance(image, (str, Path)): |
| image = Image.open(image).convert("RGB") |
|
|
| |
| inputs = self.tomoro_processor( |
| text="", |
| images=[image], |
| return_tensors="pt", |
| ) |
|
|
| input_ids = inputs["input_ids"] |
| pixel_values = inputs["pixel_values"] |
| image_grid_thw = inputs["image_grid_thw"] |
|
|
| |
| image_mask = (input_ids == IMAGE_PAD_TOKEN).numpy()[0] |
| image_positions = np.where(image_mask)[0].tolist() |
|
|
| |
| pixel_values_mx = mx.array(pixel_values.numpy()) |
| image_grid_thw_mx = mx.array(image_grid_thw.numpy()) |
| hidden_states_vision, _ = self.model["vision_tower"](pixel_values_mx, image_grid_thw_mx) |
|
|
| |
| input_ids_mx = mx.array(input_ids.numpy()) |
| embed_tokens = self.model["language_model"]["model"]["embed_tokens"] |
| text_emb_np = np.array(embed_tokens(input_ids_mx)[0]) |
| vision_np = np.array(hidden_states_vision) |
|
|
| for i, pos in enumerate(image_positions): |
| if i < vision_np.shape[0]: |
| text_emb_np[pos] = vision_np[i] |
|
|
| batch_size, seq_len = input_ids_mx.shape |
| combined_embeddings = mx.array(text_emb_np).reshape(1, seq_len, -1) |
|
|
| |
| position_ids = mx.arange(seq_len).reshape(1, -1) |
| position_ids = mx.broadcast_to(position_ids, (batch_size, seq_len)) |
| position_ids = mx.broadcast_to(position_ids[None, ...], (3, batch_size, seq_len)) |
|
|
| |
| inner_model = self.model["language_model"]["model"] |
| h = combined_embeddings |
| for layer in inner_model["layers"]: |
| h = layer(h, position_ids=position_ids) |
| h = inner_model["norm"](h) |
|
|
| |
| h_np = np.array(h[0]) |
| image_hidden_states = mx.array(h_np[image_mask]) |
|
|
| |
| embeddings = self._project_and_normalize(image_hidden_states) |
| mx.eval(embeddings) |
|
|
| return EmbeddingResult( |
| embeddings=embeddings, |
| num_tokens=embeddings.shape[0], |
| source_type="image", |
| ) |
|
|
| def embed_pdf_page( |
| self, |
| pdf_path: Union[str, Path], |
| page_num: int = 0, |
| dpi: int = 150, |
| ) -> EmbeddingResult: |
| """ |
| Embed a page from a PDF document. |
| |
| Args: |
| pdf_path: Path to PDF file |
| page_num: Page number (0-indexed) |
| dpi: Resolution for rendering |
| |
| Returns: |
| EmbeddingResult with shape [num_patches, 320] |
| """ |
| try: |
| import fitz |
| except ImportError: |
| raise ImportError("PyMuPDF required: pip install pymupdf") |
|
|
| doc = fitz.open(pdf_path) |
| page = doc.load_page(page_num) |
| pix = page.get_pixmap(dpi=dpi) |
| image = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
| doc.close() |
|
|
| return self.embed_image(image) |
|
|
| def embed_pdf( |
| self, |
| pdf_path: Union[str, Path], |
| dpi: int = 150, |
| max_pages: Optional[int] = None, |
| ) -> List[EmbeddingResult]: |
| """ |
| Embed all pages from a PDF document. |
| |
| Args: |
| pdf_path: Path to PDF file |
| dpi: Resolution for rendering |
| max_pages: Maximum pages to process (None for all) |
| |
| Returns: |
| List of EmbeddingResult, one per page |
| """ |
| try: |
| import fitz |
| except ImportError: |
| raise ImportError("PyMuPDF required: pip install pymupdf") |
|
|
| doc = fitz.open(pdf_path) |
| num_pages = min(len(doc), max_pages) if max_pages else len(doc) |
|
|
| results = [] |
| for i in range(num_pages): |
| result = self.embed_pdf_page(pdf_path, page_num=i, dpi=dpi) |
| results.append(result) |
|
|
| doc.close() |
| return results |
|
|
| @staticmethod |
| def maxsim_score( |
| query_emb: Union[mx.array, EmbeddingResult], |
| doc_emb: Union[mx.array, EmbeddingResult], |
| ) -> float: |
| """ |
| Compute MaxSim score between query and document embeddings. |
| |
| MaxSim (Late Interaction): For each query token, find maximum |
| similarity across all document tokens, then sum. |
| |
| Args: |
| query_emb: Query embeddings [q_len, dim] |
| doc_emb: Document embeddings [d_len, dim] |
| |
| Returns: |
| Similarity score (higher = more relevant) |
| """ |
| if isinstance(query_emb, EmbeddingResult): |
| query_emb = query_emb.embeddings |
| if isinstance(doc_emb, EmbeddingResult): |
| doc_emb = doc_emb.embeddings |
|
|
| |
| similarities = query_emb @ doc_emb.T |
|
|
| |
| max_sims = mx.max(similarities, axis=1) |
|
|
| |
| score = mx.sum(max_sims) |
| mx.eval(score) |
|
|
| return float(score) |
|
|
| @staticmethod |
| def cosine_similarity( |
| emb1: Union[mx.array, EmbeddingResult], |
| emb2: Union[mx.array, EmbeddingResult], |
| ) -> float: |
| """ |
| Compute mean-pooled cosine similarity. |
| |
| Args: |
| emb1: First embeddings [n, dim] |
| emb2: Second embeddings [m, dim] |
| |
| Returns: |
| Cosine similarity in [-1, 1] |
| """ |
| if isinstance(emb1, EmbeddingResult): |
| emb1 = emb1.embeddings |
| if isinstance(emb2, EmbeddingResult): |
| emb2 = emb2.embeddings |
|
|
| |
| v1 = mx.mean(emb1, axis=0) |
| v2 = mx.mean(emb2, axis=0) |
|
|
| |
| sim = mx.sum(v1 * v2) / (mx.sqrt(mx.sum(v1**2)) * mx.sqrt(mx.sum(v2**2))) |
| mx.eval(sim) |
|
|
| return float(sim) |
|
|
| def rank_documents( |
| self, |
| query: str, |
| documents: List[EmbeddingResult], |
| top_k: Optional[int] = None, |
| ) -> List[Tuple[int, float]]: |
| """ |
| Rank documents by relevance to query. |
| |
| Args: |
| query: Query string |
| documents: List of document embeddings |
| top_k: Return top K results (None for all) |
| |
| Returns: |
| List of (doc_index, score) sorted by descending score |
| """ |
| query_emb = self.embed_text(query) |
|
|
| scores = [] |
| for i, doc_emb in enumerate(documents): |
| score = self.maxsim_score(query_emb, doc_emb) |
| scores.append((i, score)) |
|
|
| |
| scores.sort(key=lambda x: x[1], reverse=True) |
|
|
| if top_k: |
| scores = scores[:top_k] |
|
|
| return scores |
|
|
| def to_numpy(self, emb: Union[mx.array, EmbeddingResult]) -> np.ndarray: |
| """Convert embeddings to numpy array (for storage/indexing).""" |
| if isinstance(emb, EmbeddingResult): |
| emb = emb.embeddings |
| return np.array(emb) |
|
|
|
|
| |
| def load_embedder( |
| model_path: Optional[str] = None, |
| projection_path: Optional[str] = None, |
| ) -> ColQwen3Embedder: |
| """Load and return a ready-to-use embedder.""" |
| embedder = ColQwen3Embedder( |
| model_path=model_path, |
| projection_path=projection_path, |
| ) |
| embedder.load() |
| return embedder |
|
|
|
|
| if __name__ == "__main__": |
| |
| print("Testing ColQwen3 Embedder...") |
|
|
| embedder = load_embedder() |
|
|
| |
| text = "dawkowanie meloksykamu dla psa" |
| result = embedder.embed_text(text) |
| print(f"\nText: '{text}'") |
| print(f" Tokens: {result.num_tokens}") |
| print(f" Embedding shape: {result.embeddings.shape}") |
|
|
| |
| text2 = "metacam dose for dogs" |
| result2 = embedder.embed_text(text2) |
| sim = embedder.cosine_similarity(result, result2) |
| print(f"\nSimilarity to '{text2}': {sim:.4f}") |
|
|
| print("\nColQwen3 Embedder test complete!") |
|
|