waltgrace commited on
Commit
cc1d5e2
·
verified ·
1 Parent(s): 4a30158

Add Gemma 4 MLX model class + preprocess

Browse files
Files changed (1) hide show
  1. src/mlx_expert_sniper/download.py +16 -1
src/mlx_expert_sniper/download.py CHANGED
@@ -45,6 +45,13 @@ MODEL_REGISTRY = {
45
  "default_dir": "qwen3-235b-stream",
46
  "description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
47
  },
 
 
 
 
 
 
 
48
  }
49
 
50
  TENSOR_ORDER = [
@@ -69,6 +76,10 @@ def list_models():
69
  for name, info in MODEL_REGISTRY.items():
70
  if "64+" in info["description"]:
71
  print(f" {name:<22} {info['description']}")
 
 
 
 
72
  print(f"\nUsage: mlx-sniper download <model-name> [-o output_dir]")
73
 
74
 
@@ -114,7 +125,11 @@ def download_model(model_name, output_dir=None, calibrate_quick=True, keep_downl
114
  print(f"Step 2/3: Preprocessing into sniper streaming format...")
115
  print(f" This takes ~5-20 minutes. Shards are deleted after processing to save disk.\n")
116
 
117
- _preprocess(download_dir, output_dir)
 
 
 
 
118
 
119
  # Clean up download dir
120
  if not keep_download:
 
45
  "default_dir": "qwen3-235b-stream",
46
  "description": "Qwen3-235B-A22B 4-bit (~130 GB, 128 experts, needs 64+ GB RAM)",
47
  },
48
+ # Gemma 4 (Google) — NEW ARCHITECTURE
49
+ "gemma4-26b": {
50
+ "repo": "google/gemma-4-26B-A4B-it",
51
+ "default_dir": "gemma4-26b-stream",
52
+ "description": "Gemma 4-26B-A4B bf16 (~50 GB, 128 experts, Google MoE — EXPERIMENTAL)",
53
+ "preprocess": "gemma4",
54
+ },
55
  }
56
 
57
  TENSOR_ORDER = [
 
76
  for name, info in MODEL_REGISTRY.items():
77
  if "64+" in info["description"]:
78
  print(f" {name:<22} {info['description']}")
79
+ print("\n Experimental (new architectures):")
80
+ for name, info in MODEL_REGISTRY.items():
81
+ if "EXPERIMENTAL" in info["description"]:
82
+ print(f" {name:<22} {info['description']}")
83
  print(f"\nUsage: mlx-sniper download <model-name> [-o output_dir]")
84
 
85
 
 
125
  print(f"Step 2/3: Preprocessing into sniper streaming format...")
126
  print(f" This takes ~5-20 minutes. Shards are deleted after processing to save disk.\n")
127
 
128
+ if info.get("preprocess") == "gemma4":
129
+ from .preprocess_gemma4 import preprocess_gemma4
130
+ preprocess_gemma4(download_dir, output_dir)
131
+ else:
132
+ _preprocess(download_dir, output_dir)
133
 
134
  # Clean up download dir
135
  if not keep_download: