# Copyright (C) 2025 Arcee AI # SPDX-License-Identifier: BUSL-1.1 import logging import os from typing import Dict, Optional, Set, Tuple, Union import click import yaml from mergekit.common import ImmutableMap, ModelReference from mergekit.config import MergeConfiguration from mergekit.graph import Executor, Task from mergekit.merge import run_merge from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options logger = logging.getLogger("multimerge") MODEL_CHECK_FILENAMES = [ "model.safetensors", "pytorch_model.bin", "model.safetensors.index.json", "pytorch_model.bin.index.json", ] class MergeModelTask(Task[str]): config_yaml: str name: str input_merges: ImmutableMap[str, "MergeModelTask"] options: MergeOptions out_path: str lazy: bool = True def arguments(self): return {str(key): self.input_merges[key] for key in self.input_merges} def execute(self, **kwargs): if ( self.lazy and os.path.exists(os.path.join(self.out_path, "config.json")) and any( os.path.exists(os.path.join(self.out_path, filename)) for filename in MODEL_CHECK_FILENAMES ) ): logger.info(f"Model already exists at {self.out_path}, skipping") return self.out_path logger.info(f"Running merge for {self.name}") cfg = MergeConfiguration.model_validate(yaml.safe_load(self.config_yaml)) run_merge( cfg, self.out_path, options=self.options, ) logger.info(f"Merge complete for {self.name}") return self.out_path @click.command("mergekit-multimerge", cls=PrettyPrintHelp) @click.argument("config_file", type=click.Path(exists=True)) @click.option( "--out-path", type=click.Path(), required=False, # validated later help="Path to save the final merged model", ) @click.option( "--intermediate-dir", "-I", type=click.Path(), required=True, help="Directory to store intermediate merges", ) @click.option( "--lazy/--no-lazy", default=True, help="Skip merges that already exist", ) @add_merge_options def main( config_file: str, intermediate_dir: str, out_path: Optional[str], lazy: bool, merge_options: MergeOptions, ): """Execute a set of potentially interdependent merge recipes. The configuration file should be a YAML file containing multiple documents, each of which is a merge configuration with the addition of a `name` field. The `intermediate_dir` is used to store intermediate merge results. Any merge configuration with a `name` field will be saved to this directory. If an unnamed merge configuration is present, it will be saved to `out_path` (which is required in this case).""" merge_options.apply_global_options() os.makedirs(intermediate_dir, exist_ok=True) with open(config_file, "r", encoding="utf-8") as file: config_source = file.read() merge_configs, dependencies = load_config(config_source, intermediate_dir) # Validate out_path requirement if None in merge_configs and not out_path: raise click.UsageError( "--out-path is required when configuration contains an unnamed final merge" ) tasks = make_tasks( merge_configs, dependencies, merge_options, intermediate_dir, out_path, lazy ) executor = Executor( tasks, math_device="cpu", storage_device="cpu" ) # inner executors will handle cuda executor.execute(desc="Merging models") def patched_config(config: MergeConfiguration, merge_names: Set[str], working_dir: str): """Replace instances of intermediate merge names with actual paths. Also returns the set of intermediate merge names that were used. Args: config: The configuration to patch merge_names: The set of all merge names working_dir: The directory to use as the base for relative paths """ used = set() def _patch_mr(value: Union[dict, list, str, int, None]): nonlocal used if isinstance(value, list): return [_patch_mr(x) for x in value] elif isinstance(value, dict): if set(value.keys()) == {"model", "lora", "override_architecture"}: # is a ModelReference base = value["model"]["path"] if base in merge_names: value["model"] = value["model"].copy() value["model"]["path"] = os.path.join(working_dir, base) used.add(base) return value return {k: _patch_mr(v) for k, v in value.items()} elif isinstance(value, str): try: mr = ModelReference.model_validate(value) if mr.model.path in merge_names: used.add(mr.model.path) return ModelReference( model={ "path": os.path.join(working_dir, mr.model.path), "revision": mr.model.revision, }, lora=mr.lora, override_architecture=mr.override_architecture, ).model_dump() except ValueError: pass return value new_dict = _patch_mr(config.model_dump()) return MergeConfiguration.model_validate(new_dict), used def make_tasks( merge_configs: Dict[str, MergeConfiguration], dependencies: Dict[str, Set[str]], merge_options: MergeOptions, intermediate_dir: str, out_path: Optional[str], lazy: bool, ): """Build the task dependency graph for the merge recipes.""" touched = set() tasks = {} def _make_task(name: str): nonlocal touched, tasks, out_path if name in tasks: return tasks[name] elif name in touched: raise ValueError(f"Circular dependency detected involving {name}") touched.add(name) if name is None: # out_path validation happens earlier in main() merge_out_path = out_path else: merge_out_path = os.path.join(intermediate_dir, name) tasks[name] = MergeModelTask( config_yaml=merge_configs[name].to_yaml(), name=name or "final merge", input_merges=ImmutableMap( {dep: _make_task(dep) for dep in dependencies[name]} ), options=merge_options, out_path=merge_out_path, lazy=lazy, ) return tasks[name] # Only create tasks that exist in the config (allow missing None) tasks_to_create = [ name for name in merge_configs.keys() if name is not None or out_path ] tasks = [_make_task(name) for name in tasks_to_create] return tasks def load_config( config_source: str, intermediate_dir: str ) -> Tuple[Dict[str, MergeConfiguration], Dict[str, Set[str]]]: """Load the merge configurations from the YAML source. Args: config_source: The YAML source to load intermediate_dir: The directory to use for intermediate merges Returns: A tuple containing: - A dictionary of merge configurations keyed by name - A dictionary of dependencies keyed by name """ docs = list(yaml.safe_load_all(config_source)) merge_configs = {} for doc in docs: if "name" in doc: merge_name = doc.pop("name") else: merge_name = None if merge_name in merge_configs: if merge_name is not None: raise ValueError(f"Duplicate merge name {merge_name}") else: raise ValueError( "Multiple unnamed merge configurations are not supported" ) merge_configs[merge_name] = MergeConfiguration.model_validate(doc) merge_names = set(merge_configs.keys()) dependencies = {} for merge_name in merge_names: merge_config, used_names = patched_config( merge_configs[merge_name], merge_names, intermediate_dir ) merge_configs[merge_name] = merge_config dependencies[merge_name] = used_names return merge_configs, dependencies if __name__ == "__main__": main()