|
|
|
|
|
""" |
|
|
MLX Visual Embedding Server - ColQwen3 |
|
|
|
|
|
HTTP server wrapper for ColQwen3Embedder providing visual document embeddings. |
|
|
Power of Wet Coders edition - custom merged model by LibraxisAI. |
|
|
|
|
|
Uses the production ColQwen3Embedder class from colqwen3_embedder.py |
|
|
|
|
|
Usage: |
|
|
cd knowledge/vista-brain |
|
|
uv run python scripts/mlx_visual_server.py |
|
|
|
|
|
# Or via Makefile: |
|
|
make visual |
|
|
|
|
|
Endpoints: |
|
|
POST /v1/visual-embeddings - Generate visual embeddings from images/PDFs |
|
|
POST /v1/maxsim - Compute MaxSim score between query and docs |
|
|
GET /v1/models - List models |
|
|
GET /health - Health check |
|
|
|
|
|
Created by M&K (c)2025 The LibraxisAI Team |
|
|
Co-Authored-By: Maciej (void@div0.space) & Klaudiusz (the1st@whoai.am) |
|
|
""" |
|
|
import base64 |
|
|
import io |
|
|
import json |
|
|
import os |
|
|
import sys |
|
|
import time |
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer |
|
|
from pathlib import Path |
|
|
from typing import List, Union |
|
|
|
|
|
|
|
|
sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
|
|
|
from colqwen3_embedder import ColQwen3Embedder, load_embedder |
|
|
|
|
|
|
|
|
PORT = int(os.environ.get("MLX_VISUAL_PORT", "12347")) |
|
|
|
|
|
|
|
|
EMBED_DIM = 320 |
|
|
|
|
|
|
|
|
_embedder = None |
|
|
|
|
|
|
|
|
def get_embedder() -> ColQwen3Embedder: |
|
|
"""Lazy load the ColQwen3 embedder.""" |
|
|
global _embedder |
|
|
if _embedder is None: |
|
|
print("Loading ColQwen3 Embedder...", file=sys.stderr) |
|
|
_embedder = load_embedder() |
|
|
print(f"ColQwen3 ready (dim={EMBED_DIM})", file=sys.stderr) |
|
|
return _embedder |
|
|
|
|
|
|
|
|
def decode_image(image_data: Union[str, bytes]): |
|
|
"""Decode image from base64 or bytes.""" |
|
|
from PIL import Image |
|
|
|
|
|
if isinstance(image_data, str): |
|
|
|
|
|
if image_data.startswith("data:"): |
|
|
|
|
|
image_data = image_data.split(",", 1)[1] |
|
|
image_bytes = base64.b64decode(image_data) |
|
|
else: |
|
|
image_bytes = image_data |
|
|
|
|
|
return Image.open(io.BytesIO(image_bytes)).convert("RGB") |
|
|
|
|
|
|
|
|
def embed_images(images: List[Union[str, bytes]]) -> List[dict]: |
|
|
"""Generate ColBERT-style embeddings for images.""" |
|
|
embedder = get_embedder() |
|
|
import mlx.core as mx |
|
|
|
|
|
results = [] |
|
|
for img_data in images: |
|
|
try: |
|
|
|
|
|
if isinstance(img_data, str) and ( |
|
|
img_data.startswith("/") or img_data.startswith(".") |
|
|
): |
|
|
|
|
|
pil_img = img_data |
|
|
else: |
|
|
|
|
|
pil_img = decode_image(img_data) |
|
|
|
|
|
|
|
|
result = embedder.embed_image(pil_img) |
|
|
|
|
|
results.append({ |
|
|
"embedding": embedder.to_numpy(result).tolist(), |
|
|
"num_tokens": result.num_tokens, |
|
|
"source_type": result.source_type, |
|
|
}) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Image embed error: {e}", file=sys.stderr) |
|
|
results.append({"error": str(e)}) |
|
|
|
|
|
|
|
|
mx.clear_cache() |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def embed_pdf(pdf_path: str, max_pages: int = None) -> List[dict]: |
|
|
"""Embed all pages from a PDF.""" |
|
|
embedder = get_embedder() |
|
|
import mlx.core as mx |
|
|
|
|
|
results = [] |
|
|
try: |
|
|
page_results = embedder.embed_pdf(pdf_path, max_pages=max_pages) |
|
|
for i, result in enumerate(page_results): |
|
|
results.append({ |
|
|
"page": i, |
|
|
"embedding": embedder.to_numpy(result).tolist(), |
|
|
"num_tokens": result.num_tokens, |
|
|
"source_type": result.source_type, |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"PDF embed error: {e}", file=sys.stderr) |
|
|
results.append({"error": str(e)}) |
|
|
|
|
|
mx.clear_cache() |
|
|
return results |
|
|
|
|
|
|
|
|
def embed_text(text: str) -> dict: |
|
|
"""Embed text query.""" |
|
|
embedder = get_embedder() |
|
|
import mlx.core as mx |
|
|
|
|
|
try: |
|
|
result = embedder.embed_text(text) |
|
|
mx.clear_cache() |
|
|
return { |
|
|
"embedding": embedder.to_numpy(result).tolist(), |
|
|
"num_tokens": result.num_tokens, |
|
|
"source_type": result.source_type, |
|
|
} |
|
|
except Exception as e: |
|
|
print(f"Text embed error: {e}", file=sys.stderr) |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def compute_maxsim(query_embedding: List, doc_embedding: List) -> float: |
|
|
"""Compute MaxSim score between query and document embeddings.""" |
|
|
import mlx.core as mx |
|
|
|
|
|
query_mx = mx.array(query_embedding) |
|
|
doc_mx = mx.array(doc_embedding) |
|
|
|
|
|
|
|
|
similarities = query_mx @ doc_mx.T |
|
|
max_sims = mx.max(similarities, axis=1) |
|
|
score = float(mx.sum(max_sims)) |
|
|
|
|
|
mx.clear_cache() |
|
|
return score |
|
|
|
|
|
|
|
|
class VisualHandler(BaseHTTPRequestHandler): |
|
|
"""HTTP handler for visual embeddings API.""" |
|
|
|
|
|
def log_message(self, format, *args): |
|
|
"""Log to stderr.""" |
|
|
print(f"[{time.strftime('%Y-%m-%d %H:%M:%S')}] {args[0]}", file=sys.stderr) |
|
|
|
|
|
def send_json(self, data: dict, status: int = 200): |
|
|
"""Send JSON response.""" |
|
|
body = json.dumps(data).encode("utf-8") |
|
|
self.send_response(status) |
|
|
self.send_header("Content-Type", "application/json") |
|
|
self.send_header("Content-Length", len(body)) |
|
|
self.end_headers() |
|
|
self.wfile.write(body) |
|
|
|
|
|
def do_GET(self): |
|
|
"""Handle GET requests.""" |
|
|
if self.path == "/v1/models" or self.path == "/models": |
|
|
self.send_json({ |
|
|
"object": "list", |
|
|
"data": [{ |
|
|
"id": "colqwen3-8b-wetcoders", |
|
|
"object": "model", |
|
|
"owned_by": "libraxis-local", |
|
|
"type": "visual-embedding", |
|
|
"description": "ColQwen3 8B - Power of Wet Coders edition", |
|
|
"embedding_dim": EMBED_DIM, |
|
|
}] |
|
|
}) |
|
|
elif self.path == "/health": |
|
|
self.send_json({ |
|
|
"status": "healthy", |
|
|
"model": "colqwen3-8b-wetcoders", |
|
|
"dim": EMBED_DIM, |
|
|
"type": "colbert-visual-embedding", |
|
|
}) |
|
|
else: |
|
|
self.send_json({"error": "Not found"}, 404) |
|
|
|
|
|
def do_POST(self): |
|
|
"""Handle POST requests.""" |
|
|
content_length = int(self.headers.get("Content-Length", 0)) |
|
|
body = self.rfile.read(content_length) |
|
|
|
|
|
try: |
|
|
data = json.loads(body) |
|
|
except json.JSONDecodeError: |
|
|
self.send_json({"error": "Invalid JSON"}, 400) |
|
|
return |
|
|
|
|
|
if self.path in ["/v1/visual-embeddings", "/visual-embeddings"]: |
|
|
self._handle_embeddings(data) |
|
|
elif self.path in ["/v1/maxsim", "/maxsim"]: |
|
|
self._handle_maxsim(data) |
|
|
else: |
|
|
self.send_json({"error": "Not found"}, 404) |
|
|
|
|
|
def _handle_embeddings(self, data: dict): |
|
|
"""Handle embedding requests.""" |
|
|
images = data.get("images", []) |
|
|
texts = data.get("texts", []) |
|
|
pdf_path = data.get("pdf_path") |
|
|
max_pages = data.get("max_pages") |
|
|
|
|
|
response = { |
|
|
"object": "embedding_response", |
|
|
"model": "colqwen3-8b-wetcoders", |
|
|
"dim": EMBED_DIM, |
|
|
} |
|
|
|
|
|
try: |
|
|
if pdf_path: |
|
|
|
|
|
response["pdf_embeddings"] = embed_pdf(pdf_path, max_pages) |
|
|
elif images: |
|
|
|
|
|
response["image_embeddings"] = embed_images(images) |
|
|
elif texts: |
|
|
|
|
|
response["text_embeddings"] = [embed_text(t) for t in texts] |
|
|
else: |
|
|
self.send_json({"error": "No images, texts, or pdf_path provided"}, 400) |
|
|
return |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Embedding error: {e}", file=sys.stderr) |
|
|
self.send_json({"error": str(e)}, 500) |
|
|
return |
|
|
|
|
|
self.send_json(response) |
|
|
|
|
|
def _handle_maxsim(self, data: dict): |
|
|
"""Handle MaxSim scoring requests.""" |
|
|
query_embedding = data.get("query_embedding") |
|
|
doc_embedding = data.get("doc_embedding") |
|
|
|
|
|
if not query_embedding or not doc_embedding: |
|
|
self.send_json({"error": "query_embedding and doc_embedding required"}, 400) |
|
|
return |
|
|
|
|
|
try: |
|
|
score = compute_maxsim(query_embedding, doc_embedding) |
|
|
self.send_json({ |
|
|
"object": "maxsim_score", |
|
|
"score": score, |
|
|
"model": "colqwen3-8b-wetcoders", |
|
|
}) |
|
|
except Exception as e: |
|
|
print(f"MaxSim error: {e}", file=sys.stderr) |
|
|
self.send_json({"error": str(e)}, 500) |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Start the visual embedding server.""" |
|
|
print("", file=sys.stderr) |
|
|
print("=" * 60, file=sys.stderr) |
|
|
print("MLX Visual Embedding Server - ColQwen3", file=sys.stderr) |
|
|
print("Power of Wet Coders Edition", file=sys.stderr) |
|
|
print("=" * 60, file=sys.stderr) |
|
|
print(f"Port: {PORT}", file=sys.stderr) |
|
|
print(f"Embedding dim: {EMBED_DIM} (ColBERT)", file=sys.stderr) |
|
|
print("", file=sys.stderr) |
|
|
print("Endpoints:", file=sys.stderr) |
|
|
print(" POST /v1/visual-embeddings - Generate embeddings", file=sys.stderr) |
|
|
print(" body: {images: [base64...]} or {pdf_path: '/path.pdf'}", file=sys.stderr) |
|
|
print(" POST /v1/maxsim - Compute MaxSim score", file=sys.stderr) |
|
|
print(" body: {query_embedding: [...], doc_embedding: [...]}", file=sys.stderr) |
|
|
print(" GET /v1/models - List models", file=sys.stderr) |
|
|
print(" GET /health - Health check", file=sys.stderr) |
|
|
print("", file=sys.stderr) |
|
|
|
|
|
|
|
|
get_embedder() |
|
|
|
|
|
server = HTTPServer(("0.0.0.0", PORT), VisualHandler) |
|
|
print(f"Server ready at http://localhost:{PORT}", file=sys.stderr) |
|
|
print("=" * 60, file=sys.stderr) |
|
|
|
|
|
try: |
|
|
server.serve_forever() |
|
|
except KeyboardInterrupt: |
|
|
print("\nShutting down...", file=sys.stderr) |
|
|
server.shutdown() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|