| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import inspect |
| import re |
| from dataclasses import dataclass |
| from typing import Any, Callable, Dict, Generic, List, Optional, Tuple, TypeVar, Union, overload |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
|
|
| from nemo.lightning.pytorch.utils import extract_dtypes |
| from nemo.utils import logging |
|
|
| SourceModuleT = TypeVar("SourceModuleT", bound=nn.Module) |
| TargetModuleT = TypeVar("TargetModuleT", bound=nn.Module) |
| F = TypeVar("F", bound=Callable[..., Any]) |
|
|
|
|
| @dataclass |
| class TransformCTX: |
| """Transform Data class Definition.""" |
|
|
| source: nn.Module |
| source_state: dict |
| target: nn.Module |
| target_state: dict |
|
|
|
|
| class _ModelState: |
| """ |
| Helper class for used for to modify state dict of a source model during model conversion. |
| """ |
|
|
| def __init__(self, state_dict, config=None): |
| self._state_dict = state_dict |
| self.config = config |
|
|
| def state_dict(self): |
| |
| return self._state_dict |
|
|
| def to(self, dtype): |
| |
| for k, v in self._state_dict.items(): |
| if v.dtype != dtype: |
| logging.warning(f"Converting {k} from {v.dtype} (source model) to {dtype} (target model)") |
| self._state_dict[k] = v.to(dtype) |
|
|
|
|
| @torch.no_grad |
| def apply_transforms( |
| source: Union[nn.Module, _ModelState], |
| target: TargetModuleT, |
| mapping: Dict[str, str], |
| transforms: Optional[List[Callable[[TransformCTX], TransformCTX]]] = [], |
| state_dict_ignored_entries: List = [], |
| ) -> TargetModuleT: |
| """ |
| Applies a series of transformations to adapt the state dictionary of a source module to |
| match the structure of a target module's state dictionary. |
| |
| This function renames keys according to a provided mapping and modifies values using a list |
| of transformation functions. Each transformation function typically is decorated |
| with `io.state_transform`. |
| |
| Args: |
| source (nn.Module): The source module from which parameters and buffers are taken. |
| target (TargetModuleT): The target module to which parameters and buffers are adapted. |
| mapping (Dict[str, str]): Key-value pairs where each key from the source state dictionary |
| is mapped to a corresponding key in the target state dictionary. |
| transforms (Optional[List[Callable[[TransformCTX], TransformCTX]]]): A list of functions |
| that modify the `TransformCTX` object. If None, no transformations beyond key renaming |
| are applied. Defaults to None. |
| state_dict_ignored_entries: List of entries to ignore in _target.state_dict(). There are cases |
| where multiple entries in model's state_dict point to one entry in model's named_parameter. |
| E.g., model has multiple pointers pointing to one shared parameters (`encoder.embed_tokens.weight`, |
| `decoder.embed_tokens.weight` and `shared.weight` all points to `shared.weight |
| in T5 Huggingface implementation.). In these cases, ignore redundant entries. |
| |
| Returns |
| ------- |
| TargetModuleT: The modified target module with its state dictionary adjusted according to |
| the specified mappings and transformations. |
| |
| Raises |
| ------ |
| ValueError: If there's a mismatch in shape between corresponding source and target parameters |
| or buffers. |
| RuntimeError: If the target state dictionary contains keys that are not present in the source |
| state dictionary after all transformations. |
| |
| Examples |
| -------- |
| >>> source_module = nn.Linear(10, 5) |
| >>> target_module = nn.Linear(10, 5) |
| >>> mapping = {'weight': 'weights', 'bias': 'biases'} |
| @io.state_transform( |
| source_key="weight", |
| target_key="weights" |
| ) |
| def scale_weights(ctx): |
| ctx.target_state['weights'] = ctx.source_state['weight'] * 2 |
| return ctx |
| >>> transformed_target = apply_transforms( |
| ... source_module, target_module, mapping, [scale_weights] |
| ... ) |
| >>> print(transformed_target.state_dict()['weights']) |
| |
| See Also |
| -------- |
| - `TransformCTX`: For more details on the context object used in transformations. |
| - `StateDictTransform`: For creating complex transformations. |
| |
| Note: |
| This function is particularly useful when adapting models from different frameworks or |
| when consolidating models with different architectural changes. |
| """ |
| from megatron.core.transformer.module import MegatronModule |
|
|
| |
| _source = source |
| if hasattr(source, "module") and isinstance(source.module, MegatronModule): |
| _source = source.module |
| _target = target |
| if hasattr(target, "module") and isinstance(target.module, MegatronModule): |
| _target = target.module |
|
|
| |
| target_orig_dtypes = extract_dtypes(_target.named_parameters()) |
|
|
| target_state = _target.state_dict() |
| ctx = TransformCTX( |
| source=_source, |
| source_state=_source.state_dict(), |
| target=_target, |
| target_state=target_state, |
| ) |
|
|
| for key, val in mapping.items(): |
| logging.debug(f"Mapping {key} -> {val}") |
| ctx = StateDictTransform(key, val)(ctx) |
|
|
| for transform in transforms: |
| logging.debug(f"Transforming {transform.source_key} -> {transform.target_key}") |
| ctx = transform(ctx) |
|
|
| _params: Dict[str, nn.Parameter] = {} |
| for name, param in _target.named_parameters(): |
| if name in target_state: |
| target_param = target_state[name] |
| if param.data.shape != target_param.shape: |
| raise ValueError( |
| f"Shape mismatch for parameter {name}: target shape {param.shape} vs " |
| f"converted source shape {target_param.shape}" |
| ) |
|
|
| _params[name] = nn.Parameter(target_param, requires_grad=param.requires_grad) |
| target_state.pop(name) |
| else: |
| print(f"Unexpected key: {name} not in checkpoint but in model.") |
|
|
| for key, val in _params.items(): |
| _module, _key = _target, key |
| if "." in key: |
| for part in key.split(".")[:-1]: |
| _module = getattr(_module, part) |
| _key = key.split(".")[-1] |
|
|
| _module.register_parameter(_key, val) |
|
|
| _buffers = {} |
| for name, buffer in _target.named_buffers(): |
| if name in target_state: |
| if buffer.shape != target_state[name].shape: |
| raise ValueError(f"Shape mismatch for buffer {name}: {buffer.shape} vs {target_state[name].shape}") |
|
|
| _buffers[name] = nn.Parameter(target_state[name], requires_grad=False) |
| target_state.pop(name) |
|
|
| for key, val in _buffers.items(): |
| _module, _key = _target, key |
| if "." in key: |
| for part in key.split(".")[:-1]: |
| _module = getattr(_module, part) |
| _key = key.split(".")[-1] |
|
|
| _module.register_buffer(_key, val) |
|
|
| keys = list(filter(lambda x: x is not None and not x.endswith("_extra_state"), target_state.keys())) |
| keys = [key for key in keys if key not in state_dict_ignored_entries] |
| if len(keys) != 0: |
| raise RuntimeError(f"Additional keys: {keys} in checkpoint but not in model.") |
|
|
| |
| |
| |
| |
|
|
| """finally: |
| cls._set_model_restore_state(is_being_restored=False)""" |
|
|
| meta_tensor_keys = [] |
| for name, param in target.named_parameters(): |
| if param.is_meta: |
| meta_tensor_keys.append(name) |
|
|
| assert not meta_tensor_keys, ( |
| f"{meta_tensor_keys}\nThere are meta tensors in the model after conversion." |
| f"Did you forget to include these parameters in the mapping or transforms in `convert_state`?" |
| ) |
|
|
| assert target_orig_dtypes == extract_dtypes(_target.named_parameters()), ( |
| f"dtype mismatch between source and target state dicts. " |
| f"Left side is { {k: v for k, v in target_orig_dtypes.items() if v!=torch.bfloat16} }, " |
| f"Right side is { {k: v for k, v in extract_dtypes(_target.named_parameters()).items() if v!=torch.bfloat16} }" |
| ) |
| if hasattr(target, "module") and isinstance(target.module, MegatronModule): |
| target.module = _target |
|
|
| return target |
|
|
| return _target |
|
|
|
|
| def _default_transform(inp): |
| return inp |
|
|
|
|
| class StateDictTransform(Generic[F]): |
| """ |
| A transformation class for state dictionaries, allowing for flexible key matching and |
| transformation of values between source and target state dictionaries. |
| |
| Attributes |
| ---------- |
| source_key: A string, tuple of strings, or a dictionary specifying the keys in the source |
| state dictionary to match. Wildcards (*) are supported. |
| target_key: A string or tuple of strings specifying the keys in the target state dictionary |
| to match. Wildcards (*) are supported. |
| transform: A callable that performs the transformation on matched keys' values. |
| |
| Examples |
| -------- |
| >>> def example_transform(ctx, *args): |
| ... return sum(args) |
| >>> transform = StateDictTransform( |
| ... source_key="model.layers.*.self_attn.*_proj.weight", |
| ... target_key="decoder.layers.*.self_attention.linear_qkv.weight", |
| ... transform=example_transform |
| ... ) |
| """ |
|
|
| def __init__( |
| self, |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], |
| target_key: Union[str, Tuple[str, ...]], |
| transform: F = _default_transform, |
| ): |
| self.source_key = source_key |
| self.target_key = target_key |
| self.transform = transform |
|
|
| def __call__(self, ctx: TransformCTX) -> TransformCTX: |
| source_key = self.source_key |
| target_key = self.target_key |
| source_dict, target_dict = ctx.source_state, ctx.target_state |
|
|
| fn_params = dict(inspect.signature(self.transform).parameters) |
| fn_params.pop("ctx", None) |
|
|
| if isinstance(source_key, (dict, tuple)): |
| if isinstance(source_key, tuple): |
| source_key_dict = {param: source_key[i] for i, param in enumerate(fn_params)} |
| else: |
| source_key_dict = source_key |
| source_matches_dict = {k: _match_keys(list(source_dict.keys()), v) for k, v in source_key_dict.items()} |
| target_matches = _match_keys(list(target_dict.keys()), target_key) |
| param_names = list(filter(lambda x: x in source_matches_dict, fn_params)) |
| source_matches = [ |
| source_matches_dict[v] if source_matches_dict[v].ndim > 0 else [source_matches_dict[v].item()] |
| for v in param_names |
| ] |
| target_matches = [target_matches if target_matches.ndim > 0 else [target_matches.item()]] |
| for layer_names_group in zip(*(source_matches + target_matches)): |
| |
| if isinstance(layer_names_group[0], str): |
| layer_names_group = [[x] for x in layer_names_group] |
| for layer_names in zip(*layer_names_group): |
| target_dict[layer_names[-1]] = self.call_transform( |
| ctx, **dict(zip(param_names, [source_dict[x] for x in layer_names[:-1]])) |
| ) |
| logging.debug(f"Matched (transform)! {layer_names_group=}") |
| else: |
| source_keys = list(source_dict.keys()) |
| target_keys = list(target_dict.keys()) |
|
|
| source_matches = _match_keys(source_keys, source_key) |
| if source_matches.size == 1 and source_matches == np.array(None): |
| raise ValueError(f"No matches found for source key: {source_key}") |
|
|
| if isinstance(target_key, str): |
| target_matches = _match_keys(target_keys, target_key) |
| if target_matches.size == 1 and target_matches == np.array(None): |
| raise ValueError(f"No matches found for target key: {target_key}") |
| else: |
| if isinstance(target_key, dict): |
| raise ValueError("Target key must be a string or a tuple of strings.") |
| _matches = [_match_keys(target_keys, key) for key in target_key] |
| target_matches = np.stack(_matches, axis=-1) |
|
|
| |
| multiple_sources = source_matches.ndim >= target_matches.ndim |
| accepts_var_args = any( |
| param.kind == param.VAR_POSITIONAL for param in inspect.signature(self.transform).parameters.values() |
| ) |
|
|
| if multiple_sources: |
| for target_index, target_match in np.ndenumerate(target_matches): |
| try: |
| source_match = source_matches[target_index] |
| except IndexError as e: |
| logging.error(f"Enountered IndexError during transform.\n{source_matches=}\n{target_matches=}") |
| raise e |
| if accepts_var_args: |
| source_values = [source_dict[k] for k in source_match] |
| target_dict[target_match] = self.call_transform(ctx, *source_values) |
| else: |
| _source_match_list = [source_match] if isinstance(source_match, str) else list(source_match) |
| if len(fn_params) != len(_source_match_list): |
| raise ValueError( |
| f"Mismatch between source and target keys: {source_match} vs {target_match}" |
| ) |
|
|
| kwargs = {param: source_dict[k] for param, k in zip(fn_params, _source_match_list)} |
| target_dict[target_match] = self.call_transform(ctx, **kwargs) |
| logging.debug(f"Matched (multi source)! {target_match=} {source_match=}") |
| else: |
| for source_index, source_match in np.ndenumerate(source_matches): |
| target_match = target_matches[source_index] |
| source_values = ( |
| [source_dict[source_match]] |
| if np.isscalar(source_match) |
| else [source_dict[k] for k in source_match] |
| ) |
| if accepts_var_args: |
| outputs = self.call_transform(ctx, *source_values) |
| else: |
| kwargs = {param: val for param, val in zip(fn_params, source_values)} |
| outputs = self.call_transform(ctx, **kwargs) |
|
|
| if isinstance(target_match, str): |
| target_dict[target_match] = outputs |
| else: |
| for i, t in enumerate(outputs): |
| target_dict[target_match[i]] = t |
| logging.debug(f"Matched (single source)! {target_match=} {source_match=}") |
|
|
| return ctx |
|
|
| def call_transform(self, ctx: TransformCTX, *args, **kwargs): |
| """Perform transform and check if the given args valid.""" |
| func_params = inspect.signature(self.transform).parameters |
| expected_num_args = len([p for p in func_params if p not in ['self', 'ctx']]) |
| provided_num_args = len(args) + len(kwargs) |
| accepts_var_args = any(param.kind == param.VAR_POSITIONAL for param in func_params.values()) |
|
|
| if not accepts_var_args and provided_num_args != expected_num_args: |
| raise ValueError( |
| f"Expected {expected_num_args} arguments for the transformation function, but got {provided_num_args}." |
| ) |
|
|
| if 'ctx' in func_params: |
| return self.transform(ctx, *args, **kwargs) |
|
|
| return self.transform(*args, **kwargs) |
|
|
|
|
| def _match_keys(keys: List[str], pattern: str) -> np.ndarray: |
| escaped_pattern = '' |
| i = 0 |
| wildcard_positions = [] |
| while i < len(pattern): |
| if pattern[i : i + 2] == '**': |
| escaped_pattern += r'(.+)' |
| wildcard_positions.append('**') |
| i += 2 |
| elif pattern[i] == '*': |
| escaped_pattern += r'([^.]+)' |
| wildcard_positions.append('*') |
| i += 1 |
| else: |
| if pattern[i] == '.': |
| escaped_pattern += r'\.' |
| else: |
| escaped_pattern += pattern[i] |
| i += 1 |
|
|
| regex_pattern = re.compile("^" + escaped_pattern + "$") |
| num_wildcards = len(wildcard_positions) |
| wildcard_matches = [[] for _ in range(num_wildcards)] |
|
|
| for key in filter(lambda x: x is not None, keys): |
| match = regex_pattern.match(key) |
| if match: |
| for i, group in enumerate(match.groups()): |
| if group not in wildcard_matches[i]: |
| wildcard_matches[i].append(group) |
|
|
| |
| for i in range(len(wildcard_matches)): |
| wildcard_matches[i].sort(key=lambda x: int(x) if x.isdigit() else x) |
|
|
| |
| shape = [len(matches) for matches in wildcard_matches] |
|
|
| if len(wildcard_matches) == 0: |
| |
| shape = [1] |
| |
| output_array = np.empty(shape, dtype=object) |
|
|
| |
| for key in filter(lambda x: x is not None, keys): |
| match = regex_pattern.match(key) |
| if match: |
| |
| indices = [wildcard_matches[i].index(group) for i, group in enumerate(match.groups())] |
| output_array[tuple(indices)] = key |
|
|
| return output_array |
|
|
|
|
| @overload |
| def state_transform( |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], |
| target_key: Union[str, Tuple[str, ...]], |
| ) -> Callable[[F], StateDictTransform[F]]: ... |
|
|
|
|
| @overload |
| def state_transform( |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], target_key: Union[str, Tuple[str, ...]], fn: F |
| ) -> StateDictTransform[F]: ... |
|
|
|
|
| def state_transform( |
| source_key: Union[str, Tuple[str, ...], Dict[str, str]], |
| target_key: Union[str, Tuple[str, ...]], |
| fn: Optional[F] = None, |
| ): |
| """ |
| A decorator for creating StateDictTransform instances with specified source and target keys, |
| and a transformation function. This allows for concise definition of state dictionary |
| transformations. |
| |
| Args: |
| source_key: A string, tuple of strings, or a dictionary specifying the keys in the source |
| state dictionary to match. Wildcards (*) are supported. |
| target_key: A string or tuple of strings specifying the keys in the target state dictionary |
| to match. Wildcards (*) are supported. |
| fn: An optional callable that performs the transformation on matched keys' values. If not |
| provided, the decorator can be used to wrap a function definition. |
| |
| Returns |
| ------- |
| A StateDictTransform instance if `fn` is provided, otherwise returns a decorator that |
| takes a function and returns a StateDictTransform instance. |
| |
| Examples |
| -------- |
| >>> @state_transform( |
| ... source_key="model.layers.*.self_attn.*_proj.weight", |
| ... target_key="decoder.layers.*.self_attention.linear_qkv.weight" |
| ... ) |
| ... def sum_transform(ctx, *args): |
| ... return sum(args) |
| """ |
|
|
| def wrapper(fn) -> StateDictTransform: |
| return StateDictTransform(source_key, target_key, fn) |
|
|
| if fn is None: |
| return wrapper |
|
|
| return wrapper(fn) |
|
|
|
|
| class TransformFns: |
| """ |
| A collection of common functions used in state dict transformation. |
| """ |
|
|
| @staticmethod |
| def split_qkv(ctx: TransformCTX, linear_qkv: torch.Tensor): |
| """ |
| Split interleave-concatenated qkv to q, k, v |
| |
| Example: export layer linear_qkv to HF {q|k|v}_proj |
| """ |
| megatron_config = ctx.source.config |
|
|
| head_num = megatron_config.num_attention_heads |
| num_query_groups = megatron_config.num_query_groups |
| heads_per_group = head_num // num_query_groups |
| |
| head_size = megatron_config.kv_channels |
| qkv_total_dim = head_num + 2 * num_query_groups |
|
|
| linear_qkv = linear_qkv.reshape([qkv_total_dim, head_size, -1]) |
| |
| |
| hidden_size = linear_qkv.size(-1) |
| q_slice = torch.cat( |
| [ |
| torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) |
| for i in range(num_query_groups) |
| ] |
| ) |
| k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) |
| v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) |
|
|
| q_proj = linear_qkv[q_slice].reshape(-1, hidden_size).cpu() |
| k_proj = linear_qkv[k_slice].reshape(-1, hidden_size).cpu() |
| v_proj = linear_qkv[v_slice].reshape(-1, hidden_size).cpu() |
|
|
| return q_proj, k_proj, v_proj |
|
|
| @staticmethod |
| def split_qkv_bias(ctx: TransformCTX, qkv_bias: torch.Tensor): |
| """ |
| Split interleave-concatenated qkv bias to separate q, k, v bias |
| |
| Example: export layer linear_qkv bias to HF {q|k|v}_proj bias |
| """ |
| megatron_config = ctx.source.config |
|
|
| head_num = megatron_config.num_attention_heads |
| num_query_groups = megatron_config.num_query_groups |
| heads_per_group = head_num // num_query_groups |
| head_size = megatron_config.kv_channels |
| qkv_total_dim = head_num + 2 * num_query_groups |
|
|
| qkv_bias = qkv_bias.reshape([qkv_total_dim, head_size]) |
| q_slice = torch.cat( |
| [ |
| torch.arange((heads_per_group + 2) * i, (heads_per_group + 2) * i + heads_per_group) |
| for i in range(num_query_groups) |
| ] |
| ) |
| k_slice = torch.arange(heads_per_group, qkv_total_dim, (heads_per_group + 2)) |
| v_slice = torch.arange(heads_per_group + 1, qkv_total_dim, (heads_per_group + 2)) |
|
|
| q_bias = qkv_bias[q_slice].reshape(-1).cpu() |
| k_bias = qkv_bias[k_slice].reshape(-1).cpu() |
| v_bias = qkv_bias[v_slice].reshape(-1).cpu() |
|
|
| return q_bias, k_bias, v_bias |
|
|
| @staticmethod |
| def merge_qkv(ctx: TransformCTX, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): |
| """ |
| Merge q, k, v to interleave-concatenated qkv. |
| |
| Example: import HF {q|k|v}_proj to layer linear_qkv |
| """ |
| megatron_config = ctx.target.config |
|
|
| head_num = megatron_config.num_attention_heads |
| num_query_groups = megatron_config.num_query_groups |
| heads_per_group = head_num // num_query_groups |
| hidden_size = megatron_config.hidden_size |
| head_size = megatron_config.kv_channels |
| old_tensor_shape = q.size() |
| new_q_tensor_shape = (head_num, head_size) + old_tensor_shape[1:] |
| new_kv_tensor_shape = (num_query_groups, head_size) + old_tensor_shape[1:] |
|
|
| q = q.view(*new_q_tensor_shape) |
| k = k.view(*new_kv_tensor_shape) |
| v = v.view(*new_kv_tensor_shape) |
|
|
| qkv_weights_l = [] |
| for i in range(num_query_groups): |
| qkv_weights_l.append(q[i * heads_per_group : (i + 1) * heads_per_group, :, :]) |
| qkv_weights_l.append(k[i : i + 1, :, :]) |
| qkv_weights_l.append(v[i : i + 1, :, :]) |
| qkv_weights = torch.cat(qkv_weights_l) |
| assert qkv_weights.ndim == 3, qkv_weights.shape |
| assert qkv_weights.shape[0] == (heads_per_group + 2) * num_query_groups, qkv_weights.shape |
| assert qkv_weights.shape[1] == head_size, qkv_weights.shape |
| assert qkv_weights.shape[2] == old_tensor_shape[1], qkv_weights.shape |
|
|
| qkv_weights = qkv_weights.reshape([head_size * (head_num + 2 * num_query_groups), hidden_size]) |
|
|
| return qkv_weights |
|
|
| @staticmethod |
| def merge_qkv_bias(ctx: TransformCTX, qb: torch.Tensor, kb: torch.Tensor, vb: torch.Tensor): |
| """ |
| Merge q, k, v bias to interleave-concatenated qkv bias. |
| |
| Example: import HF {q|k|v}_proj bias to layer linear_qkv bias |
| """ |
| megatron_config = ctx.target.config |
|
|
| head_num = megatron_config.num_attention_heads |
| num_query_groups = megatron_config.num_query_groups |
| heads_per_group = head_num // num_query_groups |
| head_size = megatron_config.kv_channels |
|
|
| new_q_tensor_shape = (head_num, head_size) |
| new_kv_tensor_shape = (num_query_groups, head_size) |
|
|
| qb = qb.view(*new_q_tensor_shape) |
| kb = kb.view(*new_kv_tensor_shape) |
| vb = vb.view(*new_kv_tensor_shape) |
|
|
| qkv_bias = torch.empty((0, head_size)).type_as(qb) |
| for i in range(num_query_groups): |
| qkv_bias = torch.cat((qkv_bias, qb[i * heads_per_group : (i + 1) * heads_per_group, :])) |
| qkv_bias = torch.cat((qkv_bias, kb[i : i + 1, :])) |
| qkv_bias = torch.cat((qkv_bias, vb[i : i + 1, :])) |
| qkv_bias = qkv_bias.reshape( |
| [ |
| head_size * (head_num + 2 * num_query_groups), |
| ] |
| ) |
| return qkv_bias |
|
|
| @staticmethod |
| def merge_fc1(gate: torch.Tensor, up: torch.Tensor): |
| """ |
| Merge gate and up proj into concatenated fc1 |
| |
| Example: import HF {gate|up}_proj to layer linear_fc1 |
| """ |
| return torch.cat((gate, up), dim=0) |
|
|
| @staticmethod |
| def split_fc1(linear_fc1: torch.Tensor): |
| """ |
| Split concatenated fc1 to gate and up proj |
| |
| Example: export layer linear_fc1 to HF {gate|up}_proj |
| """ |
| gate_proj, up_proj = torch.chunk(linear_fc1, 2, dim=0) |
| return gate_proj, up_proj |
|
|
| @staticmethod |
| def duplicate2(param: torch.Tensor): |
| """ |
| Duplicate the source parameter to two target parameters |
| |
| Example: export Performant LoRA linear_fc1.adapter.linear_in to HF {gate|up}_proj.lora_A |
| """ |
| return param, param |
|
|
| @staticmethod |
| def duplicate3(param: torch.Tensor): |
| """ |
| Duplicate the source parameter to three target parameters |
| |
| Example: export Performant LoRA linear_qkv.adapter.linear_in to HF {q|k|v}_proj.lora_A |
| """ |
| return param, param, param |
|
|
| @staticmethod |
| def prune_padding(ctx: TransformCTX, embedding: torch.Tensor): |
| """ |
| Prune the embedding size to vocab size |
| |
| Example: export embedding/output layer to HF with non-padded vocab size |
| """ |
| megatron_config = ctx.target.config |
| return embedding[: megatron_config.vocab_size, :] |
|
|