import os import yaml import gc import torch import shutil import subprocess import sys from pathlib import Path # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS --- import pydantic from pydantic import ConfigDict, BaseModel BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True) try: from mergekit.config import MergeConfiguration from mergekit.merge import run_merge except ImportError: print("MergeKit not found. Please install 'mergekit' in requirements.txt") def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"): """ Executes a MergeKit run. intelligently branches between: 1. Standard Graph Merges (TIES, SLERP, etc.) -> uses internal run_merge 2. MoE Construction -> uses mergekit-moe CLI """ # Convert dict to YAML string config_yaml = yaml.dump(config_dict, sort_keys=False) print("--- Generated MergeKit Config ---") print(config_yaml) print("---------------------------------") # --- BRANCH 1: MIXTURE OF EXPERTS (MoE) --- if "experts" in config_dict: print("🚀 Detected MoE Configuration. Switching to 'mergekit-moe' pipeline...") # 1. Write Config to Temp File (CLI requires a file) config_path = Path(out_path).parent / "moe_config.yaml" if not config_path.parent.exists(): os.makedirs(config_path.parent, exist_ok=True) with open(config_path, "w") as f: f.write(config_yaml) # 2. Build CLI Command # We use sys.executable to ensure we use the current environment's mergekit cmd = [ "mergekit-moe", str(config_path), str(out_path), "--shard-size", f"{int(shard_gb * 1024**3)}", # Bytes "--copy-tokenizer", "--trust-remote-code" ] # 3. Execute try: subprocess.run(cmd, check=True) print("✅ MoE Construction Complete.") except subprocess.CalledProcessError as e: raise RuntimeError(f"MoE Build Failed with exit code {e.returncode}. Check logs.") finally: if config_path.exists(): os.remove(config_path) # --- BRANCH 2: STANDARD MERGE (TIES, SLERP, ETC.) --- else: print("⚡ Detected Standard Merge Configuration. Using internal engine...") # Validate against the Strict Schema (MergeConfiguration) # This will fail if 'merge_method' is missing, which is correct for standard merges try: conf = MergeConfiguration.model_validate(yaml.safe_load(config_yaml)) except pydantic.ValidationError as e: raise ValueError(f"Invalid Merge Configuration: {e}\n(Did you forget 'merge_method'?)") run_merge( conf, out_path=out_path, device=device, low_cpu_mem=True, copy_tokenizer=True, lazy_unpickle=True, max_shard_size=int(shard_gb * 1024**3) ) # Force cleanup gc.collect() def build_full_merge_config( method, models, base_model, weights, density, dtype, tokenizer_source, layer_ranges ): """ Constructs the YAML dictionary for general merging (Linear, SLERP, TIES, etc.) """ # Basic Config config = { "merge_method": method.lower(), "base_model": base_model if base_model else models[0], "dtype": dtype, "tokenizer_source": tokenizer_source, "models": [] } # Helper to parse weights safely w_list = [] if weights: try: w_list = [float(x.strip()) for x in weights.split(',')] except: print("Warning: Could not parse weights, defaulting to 1.0") # Model Construction for i, m in enumerate(models): entry = {"model": m, "parameters": {}} # Method Specific Param Injection if method.lower() in ["ties", "dare_ties", "dare_linear"]: entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 entry["parameters"]["density"] = density elif method.lower() == "slerp": entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 elif method.lower() == "linear": entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 config["models"].append(entry) # Inject Slices/Layer Ranges if provided (Raw JSON override) if layer_ranges and layer_ranges.strip(): try: extra_params = yaml.safe_load(layer_ranges) if isinstance(extra_params, dict): config.update(extra_params) except Exception as e: print(f"Error parsing layer ranges JSON: {e}") return config def build_moe_config( base_model, experts, gate_mode, dtype, tokenizer_source, positive_prompts=None ): """ Constructs the YAML dictionary for Mixture of Experts (MoE). Note: We do NOT add 'merge_method' here because MoE configs do not use that field in the standard MergeKit schema. """ config = { "base_model": base_model, "gate_mode": gate_mode, "dtype": dtype, "tokenizer_source": tokenizer_source, "experts": [] } # Parse experts for i, exp in enumerate(experts): expert_entry = { "source_model": exp, "positive_prompts": [f"expert_{i}"] # Placeholder if not provided } config["experts"].append(expert_entry) return config