Spaces:
Sleeping
Sleeping
File size: 7,565 Bytes
e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 0032edf e859d40 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | import os
import yaml
import gc
import torch
import shutil
import sys
import warnings
from pathlib import Path
# --- SILENCE PYDANTIC WARNINGS ---
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
# --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS ---
import pydantic
from pydantic import ConfigDict, BaseModel
BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)
try:
# Standard Merging
from mergekit.config import MergeConfiguration
from mergekit.merge import run_merge, MergeOptions
# MoE Merging
from mergekit.moe.config import MoEMergeConfig
from mergekit.scripts.moe import build as build_moe
# Raw PyTorch Merging
from mergekit.scripts.merge_raw_pytorch import RawPyTorchMergeConfig, plan_flat_merge
from mergekit.graph import Executor
except ImportError:
print("Warning: mergekit not installed. Please install it via requirements.txt")
def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
"""
Executes a MergeKit run.
"""
# Force garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Shared Options
merge_opts = MergeOptions(
device=device,
copy_tokenizer=True,
lazy_unpickle=True,
low_cpu_memory=True,
max_shard_size=int(shard_gb * 1024**3),
allow_crimes=True
)
# --- BRANCH 1: MIXTURE OF EXPERTS (MoE) ---
if "experts" in config_dict:
print("🚀 Detected MoE Configuration.")
try:
# Validate using the specific MoE Schema
conf = MoEMergeConfig.model_validate(config_dict)
# Execute using the build function from mergekit.scripts.moe
build_moe(
config=conf,
out_path=out_path,
merge_options=merge_opts,
load_in_4bit=False,
load_in_8bit=False,
device=device,
verbose=True
)
print("✅ MoE Construction Complete.")
except Exception as e:
raise RuntimeError(f"MoE Build Failed: {e}")
# --- BRANCH 2: STANDARD MERGE ---
else:
print("⚡ Detected Standard Merge Configuration.")
try:
# Validate using the Standard Schema
conf = MergeConfiguration.model_validate(config_dict)
run_merge(
conf,
out_path=out_path,
device=device,
low_cpu_mem=True,
copy_tokenizer=True,
lazy_unpickle=True,
max_shard_size=int(shard_gb * 1024**3)
)
print("✅ Standard Merge Complete.")
except pydantic.ValidationError as e:
raise ValueError(f"Invalid Merge Configuration: {e}")
except Exception as e:
raise RuntimeError(f"Merge Failed: {e}")
gc.collect()
def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"):
"""
Executes a Raw PyTorch merge for non-transformer models.
"""
print("🧠 Executing Raw PyTorch Merge...")
try:
conf = RawPyTorchMergeConfig.model_validate(config_dict)
merge_opts = MergeOptions(
device=device,
low_cpu_memory=True,
out_shard_size=int(shard_gb * 1024**3),
lazy_unpickle=True,
safe_serialization=True
)
tasks = plan_flat_merge(
conf,
out_path,
tensor_union=False,
tensor_intersection=False,
options=merge_opts
)
executor = Executor(
tasks,
math_device=device,
storage_device="cpu"
)
executor.execute()
print("✅ Raw PyTorch Merge Complete.")
except Exception as e:
raise RuntimeError(f"Raw Merge Failed: {e}")
finally:
gc.collect()
def build_full_merge_config(
method, models, base_model, weights, density,
dtype, tokenizer_source, layer_ranges
):
config = {
"merge_method": method.lower(),
"base_model": base_model if base_model else models[0],
"dtype": dtype,
"tokenizer_source": tokenizer_source,
"models": []
}
w_list = []
if weights:
try:
w_list = [float(x.strip()) for x in weights.split(',')]
except: pass
for i, m in enumerate(models):
entry = {"model": m, "parameters": {}}
if method.lower() in ["ties", "dare_ties", "dare_linear"]:
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
entry["parameters"]["density"] = density
elif method.lower() in ["slerp", "linear"]:
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
config["models"].append(entry)
if layer_ranges and layer_ranges.strip():
try:
extra_params = yaml.safe_load(layer_ranges)
if isinstance(extra_params, dict):
config.update(extra_params)
except Exception as e:
print(f"Error parsing layer ranges JSON: {e}")
return config
def build_moe_config(
base_model, experts, prompts, gate_mode, dtype,
tokenizer_source, shared_experts=None
):
"""
Constructs the YAML dictionary for MoE.
Key Logic based on MergeKit source:
- 'random'/'uniform_random' modes do NOT require prompts.
- 'hidden'/'cheap_embed' modes REQUIRE prompts.
- Qwen2 MoE requires exactly one shared expert.
- Mixtral requires ZERO shared experts.
"""
config = {
"base_model": base_model,
"gate_mode": gate_mode,
"dtype": dtype,
"tokenizer_source": tokenizer_source,
"experts": []
}
# Handle Experts
if len(prompts) < len(experts):
prompts += [""] * (len(experts) - len(prompts))
for i, exp in enumerate(experts):
expert_entry = {"source_model": exp}
# Only attach prompts if they exist.
# mergekit.moe.config.is_bad_config will fail if prompts are missing
# BUT ONLY IF gate_mode != "random".
if prompts[i].strip():
expert_entry["positive_prompts"] = [prompts[i].strip()]
config["experts"].append(expert_entry)
# Handle Shared Experts (Required for Qwen2, Optional for DeepSeek)
if shared_experts:
config["shared_experts"] = []
for sh_exp in shared_experts:
# Shared experts usually don't use gating prompts in MergeKit implementations
# (DeepSeek forbids them, Qwen2 requires them if not random)
# We add a basic entry here; users might need advanced YAML editing for complex shared gating.
config["shared_experts"].append({"source_model": sh_exp})
return config
def build_raw_config(method, models, base_model, dtype, weights):
config = {
"merge_method": method.lower(),
"dtype": dtype,
"models": []
}
if base_model:
config["base_model"] = base_model
w_list = []
if weights:
try:
w_list = [float(x.strip()) for x in weights.split(',')]
except: pass
for i, m in enumerate(models):
entry = {"model": m, "parameters": {}}
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
config["models"].append(entry)
return config |