#!/usr/bin/env python3 """ MLX Visual Embedding Server - ColQwen3 HTTP server wrapper for ColQwen3Embedder providing visual document embeddings. Power of Wet Coders edition - custom merged model by LibraxisAI. Uses the production ColQwen3Embedder class from colqwen3_embedder.py Usage: cd knowledge/vista-brain uv run python scripts/mlx_visual_server.py # Or via Makefile: make visual Endpoints: POST /v1/visual-embeddings - Generate visual embeddings from images/PDFs POST /v1/maxsim - Compute MaxSim score between query and docs GET /v1/models - List models GET /health - Health check Created by M&K (c)2025 The LibraxisAI Team Co-Authored-By: Maciej (void@div0.space) & Klaudiusz (the1st@whoai.am) """ import base64 import io import json import os import sys import time from http.server import BaseHTTPRequestHandler, HTTPServer from pathlib import Path from typing import List, Union # Add parent directory to path for colqwen3_embedder import sys.path.insert(0, str(Path(__file__).parent.parent)) from colqwen3_embedder import ColQwen3Embedder, load_embedder # Configuration from environment PORT = int(os.environ.get("MLX_VISUAL_PORT", "12347")) # ColBERT embedding dimension (320 for our custom projection) EMBED_DIM = 320 # Lazy load embedder _embedder = None def get_embedder() -> ColQwen3Embedder: """Lazy load the ColQwen3 embedder.""" global _embedder if _embedder is None: print("Loading ColQwen3 Embedder...", file=sys.stderr) _embedder = load_embedder() print(f"ColQwen3 ready (dim={EMBED_DIM})", file=sys.stderr) return _embedder def decode_image(image_data: Union[str, bytes]): """Decode image from base64 or bytes.""" from PIL import Image if isinstance(image_data, str): # Handle base64 with or without data URL prefix if image_data.startswith("data:"): #  image_data = image_data.split(",", 1)[1] image_bytes = base64.b64decode(image_data) else: image_bytes = image_data return Image.open(io.BytesIO(image_bytes)).convert("RGB") def embed_images(images: List[Union[str, bytes]]) -> List[dict]: """Generate ColBERT-style embeddings for images.""" embedder = get_embedder() import mlx.core as mx results = [] for img_data in images: try: # Decode image if isinstance(img_data, str) and ( img_data.startswith("/") or img_data.startswith(".") ): # It's a file path pil_img = img_data else: # Base64 data pil_img = decode_image(img_data) # Embed using ColQwen3Embedder result = embedder.embed_image(pil_img) results.append({ "embedding": embedder.to_numpy(result).tolist(), "num_tokens": result.num_tokens, "source_type": result.source_type, }) except Exception as e: print(f"Image embed error: {e}", file=sys.stderr) results.append({"error": str(e)}) # Clear MLX cache mx.clear_cache() return results def embed_pdf(pdf_path: str, max_pages: int = None) -> List[dict]: """Embed all pages from a PDF.""" embedder = get_embedder() import mlx.core as mx results = [] try: page_results = embedder.embed_pdf(pdf_path, max_pages=max_pages) for i, result in enumerate(page_results): results.append({ "page": i, "embedding": embedder.to_numpy(result).tolist(), "num_tokens": result.num_tokens, "source_type": result.source_type, }) except Exception as e: print(f"PDF embed error: {e}", file=sys.stderr) results.append({"error": str(e)}) mx.clear_cache() return results def embed_text(text: str) -> dict: """Embed text query.""" embedder = get_embedder() import mlx.core as mx try: result = embedder.embed_text(text) mx.clear_cache() return { "embedding": embedder.to_numpy(result).tolist(), "num_tokens": result.num_tokens, "source_type": result.source_type, } except Exception as e: print(f"Text embed error: {e}", file=sys.stderr) return {"error": str(e)} def compute_maxsim(query_embedding: List, doc_embedding: List) -> float: """Compute MaxSim score between query and document embeddings.""" import mlx.core as mx query_mx = mx.array(query_embedding) doc_mx = mx.array(doc_embedding) # MaxSim: for each query token, max over doc tokens, then sum similarities = query_mx @ doc_mx.T max_sims = mx.max(similarities, axis=1) score = float(mx.sum(max_sims)) mx.clear_cache() return score class VisualHandler(BaseHTTPRequestHandler): """HTTP handler for visual embeddings API.""" def log_message(self, format, *args): """Log to stderr.""" print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {args[0]}", file=sys.stderr) def send_json(self, data: dict, status: int = 200): """Send JSON response.""" body = json.dumps(data).encode("utf-8") self.send_response(status) self.send_header("Content-Type", "application/json") self.send_header("Content-Length", len(body)) self.end_headers() self.wfile.write(body) def do_GET(self): """Handle GET requests.""" if self.path == "/v1/models" or self.path == "/models": self.send_json({ "object": "list", "data": [{ "id": "colqwen3-8b-wetcoders", "object": "model", "owned_by": "libraxis-local", "type": "visual-embedding", "description": "ColQwen3 8B - Power of Wet Coders edition", "embedding_dim": EMBED_DIM, }] }) elif self.path == "/health": self.send_json({ "status": "healthy", "model": "colqwen3-8b-wetcoders", "dim": EMBED_DIM, "type": "colbert-visual-embedding", }) else: self.send_json({"error": "Not found"}, 404) def do_POST(self): """Handle POST requests.""" content_length = int(self.headers.get("Content-Length", 0)) body = self.rfile.read(content_length) try: data = json.loads(body) except json.JSONDecodeError: self.send_json({"error": "Invalid JSON"}, 400) return if self.path in ["/v1/visual-embeddings", "/visual-embeddings"]: self._handle_embeddings(data) elif self.path in ["/v1/maxsim", "/maxsim"]: self._handle_maxsim(data) else: self.send_json({"error": "Not found"}, 404) def _handle_embeddings(self, data: dict): """Handle embedding requests.""" images = data.get("images", []) texts = data.get("texts", []) pdf_path = data.get("pdf_path") max_pages = data.get("max_pages") response = { "object": "embedding_response", "model": "colqwen3-8b-wetcoders", "dim": EMBED_DIM, } try: if pdf_path: # PDF embedding response["pdf_embeddings"] = embed_pdf(pdf_path, max_pages) elif images: # Image embeddings response["image_embeddings"] = embed_images(images) elif texts: # Text embeddings response["text_embeddings"] = [embed_text(t) for t in texts] else: self.send_json({"error": "No images, texts, or pdf_path provided"}, 400) return except Exception as e: print(f"Embedding error: {e}", file=sys.stderr) self.send_json({"error": str(e)}, 500) return self.send_json(response) def _handle_maxsim(self, data: dict): """Handle MaxSim scoring requests.""" query_embedding = data.get("query_embedding") doc_embedding = data.get("doc_embedding") if not query_embedding or not doc_embedding: self.send_json({"error": "query_embedding and doc_embedding required"}, 400) return try: score = compute_maxsim(query_embedding, doc_embedding) self.send_json({ "object": "maxsim_score", "score": score, "model": "colqwen3-8b-wetcoders", }) except Exception as e: print(f"MaxSim error: {e}", file=sys.stderr) self.send_json({"error": str(e)}, 500) def main(): """Start the visual embedding server.""" print("", file=sys.stderr) print("=" * 60, file=sys.stderr) print("MLX Visual Embedding Server - ColQwen3", file=sys.stderr) print("Power of Wet Coders Edition", file=sys.stderr) print("=" * 60, file=sys.stderr) print(f"Port: {PORT}", file=sys.stderr) print(f"Embedding dim: {EMBED_DIM} (ColBERT)", file=sys.stderr) print("", file=sys.stderr) print("Endpoints:", file=sys.stderr) print(" POST /v1/visual-embeddings - Generate embeddings", file=sys.stderr) print(" body: {images: [base64...]} or {pdf_path: '/path.pdf'}", file=sys.stderr) print(" POST /v1/maxsim - Compute MaxSim score", file=sys.stderr) print(" body: {query_embedding: [...], doc_embedding: [...]}", file=sys.stderr) print(" GET /v1/models - List models", file=sys.stderr) print(" GET /health - Health check", file=sys.stderr) print("", file=sys.stderr) # Pre-load embedder get_embedder() server = HTTPServer(("0.0.0.0", PORT), VisualHandler) print(f"Server ready at http://localhost:{PORT}", file=sys.stderr) print("=" * 60, file=sys.stderr) try: server.serve_forever() except KeyboardInterrupt: print("\nShutting down...", file=sys.stderr) server.shutdown() if __name__ == "__main__": main()