import os import yaml import gc import torch import shutil import sys import warnings from pathlib import Path # --- SILENCE PYDANTIC WARNINGS --- warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS --- import pydantic from pydantic import ConfigDict, BaseModel BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True) try: # Standard Merging from mergekit.config import MergeConfiguration from mergekit.merge import run_merge, MergeOptions # MoE Merging from mergekit.moe.config import MoEMergeConfig from mergekit.scripts.moe import build as build_moe # Raw PyTorch Merging from mergekit.scripts.merge_raw_pytorch import RawPyTorchMergeConfig, plan_flat_merge from mergekit.graph import Executor except ImportError: print("Warning: mergekit not installed. Please install it via requirements.txt") def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"): """ Executes a MergeKit run. """ # Force garbage collection gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() # Shared Options merge_opts = MergeOptions( device=device, copy_tokenizer=True, lazy_unpickle=True, low_cpu_memory=True, max_shard_size=int(shard_gb * 1024**3), allow_crimes=True ) # --- BRANCH 1: MIXTURE OF EXPERTS (MoE) --- if "experts" in config_dict: print("🚀 Detected MoE Configuration.") try: # Validate using the specific MoE Schema conf = MoEMergeConfig.model_validate(config_dict) # Execute using the build function from mergekit.scripts.moe build_moe( config=conf, out_path=out_path, merge_options=merge_opts, load_in_4bit=False, load_in_8bit=False, device=device, verbose=True ) print("✅ MoE Construction Complete.") except Exception as e: raise RuntimeError(f"MoE Build Failed: {e}") # --- BRANCH 2: STANDARD MERGE --- else: print("⚡ Detected Standard Merge Configuration.") try: # Validate using the Standard Schema conf = MergeConfiguration.model_validate(config_dict) run_merge( conf, out_path=out_path, device=device, low_cpu_mem=True, copy_tokenizer=True, lazy_unpickle=True, max_shard_size=int(shard_gb * 1024**3) ) print("✅ Standard Merge Complete.") except pydantic.ValidationError as e: raise ValueError(f"Invalid Merge Configuration: {e}") except Exception as e: raise RuntimeError(f"Merge Failed: {e}") gc.collect() def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"): """ Executes a Raw PyTorch merge for non-transformer models. """ print("🧠 Executing Raw PyTorch Merge...") try: conf = RawPyTorchMergeConfig.model_validate(config_dict) merge_opts = MergeOptions( device=device, low_cpu_memory=True, out_shard_size=int(shard_gb * 1024**3), lazy_unpickle=True, safe_serialization=True ) tasks = plan_flat_merge( conf, out_path, tensor_union=False, tensor_intersection=False, options=merge_opts ) executor = Executor( tasks, math_device=device, storage_device="cpu" ) executor.execute() print("✅ Raw PyTorch Merge Complete.") except Exception as e: raise RuntimeError(f"Raw Merge Failed: {e}") finally: gc.collect() def build_full_merge_config( method, models, base_model, weights, density, dtype, tokenizer_source, layer_ranges ): config = { "merge_method": method.lower(), "base_model": base_model if base_model else models[0], "dtype": dtype, "tokenizer_source": tokenizer_source, "models": [] } w_list = [] if weights: try: w_list = [float(x.strip()) for x in weights.split(',')] except: pass for i, m in enumerate(models): entry = {"model": m, "parameters": {}} if method.lower() in ["ties", "dare_ties", "dare_linear"]: entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 entry["parameters"]["density"] = density elif method.lower() in ["slerp", "linear"]: entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 config["models"].append(entry) if layer_ranges and layer_ranges.strip(): try: extra_params = yaml.safe_load(layer_ranges) if isinstance(extra_params, dict): config.update(extra_params) except Exception as e: print(f"Error parsing layer ranges JSON: {e}") return config def build_moe_config( base_model, experts, prompts, gate_mode, dtype, tokenizer_source, shared_experts=None ): """ Constructs the YAML dictionary for MoE. Key Logic based on MergeKit source: - 'random'/'uniform_random' modes do NOT require prompts. - 'hidden'/'cheap_embed' modes REQUIRE prompts. - Qwen2 MoE requires exactly one shared expert. - Mixtral requires ZERO shared experts. """ config = { "base_model": base_model, "gate_mode": gate_mode, "dtype": dtype, "tokenizer_source": tokenizer_source, "experts": [] } # Handle Experts if len(prompts) < len(experts): prompts += [""] * (len(experts) - len(prompts)) for i, exp in enumerate(experts): expert_entry = {"source_model": exp} # Only attach prompts if they exist. # mergekit.moe.config.is_bad_config will fail if prompts are missing # BUT ONLY IF gate_mode != "random". if prompts[i].strip(): expert_entry["positive_prompts"] = [prompts[i].strip()] config["experts"].append(expert_entry) # Handle Shared Experts (Required for Qwen2, Optional for DeepSeek) if shared_experts: config["shared_experts"] = [] for sh_exp in shared_experts: # Shared experts usually don't use gating prompts in MergeKit implementations # (DeepSeek forbids them, Qwen2 requires them if not random) # We add a basic entry here; users might need advanced YAML editing for complex shared gating. config["shared_experts"].append({"source_model": sh_exp}) return config def build_raw_config(method, models, base_model, dtype, weights): config = { "merge_method": method.lower(), "dtype": dtype, "models": [] } if base_model: config["base_model"] = base_model w_list = [] if weights: try: w_list = [float(x.strip()) for x in weights.split(',')] except: pass for i, m in enumerate(models): entry = {"model": m, "parameters": {}} entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 config["models"].append(entry) return config