import json import re import tarfile from collections import defaultdict from dataclasses import dataclass, field from pathlib import Path from typing import Dict, List import numpy as np import torch import yaml from ._utils import (DictConversion, pad_vocab_size, release_gc, str_dtype_to_torch, torch_to_numpy) from .layers.linear import ColumnLinear from .logger import logger from .mapping import Mapping from .models.convert_utils import (get_model_path, load_state_dict, split_matrix_tp) def get_all_nemo_lora_weights(lora_weights): layer_weights = defaultdict(dict) adapter_key = "self_attention.adapter_layer.lora_kqv_adapter" layer_pattern = re.compile(r'.*\.layers\.(\d+)\..*') for key, weights in lora_weights.items(): if adapter_key in key: if key.endswith('linear_in.weight'): inout = 'in' elif key.endswith('linear_out.weight'): inout = 'out' else: continue m = layer_pattern.match(key) layer_idx = int(m.group(1)) layer_weights[layer_idx][inout] = weights return layer_weights def get_all_hf_lora_weights(lora_weights, hf_modules, component=None): all_weights = defaultdict(lambda: defaultdict(dict)) pattern = re.compile( r'(.*)\.(\d+)\.(\w+)\.(\w+|experts\.(\d+)\.(\w+))\.lora_(A|B)\.weight') for key, weights in lora_weights.items(): m = pattern.match(key) if not m: if "lm_head" not in key and "embed_tokens" not in key: logger.warning(f"no match {key} from HF LoRA weights") continue if component is not None and component not in m.group(1): continue layer_idx = int(m.group(2)) expert_idx = m.group(5) is_moe = expert_idx is not None module_name = m.group(6 if is_moe else 4) hf_module = m.group(3) + "." + module_name if hf_module not in hf_modules: hf_module = module_name assert hf_module in hf_modules inout = "in" if m.group(7) == "A" else "out" if not is_moe: all_weights[layer_idx][hf_module][inout] = weights else: all_weights[layer_idx][hf_module].setdefault(expert_idx, {}) all_weights[layer_idx][hf_module][expert_idx][inout] = weights return all_weights def get_hf_target_modules(lora_weights, hf_modules, lora_target_modules): hf_target_modules = set() pattern = re.compile( r'(.*)\.(\d+)\.(\w+)\.(\w+|experts\.(\d+)\.(\w+))\.lora_(A|B)\.weight') for key in lora_weights.keys(): m = pattern.match(key) if not m: if "lm_head" not in key and "embed_tokens" not in key: logger.warning(f"no match {key} from HF LoRA weights") continue match_target_module = False for module in lora_target_modules: if module in key: match_target_module = True break if not match_target_module: continue expert_idx = m.group(5) is_moe = expert_idx is not None module_name = m.group(6 if is_moe else 4) hf_module = m.group(3) + "." + module_name if hf_module not in hf_modules: hf_module = module_name assert hf_module in hf_modules hf_target_modules.add(hf_module) return hf_target_modules def invert_module_mapping(trtllm_modules_to_hf_modules): hf_modules_to_trtllm_modules = {} for k, hf_modules in trtllm_modules_to_hf_modules.items(): if isinstance(hf_modules, list): for hf_module in hf_modules: hf_modules_to_trtllm_modules[hf_module] = k else: hf_modules_to_trtllm_modules[hf_modules] = k return hf_modules_to_trtllm_modules @dataclass class LoraConfig(DictConversion): lora_dir: List[str] = field(default_factory=list) lora_ckpt_source: str = 'hf' max_lora_rank: int = 64 lora_target_modules: List[str] = field(default_factory=list) trtllm_modules_to_hf_modules: Dict[str, str] = field(default_factory=dict) def __post_init__(self): assert self.lora_ckpt_source in [ 'hf', 'nemo' ], f"lora_ckpt_source must be one of 'hf' or 'nemo', got {self.lora_ckpt_source}" class HfLoraLoader: def __init__(self, lora_dirs: List[str]): self.lora_target_modules = [] self.is_valid = False self.lm_head = None self.embed_tokens = None self.vocab_size = 0 if len(lora_dirs) == 0: return for lora_dir in lora_dirs: model_path = get_model_path(lora_dir, "adapter_model") if model_path is None: raise ValueError( f"adapter_model file does not exist in {lora_dir}") config_file = Path(f"{lora_dir}/adapter_config.json") if not config_file.exists(): raise ValueError(f"{config_file} does not exist") if not config_file.is_file(): raise ValueError(f"{config_file} is not a file") self.is_valid = True lora_dir = lora_dirs[0] with open(f"{lora_dir}/adapter_config.json") as f: adapter_config = json.load(f) self.lora_target_modules = adapter_config["target_modules"] lora_weight = load_state_dict(get_model_path(lora_dir, "adapter_model")) self.lora_weight = lora_weight if adapter_config["modules_to_save"] is not None: if "lm_head" in adapter_config["modules_to_save"]: self.lm_head = lora_weight["base_model.model.lm_head.weight"] self.vocab_size = self.lm_head.shape[0] if "embed_tokens" in adapter_config["modules_to_save"]: self.embed_tokens = lora_weight[ "base_model.model.model.embed_tokens.weight"] def get_target_modules(self, trtllm_modules_to_hf_modules): hf_modules_to_trtllm_modules = invert_module_mapping( trtllm_modules_to_hf_modules) lora_target_modules = [] if self.is_valid: hf_target_modules = get_hf_target_modules( self.lora_weight, hf_modules=set(hf_modules_to_trtllm_modules.keys()), lora_target_modules=self.lora_target_modules, ) for m in hf_target_modules: trtllm_module = hf_modules_to_trtllm_modules[m] lora_target_modules.append(trtllm_module) return lora_target_modules class NemoLoraLoader: def __init__(self, lora_dirs: List[str]): self.lora_target_modules = [] self.is_valid = False if len(lora_dirs) == 0: return for lora_file in lora_dirs: path = Path(lora_file) if not path.exists(): raise ValueError(f"{path} does not exist") if not path.is_file(): raise ValueError(f"{path} is not a file") self.is_valid = True # Hardcoded since LoraManager only supports this case now self.lora_target_modules = ["attn_qkv"] def load_nemo_lora(model, lora_config: LoraConfig): lora_loader = NemoLoraLoader(lora_config.lora_dir) if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.lora_target_modules def get_default_trtllm_modules_to_hf_modules(): return { "attn_q": "q_proj", "attn_k": "k_proj", "attn_v": "v_proj", "attn_dense": "o_proj", "mlp_h_to_4h": "gate_proj", "mlp_4h_to_h": "down_proj", "mlp_gate": "up_proj", "moe_h_to_4h": "w1", "moe_4h_to_h": "w2", "moe_gate": "w3", "moe_router": "gate", } def load_hf_lora( model, lora_config: LoraConfig, trtllm_modules_to_hf_modules: Dict[str, str] = None, ): trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules or get_default_trtllm_modules_to_hf_modules( ) lora_config.trtllm_modules_to_hf_modules = trtllm_modules_to_hf_modules lora_loader = HfLoraLoader(lora_config.lora_dir) if len(lora_config.lora_target_modules) == 0: lora_config.lora_target_modules = lora_loader.get_target_modules( trtllm_modules_to_hf_modules) if lora_loader.is_valid: config = model.config torch_dtype = str_dtype_to_torch(config.dtype) # the lora checkpoint might finetune the embedding if lora_loader.vocab_size != 0: config.vocab_size = lora_loader.vocab_size mapping = config.mapping if mapping.is_first_pp_rank() and lora_loader.embed_tokens is not None: weight = lora_loader.embed_tokens if config.use_parallel_embedding: weight = split_matrix_tp( weight, mapping.tp_size, mapping.tp_rank, dim=config.embedding_sharding_dim, ) if model.transformer.vocab_embedding.weight.raw_value.shape != weight.shape: model.transformer.vocab_embedding = model.transformer.vocab_embedding.__class__( num_embeddings=config.vocab_size, embedding_dim=config.hidden_size, dtype=config.dtype, tp_size=mapping.tp_size if config.use_parallel_embedding else 1, tp_group=mapping.tp_group if config.use_parallel_embedding else None, sharding_dim=config.embedding_sharding_dim, tp_rank=mapping.tp_rank, ) model.transformer.vocab_embedding.weight.value = weight.to( torch_dtype) if mapping.is_last_pp_rank() and lora_loader.lm_head is not None: weight = lora_loader.lm_head vocab_size = lora_loader.vocab_size if vocab_size % mapping.tp_size != 0: # padding vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) pad_width = vocab_size_padded - vocab_size weight = torch.from_numpy( np.pad(torch_to_numpy(weight), ((0, pad_width), (0, 0)), 'constant', constant_values=0)) else: vocab_size_padded = vocab_size if model.lm_head.weight.raw_value.shape != weight.shape: model.lm_head = ColumnLinear( config.hidden_size, vocab_size_padded, bias=False, dtype=config.dtype, tp_group=mapping.tp_group, tp_size=mapping.tp_size, gather_output=True, ) model.lm_head.weight.value = split_matrix_tp( weight, mapping.tp_size, mapping.tp_rank, dim=0, ).to(torch_dtype) def use_lora( model, lora_config: LoraConfig, trtllm_modules_to_hf_modules: Dict[str, str] = None, ): if lora_config.lora_ckpt_source == "nemo": load_nemo_lora(model, lora_config) elif lora_config.lora_ckpt_source == "hf": load_hf_lora(model, lora_config, trtllm_modules_to_hf_modules) else: raise ValueError( f"Unsupported lora_ckpt_source: {lora_config.lora_ckpt_source}") def unpack_nemo_weights(nemo_archive_path): with tarfile.open(nemo_archive_path) as tar: try: model_weights = tar.extractfile("model_weights.ckpt") model_config = tar.extractfile("model_config.yaml") except KeyError: try: model_weights = tar.extractfile("./model_weights.ckpt") model_config = tar.extractfile("./model_config.yaml") except KeyError: err_str = "Both model_weights paths not found in the tar archive." raise Exception(err_str) return yaml.safe_load(model_config), torch.load( model_weights, map_location=torch.device("cpu")) class LoraManager(object): LORA_MODULE_IDS = { "attn_qkv": 0, "attn_q": 1, "attn_k": 2, "attn_v": 3, "attn_dense": 4, "mlp_h_to_4h": 5, "mlp_4h_to_h": 6, "mlp_gate": 7, "cross_attn_qkv": 8, "cross_attn_q": 9, "cross_attn_k": 10, "cross_attn_v": 11, "cross_attn_dense": 12, "moe_h_to_4h": 13, "moe_4h_to_h": 14, "moe_gate": 15, "moe_router": 16, } def __init__(self): ''' _lora_uid_to_low_ranks: dict[str -> dict[int -> dict[str -> int]]] { uid: { 0: { lora_module: int }, # layer_0_rank, 1: { lora_module: int }, # layer_1_rank, ... } } _lora_weights_pointers_list: dict[str -> dict[int -> dict[str -> [Tensor, Tensor]]]] { uid: { 0: { lora_module: [t_in, t_out] }, # layer_0, 1: { lora_module: [t_in, t_out] }, # layer_1, ... } } ''' self._lora_uid_to_low_ranks = {} self._lora_weights = [] self._lora_weights_pointers_list = {} self._lora_cpp_weights = {} self._lora_weight_config = {} self.missing_qkv_modules = [] self.lora_target_modules = [] @staticmethod def get_missing_qkv_modules(lora_target_modules): # In current design, q_lora_params, k_lora_params and v_lora_params should be all enabled or all disabled at the same time. # However, some lora checkpoint (e.g. BART) only contain two of them, so we use zero tensor to fill the missing ones. missing_qkv_modules = [] if any(x in lora_target_modules for x in ["attn_q", "attn_k", "attn_v"]): for lora_module in ["attn_q", "attn_k", "attn_v"]: if lora_module not in lora_target_modules: missing_qkv_modules.append(lora_module) if any(x in lora_target_modules for x in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]): for lora_module in ["cross_attn_q", "cross_attn_k", "cross_attn_v"]: if lora_module not in lora_target_modules: missing_qkv_modules.append(lora_module) return missing_qkv_modules def load_from_ckpt(self, model_dir, model_config, runtime_mapping, ckpt_source): if ckpt_source == "hf": self.load_from_hf(model_dir, model_config, runtime_mapping) elif ckpt_source == "nemo": self.load_from_nemo(model_dir, model_config, runtime_mapping) else: assert False, f"LoraManager does not support source {ckpt_source}" def load_from_nemo(self, model_files, model_config, runtime_mapping): tp_size = runtime_mapping.tp_size tp_rank = runtime_mapping.tp_rank lora_target_modules = model_config.lora_target_modules dtype = model_config.dtype uids = list(map(str, range(len(model_files)))) self.lora_target_modules = lora_target_modules self.missing_qkv_modules = self.get_missing_qkv_modules( lora_target_modules) def load_from_model_file(uid, model_file): if uid not in self._lora_cpp_weights: self._lora_cpp_weights[uid] = [] if uid not in self._lora_weight_config: self._lora_weight_config[uid] = [] _, nemo_weights = unpack_nemo_weights(model_file) all_lora_weights = get_all_nemo_lora_weights(nemo_weights) self._lora_uid_to_low_ranks[uid] = {} self._lora_weights_pointers_list[uid] = {} for layer_idx in sorted(all_lora_weights.keys()): self._lora_uid_to_low_ranks[uid][layer_idx] = {} self._lora_weights_pointers_list[uid][layer_idx] = {} for lora_module in lora_target_modules: if lora_module != "attn_qkv": self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = 0 continue if lora_module == "attn_qkv": t_in = all_lora_weights[layer_idx]["in"] t_out = all_lora_weights[layer_idx]["out"] assert t_out.shape[0] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[tp_rank].contiguous() else: t_in = None t_out = None if t_in is not None and t_out is not None: t_in = t_in.cuda().to( str_dtype_to_torch(dtype)).contiguous() t_out = t_out.cuda().to( str_dtype_to_torch(dtype)).contiguous() rank = t_in.shape[0] self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = int(rank) self._lora_weights_pointers_list[uid][layer_idx][ lora_module] = [t_in.data_ptr(), t_out.data_ptr()] # prevent torch free this buffer self._lora_weights.append(t_in) self._lora_weights.append(t_out) self._lora_cpp_weights[uid].append( torch.concatenate([t_in.flatten(), t_out.flatten()])) self._lora_weight_config[uid].append( np.array([ self.LORA_MODULE_IDS[lora_module], layer_idx, int(rank) ], dtype=np.int32)) for uid, model_file in zip(uids, model_files): load_from_model_file(uid, model_file) release_gc() def load_from_hf(self, model_dirs, model_config, runtime_mapping, component=None): ''' lora config of https://huggingface.co/hfl/chinese-alpaca-2-lora-7b { "base_model_name_or_path": "/Llama-2-7b-hf", "bias": "none", "enable_lora": null, "fan_in_fan_out": false, "inference_mode": true, "lora_alpha": 128.0, "lora_dropout": 0.05, "merge_weights": false, "modules_to_save": [ "embed_tokens", "lm_head" ], "peft_type": "LORA", "r": 64, "target_modules": [ "q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "down_proj", "up_proj" ], "task_type": "CAUSAL_LM" } keys in adapter_model.bin: base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.k_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.k_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.self_attn.o_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.self_attn.o_proj.lora_B.weight torch.Size([4096, 64]) base_model.model.model.layers.0.mlp.gate_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.mlp.gate_proj.lora_B.weight torch.Size([11008, 64]) base_model.model.model.layers.0.mlp.up_proj.lora_A.weight torch.Size([64, 4096]) base_model.model.model.layers.0.mlp.up_proj.lora_B.weight torch.Size([11008, 64]) base_model.model.model.layers.0.mlp.down_proj.lora_A.weight torch.Size([64, 11008]) base_model.model.model.layers.0.mlp.down_proj.lora_B.weight torch.Size([4096, 64]) ... ''' tp_size = runtime_mapping.tp_size tp_rank = runtime_mapping.tp_rank lora_hf_configs = [] uids = [] for i, model_dir in enumerate(model_dirs): with open(f"{model_dir}/adapter_config.json", 'r') as f: config = json.load(f) lora_hf_configs.append(config) uids.append(str(i)) lora_target_modules = model_config.lora_target_modules dtype = model_config.dtype hf_modules_to_trtllm_modules = invert_module_mapping( model_config.trtllm_modules_to_hf_modules) hf_modules = set(hf_modules_to_trtllm_modules.keys()) missing_qkv_modules = self.get_missing_qkv_modules(lora_target_modules) self.lora_target_modules = lora_target_modules self.missing_qkv_modules = missing_qkv_modules def preprocess_lora_weights(lora_model): # Swap weights of gate_up_proj for key, value in lora_model.items(): if "gate_up_proj.lora_B.weight" in key: original_weights = value.contiguous().clone() half_split = original_weights.shape[0] // 2 first_half = original_weights[:half_split, :] second_half = original_weights[half_split:, :] value = torch.cat((second_half, first_half), dim=0) lora_model[key] = value return lora_model def load_from_model_dir(uid, model_dir, hf_config): if uid not in self._lora_cpp_weights: self._lora_cpp_weights[uid] = [] if uid not in self._lora_weight_config: self._lora_weight_config[uid] = [] lora_model = load_state_dict( get_model_path(model_dir, "adapter_model")) lora_model = preprocess_lora_weights(lora_model) all_weights = get_all_hf_lora_weights(lora_model, hf_modules, component) rank = int(hf_config["r"]) rs_lora = bool(hf_config.get("use_rslora", False)) self._lora_uid_to_low_ranks[uid] = {} self._lora_weights_pointers_list[uid] = {} for layer_idx in sorted(all_weights.keys()): layer_weights = all_weights[layer_idx] self._lora_uid_to_low_ranks[uid][layer_idx] = {} self._lora_weights_pointers_list[uid][layer_idx] = {} for lora_module in missing_qkv_modules: hf_module = model_config.trtllm_modules_to_hf_modules[ lora_module] if isinstance(hf_module, list): hf_module = hf_module[0] layer_weights[hf_module] = { "in": torch.zeros(rank, model_config.hidden_size), "out": torch.zeros(model_config.hidden_size, rank), } for hf_module, module_weights in layer_weights.items(): lora_module = hf_modules_to_trtllm_modules[hf_module] if lora_module not in lora_target_modules: self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = 0 continue if "in" not in module_weights: is_moe = True t_in = torch.stack([ module_weights[expert_idx]["in"] for expert_idx in sorted(module_weights.keys()) ]) t_out = torch.stack([ module_weights[expert_idx]["out"] for expert_idx in sorted(module_weights.keys()) ]) else: is_moe = False t_in = module_weights["in"] t_out = module_weights["out"] if lora_module in ["moe_router"]: pass elif "moe" in lora_module and runtime_mapping.has_moe_ep(): pass elif lora_module in [ "attn_dense", "cross_attn_dense", "mlp_4h_to_h", "moe_4h_to_h", ]: # split by row dim = 2 if is_moe else 1 assert t_in.shape[dim] % tp_size == 0 t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[tp_rank].contiguous() else: # split by column dim = 1 if is_moe else 0 assert t_out.shape[dim] % tp_size == 0 t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[tp_rank].contiguous() t_in = t_in.cuda().contiguous() t_out = t_out.cuda().contiguous() if rs_lora: scale = float(hf_config["lora_alpha"]) / np.sqrt(rank) else: scale = float(hf_config["lora_alpha"]) / rank t_out = t_out * scale t_in = t_in.to(str_dtype_to_torch(dtype)) t_out = t_out.to(str_dtype_to_torch(dtype)) rank_dim = 1 if is_moe else 0 assert t_in.shape[rank_dim] == rank self._lora_uid_to_low_ranks[uid][layer_idx][ lora_module] = rank self._lora_weights_pointers_list[uid][layer_idx][ lora_module] = [t_in.data_ptr(), t_out.data_ptr()] # prevent torch free this buffer self._lora_weights.append(t_in) self._lora_weights.append(t_out) self._lora_cpp_weights[uid].append( torch.concatenate([t_in.flatten(), t_out.flatten()])) self._lora_weight_config[uid].append( np.array([ self.LORA_MODULE_IDS[lora_module], layer_idx, int(hf_config['r']) ], dtype=np.int32)) for uid, model_dir, hf_config in zip(uids, model_dirs, lora_hf_configs): load_from_model_dir(uid, model_dir, hf_config) release_gc() def save_lora_weights_to_bin(self, out_dir): def save_val(val, dir, key, tp_num=None, write_npy=False): ext = "npy" if write_npy else "bin" suffix = ext if tp_num is None else f"{tp_num}.{ext}" if write_npy: np.save(dir / f"model.{key}.{suffix}", val) else: val.tofile(dir / f"model.{key}.{suffix}") if isinstance(out_dir, str): out_dir_path = Path(out_dir) elif isinstance(out_dir, Path): out_dir_path = out_dir else: assert False for uid in self._lora_cpp_weights: if uid == '-1': continue all_weights = np.expand_dims( np.stack([ torch_to_numpy(w.flatten().contiguous()) for w in self._lora_cpp_weights[uid] ]), 0) all_configs = np.expand_dims( np.stack(self._lora_weight_config[uid]), 0) uid_path = out_dir_path / f"{uid}" uid_path.mkdir(parents=True, exist_ok=True) save_val(all_weights, uid_path, "lora_weights", tp_num=None, write_npy=True) save_val(all_configs, uid_path, "lora_config", tp_num=None, write_npy=True) def uid_to_low_ranks(self, uid: str): assert isinstance(uid, str) return self._lora_uid_to_low_ranks[uid] @property def lora_weights(self): return self._lora_weights @property def lora_weights_pointers_list(self): return self._lora_weights_pointers_list def input_buffers(self, lora_uids, mapping: Mapping, num_layers: int): inputs = {} for layer_idx in mapping.pp_layers(num_layers): for lora_module in (self.lora_target_modules + self.missing_qkv_modules): lora_ranks_ = [] lora_ptrs_ = [] for lora_uid in lora_uids: lora_rank = 0 lora_ptrs = [0, 0] if lora_uid != "-1": low_ranks = self.uid_to_low_ranks(lora_uid) if (layer_idx in low_ranks and lora_module in low_ranks[layer_idx].keys() and low_ranks[layer_idx][lora_module] != 0): lora_rank = low_ranks[layer_idx][lora_module] lora_ptrs = self.lora_weights_pointers_list[ lora_uid][layer_idx][lora_module] lora_ranks_.append(lora_rank) lora_ptrs_.append(lora_ptrs) inputs[ f'{lora_module}_lora_ranks_{layer_idx}'] = torch.IntTensor( lora_ranks_) inputs[ f'{lora_module}_lora_weights_pointers_{layer_idx}'] = torch.LongTensor( lora_ptrs_) return inputs