Soon_Merger_Toolkit / merge_utils4.py
AlekseyCalvin's picture
Rename merge_utils.py to merge_utils4.py
dda94e6 verified
import os
import yaml
import gc
import torch
import shutil
import subprocess
import sys
from pathlib import Path
# --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS ---
import pydantic
from pydantic import ConfigDict, BaseModel
BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True)
try:
from mergekit.config import MergeConfiguration
from mergekit.merge import run_merge
except ImportError:
print("MergeKit not found. Please install 'mergekit' in requirements.txt")
def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"):
"""
Executes a MergeKit run. intelligently branches between:
1. Standard Graph Merges (TIES, SLERP, etc.) -> uses internal run_merge
2. MoE Construction -> uses mergekit-moe CLI
"""
# Convert dict to YAML string
config_yaml = yaml.dump(config_dict, sort_keys=False)
print("--- Generated MergeKit Config ---")
print(config_yaml)
print("---------------------------------")
# --- BRANCH 1: MIXTURE OF EXPERTS (MoE) ---
if "experts" in config_dict:
print("🚀 Detected MoE Configuration. Switching to 'mergekit-moe' pipeline...")
# 1. Write Config to Temp File (CLI requires a file)
config_path = Path(out_path).parent / "moe_config.yaml"
if not config_path.parent.exists():
os.makedirs(config_path.parent, exist_ok=True)
with open(config_path, "w") as f:
f.write(config_yaml)
# 2. Build CLI Command
# We use sys.executable to ensure we use the current environment's mergekit
cmd = [
"mergekit-moe",
str(config_path),
str(out_path),
"--shard-size", f"{int(shard_gb * 1024**3)}", # Bytes
"--copy-tokenizer",
"--trust-remote-code"
]
# 3. Execute
try:
subprocess.run(cmd, check=True)
print("✅ MoE Construction Complete.")
except subprocess.CalledProcessError as e:
raise RuntimeError(f"MoE Build Failed with exit code {e.returncode}. Check logs.")
finally:
if config_path.exists():
os.remove(config_path)
# --- BRANCH 2: STANDARD MERGE (TIES, SLERP, ETC.) ---
else:
print("⚡ Detected Standard Merge Configuration. Using internal engine...")
# Validate against the Strict Schema (MergeConfiguration)
# This will fail if 'merge_method' is missing, which is correct for standard merges
try:
conf = MergeConfiguration.model_validate(yaml.safe_load(config_yaml))
except pydantic.ValidationError as e:
raise ValueError(f"Invalid Merge Configuration: {e}\n(Did you forget 'merge_method'?)")
run_merge(
conf,
out_path=out_path,
device=device,
low_cpu_mem=True,
copy_tokenizer=True,
lazy_unpickle=True,
max_shard_size=int(shard_gb * 1024**3)
)
# Force cleanup
gc.collect()
def build_full_merge_config(
method, models, base_model, weights, density,
dtype, tokenizer_source, layer_ranges
):
"""
Constructs the YAML dictionary for general merging (Linear, SLERP, TIES, etc.)
"""
# Basic Config
config = {
"merge_method": method.lower(),
"base_model": base_model if base_model else models[0],
"dtype": dtype,
"tokenizer_source": tokenizer_source,
"models": []
}
# Helper to parse weights safely
w_list = []
if weights:
try:
w_list = [float(x.strip()) for x in weights.split(',')]
except:
print("Warning: Could not parse weights, defaulting to 1.0")
# Model Construction
for i, m in enumerate(models):
entry = {"model": m, "parameters": {}}
# Method Specific Param Injection
if method.lower() in ["ties", "dare_ties", "dare_linear"]:
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
entry["parameters"]["density"] = density
elif method.lower() == "slerp":
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
elif method.lower() == "linear":
entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0
config["models"].append(entry)
# Inject Slices/Layer Ranges if provided (Raw JSON override)
if layer_ranges and layer_ranges.strip():
try:
extra_params = yaml.safe_load(layer_ranges)
if isinstance(extra_params, dict):
config.update(extra_params)
except Exception as e:
print(f"Error parsing layer ranges JSON: {e}")
return config
def build_moe_config(
base_model, experts, gate_mode, dtype,
tokenizer_source, positive_prompts=None
):
"""
Constructs the YAML dictionary for Mixture of Experts (MoE).
Note: We do NOT add 'merge_method' here because MoE configs
do not use that field in the standard MergeKit schema.
"""
config = {
"base_model": base_model,
"gate_mode": gate_mode,
"dtype": dtype,
"tokenizer_source": tokenizer_source,
"experts": []
}
# Parse experts
for i, exp in enumerate(experts):
expert_entry = {
"source_model": exp,
"positive_prompts": [f"expert_{i}"] # Placeholder if not provided
}
config["experts"].append(expert_entry)
return config