Spaces:
Sleeping
Sleeping
| import os | |
| import yaml | |
| import gc | |
| import torch | |
| import shutil | |
| import subprocess | |
| import sys | |
| from pathlib import Path | |
| # --- CRITICAL PATCH: MUST RUN BEFORE MERGEKIT IMPORTS --- | |
| import pydantic | |
| from pydantic import ConfigDict, BaseModel | |
| BaseModel.model_config = ConfigDict(arbitrary_types_allowed=True) | |
| try: | |
| from mergekit.config import MergeConfiguration | |
| from mergekit.merge import run_merge | |
| except ImportError: | |
| print("MergeKit not found. Please install 'mergekit' in requirements.txt") | |
| def execute_mergekit_config(config_dict, out_path, shard_gb, device="cpu"): | |
| """ | |
| Executes a MergeKit run. intelligently branches between: | |
| 1. Standard Graph Merges (TIES, SLERP, etc.) -> uses internal run_merge | |
| 2. MoE Construction -> uses mergekit-moe CLI | |
| """ | |
| # Convert dict to YAML string | |
| config_yaml = yaml.dump(config_dict, sort_keys=False) | |
| print("--- Generated MergeKit Config ---") | |
| print(config_yaml) | |
| print("---------------------------------") | |
| # --- BRANCH 1: MIXTURE OF EXPERTS (MoE) --- | |
| if "experts" in config_dict: | |
| print("🚀 Detected MoE Configuration. Switching to 'mergekit-moe' pipeline...") | |
| # 1. Write Config to Temp File (CLI requires a file) | |
| config_path = Path(out_path).parent / "moe_config.yaml" | |
| if not config_path.parent.exists(): | |
| os.makedirs(config_path.parent, exist_ok=True) | |
| with open(config_path, "w") as f: | |
| f.write(config_yaml) | |
| # 2. Build CLI Command | |
| # We use sys.executable to ensure we use the current environment's mergekit | |
| cmd = [ | |
| "mergekit-moe", | |
| str(config_path), | |
| str(out_path), | |
| "--shard-size", f"{int(shard_gb * 1024**3)}", # Bytes | |
| "--copy-tokenizer", | |
| "--trust-remote-code" | |
| ] | |
| # 3. Execute | |
| try: | |
| subprocess.run(cmd, check=True) | |
| print("✅ MoE Construction Complete.") | |
| except subprocess.CalledProcessError as e: | |
| raise RuntimeError(f"MoE Build Failed with exit code {e.returncode}. Check logs.") | |
| finally: | |
| if config_path.exists(): | |
| os.remove(config_path) | |
| # --- BRANCH 2: STANDARD MERGE (TIES, SLERP, ETC.) --- | |
| else: | |
| print("⚡ Detected Standard Merge Configuration. Using internal engine...") | |
| # Validate against the Strict Schema (MergeConfiguration) | |
| # This will fail if 'merge_method' is missing, which is correct for standard merges | |
| try: | |
| conf = MergeConfiguration.model_validate(yaml.safe_load(config_yaml)) | |
| except pydantic.ValidationError as e: | |
| raise ValueError(f"Invalid Merge Configuration: {e}\n(Did you forget 'merge_method'?)") | |
| run_merge( | |
| conf, | |
| out_path=out_path, | |
| device=device, | |
| low_cpu_mem=True, | |
| copy_tokenizer=True, | |
| lazy_unpickle=True, | |
| max_shard_size=int(shard_gb * 1024**3) | |
| ) | |
| # Force cleanup | |
| gc.collect() | |
| def build_full_merge_config( | |
| method, models, base_model, weights, density, | |
| dtype, tokenizer_source, layer_ranges | |
| ): | |
| """ | |
| Constructs the YAML dictionary for general merging (Linear, SLERP, TIES, etc.) | |
| """ | |
| # Basic Config | |
| config = { | |
| "merge_method": method.lower(), | |
| "base_model": base_model if base_model else models[0], | |
| "dtype": dtype, | |
| "tokenizer_source": tokenizer_source, | |
| "models": [] | |
| } | |
| # Helper to parse weights safely | |
| w_list = [] | |
| if weights: | |
| try: | |
| w_list = [float(x.strip()) for x in weights.split(',')] | |
| except: | |
| print("Warning: Could not parse weights, defaulting to 1.0") | |
| # Model Construction | |
| for i, m in enumerate(models): | |
| entry = {"model": m, "parameters": {}} | |
| # Method Specific Param Injection | |
| if method.lower() in ["ties", "dare_ties", "dare_linear"]: | |
| entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 | |
| entry["parameters"]["density"] = density | |
| elif method.lower() == "slerp": | |
| entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 | |
| elif method.lower() == "linear": | |
| entry["parameters"]["weight"] = w_list[i] if i < len(w_list) else 1.0 | |
| config["models"].append(entry) | |
| # Inject Slices/Layer Ranges if provided (Raw JSON override) | |
| if layer_ranges and layer_ranges.strip(): | |
| try: | |
| extra_params = yaml.safe_load(layer_ranges) | |
| if isinstance(extra_params, dict): | |
| config.update(extra_params) | |
| except Exception as e: | |
| print(f"Error parsing layer ranges JSON: {e}") | |
| return config | |
| def build_moe_config( | |
| base_model, experts, gate_mode, dtype, | |
| tokenizer_source, positive_prompts=None | |
| ): | |
| """ | |
| Constructs the YAML dictionary for Mixture of Experts (MoE). | |
| Note: We do NOT add 'merge_method' here because MoE configs | |
| do not use that field in the standard MergeKit schema. | |
| """ | |
| config = { | |
| "base_model": base_model, | |
| "gate_mode": gate_mode, | |
| "dtype": dtype, | |
| "tokenizer_source": tokenizer_source, | |
| "experts": [] | |
| } | |
| # Parse experts | |
| for i, exp in enumerate(experts): | |
| expert_entry = { | |
| "source_model": exp, | |
| "positive_prompts": [f"expert_{i}"] # Placeholder if not provided | |
| } | |
| config["experts"].append(expert_entry) | |
| return config |