| |
| |
|
|
| import json |
| import logging |
| import os |
| import re |
| import sys |
| from typing import Any, Dict, List, Optional, Tuple |
|
|
| import click |
| import torch |
| import torch.nn as nn |
| import tqdm |
| from pydantic import BaseModel |
| from transformers import AutoModelForCausalLM |
|
|
| from mergekit.architecture import ArchitectureInfoUtils, WeightInfo |
| from mergekit.card import generate_card_lora |
| from mergekit.common import ModelReference |
| from mergekit.graph import Executor, Task |
| from mergekit.io.tasks import FinalizeModel, LoadTensor, SaveTensor, TensorWriterTask |
| from mergekit.io.tensor_writer import TensorWriter |
| from mergekit.multigpu_executor import MultiGPUExecutor |
| from mergekit.options import MergeOptions, PrettyPrintHelp, add_merge_options |
|
|
| logger = logging.getLogger("extract_lora") |
|
|
|
|
| @click.command("mergekit-extract-lora", cls=PrettyPrintHelp) |
| @click.option( |
| "--model", |
| required=True, |
| help="Fine-tuned model path", |
| ) |
| @click.option( |
| "--base-model", |
| required=True, |
| help="Base model path", |
| ) |
| @click.option( |
| "--out-path", |
| required=True, |
| help="Output path for extracted LoRA adapter", |
| ) |
| @click.option( |
| "--max-rank", |
| type=int, |
| default=128, |
| help="Maximum rank for LoRA decomposition", |
| ) |
| @click.option( |
| "--distribute-scale/--no-distribute-scale", |
| is_flag=True, |
| default=True, |
| help="Distribute scale between A and B matrices", |
| ) |
| @click.option( |
| "--embed-lora/--no-embed-lora", |
| is_flag=True, |
| default=False, |
| help="Extract LoRA weights for embeddings (vs. in modules_to_save)", |
| ) |
| @click.option( |
| "--save-module", |
| "modules_to_save", |
| type=str, |
| multiple=True, |
| default=[], |
| help="Save the specified module(s) at full rank", |
| ) |
| @click.option( |
| "--exclude-regex", |
| "-e", |
| "exclude_regexes", |
| type=str, |
| multiple=True, |
| help="Exclude modules matching the specified regex", |
| ) |
| @click.option( |
| "--include-regex", |
| "-i", |
| "include_regexes", |
| type=str, |
| multiple=True, |
| help="Include modules matching the specified regex", |
| ) |
| @click.option( |
| "--sv-epsilon", |
| type=float, |
| default=0, |
| help="Threshold for singular values to discard", |
| show_default=True, |
| ) |
| @click.option( |
| "--skip-undecomposable", |
| is_flag=True, |
| help="Skip saving undecomposable modules", |
| default=False, |
| ) |
| @add_merge_options |
| def main( |
| base_model: str, |
| model: str, |
| out_path: str, |
| max_rank: int, |
| distribute_scale: bool, |
| embed_lora: bool, |
| modules_to_save: List[str], |
| exclude_regexes: List[str], |
| include_regexes: List[str], |
| sv_epsilon: float, |
| skip_undecomposable: bool, |
| merge_options: MergeOptions, |
| ): |
| merge_options.apply_global_options() |
|
|
| if not modules_to_save: |
| modules_to_save = [] |
|
|
| base_model_ref = ModelReference.model_validate(base_model) |
| model_ref = ModelReference.model_validate(model) |
| plan_result = plan_extraction( |
| base_model_ref=base_model_ref.merged( |
| cache_dir=merge_options.lora_merge_cache, |
| trust_remote_code=merge_options.trust_remote_code, |
| lora_merge_dtype=merge_options.lora_merge_dtype, |
| ), |
| model_ref=model_ref.merged( |
| cache_dir=merge_options.lora_merge_cache, |
| trust_remote_code=merge_options.trust_remote_code, |
| lora_merge_dtype=merge_options.lora_merge_dtype, |
| ), |
| modules_to_save=modules_to_save, |
| out_path=out_path, |
| options=merge_options, |
| max_rank=max_rank, |
| distribute_scale=distribute_scale, |
| embed_lora=embed_lora, |
| exclude_regexes=exclude_regexes, |
| include_regexes=include_regexes, |
| sv_epsilon=sv_epsilon, |
| skip_undecomposable=skip_undecomposable, |
| ) |
|
|
| tasks = plan_result.tasks |
| if merge_options.multi_gpu: |
| executor = MultiGPUExecutor( |
| tasks, storage_device="cpu" if not merge_options.low_cpu_memory else None |
| ) |
| else: |
| executor = Executor( |
| tasks, |
| math_device="cuda" if merge_options.cuda else "cpu", |
| storage_device="cuda" if merge_options.low_cpu_memory else "cpu", |
| ) |
|
|
| module_real_ranks = {} |
| for task, result in executor.run(): |
| if isinstance(task, TaskVectorDecompositionTask): |
| module_real_ranks[task.weight_info.name.removesuffix(".weight")] = result[ |
| 0 |
| ].shape[0] |
|
|
| real_max_rank = max(module_real_ranks.values()) |
| config_dict = make_config_dict( |
| base_ref=base_model_ref, |
| max_rank=real_max_rank, |
| modules_to_save=modules_to_save, |
| target_modules=list( |
| set(key.split(".")[-1] for key in module_real_ranks.keys()) |
| ), |
| module_ranks=module_real_ranks, |
| ) |
| with open(os.path.join(out_path, "adapter_config.json"), "w") as f: |
| json.dump(config_dict, f, indent=4) |
|
|
| invocation = " ".join(sys.argv) |
| with open(os.path.join(out_path, "README.md"), "w", encoding="utf-8") as f: |
| f.write( |
| generate_card_lora( |
| base_model_ref, |
| model_ref, |
| invocation, |
| os.path.basename(out_path), |
| base_vocab_size=plan_result.base_vocab_size, |
| final_vocab_size=plan_result.final_vocab_size, |
| ) |
| ) |
|
|
| logger.info(f"LoRA adapter extracted to {out_path}") |
|
|
|
|
| def make_config_dict( |
| base_ref: ModelReference, |
| max_rank: int, |
| modules_to_save: List[str], |
| target_modules: List[str], |
| module_ranks: Dict[str, int], |
| ): |
| different_ranked = {k: v for k, v in module_ranks.items() if v != max_rank} |
| return { |
| "base_model_name_or_path": base_ref.model.path, |
| "peft_type": "LORA", |
| "use_rslora": False, |
| "target_modules": target_modules, |
| "modules_to_save": modules_to_save, |
| "task_type": "CAUSAL_LM", |
| "r": max_rank, |
| "lora_alpha": max_rank, |
| "rank_pattern": different_ranked, |
| "alpha_pattern": different_ranked, |
| "lora_dropout": 0.0, |
| "fan_in_fan_out": False, |
| "inference_mode": True, |
| } |
|
|
|
|
| class TaskVectorDecompositionTask(Task[Tuple[torch.Tensor, torch.Tensor]]): |
| weight_info: WeightInfo |
| input_task: Task |
| max_rank: int |
| distribute_scale: bool = True |
| transpose: bool = False |
| sv_epsilon: float = 0 |
|
|
| def arguments(self) -> Dict[str, Any]: |
| return {"task_vector": self.input_task} |
|
|
| def execute(self, task_vector: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: |
| if self.transpose: |
| task_vector = task_vector.T |
| out_dtype = task_vector.dtype |
| u, s, vh = torch.linalg.svd( |
| task_vector.to(dtype=torch.float32), full_matrices=False |
| ) |
| rank = min(self.max_rank, s.shape[0]) |
| if self.sv_epsilon > 0: |
| rank = min((s > self.sv_epsilon).sum().item(), rank) |
| if self.distribute_scale: |
| sqrt_s = torch.diag(torch.sqrt(s[:rank])) |
| scale_a = sqrt_s |
| scale_b = sqrt_s |
| else: |
| scale_a = torch.diag(s[:rank]) |
| scale_b = torch.eye(rank) |
| sqrt_s = torch.diag(torch.sqrt(s[:rank])) |
| weight_a = scale_a @ vh[:rank] |
| weight_b = u[:, :rank] @ scale_b |
|
|
| return weight_a.to(dtype=out_dtype), weight_b.to(dtype=out_dtype) |
|
|
| def group_label(self) -> Optional[str]: |
| return self.input_task.group_label() |
|
|
| def uses_accelerator(self): |
| return True |
|
|
|
|
| class TaskVectorTask(Task[torch.Tensor]): |
| base_tensor: Task |
| model_tensor: Task |
|
|
| def arguments(self) -> Dict[str, Any]: |
| return {"base": self.base_tensor, "model": self.model_tensor} |
|
|
| def execute(self, base: torch.Tensor, model: torch.Tensor) -> torch.Tensor: |
| return model - base |
|
|
| def group_label(self): |
| return max( |
| self.base_tensor.group_label() or "", self.model_tensor.group_label() or "" |
| ) |
|
|
| def uses_accelerator(self): |
| return True |
|
|
|
|
| class LoRAModuleSaveTask(Task): |
| weight_info: WeightInfo |
| writer_task: TensorWriterTask |
| model_ref: ModelReference |
| decomposition_task: TaskVectorDecompositionTask |
|
|
| def arguments(self) -> Dict[str, Any]: |
| return {"writer": self.writer_task, "decomp": self.decomposition_task} |
|
|
| def execute( |
| self, writer: TensorWriter, decomp: Tuple[torch.Tensor, torch.Tensor] |
| ) -> None: |
| weight_a, weight_b = decomp |
| if weight_a is None or weight_b is None: |
| if not self.weight_info.optional: |
| raise RuntimeError( |
| f"No SVD decomposition for required weight {self.weight_info.name}" |
| ) |
| return |
| lora_type = "lora_embedding" if self.decomposition_task.transpose else "lora" |
| lora_suffix = ".weight" if not self.decomposition_task.transpose else "" |
| base_name = self.weight_info.name.removesuffix(".weight") |
| writer.save_tensor( |
| f"base_model.model.{base_name}.{lora_type}_A{lora_suffix}", weight_a |
| ) |
| writer.save_tensor( |
| f"base_model.model.{base_name}.{lora_type}_B{lora_suffix}", weight_b |
| ) |
|
|
| def priority(self) -> int: |
| return 1000 |
|
|
| def group_label(self) -> Optional[str]: |
| return self.decomposition_task.group_label() |
|
|
|
|
| def _wi_load(model_ref: ModelReference, weight_info: WeightInfo) -> LoadTensor: |
| return LoadTensor( |
| model=model_ref, |
| tensor=weight_info.name, |
| dtype=weight_info.force_dtype, |
| optional=weight_info.optional, |
| aliases=weight_info.aliases, |
| tied_names=weight_info.tied_names, |
| ) |
|
|
|
|
| class PlanResults(BaseModel): |
| tasks: List[Task] |
| base_vocab_size: int |
| final_vocab_size: int |
|
|
|
|
| def plan_extraction( |
| base_model_ref: ModelReference, |
| model_ref: ModelReference, |
| modules_to_save: List[str], |
| out_path: str, |
| options: MergeOptions, |
| max_rank: int, |
| distribute_scale: bool = True, |
| embed_lora: bool = False, |
| exclude_regexes: Optional[List[str]] = None, |
| include_regexes: Optional[List[str]] = None, |
| sv_epsilon: float = 0, |
| skip_undecomposable: bool = False, |
| ) -> PlanResults: |
| targets = [] |
| writer_task = TensorWriterTask( |
| out_path=out_path, |
| override_basename="adapter_model", |
| max_shard_size=-1, |
| safe_serialization=options.safe_serialization, |
| ) |
|
|
| name_to_wi = all_weights_map(model_ref, options) |
| dummy_model = AutoModelForCausalLM.from_pretrained( |
| model_ref.model.path, |
| revision=model_ref.model.revision, |
| trust_remote_code=options.trust_remote_code, |
| device_map="meta", |
| state_dict={}, |
| ) |
| dummy_base = AutoModelForCausalLM.from_pretrained( |
| base_model_ref.model.path, |
| revision=base_model_ref.model.revision, |
| trust_remote_code=options.trust_remote_code, |
| device_map="meta", |
| state_dict={}, |
| ) |
|
|
| embed_in = dummy_model.get_input_embeddings() |
| embed_out = dummy_model.get_output_embeddings() |
|
|
| ft_vocab = embed_in.weight.shape[0] |
| base_vocab = dummy_base.get_input_embeddings().weight.shape[0] |
| if ft_vocab != base_vocab and embed_lora: |
| logger.warning( |
| f"Vocabulary size mismatch: fine-tuned model has {ft_vocab} tokens, base model has {base_vocab} tokens" |
| ) |
| logger.warning("Enforcing embeddings in modules_to_save, embed_lora=False") |
| embed_lora = False |
|
|
| warned_modules = set() |
|
|
| def _should_extract(name: str) -> bool: |
| if include_regexes and not any(re.search(r, name) for r in include_regexes): |
| return False |
| if any(re.search(r, name) for r in exclude_regexes): |
| return False |
| return True |
|
|
| for name, module in tqdm.tqdm( |
| list(dummy_model.named_modules()), desc="Planning operations" |
| ): |
| wi = name_to_wi.get(name + ".weight") |
| bias_wi = name_to_wi.get(name + ".bias") |
| if wi is None: |
| if hasattr(module, "weight"): |
| logger.warning( |
| f"Weight {name} present in model but not in architecture info" |
| ) |
| wi = WeightInfo( |
| name=name + ".weight", |
| optional=True, |
| is_embed=isinstance(module, nn.Embedding), |
| ) |
| else: |
| continue |
|
|
| if ( |
| (not embed_lora) |
| and ( |
| module == embed_in |
| or module == embed_out |
| or isinstance(module, nn.Embedding) |
| ) |
| and not any(re.search(r, name) for r in exclude_regexes or []) |
| ): |
| |
| |
| key = name.split(".")[-1] |
| if key not in modules_to_save: |
| logger.warning(f"Adding {key} to modules_to_save") |
| modules_to_save.append(key) |
|
|
| if name in modules_to_save or (name.split(".")[-1] in modules_to_save): |
| logger.info(f"Planning to save {name} at full rank") |
| targets.extend(plan_module_to_save(model_ref, writer_task, wi, bias_wi)) |
| elif _should_extract(name): |
| if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Embedding)): |
| logger.info(f"Planning LoRA extraction for {name}") |
| targets.extend( |
| plan_lora_module( |
| base_model_ref, |
| model_ref, |
| wi, |
| bias_wi, |
| writer_task, |
| max_rank, |
| distribute_scale, |
| transpose=isinstance(module, nn.Embedding), |
| sv_epsilon=sv_epsilon, |
| ) |
| ) |
| else: |
| key = name.split(".")[-1] |
| message = ( |
| f"{key} has unsupported module type {type(module).__name__} - " |
| + ("skipping" if skip_undecomposable else "saving at full rank") |
| ) |
| if not skip_undecomposable: |
| |
| if key not in modules_to_save: |
| modules_to_save.append(key) |
| targets.extend( |
| plan_module_to_save(model_ref, writer_task, wi, bias_wi) |
| ) |
| if key not in warned_modules: |
| logger.warning(message) |
| warned_modules.add(key) |
|
|
| save_tasks = [t for t in targets if isinstance(t, (SaveTensor, LoRAModuleSaveTask))] |
| finalize = FinalizeModel(tensor_save_tasks=save_tasks, writer_task=writer_task) |
| return PlanResults( |
| tasks=targets + [finalize], |
| base_vocab_size=base_vocab, |
| final_vocab_size=ft_vocab, |
| ) |
|
|
|
|
| def plan_lora_module( |
| base_model_ref: ModelReference, |
| model_ref: ModelReference, |
| wi: WeightInfo, |
| bias_wi: Optional[WeightInfo], |
| writer_task: TensorWriterTask, |
| max_rank: int, |
| distribute_scale: bool = True, |
| transpose: bool = False, |
| sv_epsilon: float = 0, |
| ) -> List[Task]: |
| targets = [] |
| base_load_task = _wi_load(base_model_ref, wi) |
| model_load_task = _wi_load(model_ref, wi) |
| tv_task = TaskVectorTask(base_tensor=base_load_task, model_tensor=model_load_task) |
| decomp_task = TaskVectorDecompositionTask( |
| weight_info=wi, |
| input_task=tv_task, |
| max_rank=max_rank, |
| distribute_scale=distribute_scale, |
| transpose=transpose, |
| sv_epsilon=sv_epsilon, |
| ) |
| targets.append(decomp_task) |
| targets.append( |
| LoRAModuleSaveTask( |
| weight_info=wi, |
| writer_task=writer_task, |
| model_ref=model_ref, |
| decomposition_task=decomp_task, |
| ) |
| ) |
| if bias_wi is not None: |
| base_bias_load_task = _wi_load(base_model_ref, bias_wi) |
| model_bias_load_task = _wi_load(model_ref, bias_wi) |
| tv_bias_task = TaskVectorTask( |
| base_tensor=base_bias_load_task, model_tensor=model_bias_load_task |
| ) |
| base_bias_name = bias_wi.name.removesuffix(".bias") |
| name_out = f"base_model.model.{base_bias_name}.lora_B.bias" |
| targets.append( |
| SaveTensor( |
| tensor_name=name_out, |
| tensor_task=tv_bias_task, |
| writer_task=writer_task, |
| optional=bias_wi.optional, |
| clone=False, |
| ) |
| ) |
| return targets |
|
|
|
|
| def plan_module_to_save( |
| model_ref: ModelReference, |
| writer_task: TensorWriterTask, |
| wi: WeightInfo, |
| bias_wi: Optional[WeightInfo], |
| ): |
| save_tasks = [] |
| load_task = _wi_load(model_ref, wi) |
| save_task = SaveTensor( |
| tensor_name=f"base_model.model.{wi.name}", |
| tensor_task=load_task, |
| writer_task=writer_task, |
| optional=wi.optional, |
| clone=False, |
| ) |
| save_tasks.append(save_task) |
| if bias_wi is not None: |
| bias_load_task = _wi_load(model_ref, bias_wi) |
| bias_save_task = SaveTensor( |
| tensor_name=f"base_model.model.{bias_wi.name}", |
| tensor_task=bias_load_task, |
| writer_task=writer_task, |
| optional=bias_wi.optional, |
| clone=False, |
| ) |
| save_tasks.append(bias_save_task) |
| return save_tasks |
|
|
|
|
| def all_weights_map( |
| model_ref: ModelReference, options: MergeOptions |
| ) -> Dict[str, WeightInfo]: |
| name_to_wi = {} |
| model_cfg = model_ref.config(trust_remote_code=options.trust_remote_code) |
| arch_info = ArchitectureInfoUtils.get_architecture_info(model_cfg) |
| for wi in arch_info.all_weights(model_cfg): |
| name_to_wi[wi.name] = wi |
| return name_to_wi |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|