Danaasa's picture
Upload folder using huggingface_hub
656b04b verified
# Copyright (C) 2025 Arcee AI
# SPDX-License-Identifier: BUSL-1.1
import logging
import os
import sys
from typing import List
import click
import transformers
import yaml
from mergekit.merge import MergeOptions
from mergekit.moe import ALL_OUTPUT_ARCHITECTURES, MoEOutputArchitecture
from mergekit.moe.config import MoEMergeConfig, is_bad_config
from mergekit.moe.router import get_gate_params, warn_degenerate_gates
from mergekit.options import PrettyPrintHelp, add_merge_options
def build(
config: MoEMergeConfig,
out_path: str,
merge_options: MergeOptions,
load_in_4bit: bool = False,
load_in_8bit: bool = False,
device: str = "auto",
allow_all_same: bool = False,
verbose: bool = False,
):
if is_bad_config(config, allow_all_same=allow_all_same):
sys.exit(1)
base_model = config.base_model
out_arch = select_output_arch(config, merge_options, verbose=verbose)
tokenizer = transformers.AutoTokenizer.from_pretrained(
base_model.model.path, revision=base_model.model.revision
)
tokenizer.padding_side = "left"
tokenizer.pad_token_id = tokenizer.bos_token_id
if tokenizer.pad_token_id is None:
tokenizer.pad_token = tokenizer.eos_token
logging.info("Getting gate parameters...")
need_gates = list(config.experts)
if config.shared_experts:
has_prompts = any(e.positive_prompts for e in config.shared_experts)
assert all(
bool(e.positive_prompts) == has_prompts for e in config.shared_experts
), "Must specify prompts for all shared experts or none, not a mix"
if has_prompts or config.gate_mode in ("random", "uniform_random"):
need_gates.extend(config.shared_experts)
gate_vecs = get_gate_params(
base_model,
tokenizer,
need_gates,
mode=config.gate_mode,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
lazy_unpickle=merge_options.lazy_unpickle,
trust_remote_code=merge_options.trust_remote_code,
device=device,
)
# gate_vecs: (num_layers, num_experts, hidden_size)
router_weights = gate_vecs[:, : len(config.experts), :]
shared_router_weights = gate_vecs[:, len(config.experts) :, :]
warn_degenerate_gates(gate_vecs)
out_arch.write_model(
out_path,
config,
merge_options,
router_weights=[router_weights[i, ...] for i in range(router_weights.shape[0])],
shared_router_weights=[
shared_router_weights[i, ...] for i in range(router_weights.shape[0])
],
)
if merge_options.copy_tokenizer:
logging.info("Saving tokenizer...")
tokenizer.save_pretrained(out_path, safe_serialization=True)
logging.info("Done.")
def select_output_arch(
config: MoEMergeConfig,
merge_options: MergeOptions,
verbose: bool = False,
) -> MoEOutputArchitecture:
candidates_in = ALL_OUTPUT_ARCHITECTURES
if config.architecture:
candidates_in = [
a
for a in candidates_in
if a.name().lower().startswith(config.architecture.lower())
]
if not candidates_in:
logging.error(
f"No output architecture found that matches the given architecture: {config.architecture}"
)
logging.error("All supported output architectures:")
for arch in ALL_OUTPUT_ARCHITECTURES:
logging.error(f" * {arch.name()}")
sys.exit(1)
candidates: List[MoEOutputArchitecture] = []
for arch in candidates_in:
if arch.supports_config(
config, explain=verbose, trust_remote_code=merge_options.trust_remote_code
):
candidates.append(arch)
else:
logging.info(f"Output architecture {arch.name()} does not support config.")
if not candidates:
logging.error(
"No output architecture found that is compatible with the given models."
)
logging.error("All supported output architectures:")
for arch in ALL_OUTPUT_ARCHITECTURES:
logging.error(f" * {arch.name()}")
sys.exit(1)
# for compatibility with older configs, default to Mixtral if available
for arch in candidates:
if arch.name() == "Mixtral":
return arch
if len(candidates) > 1:
logging.warning(
"Multiple output architectures found that are compatible with the given models."
)
logging.warning(f"Defaulting to {candidates[0].name()}")
else:
logging.info(f"Selected output architecture: {candidates[0].name()}")
return candidates[0]
@click.command("mergekit-moe", cls=PrettyPrintHelp)
@click.argument("config_path", type=click.Path(exists=True, dir_okay=False))
@click.argument("out_path", type=click.Path())
@click.option(
"--load-in-4bit",
is_flag=True,
type=bool,
default=False,
help="Load model in 4bit for computing hidden states",
)
@click.option(
"--load-in-8bit",
is_flag=True,
type=bool,
default=False,
help="Load model in 8bit for computing hidden states",
)
@click.option(
"--device",
type=str,
default="auto",
help="Device to use to compute embeddings",
show_default=True,
)
@click.option(
"--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging"
)
@click.option(
"--i-understand-this-is-not-useful-without-training",
type=bool,
default=False,
is_flag=True,
help="Really make the questionable model you want.",
)
@add_merge_options
def main(
config_path: str,
out_path: str,
load_in_4bit: bool,
load_in_8bit: bool,
device: str,
verbose: bool,
i_understand_this_is_not_useful_without_training: bool,
merge_options: MergeOptions,
):
"""Create a Mixture of Experts model by combining the pretrained weights of multiple models."""
merge_options.apply_global_options()
if merge_options.cuda:
logging.warning(
'--cuda is a no-op for mergekit-moe, use "--device cuda" instead'
)
with open(config_path, "r", encoding="utf-8") as file:
config_source = file.read()
config = MoEMergeConfig.model_validate(yaml.safe_load(config_source))
build(
config,
out_path=out_path,
merge_options=merge_options,
load_in_4bit=load_in_4bit,
load_in_8bit=load_in_8bit,
device=device,
allow_all_same=i_understand_this_is_not_useful_without_training,
verbose=verbose,
)
if merge_options.write_model_card:
# TODO: generate a README.md as well
with open(
os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8"
) as fp:
fp.write(config_source)
if __name__ == "__main__":
main()