colqwen3-8b-vetcoders-mlx / scripts /mlx_visual_server.py
div0-space's picture
Upload folder using huggingface_hub
c6c3a3b verified
#!/usr/bin/env python3
"""
MLX Visual Embedding Server - ColQwen3
HTTP server wrapper for ColQwen3Embedder providing visual document embeddings.
Power of Wet Coders edition - custom merged model by LibraxisAI.
Uses the production ColQwen3Embedder class from colqwen3_embedder.py
Usage:
cd knowledge/vista-brain
uv run python scripts/mlx_visual_server.py
# Or via Makefile:
make visual
Endpoints:
POST /v1/visual-embeddings - Generate visual embeddings from images/PDFs
POST /v1/maxsim - Compute MaxSim score between query and docs
GET /v1/models - List models
GET /health - Health check
Created by M&K (c)2025 The LibraxisAI Team
Co-Authored-By: Maciej (void@div0.space) & Klaudiusz (the1st@whoai.am)
"""
import base64
import io
import json
import os
import sys
import time
from http.server import BaseHTTPRequestHandler, HTTPServer
from pathlib import Path
from typing import List, Union
# Add parent directory to path for colqwen3_embedder import
sys.path.insert(0, str(Path(__file__).parent.parent))
from colqwen3_embedder import ColQwen3Embedder, load_embedder
# Configuration from environment
PORT = int(os.environ.get("MLX_VISUAL_PORT", "12347"))
# ColBERT embedding dimension (320 for our custom projection)
EMBED_DIM = 320
# Lazy load embedder
_embedder = None
def get_embedder() -> ColQwen3Embedder:
"""Lazy load the ColQwen3 embedder."""
global _embedder
if _embedder is None:
print("Loading ColQwen3 Embedder...", file=sys.stderr)
_embedder = load_embedder()
print(f"ColQwen3 ready (dim={EMBED_DIM})", file=sys.stderr)
return _embedder
def decode_image(image_data: Union[str, bytes]):
"""Decode image from base64 or bytes."""
from PIL import Image
if isinstance(image_data, str):
# Handle base64 with or without data URL prefix
if image_data.startswith("data:"):
# data:image/png;base64,xxxx
image_data = image_data.split(",", 1)[1]
image_bytes = base64.b64decode(image_data)
else:
image_bytes = image_data
return Image.open(io.BytesIO(image_bytes)).convert("RGB")
def embed_images(images: List[Union[str, bytes]]) -> List[dict]:
"""Generate ColBERT-style embeddings for images."""
embedder = get_embedder()
import mlx.core as mx
results = []
for img_data in images:
try:
# Decode image
if isinstance(img_data, str) and (
img_data.startswith("/") or img_data.startswith(".")
):
# It's a file path
pil_img = img_data
else:
# Base64 data
pil_img = decode_image(img_data)
# Embed using ColQwen3Embedder
result = embedder.embed_image(pil_img)
results.append({
"embedding": embedder.to_numpy(result).tolist(),
"num_tokens": result.num_tokens,
"source_type": result.source_type,
})
except Exception as e:
print(f"Image embed error: {e}", file=sys.stderr)
results.append({"error": str(e)})
# Clear MLX cache
mx.clear_cache()
return results
def embed_pdf(pdf_path: str, max_pages: int = None) -> List[dict]:
"""Embed all pages from a PDF."""
embedder = get_embedder()
import mlx.core as mx
results = []
try:
page_results = embedder.embed_pdf(pdf_path, max_pages=max_pages)
for i, result in enumerate(page_results):
results.append({
"page": i,
"embedding": embedder.to_numpy(result).tolist(),
"num_tokens": result.num_tokens,
"source_type": result.source_type,
})
except Exception as e:
print(f"PDF embed error: {e}", file=sys.stderr)
results.append({"error": str(e)})
mx.clear_cache()
return results
def embed_text(text: str) -> dict:
"""Embed text query."""
embedder = get_embedder()
import mlx.core as mx
try:
result = embedder.embed_text(text)
mx.clear_cache()
return {
"embedding": embedder.to_numpy(result).tolist(),
"num_tokens": result.num_tokens,
"source_type": result.source_type,
}
except Exception as e:
print(f"Text embed error: {e}", file=sys.stderr)
return {"error": str(e)}
def compute_maxsim(query_embedding: List, doc_embedding: List) -> float:
"""Compute MaxSim score between query and document embeddings."""
import mlx.core as mx
query_mx = mx.array(query_embedding)
doc_mx = mx.array(doc_embedding)
# MaxSim: for each query token, max over doc tokens, then sum
similarities = query_mx @ doc_mx.T
max_sims = mx.max(similarities, axis=1)
score = float(mx.sum(max_sims))
mx.clear_cache()
return score
class VisualHandler(BaseHTTPRequestHandler):
"""HTTP handler for visual embeddings API."""
def log_message(self, format, *args):
"""Log to stderr."""
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {args[0]}", file=sys.stderr)
def send_json(self, data: dict, status: int = 200):
"""Send JSON response."""
body = json.dumps(data).encode("utf-8")
self.send_response(status)
self.send_header("Content-Type", "application/json")
self.send_header("Content-Length", len(body))
self.end_headers()
self.wfile.write(body)
def do_GET(self):
"""Handle GET requests."""
if self.path == "/v1/models" or self.path == "/models":
self.send_json({
"object": "list",
"data": [{
"id": "colqwen3-8b-wetcoders",
"object": "model",
"owned_by": "libraxis-local",
"type": "visual-embedding",
"description": "ColQwen3 8B - Power of Wet Coders edition",
"embedding_dim": EMBED_DIM,
}]
})
elif self.path == "/health":
self.send_json({
"status": "healthy",
"model": "colqwen3-8b-wetcoders",
"dim": EMBED_DIM,
"type": "colbert-visual-embedding",
})
else:
self.send_json({"error": "Not found"}, 404)
def do_POST(self):
"""Handle POST requests."""
content_length = int(self.headers.get("Content-Length", 0))
body = self.rfile.read(content_length)
try:
data = json.loads(body)
except json.JSONDecodeError:
self.send_json({"error": "Invalid JSON"}, 400)
return
if self.path in ["/v1/visual-embeddings", "/visual-embeddings"]:
self._handle_embeddings(data)
elif self.path in ["/v1/maxsim", "/maxsim"]:
self._handle_maxsim(data)
else:
self.send_json({"error": "Not found"}, 404)
def _handle_embeddings(self, data: dict):
"""Handle embedding requests."""
images = data.get("images", [])
texts = data.get("texts", [])
pdf_path = data.get("pdf_path")
max_pages = data.get("max_pages")
response = {
"object": "embedding_response",
"model": "colqwen3-8b-wetcoders",
"dim": EMBED_DIM,
}
try:
if pdf_path:
# PDF embedding
response["pdf_embeddings"] = embed_pdf(pdf_path, max_pages)
elif images:
# Image embeddings
response["image_embeddings"] = embed_images(images)
elif texts:
# Text embeddings
response["text_embeddings"] = [embed_text(t) for t in texts]
else:
self.send_json({"error": "No images, texts, or pdf_path provided"}, 400)
return
except Exception as e:
print(f"Embedding error: {e}", file=sys.stderr)
self.send_json({"error": str(e)}, 500)
return
self.send_json(response)
def _handle_maxsim(self, data: dict):
"""Handle MaxSim scoring requests."""
query_embedding = data.get("query_embedding")
doc_embedding = data.get("doc_embedding")
if not query_embedding or not doc_embedding:
self.send_json({"error": "query_embedding and doc_embedding required"}, 400)
return
try:
score = compute_maxsim(query_embedding, doc_embedding)
self.send_json({
"object": "maxsim_score",
"score": score,
"model": "colqwen3-8b-wetcoders",
})
except Exception as e:
print(f"MaxSim error: {e}", file=sys.stderr)
self.send_json({"error": str(e)}, 500)
def main():
"""Start the visual embedding server."""
print("", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print("MLX Visual Embedding Server - ColQwen3", file=sys.stderr)
print("Power of Wet Coders Edition", file=sys.stderr)
print("=" * 60, file=sys.stderr)
print(f"Port: {PORT}", file=sys.stderr)
print(f"Embedding dim: {EMBED_DIM} (ColBERT)", file=sys.stderr)
print("", file=sys.stderr)
print("Endpoints:", file=sys.stderr)
print(" POST /v1/visual-embeddings - Generate embeddings", file=sys.stderr)
print(" body: {images: [base64...]} or {pdf_path: '/path.pdf'}", file=sys.stderr)
print(" POST /v1/maxsim - Compute MaxSim score", file=sys.stderr)
print(" body: {query_embedding: [...], doc_embedding: [...]}", file=sys.stderr)
print(" GET /v1/models - List models", file=sys.stderr)
print(" GET /health - Health check", file=sys.stderr)
print("", file=sys.stderr)
# Pre-load embedder
get_embedder()
server = HTTPServer(("0.0.0.0", PORT), VisualHandler)
print(f"Server ready at http://localhost:{PORT}", file=sys.stderr)
print("=" * 60, file=sys.stderr)
try:
server.serve_forever()
except KeyboardInterrupt:
print("\nShutting down...", file=sys.stderr)
server.shutdown()
if __name__ == "__main__":
main()