Add Gemma 4 MLX model class + preprocess
Browse files
src/mlx_expert_sniper/download.py
CHANGED
|
@@ -45,6 +45,13 @@ MODEL_REGISTRY = {
|
|
| 45 |
"default_dir": "qwen3-235b-stream",
|
| 46 |
"description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
|
| 47 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
}
|
| 49 |
|
| 50 |
TENSOR_ORDER = [
|
|
@@ -69,6 +76,10 @@ def list_models():
|
|
| 69 |
for name, info in MODEL_REGISTRY.items():
|
| 70 |
if "64+" in info["description"]:
|
| 71 |
print(f" {name:<22} {info['description']}")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
print(f"\nUsage: mlx-sniper download <model-name> [-o output_dir]")
|
| 73 |
|
| 74 |
|
|
@@ -114,7 +125,11 @@ def download_model(model_name, output_dir=None, calibrate_quick=True, keep_downl
|
|
| 114 |
print(f"Step 2/3: Preprocessing into sniper streaming format...")
|
| 115 |
print(f" This takes ~5-20 minutes. Shards are deleted after processing to save disk.\n")
|
| 116 |
|
| 117 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
|
| 119 |
# Clean up download dir
|
| 120 |
if not keep_download:
|
|
|
|
| 45 |
"default_dir": "qwen3-235b-stream",
|
| 46 |
"description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
|
| 47 |
},
|
| 48 |
+
# Gemma 4 (Google) — NEW ARCHITECTURE
|
| 49 |
+
"gemma4-26b": {
|
| 50 |
+
"repo": "google/gemma-4-26B-A4B-it",
|
| 51 |
+
"default_dir": "gemma4-26b-stream",
|
| 52 |
+
"description": "Gemma 4-26B-A4B bf16 (~50 GB, 128 experts, Google MoE — EXPERIMENTAL)",
|
| 53 |
+
"preprocess": "gemma4",
|
| 54 |
+
},
|
| 55 |
}
|
| 56 |
|
| 57 |
TENSOR_ORDER = [
|
|
|
|
| 76 |
for name, info in MODEL_REGISTRY.items():
|
| 77 |
if "64+" in info["description"]:
|
| 78 |
print(f" {name:<22} {info['description']}")
|
| 79 |
+
print("\n Experimental (new architectures):")
|
| 80 |
+
for name, info in MODEL_REGISTRY.items():
|
| 81 |
+
if "EXPERIMENTAL" in info["description"]:
|
| 82 |
+
print(f" {name:<22} {info['description']}")
|
| 83 |
print(f"\nUsage: mlx-sniper download <model-name> [-o output_dir]")
|
| 84 |
|
| 85 |
|
|
|
|
| 125 |
print(f"Step 2/3: Preprocessing into sniper streaming format...")
|
| 126 |
print(f" This takes ~5-20 minutes. Shards are deleted after processing to save disk.\n")
|
| 127 |
|
| 128 |
+
if info.get("preprocess") == "gemma4":
|
| 129 |
+
from .preprocess_gemma4 import preprocess_gemma4
|
| 130 |
+
preprocess_gemma4(download_dir, output_dir)
|
| 131 |
+
else:
|
| 132 |
+
_preprocess(download_dir, output_dir)
|
| 133 |
|
| 134 |
# Clean up download dir
|
| 135 |
if not keep_download:
|