Soon_Merger_Toolkit / merge_utils.py
AlekseyCalvin's picture
Update merge_utils.py
0032edf verified
import os
import yaml
import gc
import torch
import shutil
import sys
import warnings
from pathlib import Path
# --- SILENCE PYDANTIC WARNINGS ---
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic")
# --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS ---
import pydantic
from pydantic import ConfigDict, BaseModel
BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)
try:
# Standard Merging
from mergekit.config import MergeConfiguration
from mergekit.merge import run_merge, MergeOptions
# MoE Merging
from mergekit.moe.config import MoEMergeConfig
from mergekit.scripts.moe import build as build_moe
# Raw PyTorch Merging
from mergekit.scripts.merge_raw_pytorch import RawPyTorchMergeConfig, plan_flat_merge
from mergekit.graph import Executor
except ImportError:
print("Warning: mergekit not installed. Please install it via requirements.txt")
def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
"""
Executes a MergeKit run.
"""
# Force garbage collection
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Shared Options
merge_opts = MergeOptions(
device=device,
copy_tokenizer=True,
lazy_unpickle=True,
low_cpu_memory=True,
max_shard_size=int(shard_gb * 1024**3),
allow_crimes=True
)
# --- BRANCH 1: MIXTURE OF EXPERTS (MoE) ---
if "experts" in config_dict:
print("🚀 Detected MoE Configuration.")
try:
# Validate using the specific MoE Schema
conf = MoEMergeConfig.model_validate(config_dict)
# Execute using the build function from mergekit.scripts.moe
build_moe(
config=conf,
out_path=out_path,
merge_options=merge_opts,
load_in_4bit=False,
load_in_8bit=False,
device=device,
verbose=True
)
print("✅ MoE Construction Complete.")
except Exception as e:
raise RuntimeError(f"MoE Build Failed: {e}")
# --- BRANCH 2: STANDARD MERGE ---
else:
print("⚡ Detected Standard Merge Configuration.")
try:
# Validate using the Standard Schema
conf = MergeConfiguration.model_validate(config_dict)
run_merge(
conf,
out_path=out_path,
device=device,
low_cpu_mem=True,
copy_tokenizer=True,
lazy_unpickle=True,
max_shard_size=int(shard_gb * 1024**3)
)
print("✅ Standard Merge Complete.")
except pydantic.ValidationError as e:
raise ValueError(f"Invalid Merge Configuration: {e}")
except Exception as e:
raise RuntimeError(f"Merge Failed: {e}")
gc.collect()
def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"):
"""
Executes a Raw PyTorch merge for non-transformer models.
"""
print("🧠 Executing Raw PyTorch Merge...")
try:
conf = RawPyTorchMergeConfig.model_validate(config_dict)
merge_opts = MergeOptions(
device=device,
low_cpu_memory=True,
out_shard_size=int(shard_gb * 1024**3),
lazy_unpickle=True,
safe_serialization=True
)
tasks = plan_flat_merge(
conf,
out_path,
tensor_union=False,
tensor_intersection=False,
options=merge_opts
)
executor = Executor(
tasks,
math_device=device,
storage_device="cpu"
)
executor.execute()
print("✅ Raw PyTorch Merge Complete.")
except Exception as e:
raise RuntimeError(f"Raw Merge Failed: {e}")
finally:
gc.collect()
def build_full_merge_config(
method, models, base_model, weights, density,
dtype, tokenizer_source, layer_ranges
):
config = {
"merge_method": method.lower(),
"base_model": base_model if base_model else models[0],
"dtype": dtype,
"tokenizer_source": tokenizer_source,
"models": []
}
w_list = []
if weights:
try:
w_list = [float(x.strip()) for x in weights.split(',')]
except: pass
for i, m in enumerate(models):
entry = {"model": m, "parameters": {}}
if method.lower() in ["ties", "dare_ties", "dare_linear"]:
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
entry["parameters"]["density"] = density
elif method.lower() in ["slerp", "linear"]:
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
config["models"].append(entry)
if layer_ranges and layer_ranges.strip():
try:
extra_params = yaml.safe_load(layer_ranges)
if isinstance(extra_params, dict):
config.update(extra_params)
except Exception as e:
print(f"Error parsing layer ranges JSON: {e}")
return config
def build_moe_config(
base_model, experts, prompts, gate_mode, dtype,
tokenizer_source, shared_experts=None
):
"""
Constructs the YAML dictionary for MoE.
Key Logic based on MergeKit source:
- 'random'/'uniform_random' modes do NOT require prompts.
- 'hidden'/'cheap_embed' modes REQUIRE prompts.
- Qwen2 MoE requires exactly one shared expert.
- Mixtral requires ZERO shared experts.
"""
config = {
"base_model": base_model,
"gate_mode": gate_mode,
"dtype": dtype,
"tokenizer_source": tokenizer_source,
"experts": []
}
# Handle Experts
if len(prompts) < len(experts):
prompts += [""] * (len(experts) - len(prompts))
for i, exp in enumerate(experts):
expert_entry = {"source_model": exp}
# Only attach prompts if they exist.
# mergekit.moe.config.is_bad_config will fail if prompts are missing
# BUT ONLY IF gate_mode != "random".
if prompts[i].strip():
expert_entry["positive_prompts"] = [prompts[i].strip()]
config["experts"].append(expert_entry)
# Handle Shared Experts (Required for Qwen2, Optional for DeepSeek)
if shared_experts:
config["shared_experts"] = []
for sh_exp in shared_experts:
# Shared experts usually don't use gating prompts in MergeKit implementations
# (DeepSeek forbids them, Qwen2 requires them if not random)
# We add a basic entry here; users might need advanced YAML editing for complex shared gating.
config["shared_experts"].append({"source_model": sh_exp})
return config
def build_raw_config(method, models, base_model, dtype, weights):
config = {
"merge_method": method.lower(),
"dtype": dtype,
"models": []
}
if base_model:
config["base_model"] = base_model
w_list = []
if weights:
try:
w_list = [float(x.strip()) for x in weights.split(',')]
except: pass
for i, m in enumerate(models):
entry = {"model": m, "parameters": {}}
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
config["models"].append(entry)
return config