Spaces:
Sleeping
Sleeping
| import os | |
| import yaml | |
| import gc | |
| import torch | |
| import shutil | |
| import sys | |
| import warnings | |
| from pathlib import Path | |
| # --- SILENCE PYDANTIC WARNINGS --- | |
| warnings.filterwarnings("ignore", category=UserWarning, module="pydantic") | |
| # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS --- | |
| import pydantic | |
| from pydantic import ConfigDict, BaseModel | |
| BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True) | |
| try: | |
| # Standard Merging | |
| from mergekit.config import MergeConfiguration | |
| from mergekit.merge import run_merge, MergeOptions | |
| # MoE Merging | |
| from mergekit.moe.config import MoEMergeConfig | |
| from mergekit.scripts.moe import build as build_moe | |
| # Raw PyTorch Merging | |
| from mergekit.scripts.merge_raw_pytorch import RawPyTorchMergeConfig, plan_flat_merge | |
| from mergekit.graph import Executor | |
| except ImportError: | |
| print("Warning: mergekit not installed. Please install it via requirements.txt") | |
| def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"): | |
| """ | |
| Executes a MergeKit run. | |
| """ | |
| # Force garbage collection | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| # Shared Options | |
| merge_opts = MergeOptions( | |
| device=device, | |
| copy_tokenizer=True, | |
| lazy_unpickle=True, | |
| low_cpu_memory=True, | |
| max_shard_size=int(shard_gb * 1024**3), | |
| allow_crimes=True | |
| ) | |
| # --- BRANCH 1: MIXTURE OF EXPERTS (MoE) --- | |
| if "experts" in config_dict: | |
| print("🚀 Detected MoE Configuration.") | |
| try: | |
| # Validate using the specific MoE Schema | |
| conf = MoEMergeConfig.model_validate(config_dict) | |
| # Execute using the build function from mergekit.scripts.moe | |
| build_moe( | |
| config=conf, | |
| out_path=out_path, | |
| merge_options=merge_opts, | |
| load_in_4bit=False, | |
| load_in_8bit=False, | |
| device=device, | |
| verbose=True | |
| ) | |
| print("✅ MoE Construction Complete.") | |
| except Exception as e: | |
| raise RuntimeError(f"MoE Build Failed: {e}") | |
| # --- BRANCH 2: STANDARD MERGE --- | |
| else: | |
| print("⚡ Detected Standard Merge Configuration.") | |
| try: | |
| # Validate using the Standard Schema | |
| conf = MergeConfiguration.model_validate(config_dict) | |
| run_merge( | |
| conf, | |
| out_path=out_path, | |
| device=device, | |
| low_cpu_mem=True, | |
| copy_tokenizer=True, | |
| lazy_unpickle=True, | |
| max_shard_size=int(shard_gb * 1024**3) | |
| ) | |
| print("✅ Standard Merge Complete.") | |
| except pydantic.ValidationError as e: | |
| raise ValueError(f"Invalid Merge Configuration: {e}") | |
| except Exception as e: | |
| raise RuntimeError(f"Merge Failed: {e}") | |
| gc.collect() | |
| def execute_raw_pytorch(config_dict, out_path, shard_gb, device="cpu"): | |
| """ | |
| Executes a Raw PyTorch merge for non-transformer models. | |
| """ | |
| print("🧠 Executing Raw PyTorch Merge...") | |
| try: | |
| conf = RawPyTorchMergeConfig.model_validate(config_dict) | |
| merge_opts = MergeOptions( | |
| device=device, | |
| low_cpu_memory=True, | |
| out_shard_size=int(shard_gb * 1024**3), | |
| lazy_unpickle=True, | |
| safe_serialization=True | |
| ) | |
| tasks = plan_flat_merge( | |
| conf, | |
| out_path, | |
| tensor_union=False, | |
| tensor_intersection=False, | |
| options=merge_opts | |
| ) | |
| executor = Executor( | |
| tasks, | |
| math_device=device, | |
| storage_device="cpu" | |
| ) | |
| executor.execute() | |
| print("✅ Raw PyTorch Merge Complete.") | |
| except Exception as e: | |
| raise RuntimeError(f"Raw Merge Failed: {e}") | |
| finally: | |
| gc.collect() | |
| def build_full_merge_config( | |
| method, models, base_model, weights, density, | |
| dtype, tokenizer_source, layer_ranges | |
| ): | |
| config = { | |
| "merge_method": method.lower(), | |
| "base_model": base_model if base_model else models[0], | |
| "dtype": dtype, | |
| "tokenizer_source": tokenizer_source, | |
| "models": [] | |
| } | |
| w_list = [] | |
| if weights: | |
| try: | |
| w_list = [float(x.strip()) for x in weights.split(',')] | |
| except: pass | |
| for i, m in enumerate(models): | |
| entry = {"model": m, "parameters": {}} | |
| if method.lower() in ["ties", "dare_ties", "dare_linear"]: | |
| entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 | |
| entry["parameters"]["density"] = density | |
| elif method.lower() in ["slerp", "linear"]: | |
| entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 | |
| config["models"].append(entry) | |
| if layer_ranges and layer_ranges.strip(): | |
| try: | |
| extra_params = yaml.safe_load(layer_ranges) | |
| if isinstance(extra_params, dict): | |
| config.update(extra_params) | |
| except Exception as e: | |
| print(f"Error parsing layer ranges JSON: {e}") | |
| return config | |
| def build_moe_config( | |
| base_model, experts, prompts, gate_mode, dtype, | |
| tokenizer_source, shared_experts=None | |
| ): | |
| """ | |
| Constructs the YAML dictionary for MoE. | |
| Key Logic based on MergeKit source: | |
| - 'random'/'uniform_random' modes do NOT require prompts. | |
| - 'hidden'/'cheap_embed' modes REQUIRE prompts. | |
| - Qwen2 MoE requires exactly one shared expert. | |
| - Mixtral requires ZERO shared experts. | |
| """ | |
| config = { | |
| "base_model": base_model, | |
| "gate_mode": gate_mode, | |
| "dtype": dtype, | |
| "tokenizer_source": tokenizer_source, | |
| "experts": [] | |
| } | |
| # Handle Experts | |
| if len(prompts) < len(experts): | |
| prompts += [""] * (len(experts) - len(prompts)) | |
| for i, exp in enumerate(experts): | |
| expert_entry = {"source_model": exp} | |
| # Only attach prompts if they exist. | |
| # mergekit.moe.config.is_bad_config will fail if prompts are missing | |
| # BUT ONLY IF gate_mode != "random". | |
| if prompts[i].strip(): | |
| expert_entry["positive_prompts"] = [prompts[i].strip()] | |
| config["experts"].append(expert_entry) | |
| # Handle Shared Experts (Required for Qwen2, Optional for DeepSeek) | |
| if shared_experts: | |
| config["shared_experts"] = [] | |
| for sh_exp in shared_experts: | |
| # Shared experts usually don't use gating prompts in MergeKit implementations | |
| # (DeepSeek forbids them, Qwen2 requires them if not random) | |
| # We add a basic entry here; users might need advanced YAML editing for complex shared gating. | |
| config["shared_experts"].append({"source_model": sh_exp}) | |
| return config | |
| def build_raw_config(method, models, base_model, dtype, weights): | |
| config = { | |
| "merge_method": method.lower(), | |
| "dtype": dtype, | |
| "models": [] | |
| } | |
| if base_model: | |
| config["base_model"] = base_model | |
| w_list = [] | |
| if weights: | |
| try: | |
| w_list = [float(x.strip()) for x in weights.split(',')] | |
| except: pass | |
| for i, m in enumerate(models): | |
| entry = {"model": m, "parameters": {}} | |
| entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 | |
| config["models"].append(entry) | |
| return config |