| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import logging |
| | import os |
| | import sys |
| | from typing import Dict, List, Optional, Union |
| |
|
| | import click |
| | import torch |
| | import tqdm |
| | import transformers |
| | import yaml |
| | from pydantic import BaseModel |
| | from transformers import ( |
| | AutoModelForCausalLM, |
| | LlamaForCausalLM, |
| | MistralConfig, |
| | MistralForCausalLM, |
| | MixtralConfig, |
| | ) |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | import mergekit.architecture |
| | from mergekit.common import ModelReference, dtype_from_name |
| | from mergekit.io import LazyTensorLoader, TensorWriter |
| | from mergekit.merge import MergeOptions |
| | from mergekit.options import add_merge_options |
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | class Expert(BaseModel): |
| | source_model: str |
| |
|
| | positive_prompts: List[str] |
| | negative_prompts: Optional[List[str]] = None |
| | noise_scale: Optional[float] = None |
| |
|
| | @property |
| | def model_ref(self): |
| | return ModelReference.parse(self.source_model) |
| |
|
| |
|
| | class MistralMOEConfig(BaseModel): |
| | base_model: str |
| | experts: List[Expert] |
| | gate_mode: str = "hidden" |
| | |
| | |
| | |
| | dtype: Optional[str] = None |
| | experts_per_token: int = 2 |
| |
|
| |
|
| | def get_hidden_states( |
| | model: Union[MistralForCausalLM, LlamaForCausalLM], |
| | tokenized: transformers.BatchEncoding, |
| | average: bool = True, |
| | ) -> List[torch.Tensor]: |
| | with torch.no_grad(): |
| | output: CausalLMOutputWithPast = model( |
| | **tokenized.to(model.device), output_hidden_states=True, return_dict=True |
| | ) |
| | hidden_states = torch.stack( |
| | output.hidden_states[:-1] |
| | ) |
| | if average: |
| | |
| | hidden_states = hidden_states.sum(dim=2) / hidden_states.shape[2] |
| | else: |
| | |
| | hidden_states = hidden_states[:, :, -1, :] |
| | return hidden_states.sum(dim=1) / hidden_states.shape[1] |
| |
|
| |
|
| | def get_cheap_embedding( |
| | embed: torch.Tensor, |
| | tokenized: Dict[str, torch.Tensor], |
| | num_layers: int, |
| | vocab_size: int, |
| | ) -> torch.Tensor: |
| | onehot = torch.nn.functional.one_hot( |
| | tokenized["input_ids"], num_classes=vocab_size |
| | ) |
| | h = onehot.float() @ embed.float() |
| | embedded = ( |
| | (h * tokenized["attention_mask"].unsqueeze(-1)) |
| | .sum(dim=1) |
| | .sum(dim=0, keepdim=True) |
| | ) |
| | res = embedded / embedded.norm(dim=-1, keepdim=True).clamp( |
| | min=1e-8 |
| | ) |
| | return res.repeat(num_layers, 1) |
| |
|
| |
|
| | def tokenize_prompts( |
| | prompts: List[str], tokenizer: transformers.PreTrainedTokenizerBase |
| | ): |
| | return tokenizer( |
| | [(tokenizer.bos_token or "") + p for p in prompts], |
| | return_tensors="pt", |
| | padding=True, |
| | add_special_tokens=False, |
| | ) |
| |
|
| |
|
| | def get_gate_params( |
| | model_ref: ModelReference, |
| | tokenizer: transformers.PreTrainedTokenizerBase, |
| | experts: List[Expert], |
| | mode: str = "hidden", |
| | load_in_4bit: bool = False, |
| | load_in_8bit: bool = False, |
| | lazy_unpickle: bool = False, |
| | trust_remote_code: bool = False, |
| | device: str = "auto", |
| | ): |
| | gate_vecs = [] |
| | _do_it = None |
| |
|
| | model_cfg = model_ref.config(trust_remote_code=trust_remote_code) |
| |
|
| | if mode == "random": |
| | return torch.randn( |
| | (model_cfg.num_hidden_layers, len(experts), model_cfg.hidden_size) |
| | ) |
| | elif mode == "cheap_embed": |
| | embed = model_ref.lazy_loader(lazy_unpickle=lazy_unpickle).get_tensor( |
| | "model.embed_tokens.weight" |
| | ) |
| |
|
| | def _do_it(tokenized): |
| | return get_cheap_embedding( |
| | embed, |
| | tokenized, |
| | num_layers=model_cfg.num_hidden_layers, |
| | vocab_size=model_cfg.vocab_size, |
| | ) |
| |
|
| | elif mode in ("hidden", "hidden_avg", "hidden_last"): |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_ref.model.path, |
| | revision=model_ref.model.revision, |
| | torch_dtype=torch.bfloat16, |
| | device_map=device, |
| | low_cpu_mem_usage=True, |
| | load_in_4bit=load_in_4bit, |
| | load_in_8bit=load_in_8bit, |
| | trust_remote_code=trust_remote_code, |
| | ) |
| |
|
| | def _do_it(tokenized): |
| | return get_hidden_states( |
| | model, tokenized=tokenized, average=mode == "hidden_avg" |
| | ) |
| |
|
| | gate_vecs = [] |
| | for expert in tqdm.tqdm(experts, desc="expert prompts"): |
| | hidden_states = _do_it(tokenize_prompts(expert.positive_prompts, tokenizer)) |
| | if expert.negative_prompts: |
| | hidden_states -= _do_it( |
| | tokenize_prompts(expert.negative_prompts, tokenizer) |
| | ) |
| |
|
| | hidden_states /= hidden_states.norm(p=2, dim=-1, keepdim=True).clamp(min=1e-8) |
| | gate_vecs.append(hidden_states) |
| | gate_vecs = torch.stack(gate_vecs, dim=0) |
| | return gate_vecs.permute(1, 0, 2) |
| |
|
| |
|
| | def warn_degenerate_gates(gate_vecs: torch.Tensor, threshold: float = 5.0): |
| | degen_indices = [] |
| | num_layers, _num_experts, _hidden_size = gate_vecs.shape |
| | for idx in range(num_layers): |
| | c = torch.linalg.cond(gate_vecs[idx, :, :].float()) |
| | if c > threshold: |
| | degen_indices.append(idx) |
| |
|
| | if degen_indices: |
| | if len(degen_indices) == 1: |
| | layer_str = f"layer {degen_indices[0]}" |
| | verb = "has" |
| | elif len(degen_indices) == 2: |
| | layer_str = f"layers {' and '.join(map(str, degen_indices))}" |
| | verb = "have" |
| | elif len(degen_indices) >= num_layers: |
| | layer_str = "ALL layers" |
| | verb = "have" |
| | else: |
| | layer_str = ( |
| | "layers " |
| | + ", ".join(map(str, degen_indices[:-1])) |
| | + ", and " |
| | + str(degen_indices[-1]) |
| | ) |
| | verb = "have" |
| |
|
| | logging.warning( |
| | f"{layer_str} {verb} degenerate routing parameters " |
| | "- your prompts may be too similar." |
| | ) |
| | logging.warning("One or more experts will be underutilized in your model.") |
| |
|
| |
|
| | def is_bad_config(config: MistralMOEConfig, allow_all_same: bool = False) -> bool: |
| | if len(config.experts) < 2: |
| | logging.error("Must include at least two experts.") |
| | return True |
| |
|
| | if config.gate_mode == "random": |
| | return False |
| |
|
| | def prompt_tup(e: Expert): |
| | return (tuple(e.positive_prompts), tuple(e.negative_prompts or [])) |
| |
|
| | |
| | p_first = prompt_tup(config.experts[0]) |
| | if all(prompt_tup(e) == p_first for e in config.experts[1:]): |
| | logging.error( |
| | "Your positive and negative prompts are identical for all experts. This will not produce a functioning MoE." |
| | ) |
| | logging.error( |
| | "For each expert, `positive_prompts` must contain one or more example prompt reflecting what should be routed to that expert." |
| | ) |
| | return True |
| |
|
| | if not allow_all_same: |
| | if all( |
| | e.source_model == config.experts[0].source_model for e in config.experts[1:] |
| | ): |
| | logging.error( |
| | "All of your expert models are the same. This will produce " |
| | "a model that uses more resources but gives the exact same output. " |
| | "If you plan to train the model after merging, proceed with the " |
| | "--i-understand-this-is-not-useful-without-training flag." |
| | ) |
| | return True |
| |
|
| |
|
| | def build( |
| | config: MistralMOEConfig, |
| | out_path: str, |
| | merge_options: MergeOptions, |
| | load_in_4bit: bool = False, |
| | load_in_8bit: bool = False, |
| | device: str = "auto", |
| | allow_all_same: bool = False, |
| | ): |
| | if is_bad_config(config, allow_all_same=allow_all_same): |
| | sys.exit(1) |
| |
|
| | if config.experts_per_token < 1: |
| | logging.error("Experts per token must be >= 1") |
| | sys.exit(1) |
| | if config.experts_per_token > len(config.experts): |
| | logging.error("Experts per token must be <= number of experts") |
| | sys.exit(1) |
| |
|
| | base_model = ModelReference.parse(config.base_model) |
| | base_cfg = base_model.config(trust_remote_code=merge_options.trust_remote_code) |
| | if not isinstance(base_cfg, MistralConfig): |
| | base_cfg_mistral = MistralConfig(**base_cfg.to_dict()) |
| | base_cfg_mistral.sliding_window = None |
| | base_cfg_mistral.max_position_embeddings = base_cfg.max_position_embeddings |
| | base_cfg = base_cfg_mistral |
| |
|
| | out_cfg = MixtralConfig(**base_cfg.to_dict()) |
| | out_cfg.architectures = ["MixtralForCausalLM"] |
| | out_cfg.num_local_experts = len(config.experts) |
| | out_cfg.num_experts_per_tok = config.experts_per_token |
| | out_cfg.sliding_window = None |
| | if config.dtype: |
| | out_cfg.torch_dtype = config.dtype |
| | out_cfg.save_pretrained(out_path) |
| |
|
| | if (out_cfg.num_local_experts & (out_cfg.num_local_experts - 1)) != 0: |
| | logging.warning( |
| | f"Your model has {out_cfg.num_local_experts} experts, which is " |
| | "not a power of two. The model will not be usable in llama.cpp." |
| | ) |
| |
|
| | loaders: Dict[ModelReference, LazyTensorLoader] = {} |
| | for model in tqdm.tqdm( |
| | [base_model] + [e.model_ref for e in config.experts], desc="Warm up loaders" |
| | ): |
| | loaders[model] = model.lazy_loader( |
| | cache_dir=merge_options.transformers_cache, |
| | lazy_unpickle=merge_options.lazy_unpickle, |
| | ) |
| |
|
| | base_loader = loaders.get(base_model) |
| | writer = TensorWriter( |
| | out_path=out_path, |
| | max_shard_size=merge_options.out_shard_size, |
| | safe_serialization=merge_options.safe_serialization, |
| | ) |
| |
|
| | if config.dtype: |
| | out_dtype = dtype_from_name(config.dtype) |
| | elif base_cfg.torch_dtype: |
| | out_dtype = base_cfg.torch_dtype |
| | if isinstance(out_dtype, str): |
| | out_dtype = dtype_from_name(out_dtype) |
| | else: |
| | out_dtype = None |
| |
|
| | logging.info("Copying parameters...") |
| | MISTRAL_INFO = mergekit.architecture.MISTRAL_INFO |
| | for weight_info in MISTRAL_INFO.pre_weights(base_cfg) + MISTRAL_INFO.post_weights( |
| | base_cfg |
| | ): |
| | tensor_name = weight_info.name |
| | tensor = base_loader.get_tensor(tensor_name, aliases=weight_info.aliases) |
| | if not out_dtype: |
| | |
| | out_dtype = tensor.dtype |
| | writer.save_tensor( |
| | tensor_name, tensor.to(dtype=out_dtype), clone=merge_options.clone_tensors |
| | ) |
| |
|
| | for layer_idx in range(base_cfg.num_hidden_layers): |
| | for weight_info in MISTRAL_INFO.layer_weights(index=layer_idx, config=base_cfg): |
| | tensor_name = weight_info.name |
| |
|
| | if ".mlp." in tensor_name: |
| | for moe_index, expert in enumerate(config.experts): |
| | expert_name = tensor_name.replace( |
| | ".mlp.gate_proj", f".block_sparse_moe.experts.{moe_index}.w1" |
| | ) |
| | expert_name = expert_name.replace( |
| | ".mlp.down_proj", f".block_sparse_moe.experts.{moe_index}.w2" |
| | ) |
| | expert_name = expert_name.replace( |
| | ".mlp.up_proj", f".block_sparse_moe.experts.{moe_index}.w3" |
| | ) |
| | expert_loader = loaders.get(expert.model_ref) |
| | tensor = expert_loader.get_tensor( |
| | tensor_name, aliases=weight_info.aliases |
| | ) |
| | if expert.noise_scale: |
| | tensor += torch.randn_like(tensor) * expert.noise_scale |
| | writer.save_tensor( |
| | expert_name, tensor.to(dtype=out_dtype), clone=True |
| | ) |
| | continue |
| | writer.save_tensor( |
| | tensor_name, |
| | base_loader.get_tensor(tensor_name, aliases=weight_info.aliases).to( |
| | dtype=out_dtype |
| | ), |
| | ) |
| |
|
| | tokenizer = transformers.AutoTokenizer.from_pretrained( |
| | base_model.model.path, revision=base_model.model.revision |
| | ) |
| | tokenizer.padding_side = "left" |
| | tokenizer.pad_token_id = tokenizer.bos_token_id |
| | if tokenizer.pad_token_id is None: |
| | tokenizer.pad_token = tokenizer.eos_token |
| |
|
| | logging.info("Getting gate parameters...") |
| | gate_vecs = get_gate_params( |
| | base_model, |
| | tokenizer, |
| | config.experts, |
| | mode=config.gate_mode, |
| | load_in_4bit=load_in_4bit, |
| | load_in_8bit=load_in_8bit, |
| | lazy_unpickle=merge_options.lazy_unpickle, |
| | trust_remote_code=merge_options.trust_remote_code, |
| | device=device, |
| | ) |
| | |
| |
|
| | warn_degenerate_gates(gate_vecs) |
| |
|
| | for layer_idx in range(base_cfg.num_hidden_layers): |
| | writer.save_tensor( |
| | f"model.layers.{layer_idx}.block_sparse_moe.gate.weight", |
| | gate_vecs[layer_idx, :, :].contiguous().to(dtype=out_dtype), |
| | ) |
| | writer.finalize() |
| |
|
| | if merge_options.copy_tokenizer: |
| | logging.info("Saving tokenizer...") |
| | tokenizer.save_pretrained(out_path, safe_serialization=True) |
| |
|
| | logging.info("Done.") |
| |
|
| |
|
| | @click.command("mergekit-moe") |
| | @click.argument("config_path", type=click.Path(exists=True, dir_okay=False)) |
| | @click.argument("out_path", type=click.Path()) |
| | @click.option( |
| | "--load-in-4bit", |
| | is_flag=True, |
| | type=bool, |
| | default=False, |
| | help="Load model in 4bit for computing hidden states", |
| | ) |
| | @click.option( |
| | "--load-in-8bit", |
| | is_flag=True, |
| | type=bool, |
| | default=False, |
| | help="Load model in 8bit for computing hidden states", |
| | ) |
| | @click.option( |
| | "--device", |
| | type=str, |
| | default="auto", |
| | help="Device to use to compute embeddings", |
| | show_default=True, |
| | ) |
| | @click.option( |
| | "--verbose", "-v", type=bool, default=False, is_flag=True, help="Verbose logging" |
| | ) |
| | @click.option( |
| | "--i-understand-this-is-not-useful-without-training", |
| | type=bool, |
| | default=False, |
| | is_flag=True, |
| | help="Really make the questionable model you want.", |
| | ) |
| | @add_merge_options |
| | def main( |
| | config_path: str, |
| | out_path: str, |
| | load_in_4bit: bool, |
| | load_in_8bit: bool, |
| | device: str, |
| | merge_options: MergeOptions, |
| | verbose: bool, |
| | i_understand_this_is_not_useful_without_training: bool, |
| | ): |
| | logging.basicConfig(level=logging.INFO if verbose else logging.WARNING) |
| |
|
| | if merge_options.cuda: |
| | logging.warning( |
| | '--cuda is a no-op for mergekit-moe, use "--device cuda" instead' |
| | ) |
| |
|
| | with open(config_path, "r", encoding="utf-8") as file: |
| | config_source = file.read() |
| |
|
| | config = MistralMOEConfig.model_validate(yaml.safe_load(config_source)) |
| | build( |
| | config, |
| | out_path=out_path, |
| | merge_options=merge_options, |
| | load_in_4bit=load_in_4bit, |
| | load_in_8bit=load_in_8bit, |
| | device=device, |
| | allow_all_same=i_understand_this_is_not_useful_without_training, |
| | ) |
| |
|
| | if merge_options.write_model_card: |
| | |
| | with open( |
| | os.path.join(out_path, "mergekit_moe_config.yml"), "w", encoding="utf-8" |
| | ) as fp: |
| | fp.write(config_source) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|