| |
| """ |
| Vision-aware Gemma 4 sniper for single-machine inference. |
| |
| Combines: |
| - mlx-vlm vision_tower + embed_vision (~1.2 GB, no SSD streaming) |
| - Our own MoESniperEngineGemma4 (~5.8 GB pinned + experts streamed from SSD) |
| |
| Total RAM: ~7 GB. Verified working on M4 Mac Mini 16 GB. |
| |
| Requirements: |
| pip install "mlx>=0.31" "mlx-vlm>=0.4" "mlx-lm>=0.31" |
| (these need Python 3.11+, install in a venv) |
| """ |
|
|
| import os |
| import sys |
| import json |
| import time |
| import glob |
|
|
| import numpy as np |
| import mlx.core as mx |
| import mlx.nn as nn |
|
|
|
|
| |
| IMAGE_TOKEN_ID = 262144 |
| BOI_TOKEN_ID = 255999 |
| EOI_TOKEN_ID = 256000 |
|
|
|
|
| class VisionGemma4Sniper: |
| """Single-machine vision-aware Gemma 4 sniper. |
| |
| The vision encoder (~1.2 GB, bfloat16) is loaded from a source 4-bit |
| Gemma 4 model directory. The LLM (text + experts) is loaded via our |
| existing MoESniperEngineGemma4 from a pre-split streaming directory. |
| """ |
|
|
| def __init__(self, stream_dir, source_dir, cache_size=4000): |
| """ |
| Args: |
| stream_dir: directory with pinned.safetensors + bin/layer_XX.bin |
| (created by split_gemma4.py) |
| source_dir: directory with the original mlx-community/gemma-4-26b-a4b-it-4bit |
| safetensors files (used to extract vision weights) |
| cache_size: LRU cache size for the expert reader |
| """ |
| self.stream_dir = os.path.expanduser(stream_dir) |
| self.source_dir = os.path.expanduser(source_dir) |
| self.cache_size = cache_size |
|
|
| self.sniper = None |
| self.vision_tower = None |
| self.embed_vision = None |
| self.image_processor = None |
|
|
| def load(self): |
| """Load all three components.""" |
| |
| SNIPER_PATHS = [ |
| os.path.expanduser("~/expert-sniper-mlx"), |
| os.path.expanduser("~/cli-agent/src"), |
| ] |
| for p in SNIPER_PATHS: |
| if os.path.isdir(p) and p not in sys.path: |
| sys.path.insert(0, p) |
|
|
| try: |
| from moe_agent_gemma4 import MoESniperEngineGemma4 |
| except ImportError as e: |
| raise ImportError( |
| "Could not import moe_agent_gemma4. The vision engine wraps " |
| "the existing single-machine sniper. Make sure these directories " |
| "exist and contain the sniper code:\n" |
| f" ~/expert-sniper-mlx/moe_agent_gemma4.py\n" |
| f" ~/cli-agent/src/mlx_expert_sniper/models/gemma4.py\n" |
| f"Original error: {e}" |
| ) |
| from mlx_vlm.models.gemma4 import VisionModel, VisionConfig, Gemma4ImageProcessor |
| from mlx_vlm.models.gemma4.gemma4 import MultimodalEmbedder |
|
|
| print("Loading Vision Gemma 4 Sniper...") |
|
|
| |
| with open(os.path.join(self.stream_dir, "config.json")) as f: |
| text_config = json.load(f) |
| with open(os.path.join(self.source_dir, "config.json")) as f: |
| source_config = json.load(f) |
| vision_config_dict = source_config.get("vision_config", {}) |
| if not vision_config_dict: |
| raise RuntimeError(f"No vision_config in {self.source_dir}/config.json") |
|
|
| |
| print(" [1/4] Sniper LLM...") |
| self.sniper = MoESniperEngineGemma4( |
| model_dir=self.stream_dir, cache_size=self.cache_size |
| ) |
| self.sniper.load() |
|
|
| |
| print(" [2/4] Vision tower...") |
| vc = VisionConfig.from_dict(vision_config_dict) |
| self.vision_tower = VisionModel(vc) |
|
|
| |
| |
| safetensors_files = sorted(glob.glob(os.path.join(self.source_dir, "model-*.safetensors"))) |
| if not safetensors_files: |
| standalone = os.path.join(self.source_dir, "vision.safetensors") |
| if os.path.exists(standalone): |
| safetensors_files = [standalone] |
| if not safetensors_files: |
| raise FileNotFoundError( |
| f"No safetensors files found in {self.source_dir}. " |
| f"Expected either model-*.safetensors or vision.safetensors" |
| ) |
|
|
| vision_weights = {} |
| for sf in safetensors_files: |
| w = mx.load(sf) |
| for k, v in w.items(): |
| if k.startswith("vision_tower."): |
| vision_weights[k[len("vision_tower."):]] = v |
| del w |
| self.vision_tower.load_weights(list(vision_weights.items()), strict=False) |
| mx.eval(self.vision_tower.parameters()) |
|
|
| |
| print(" [3/4] embed_vision...") |
| self.embed_vision = MultimodalEmbedder( |
| embedding_dim=vision_config_dict["hidden_size"], |
| text_hidden_size=text_config["hidden_size"], |
| eps=vision_config_dict.get("rms_norm_eps", 1e-6), |
| ) |
| nn.quantize(self.embed_vision, group_size=64, bits=4) |
|
|
| ev_weights = {} |
| for sf in safetensors_files: |
| w = mx.load(sf) |
| for k, v in w.items(): |
| if k.startswith("embed_vision."): |
| ev_weights[k[len("embed_vision."):]] = v |
| del w |
| self.embed_vision.load_weights(list(ev_weights.items()), strict=False) |
| mx.eval(self.embed_vision.parameters()) |
|
|
| |
| print(" [4/4] Image processor...") |
| self.image_processor = Gemma4ImageProcessor.from_pretrained(self.source_dir) |
|
|
| total_gb = mx.get_active_memory() / 1e9 |
| print(f" Total active memory: {total_gb:.2f} GB") |
| print("Ready!") |
|
|
| def encode_image(self, image_path): |
| """Encode an image into projected embeddings. |
| |
| Returns: |
| (image_features, n_tokens) where image_features is |
| mx.array [1, n_tokens, text_hidden_size] |
| """ |
| from PIL import Image |
|
|
| img = Image.open(image_path).convert("RGB") |
| processed = self.image_processor(images=[img], return_tensors="mlx") |
|
|
| |
| if isinstance(processed, tuple): |
| pv_obj = processed[0] |
| n_tokens_list = processed[1] |
| else: |
| pv_obj = processed |
| n_tokens_list = None |
|
|
| if isinstance(pv_obj, dict): |
| pixel_values = pv_obj["pixel_values"] |
| else: |
| pixel_values = pv_obj |
|
|
| if not isinstance(pixel_values, mx.array): |
| pixel_values = mx.array(np.array(pixel_values)) |
|
|
| |
| image_features = self.vision_tower(pixel_values) |
| image_features = self.embed_vision(image_features) |
| mx.eval(image_features) |
|
|
| n_tokens = n_tokens_list[0] if n_tokens_list else image_features.shape[1] |
| return image_features, n_tokens |
|
|
| def encode_chat(self, prompt, image_path=None): |
| """Build the input token sequence with optional image placeholders. |
| |
| Returns: (token_ids, image_features_or_None, n_image_tokens) |
| """ |
| NL = chr(10) |
| prompt_toks = self.sniper.tokenizer.encode(prompt).ids |
| user_toks = self.sniper.tokenizer.encode("user" + NL).ids |
| model_toks = self.sniper.tokenizer.encode("model" + NL).ids |
|
|
| image_features = None |
| n_image_tokens = 0 |
|
|
| if image_path: |
| image_features, n_image_tokens = self.encode_image(image_path) |
| tokens = ( |
| [2, 105] |
| + user_toks |
| + [BOI_TOKEN_ID] |
| + [IMAGE_TOKEN_ID] * n_image_tokens |
| + [EOI_TOKEN_ID] |
| + prompt_toks |
| + [106, 107, 105] |
| + model_toks |
| ) |
| else: |
| tokens = ( |
| [2, 105] + user_toks + prompt_toks |
| + [106, 107, 105] + model_toks |
| ) |
|
|
| return tokens, image_features, n_image_tokens |
|
|
| def generate(self, prompt, image_path=None, max_tokens=200, temperature=0.7, |
| on_chunk=None): |
| """Generate a response, optionally conditioned on an image. |
| |
| on_chunk: optional callback that receives streaming text chunks. |
| """ |
| tokens, image_features, n_image_tokens = self.encode_chat(prompt, image_path) |
|
|
| |
| next_token = self._prefill(tokens, image_features) |
|
|
| |
| generated = [next_token] |
| eos_set = {1, 106} |
|
|
| from collections import deque |
| recent = deque(maxlen=64) |
| recent.append(next_token) |
| printed = "" |
|
|
| for step in range(max_tokens - 1): |
| |
| full = self.sniper.tokenizer.decode(generated) |
| new_chunk = full[len(printed):] |
| if new_chunk and on_chunk: |
| on_chunk(new_chunk) |
| printed = full |
|
|
| if next_token in eos_set: |
| break |
|
|
| input_ids = mx.array([[next_token]]) |
| logits = self.sniper.forward(input_ids) |
| mx.eval(logits) |
|
|
| |
| if recent: |
| last_logits = logits[0, -1] |
| last_np = np.array(last_logits.astype(mx.float32)) |
| for tid in set(recent): |
| if last_np[tid] > 0: |
| last_np[tid] /= 1.1 |
| else: |
| last_np[tid] *= 1.1 |
| last_logits = mx.array(last_np) |
| else: |
| last_logits = logits[0, -1] |
|
|
| if temperature <= 0: |
| next_token = int(mx.argmax(last_logits).item()) |
| else: |
| probs = mx.softmax(last_logits / temperature, axis=-1) |
| next_token = int(mx.random.categorical(mx.log(probs + 1e-10)).item()) |
|
|
| generated.append(next_token) |
| recent.append(next_token) |
|
|
| |
| full = self.sniper.tokenizer.decode(generated) |
| new_chunk = full[len(printed):] |
| if new_chunk and on_chunk: |
| on_chunk(new_chunk) |
|
|
| return full |
|
|
| def _prefill(self, input_token_ids, image_features): |
| """Run the first forward pass with image embeddings injected.""" |
| from mlx_lm.models.base import create_attention_mask |
| from mlx_vlm.models.gemma4.gemma4 import masked_scatter |
| from moe_agent_gemma4 import run_expert_ffn |
|
|
| self.sniper.reset_cache() |
|
|
| input_ids = mx.array([input_token_ids]) |
|
|
| |
| h = self.sniper.model.model.embed_tokens(input_ids) |
| h = h * (self.sniper.model.args.hidden_size ** 0.5) |
|
|
| |
| if image_features is not None: |
| image_mask = (input_ids == IMAGE_TOKEN_ID) |
| image_feats_flat = image_features.reshape(-1, image_features.shape[-1]) |
| image_feats_flat = image_feats_flat.astype(h.dtype) |
| image_mask_expanded = mx.expand_dims(image_mask, -1) |
| image_mask_expanded = mx.broadcast_to(image_mask_expanded, h.shape) |
| h = masked_scatter(h, image_mask_expanded, image_feats_flat) |
|
|
| |
| |
| mask = create_attention_mask(h, self.sniper.cache[0] if self.sniper.cache else None) |
|
|
| for i in range(self.sniper.num_layers): |
| layer = self.sniper.model.model.layers[i] |
| cache_i = self.sniper.cache[i] if self.sniper.cache else None |
|
|
| residual = h |
| h_norm = layer.input_layernorm(h) |
| h_attn = layer.self_attn(h_norm, mask=mask, cache=cache_i) |
| h_attn = layer.post_attention_layernorm(h_attn) |
| h = residual + h_attn |
| mx.eval(h) |
|
|
| residual = h |
| h_ff = layer.pre_feedforward_layernorm(h) |
| h_ff = layer.mlp(h_ff) |
|
|
| if layer.enable_moe_block: |
| h_dense = layer.post_feedforward_layernorm_1(h_ff) |
| B, L, D = residual.shape |
| residual_flat = residual.reshape(-1, D) |
|
|
| router = layer.router |
| x_normed = router._inline_rms_norm(residual_flat) |
| x_normed = x_normed * router.scale * (router.hidden_size ** -0.5) |
| scores = router.proj(x_normed) |
| probs = mx.softmax(scores, axis=-1) |
| top_k_indices = mx.argpartition(-probs, kth=router.top_k - 1, axis=-1)[..., :router.top_k] |
| top_k_weights = mx.take_along_axis(probs, top_k_indices, axis=-1) |
| top_k_weights = top_k_weights / mx.sum(top_k_weights, axis=-1, keepdims=True) |
| expert_scales = router.per_expert_scale[top_k_indices] |
| top_k_weights = top_k_weights * expert_scales |
|
|
| moe_input = layer.pre_feedforward_layernorm_2(residual_flat) |
| mx.eval(moe_input, top_k_indices, top_k_weights) |
| top_k_indices_r = top_k_indices.reshape(B, L, -1) |
| top_k_weights_r = top_k_weights.reshape(B, L, -1) |
| active_ids = list(set(int(e) for e in np.array(top_k_indices_r).flatten())) |
|
|
| expert_data = self.sniper.reader.get_experts(i, active_ids) |
| moe_input_r = moe_input.reshape(B, L, D) |
| expert_out = run_expert_ffn(moe_input_r, expert_data, top_k_indices_r, top_k_weights_r) |
| h_moe = layer.post_feedforward_layernorm_2(expert_out) |
| h_ff = h_dense + h_moe |
|
|
| h_ff = layer.post_feedforward_layernorm(h_ff) |
| h = residual + h_ff |
| h = h * layer.layer_scalar |
| mx.eval(h) |
| mx.clear_cache() |
|
|
| h = self.sniper.model.model.norm(h) |
|
|
| if self.sniper.model.args.tie_word_embeddings: |
| logits = self.sniper.model.model.embed_tokens.as_linear(h) |
| else: |
| logits = self.sniper.model.lm_head(h) |
| mx.eval(logits) |
|
|
| |
| return int(mx.argmax(logits[0, -1]).item()) |
|
|