| | |
| | """ |
| | Merges multiple models and their dependencies into a single model |
| | using multiple merge yaml documents in a single yaml file as the input |
| | """ |
| |
|
| | import logging |
| | import os |
| | import sys |
| | from pathlib import Path |
| |
|
| | import click |
| | import yaml |
| |
|
| | from mergekit.config import MergeConfiguration |
| | from mergekit.merge import MergeOptions, run_merge |
| | from mergekit.options import add_merge_options |
| |
|
| | merges = {} |
| |
|
| |
|
| | def has_circular_dependency(nodes): |
| | """ |
| | Detects circular in merges dependencies using DFS |
| | Returns the node where the cycle is detected |
| | """ |
| |
|
| | def dfs(node, visited, stack): |
| | """ |
| | Returns True if a cycle is detected |
| | """ |
| | visited[node] = True |
| | stack[node] = True |
| |
|
| | for dependency in nodes[node]["deps"]: |
| | if not visited[dependency]: |
| | if dfs(dependency, visited, stack): |
| | return True |
| | elif stack[dependency]: |
| | return True |
| |
|
| | stack[node] = False |
| | return False |
| |
|
| | visited = {key: False for key in nodes} |
| | stack = {key: False for key in nodes} |
| |
|
| | for node in nodes: |
| | if not visited[node]: |
| | if dfs(node, visited, stack): |
| | return node |
| |
|
| | return None |
| |
|
| |
|
| | def merge(m: str, merge_options: MergeOptions, force: bool, out_path: Path): |
| | """ |
| | Merges a model and its dependencies |
| | |
| | Params: |
| | m: name of the model to merge |
| | merge_options: MergeOptions |
| | force: overwrite existing merge results |
| | out_path: output path |
| | """ |
| | |
| | if os.path.exists(out_path / m): |
| | if not force: |
| | logging.info("Skipping %s as it already exists", m) |
| | del merges[m] |
| | return |
| | logging.info("Overwriting %s as --force was specified", m) |
| |
|
| | if len(merges[m]["deps"]) != 0: |
| | for dep in merges[m]["deps"]: |
| | if dep in merges: |
| | merge(dep, merge_options, force, out_path) |
| |
|
| | logging.info("Merging model %s", m) |
| | merge_config: MergeConfiguration = MergeConfiguration.model_validate(merges[m]) |
| | run_merge( |
| | merge_config, |
| | str(out_path / merges[m]["name"]), |
| | options=merge_options, |
| | ) |
| | del merges[m] |
| |
|
| |
|
| | def add_model_deps(model: str, name: str, out_path: Path): |
| | """ |
| | Adds a model to `name`s dependencies if it is not already there and is a merge |
| | """ |
| | model_lora = model.split("+") |
| | |
| | |
| | if "/" not in model_lora[0]: |
| | |
| | if model_lora[0] not in merges[name]["deps"]: |
| | merges[name]["deps"].append(model_lora[0]) |
| | model = str(out_path / model_lora[0]) |
| | if len(model_lora) == 2: |
| | model += "+" + model_lora[1] |
| |
|
| | return model |
| |
|
| |
|
| | @click.command("mergekit-mega") |
| | @click.argument("config_file") |
| | @click.argument("out_path") |
| | @click.option( |
| | "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging" |
| | ) |
| | @click.option( |
| | "--force", |
| | "-f", |
| | type=bool, |
| | default=False, |
| | is_flag=True, |
| | help="Overwrite existing merge results instead of skipping them", |
| | ) |
| | @click.option( |
| | "--require-nameless", |
| | "-R", |
| | type=bool, |
| | default=False, |
| | is_flag=True, |
| | help="Enforces exactly one unnamed merge in the YAML, which will inherit the input file's name.", |
| | ) |
| | @add_merge_options |
| | def main( |
| | merge_options: MergeOptions, |
| | config_file: str, |
| | out_path: str, |
| | force: bool, |
| | verbose: bool, |
| | require_nameless: bool, |
| | ): |
| | """ |
| | Main entrypoint for mergekit-mega command see module docstring for more info |
| | Params are supplied by click decorators |
| | """ |
| | logging.basicConfig(level=logging.INFO if verbose else logging.WARNING) |
| |
|
| | out_path = Path(out_path) |
| | final_found = False |
| |
|
| | with open(config_file, "r", encoding="utf-8") as f: |
| | data = yaml.load_all(f, Loader=yaml.FullLoader) |
| |
|
| | for d in data: |
| | if "name" not in d: |
| | if final_found: |
| | logging.error("Only one merge must not have a name") |
| | sys.exit(1) |
| | |
| | d["name"] = os.path.basename(config_file).rsplit(".", maxsplit=1)[0] |
| | final_found = True |
| |
|
| | if "/" in d["name"]: |
| | logging.error("name must not contain a slash") |
| | sys.exit(1) |
| |
|
| | merges[d["name"]] = d |
| | merges[d["name"]]["deps"] = [] |
| | if "base_model" in d: |
| | d["base_model"] = add_model_deps(d["base_model"], d["name"], out_path) |
| | if "slices" in d: |
| | for slc in d["slices"]: |
| | for src in slc["sources"]: |
| | src["model"] = add_model_deps(src["model"], d["name"], out_path) |
| | if "models" in d: |
| | for mdl in d["models"]: |
| | mdl["model"] = add_model_deps(mdl["model"], d["name"], out_path) |
| |
|
| | if require_nameless and not final_found: |
| | logging.error("No final merge found") |
| | sys.exit(1) |
| |
|
| | logging.info("Merging: %s", ", ".join(merges)) |
| |
|
| | if (dep := has_circular_dependency(merges)) is not None: |
| | logging.error("Circular dependency detected: %s", dep) |
| | sys.exit(1) |
| |
|
| | while len(merges) != 0: |
| | m = list(merges.keys())[0] |
| | merge(m, merge_options, force, out_path) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|