|
|
import argparse |
|
|
import copy |
|
|
import dataclasses |
|
|
import json |
|
|
import os |
|
|
from enum import IntFlag, auto |
|
|
from functools import cached_property |
|
|
from typing import Dict, List, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import safetensors |
|
|
import torch |
|
|
|
|
|
from .._common import default_net |
|
|
from .._utils import (get_init_params, numpy_to_torch, release_gc, |
|
|
str_dtype_to_torch, str_dtype_to_trt, trt_dtype_to_torch) |
|
|
from ..functional import PositionEmbeddingType, Tensor, gather_last_token_logits |
|
|
from ..layers import (AttentionParams, Embedding, FusedGatedMLP, FusedRgLru, |
|
|
GatedMLP, KeyValueCacheParams, LoraParams, |
|
|
PromptTuningEmbedding, RgLru) |
|
|
from ..layers.attention import Attention, BertAttention |
|
|
from ..layers.linear import ColumnLinear, Linear, RowLinear |
|
|
from ..layers.lora import Lora |
|
|
from ..layers.moe import MOE, MoeOOTB |
|
|
from ..logger import logger |
|
|
from ..mapping import Mapping |
|
|
from ..module import Module, ModuleList |
|
|
from ..parameter import Parameter |
|
|
from ..plugin import init_all_reduce_helper |
|
|
from ..quantization import QuantMode |
|
|
from ..quantization.layers import (WeightOnlyGroupwiseQuantLinear, |
|
|
WeightOnlyGroupwiseQuantRowLinear, |
|
|
WeightOnlyQuantLinear, |
|
|
WeightOnlyQuantRowLinear) |
|
|
from ..quantization.mode import W8A8_SQ_PLUGIN_LIST, QuantAlgo |
|
|
from ..top_model_mixin import TopModelMixin |
|
|
from .convert_utils import weight_only_quantize_dict |
|
|
from .generation_mixin import GenerationMixin |
|
|
|
|
|
WEIGHT_LOADER_MODELS = {"PhiForCausalLM"} |
|
|
|
|
|
|
|
|
class SpeculativeDecodingMode(IntFlag): |
|
|
|
|
|
NONE = auto() |
|
|
DRAFT_TOKENS_EXTERNAL = auto() |
|
|
MEDUSA = auto() |
|
|
LOOKAHEAD_DECODING = auto() |
|
|
EXPLICIT_DRAFT_TOKENS = auto() |
|
|
|
|
|
@staticmethod |
|
|
def from_arguments(args: argparse.Namespace): |
|
|
if args.speculative_decoding_mode is None: |
|
|
return SpeculativeDecodingMode.NONE |
|
|
elif args.speculative_decoding_mode == "draft_tokens_external": |
|
|
return SpeculativeDecodingMode.DRAFT_TOKENS_EXTERNAL |
|
|
elif args.speculative_decoding_mode == "medusa": |
|
|
return SpeculativeDecodingMode.MEDUSA |
|
|
elif args.speculative_decoding_mode == "lookahead_decoding": |
|
|
return SpeculativeDecodingMode.LOOKAHEAD_DECODING |
|
|
elif args.speculative_decoding_mode == "explicit_draft_tokens": |
|
|
return SpeculativeDecodingMode.EXPLICIT_DRAFT_TOKENS |
|
|
else: |
|
|
assert False, "Unknown speculative_decoding_mode " + args.speculative_decoding_mode |
|
|
|
|
|
|
|
|
@dataclasses.dataclass |
|
|
class QuantConfig: |
|
|
'''Serializable quantization configuration class, part of the PretrainedConfig |
|
|
''' |
|
|
|
|
|
quant_algo: Optional[QuantAlgo] = None |
|
|
kv_cache_quant_algo: Optional[QuantAlgo] = None |
|
|
group_size: Optional[int] = 128 |
|
|
smoothquant_val: Optional[float] = None |
|
|
clamp_val: Optional[List[float]] = None |
|
|
has_zero_point: Optional[bool] = False |
|
|
pre_quant_scale: Optional[bool] = False |
|
|
exclude_modules: Optional[List[str]] = None |
|
|
|
|
|
@property |
|
|
def use_plugin_sq(self): |
|
|
return self.quant_algo in W8A8_SQ_PLUGIN_LIST |
|
|
|
|
|
@cached_property |
|
|
def quant_mode(self) -> QuantMode: |
|
|
return QuantMode.from_quant_algo( |
|
|
self.quant_algo, |
|
|
self.kv_cache_quant_algo, |
|
|
) |
|
|
|
|
|
def quant_algo_to_modelopt_qformat(self): |
|
|
algo_to_modelopt_map = { |
|
|
QuantAlgo.W8A16: "int8_wo", |
|
|
QuantAlgo.W4A16: "int4_wo", |
|
|
QuantAlgo.W4A16_AWQ: "int4_awq", |
|
|
QuantAlgo.W4A8_AWQ: 'w4a8_awq', |
|
|
QuantAlgo.FP8: 'fp8', |
|
|
QuantAlgo.W8A8_SQ_PER_CHANNEL: 'int8_sq', |
|
|
} |
|
|
if self.quant_algo is not None: |
|
|
assert self.quant_algo in algo_to_modelopt_map, f"We don't use Modelopt for quantization algorithm {self.quant_algo}, you probably shall not call this" |
|
|
qformat = algo_to_modelopt_map[self.quant_algo] |
|
|
else: |
|
|
qformat = 'full_prec' |
|
|
return qformat |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, config: dict): |
|
|
return cls(**config) |
|
|
|
|
|
def to_dict(self): |
|
|
return dataclasses.asdict(self) |
|
|
|
|
|
|
|
|
def default_weight_loader(mapping: Mapping, param: torch.Tensor, |
|
|
loaded_weight: torch.Tensor) -> None: |
|
|
"""Default weight loader.""" |
|
|
param.value = loaded_weight |
|
|
|
|
|
|
|
|
def save_checkpoint(output_dir: str, config: dict, weights: dict) -> None: |
|
|
""" Checkpoint saver for weight loader.""" |
|
|
with open(os.path.join(output_dir, 'config.json'), 'w') as f: |
|
|
json.dump(config, f, indent=4) |
|
|
safetensors.torch.save_file(weights, |
|
|
os.path.join(output_dir, 'rank0.safetensors')) |
|
|
|
|
|
|
|
|
class PretrainedConfig: |
|
|
|
|
|
def __init__(self, |
|
|
*, |
|
|
architecture: str, |
|
|
dtype: str, |
|
|
hidden_size: int, |
|
|
num_hidden_layers: int, |
|
|
num_attention_heads: int, |
|
|
vocab_size: Optional[int] = None, |
|
|
hidden_act: str = 'gelu', |
|
|
logits_dtype: str = 'float32', |
|
|
norm_epsilon: float = 1e-5, |
|
|
position_embedding_type: Union[ |
|
|
PositionEmbeddingType, |
|
|
str] = PositionEmbeddingType.learned_absolute, |
|
|
max_position_embeddings: Optional[int] = None, |
|
|
num_key_value_heads: Optional[int] = None, |
|
|
intermediate_size: Optional[int] = None, |
|
|
mapping: Optional[Union[Mapping, dict]] = None, |
|
|
quantization: Optional[Union[QuantConfig, dict]] = None, |
|
|
use_parallel_embedding: bool = False, |
|
|
embedding_sharding_dim: int = 0, |
|
|
share_embedding_table: bool = False, |
|
|
head_size: Optional[int] = None, |
|
|
qk_layernorm: bool = False, |
|
|
**kwargs): |
|
|
self.architecture = architecture |
|
|
self.dtype = dtype |
|
|
self.vocab_size = vocab_size |
|
|
self.hidden_size = hidden_size |
|
|
self.num_hidden_layers = num_hidden_layers |
|
|
self.num_attention_heads = num_attention_heads |
|
|
self.hidden_act = hidden_act |
|
|
|
|
|
self.logits_dtype = logits_dtype |
|
|
self.norm_epsilon = norm_epsilon |
|
|
|
|
|
if isinstance(position_embedding_type, str): |
|
|
position_embedding_type = PositionEmbeddingType.from_string( |
|
|
position_embedding_type) |
|
|
assert isinstance(position_embedding_type, PositionEmbeddingType) |
|
|
self.position_embedding_type = position_embedding_type |
|
|
|
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
|
|
if num_key_value_heads is None: |
|
|
num_key_value_heads = num_attention_heads |
|
|
self.num_key_value_heads = num_key_value_heads |
|
|
|
|
|
if intermediate_size is None: |
|
|
intermediate_size = hidden_size * 4 |
|
|
self.intermediate_size = intermediate_size |
|
|
|
|
|
if mapping is None: |
|
|
mapping = Mapping() |
|
|
elif isinstance(mapping, dict): |
|
|
mapping = Mapping.from_dict(mapping) |
|
|
assert isinstance(mapping, Mapping) |
|
|
self.mapping = mapping |
|
|
|
|
|
if quantization is None: |
|
|
quantization = QuantConfig() |
|
|
elif isinstance(quantization, dict): |
|
|
quantization = QuantConfig.from_dict(quantization) |
|
|
assert isinstance(quantization, QuantConfig) |
|
|
self.quantization = quantization |
|
|
|
|
|
self.use_parallel_embedding = use_parallel_embedding |
|
|
self.embedding_sharding_dim = embedding_sharding_dim |
|
|
self.share_embedding_table = share_embedding_table |
|
|
|
|
|
if share_embedding_table and mapping.tp_size > 1: |
|
|
if (not use_parallel_embedding) or (use_parallel_embedding and |
|
|
embedding_sharding_dim == 1): |
|
|
raise NotImplementedError( |
|
|
"For tensor parallelism, sharing the embedding table must set" \ |
|
|
"use_parallel_embedding=True and embedding_sharding_dim=0" |
|
|
) |
|
|
if share_embedding_table and mapping.pp_size > 1: |
|
|
raise NotImplementedError( |
|
|
"Embedding table cannot be shared for pipeline parallelism") |
|
|
|
|
|
if head_size is None: |
|
|
head_size = hidden_size // num_attention_heads |
|
|
self.head_size = head_size |
|
|
self.qk_layernorm = qk_layernorm |
|
|
|
|
|
for key, value in kwargs.items(): |
|
|
try: |
|
|
setattr(self, key, value) |
|
|
logger.warning( |
|
|
f"Implicitly setting {self.__class__.__name__}.{key} = {value}" |
|
|
) |
|
|
except AttributeError as err: |
|
|
raise err |
|
|
|
|
|
@property |
|
|
def kv_dtype(self): |
|
|
if self.quant_mode.has_int8_kv_cache(): |
|
|
return 'int8' |
|
|
elif self.quant_mode.has_fp8_kv_cache(): |
|
|
return 'fp8' |
|
|
else: |
|
|
return self.dtype |
|
|
|
|
|
def set_if_not_exist(self, key, value): |
|
|
if not hasattr(self, key): |
|
|
setattr(self, key, value) |
|
|
|
|
|
@classmethod |
|
|
def from_dict(cls, config: dict): |
|
|
|
|
|
from . import MODEL_MAP |
|
|
model_cls = MODEL_MAP[config['architecture']] |
|
|
config_cls = getattr(model_cls, 'config_class', cls) |
|
|
return config_cls(**config) |
|
|
|
|
|
def to_dict(self): |
|
|
output = copy.deepcopy(self.__dict__) |
|
|
|
|
|
output['position_embedding_type'] = str(self.position_embedding_type) |
|
|
output['mapping'] = self.mapping.to_dict() |
|
|
output['mapping'].pop('rank') |
|
|
output['quantization'] = self.quantization.to_dict() |
|
|
|
|
|
return output |
|
|
|
|
|
@classmethod |
|
|
def from_json_file(cls, config_file: str): |
|
|
with open(config_file) as f: |
|
|
config = json.load(f) |
|
|
return cls.from_dict(config) |
|
|
|
|
|
@classmethod |
|
|
def from_checkpoint(cls, ckpt_dir: str): |
|
|
return cls.from_json_file(os.path.join(ckpt_dir, 'config.json')) |
|
|
|
|
|
def to_json_file(self, config_file: str): |
|
|
with open(config_file, 'w') as f: |
|
|
json.dump(self.to_dict(), f, indent=4) |
|
|
|
|
|
@property |
|
|
def quant_mode(self): |
|
|
return self.quantization.quant_mode |
|
|
|
|
|
def set_rank(self, rank): |
|
|
self.mapping = Mapping(self.mapping.world_size, |
|
|
rank=rank, |
|
|
tp_size=self.mapping.tp_size, |
|
|
pp_size=self.mapping.pp_size, |
|
|
moe_tp_size=self.mapping.moe_tp_size, |
|
|
moe_ep_size=self.mapping.moe_ep_size, |
|
|
gpus_per_node=self.mapping.gpus_per_node) |
|
|
|
|
|
|
|
|
class DecoderLayerList(ModuleList): |
|
|
|
|
|
def __init__(self, cls, config): |
|
|
self.num_hidden_layers = config.num_hidden_layers |
|
|
self.layer_list = config.mapping.pp_layers(config.num_hidden_layers) |
|
|
super().__init__([cls(config, idx) for idx in self.layer_list]) |
|
|
|
|
|
def forward(self, |
|
|
hidden_states, |
|
|
use_cache=False, |
|
|
attention_mask=None, |
|
|
kv_cache_params=None, |
|
|
attention_params=None, |
|
|
position_ids=None, |
|
|
lora_params=None, |
|
|
spec_decoding_params=None): |
|
|
kv_cache_params.fill_none_tensor_list(len(self.layer_list)) |
|
|
|
|
|
if use_cache: |
|
|
presents = [] |
|
|
|
|
|
for layer_idx, (layer, past) in enumerate( |
|
|
zip(self, kv_cache_params.past_key_value)): |
|
|
|
|
|
lora_layer_params = None |
|
|
if lora_params is not None and lora_params.lora_ranks is not None: |
|
|
lora_layer_params = lora_params.get_layer_params(layer_idx) |
|
|
|
|
|
kwargs = {} |
|
|
if position_ids is not None: |
|
|
kwargs['position_ids'] = position_ids |
|
|
if lora_layer_params is not None: |
|
|
kwargs['lora_layer_params'] = lora_layer_params |
|
|
if spec_decoding_params is not None: |
|
|
kwargs['spec_decoding_params'] = spec_decoding_params |
|
|
if default_net().plugin_config.reduce_fusion: |
|
|
if layer_idx < self.layer_list[-1]: |
|
|
kwargs['next_layer_input_layernorm_args'] = ( |
|
|
self[layer_idx + 1].input_layernorm.weight.value, |
|
|
self[layer_idx + 1].input_layernorm.eps) |
|
|
else: |
|
|
kwargs['next_layer_input_layernorm_args'] = None |
|
|
|
|
|
hidden_states = layer( |
|
|
hidden_states, |
|
|
use_cache=use_cache, |
|
|
attention_mask=attention_mask, |
|
|
kv_cache_params=KeyValueCacheParams( |
|
|
past_key_value=[past], |
|
|
host_past_key_value_lengths=kv_cache_params. |
|
|
host_past_key_value_lengths, |
|
|
host_max_attention_window_sizes=kv_cache_params. |
|
|
host_max_attention_window_sizes, |
|
|
host_sink_token_length=kv_cache_params. |
|
|
host_sink_token_length, |
|
|
kv_cache_block_offsets=kv_cache_params. |
|
|
kv_cache_block_offsets, |
|
|
host_kv_cache_block_offsets=kv_cache_params. |
|
|
host_kv_cache_block_offsets, |
|
|
host_kv_cache_pool_pointers=kv_cache_params. |
|
|
host_kv_cache_pool_pointers, |
|
|
cache_indirection=kv_cache_params.cache_indirection), |
|
|
attention_params=attention_params, |
|
|
**kwargs) |
|
|
|
|
|
if use_cache: |
|
|
presents.append(hidden_states[1]) |
|
|
hidden_states = hidden_states[0] |
|
|
|
|
|
if use_cache: |
|
|
return hidden_states, presents |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
class PostInitCaller(type): |
|
|
|
|
|
def __call__(cls, *args, **kwargs): |
|
|
obj = type.__call__(cls, *args, **kwargs) |
|
|
obj.__post_init__() |
|
|
return obj |
|
|
|
|
|
|
|
|
class PretrainedModel(Module, |
|
|
GenerationMixin, |
|
|
TopModelMixin, |
|
|
metaclass=PostInitCaller): |
|
|
|
|
|
def __init__(self, config: PretrainedConfig): |
|
|
super().__init__() |
|
|
init_all_reduce_helper() |
|
|
self.config = config |
|
|
|
|
|
def __post_init__(self): |
|
|
from ..quantization.quantize import quantize |
|
|
quantize(self, self.config.quantization) |
|
|
|
|
|
|
|
|
|
|
|
optimize_model( |
|
|
self, |
|
|
use_parallel_embedding=self.config.use_parallel_embedding, |
|
|
share_embedding_table=self.config.share_embedding_table, |
|
|
) |
|
|
|
|
|
def release(self): |
|
|
release_gc() |
|
|
|
|
|
def __del__(self): |
|
|
self.release() |
|
|
|
|
|
def check_config(self, config): |
|
|
raise NotImplementedError( |
|
|
f"{self.__class__} is an abstract class. Only classes inheriting this class can be called." |
|
|
) |
|
|
|
|
|
@classmethod |
|
|
def from_config(cls, config: PretrainedConfig): |
|
|
return cls(config) |
|
|
|
|
|
@classmethod |
|
|
def from_checkpoint(cls, |
|
|
ckpt_dir: str, |
|
|
rank: Optional[int] = None, |
|
|
config: Optional[PretrainedConfig] = None): |
|
|
if config is None: |
|
|
config = PretrainedConfig.from_json_file( |
|
|
os.path.join(ckpt_dir, 'config.json')) |
|
|
|
|
|
if rank is not None: |
|
|
config.set_rank(rank) |
|
|
|
|
|
if config.architecture in WEIGHT_LOADER_MODELS: |
|
|
weights_path = os.path.join(ckpt_dir, 'rank0.safetensors') |
|
|
else: |
|
|
rank = config.mapping.rank |
|
|
weights_path = os.path.join(ckpt_dir, f'rank{rank}.safetensors') |
|
|
|
|
|
assert os.path.isfile(weights_path) |
|
|
weights = safetensors.torch.load_file(weights_path) |
|
|
|
|
|
is_checkpoint_pruned = getattr(config, 'is_pruned', False) |
|
|
preprocess_weights(weights, config, from_pruned=is_checkpoint_pruned) |
|
|
model = cls(config) |
|
|
model.load(weights, from_pruned=is_checkpoint_pruned) |
|
|
return model |
|
|
|
|
|
def load(self, weights, from_pruned=False): |
|
|
expected_names = set() |
|
|
required_names = set() |
|
|
for name, param in self.named_parameters(): |
|
|
expected_names.add(name) |
|
|
if not param.is_inited(): |
|
|
required_names.add(name) |
|
|
|
|
|
provided_names = set(weights.keys()) |
|
|
if not required_names.issubset(provided_names): |
|
|
raise RuntimeError( |
|
|
f"Required but not provided tensors:{required_names.difference(provided_names)}" |
|
|
) |
|
|
if not provided_names.issubset(expected_names): |
|
|
logger.warning( |
|
|
f"Provided but not expected tensors: {provided_names.difference(expected_names)}" |
|
|
) |
|
|
|
|
|
if self.config.architecture in WEIGHT_LOADER_MODELS: |
|
|
mapping = self.config.mapping |
|
|
for name, param in self.named_parameters(): |
|
|
if name in provided_names: |
|
|
weight_loader = getattr(param, "weight_loader", |
|
|
default_weight_loader) |
|
|
if from_pruned and param._shape != weights[name].shape: |
|
|
dummy_weight = torch.empty(param._shape, |
|
|
dtype=trt_dtype_to_torch( |
|
|
param._dtype)) |
|
|
weight_loader(mapping, param, dummy_weight) |
|
|
else: |
|
|
weight_loader(mapping, param, weights[name]) |
|
|
else: |
|
|
for name, param in self.named_parameters(): |
|
|
if name in provided_names: |
|
|
if not from_pruned: |
|
|
try: |
|
|
param.value = weights[name] |
|
|
except Exception as e: |
|
|
raise RuntimeError( |
|
|
f"Encounter error '{e}' for parameter '{name}'") |
|
|
else: |
|
|
param.set_value_or_dummy(weights[name]) |
|
|
|
|
|
def load_partial_weights(self, weights: dict): |
|
|
params = {name: param for name, param in self.named_parameters()} |
|
|
mapping = self.config.mapping |
|
|
|
|
|
for k, v in weights.items(): |
|
|
if k in params.keys(): |
|
|
param = params[k] |
|
|
weight_loader = getattr(param, "weight_loader", |
|
|
default_weight_loader) |
|
|
weight_loader(mapping, param, v) |
|
|
elif mapping.pp_size == 1: |
|
|
logger.warning(f"Provided but not expected tensors: {k}") |
|
|
|
|
|
def save_checkpoint(self, output_dir, save_config=True): |
|
|
|
|
|
rank = self.config.mapping.rank |
|
|
weights = { |
|
|
name: numpy_to_torch(param.raw_value) |
|
|
for name, param in self.named_parameters() |
|
|
} |
|
|
safetensors.torch.save_file( |
|
|
weights, os.path.join(output_dir, f'rank{rank}.safetensors')) |
|
|
if save_config: |
|
|
self.config.to_json_file(os.path.join(output_dir, 'config.json')) |
|
|
|
|
|
def prepare_inputs( |
|
|
self, |
|
|
max_batch_size, |
|
|
max_input_len, |
|
|
max_seq_len, |
|
|
max_num_tokens, |
|
|
use_cache, |
|
|
max_beam_width: int = 1, |
|
|
opt_num_tokens: int = None, |
|
|
prompt_embedding_table_size: int = 0, |
|
|
position_encoding_2d: bool = False, |
|
|
max_draft_len: int = 0, |
|
|
speculative_decoding_draft_tokens_external: bool = False, |
|
|
spec_decoding_is_generation_length_variable: bool = False, |
|
|
gather_context_logits: bool = False, |
|
|
gather_generation_logits: bool = False, |
|
|
lora_target_modules: List[str] = None, |
|
|
opt_batch_size: int = 0): |
|
|
'''@brief: Prepare inputs Tensors for the model, the given sizes are used to determine the |
|
|
ranges of the dimensions of when using TRT dynamic shapes. |
|
|
|
|
|
@return: a list contains values which can be fed into the self.forward() |
|
|
''' |
|
|
|
|
|
|
|
|
remove_input_padding = default_net().plugin_config.remove_input_padding |
|
|
use_gpt_attention_plugin = default_net( |
|
|
).plugin_config.gpt_attention_plugin |
|
|
use_gemm_plugin = default_net().plugin_config.gemm_plugin |
|
|
paged_kv_cache = default_net().plugin_config.paged_kv_cache |
|
|
tokens_per_block = default_net().plugin_config.tokens_per_block |
|
|
use_lora_plugin = default_net().plugin_config.lora_plugin |
|
|
multiple_profiles = default_net().plugin_config.multiple_profiles |
|
|
streamingllm = default_net().plugin_config.streamingllm |
|
|
|
|
|
model_inputs = self.prepare_basic_inputs( |
|
|
max_batch_size=max_batch_size, |
|
|
max_beam_width=max_beam_width, |
|
|
max_input_len=max_input_len, |
|
|
max_seq_len=max_seq_len, |
|
|
hidden_size=self.config.hidden_size, |
|
|
num_kv_heads=self.config.num_key_value_heads, |
|
|
head_size=self.config.head_size, |
|
|
num_layers=self.config.num_hidden_layers, |
|
|
kv_dtype=str_dtype_to_trt(self.config.kv_dtype), |
|
|
remove_input_padding=remove_input_padding, |
|
|
use_gpt_attention_plugin=use_gpt_attention_plugin, |
|
|
use_gemm_plugin=use_gemm_plugin, |
|
|
paged_kv_cache=paged_kv_cache, |
|
|
tokens_per_block=tokens_per_block, |
|
|
num_heads=self.config.num_attention_heads, |
|
|
max_num_tokens=max_num_tokens, |
|
|
opt_num_tokens=opt_num_tokens, |
|
|
dtype=str_dtype_to_trt(self.config.dtype), |
|
|
prompt_embedding_table_size=prompt_embedding_table_size, |
|
|
position_encoding_2d=position_encoding_2d, |
|
|
mapping=self.config.mapping, |
|
|
gather_context_logits=gather_context_logits, |
|
|
gather_generation_logits=gather_generation_logits, |
|
|
use_lora_plugin=use_lora_plugin, |
|
|
max_draft_len=max_draft_len, |
|
|
speculative_decoding_draft_tokens_external= |
|
|
speculative_decoding_draft_tokens_external, |
|
|
spec_decoding_is_generation_length_variable= |
|
|
spec_decoding_is_generation_length_variable, |
|
|
lora_target_modules=lora_target_modules, |
|
|
multiple_profiles=multiple_profiles, |
|
|
streamingllm=streamingllm, |
|
|
opt_batch_size=opt_batch_size) |
|
|
|
|
|
result = { |
|
|
'input_ids': |
|
|
model_inputs['input_ids'], |
|
|
'position_ids': |
|
|
model_inputs['position_ids'], |
|
|
'use_cache': |
|
|
True, |
|
|
'last_token_ids': |
|
|
model_inputs['last_token_ids'], |
|
|
'attention_mask': |
|
|
model_inputs['attention_mask'], |
|
|
'kv_cache_params': |
|
|
KeyValueCacheParams( |
|
|
past_key_value=model_inputs['past_key_value'], |
|
|
host_past_key_value_lengths=model_inputs[ |
|
|
'host_past_key_value_lengths'], |
|
|
host_max_attention_window_sizes=model_inputs[ |
|
|
'host_max_attention_window_sizes'], |
|
|
host_sink_token_length=model_inputs['host_sink_token_length'], |
|
|
kv_cache_block_offsets=model_inputs['kv_cache_block_offsets'], |
|
|
host_kv_cache_block_offsets=model_inputs[ |
|
|
'host_kv_cache_block_offsets'], |
|
|
host_kv_cache_pool_pointers=model_inputs[ |
|
|
'host_kv_cache_pool_pointers'], |
|
|
cache_indirection=model_inputs['cache_indirection'], |
|
|
), |
|
|
'attention_params': |
|
|
AttentionParams( |
|
|
sequence_length=model_inputs['sequence_length'], |
|
|
context_lengths=model_inputs['context_lengths'], |
|
|
host_context_lengths=model_inputs['host_context_lengths'], |
|
|
max_context_length=max_input_len, |
|
|
host_request_types=model_inputs['host_request_types'], |
|
|
host_runtime_perf_knobs=model_inputs['host_runtime_perf_knobs']) |
|
|
} |
|
|
|
|
|
if prompt_embedding_table_size > 0: |
|
|
result['prompt_embedding_table'] = model_inputs[ |
|
|
'prompt_embedding_table'] |
|
|
result['prompt_tasks'] = model_inputs['tasks'] |
|
|
result['prompt_vocab_size'] = model_inputs['prompt_vocab_size'] |
|
|
if model_inputs['hidden_states_input'] is not None: |
|
|
result['hidden_states'] = model_inputs['hidden_states_input'] |
|
|
if use_lora_plugin: |
|
|
result['lora_params'] = LoraParams( |
|
|
model_inputs['lora_ranks'], |
|
|
model_inputs['lora_weights_pointers'], |
|
|
host_context_lengths=model_inputs['host_context_lengths'], |
|
|
max_context_length=max_input_len, |
|
|
host_request_types=model_inputs['host_request_types']) |
|
|
if model_inputs['spec_decoding_params'] is not None: |
|
|
result['spec_decoding_params'] = model_inputs[ |
|
|
'spec_decoding_params'] |
|
|
|
|
|
return result |
|
|
|
|
|
@classmethod |
|
|
def quantize( |
|
|
cls, |
|
|
hf_model_dir: str, |
|
|
output_dir: str, |
|
|
dtype: str = 'float16', |
|
|
mapping: Optional[Mapping] = None, |
|
|
quant_config: Optional[QuantConfig] = None, |
|
|
*, |
|
|
device: str = 'cuda', |
|
|
calib_dataset: str = 'cnn_dailymail', |
|
|
calib_batches: int = 512, |
|
|
calib_batch_size: int = 1, |
|
|
calib_max_seq_length: int = 512, |
|
|
random_seed: int = 1234, |
|
|
tokenizer_max_seq_length: int = 2048, |
|
|
): |
|
|
if mapping is None: |
|
|
mapping = Mapping() |
|
|
if mapping.moe_ep_size > 1: |
|
|
raise NotImplementedError( |
|
|
"Quantization for expert parallelism is not supported") |
|
|
modelopt_qformat = quant_config.quant_algo_to_modelopt_qformat() |
|
|
kv_cache_dtype = quant_config.kv_cache_quant_algo |
|
|
assert modelopt_qformat is not None |
|
|
from ..quantization import quantize_and_export |
|
|
hf_model_dir = str( |
|
|
hf_model_dir) |
|
|
quantize_and_export( |
|
|
model_dir=hf_model_dir, |
|
|
device=device, |
|
|
calib_dataset=calib_dataset, |
|
|
dtype=dtype, |
|
|
qformat=modelopt_qformat, |
|
|
kv_cache_dtype=kv_cache_dtype, |
|
|
calib_size=calib_batches, |
|
|
batch_size=calib_batch_size, |
|
|
calib_max_seq_length=calib_max_seq_length, |
|
|
awq_block_size=quant_config.group_size, |
|
|
output_dir=output_dir, |
|
|
tp_size=mapping.tp_size, |
|
|
pp_size=mapping.pp_size, |
|
|
seed=random_seed, |
|
|
tokenizer_max_seq_length=tokenizer_max_seq_length, |
|
|
) |
|
|
|
|
|
|
|
|
class DecoderModelForCausalLM(PretrainedModel): |
|
|
|
|
|
def __init__(self, config: PretrainedConfig, transformer, lm_head): |
|
|
super().__init__(config) |
|
|
self.transformer = transformer |
|
|
self.lm_head = lm_head |
|
|
self.mup_width_multiplier = getattr(config, 'mup_width_multiplier', |
|
|
None) |
|
|
|
|
|
def forward(self, |
|
|
input_ids: Tensor, |
|
|
position_ids=None, |
|
|
use_cache=False, |
|
|
last_token_ids=None, |
|
|
attention_mask=None, |
|
|
kv_cache_params=None, |
|
|
attention_params=None, |
|
|
hidden_states=None, |
|
|
prompt_embedding_table: Optional[Tensor] = None, |
|
|
prompt_tasks: Optional[Tensor] = None, |
|
|
prompt_vocab_size: Optional[Tensor] = None, |
|
|
lora_params=None, |
|
|
spec_decoding_params=None): |
|
|
kwargs = { |
|
|
'input_ids': input_ids, |
|
|
'position_ids': position_ids, |
|
|
'use_cache': use_cache, |
|
|
'attention_mask': attention_mask, |
|
|
'kv_cache_params': kv_cache_params, |
|
|
'attention_params': attention_params, |
|
|
} |
|
|
if lora_params is not None: |
|
|
kwargs['lora_params'] = lora_params |
|
|
if hidden_states is not None: |
|
|
kwargs['hidden_states'] = hidden_states |
|
|
if prompt_embedding_table is not None: |
|
|
kwargs['prompt_embedding_table'] = prompt_embedding_table |
|
|
if prompt_tasks is not None: |
|
|
kwargs['prompt_tasks'] = prompt_tasks |
|
|
if prompt_vocab_size is not None: |
|
|
kwargs['prompt_vocab_size'] = prompt_vocab_size |
|
|
|
|
|
if spec_decoding_params is not None: |
|
|
kwargs['spec_decoding_params'] = spec_decoding_params |
|
|
|
|
|
hidden_states = self.transformer.forward(**kwargs) |
|
|
|
|
|
if use_cache: |
|
|
hidden_states, presents = hidden_states |
|
|
|
|
|
if self.config.mapping.is_last_pp_rank(): |
|
|
hidden_states = gather_last_token_logits( |
|
|
hidden_states, last_token_ids, |
|
|
default_net().plugin_config.remove_input_padding) |
|
|
|
|
|
|
|
|
lm_logits = self.lm_head(hidden_states) |
|
|
if hasattr(self.config, 'output_multiplier_scale'): |
|
|
lm_logits *= getattr(self.config, 'output_multiplier_scale', 1) |
|
|
if self.mup_width_multiplier is not None: |
|
|
lm_logits = lm_logits / self.mup_width_multiplier |
|
|
lm_logits.mark_output('logits', self.config.logits_dtype) |
|
|
else: |
|
|
hidden_states.mark_output('hidden_states_output', self.config.dtype) |
|
|
|
|
|
if use_cache and not default_net().plugin_config.paged_kv_cache: |
|
|
for i, present in zip( |
|
|
self.config.mapping.pp_layers( |
|
|
self.config.num_hidden_layers), presents): |
|
|
present.mark_output(f'present_key_value_{i}', |
|
|
self.config.kv_dtype) |
|
|
if self.config.mapping.is_last_pp_rank(): |
|
|
return (lm_logits, presents, hidden_states) |
|
|
return (hidden_states, presents) |
|
|
else: |
|
|
if self.config.mapping.is_last_pp_rank(): |
|
|
return lm_logits, hidden_states |
|
|
return hidden_states |
|
|
|
|
|
|
|
|
def fuse_gate_mlp( |
|
|
model: PretrainedModel, |
|
|
gemm_swiglu_plugin_dtype: Optional[str] = None, |
|
|
) -> PretrainedModel: |
|
|
from ..quantization.quantize import fp8_quantize |
|
|
|
|
|
quant_algo = model.config.quantization.quant_algo |
|
|
for name, mlp, layer in model.named_modules_with_parent(): |
|
|
if isinstance(mlp, GatedMLP): |
|
|
init_params = get_init_params(mlp) |
|
|
init_params["inner_layernorm"] = mlp.inner_layernorm is not None |
|
|
fused_layer = FusedGatedMLP(**init_params) |
|
|
|
|
|
if quant_algo == QuantAlgo.FP8: |
|
|
fused_layer = fp8_quantize(fused_layer, |
|
|
model.config.quantization) |
|
|
|
|
|
if isinstance(mlp.dtype, str): |
|
|
dtype = str_dtype_to_torch(mlp.dtype) |
|
|
else: |
|
|
dtype = trt_dtype_to_torch(mlp.dtype) |
|
|
|
|
|
|
|
|
gate_weight = numpy_to_torch( |
|
|
mlp.gate.weight.raw_value).to(dtype) * numpy_to_torch( |
|
|
mlp.gate.weights_scaling_factor.raw_value) |
|
|
fc_weight = numpy_to_torch( |
|
|
mlp.fc.weight.raw_value).to(dtype) * numpy_to_torch( |
|
|
mlp.fc.weights_scaling_factor.raw_value) |
|
|
|
|
|
|
|
|
fused_weight = torch.cat([gate_weight, fc_weight], dim=0) |
|
|
|
|
|
|
|
|
fused_weight_scaling_factor = numpy_to_torch( |
|
|
max( |
|
|
mlp.gate.weights_scaling_factor.raw_value, |
|
|
mlp.fc.weights_scaling_factor.raw_value, |
|
|
)) |
|
|
fused_weight = (fused_weight / fused_weight_scaling_factor).to( |
|
|
torch.float8_e4m3fn) |
|
|
|
|
|
if gemm_swiglu_plugin_dtype == 'fp8': |
|
|
|
|
|
|
|
|
fused_layer.fused_fc.weight = Parameter( |
|
|
shape=(fused_layer.fused_fc.in_features, |
|
|
fused_layer.fused_fc.out_features), |
|
|
dtype='fp8') |
|
|
fused_layer.fused_fc.weight.value = fused_weight.view( |
|
|
fused_layer.fused_fc.in_features, |
|
|
fused_layer.fused_fc.out_features) |
|
|
else: |
|
|
fused_layer.fused_fc.weight.value = fused_weight |
|
|
fused_layer.fused_fc.weights_scaling_factor.value = fused_weight_scaling_factor |
|
|
|
|
|
fused_layer.fused_fc.activation_scaling_factor.value = max( |
|
|
mlp.gate.activation_scaling_factor.raw_value, |
|
|
mlp.fc.activation_scaling_factor.raw_value, |
|
|
) |
|
|
elif quant_algo is None: |
|
|
fused_layer.fused_fc.weight.value = np.concatenate( |
|
|
[ |
|
|
mlp.gate.weight.raw_value, |
|
|
mlp.fc.weight.raw_value, |
|
|
], |
|
|
axis=0, |
|
|
) |
|
|
if mlp.bias: |
|
|
fused_layer.fused_fc.bias.value = np.concatenate( |
|
|
[mlp.gate.bias.raw_value, mlp.fc.bias.raw_value], |
|
|
axis=0) |
|
|
else: |
|
|
raise ValueError(f'Unsupported quant algo: {quant_algo}') |
|
|
|
|
|
fused_layer.proj = mlp.proj |
|
|
fused_layer.inner_layernorm = mlp.inner_layernorm |
|
|
|
|
|
mlp_name = name.rsplit('.', 1)[-1] |
|
|
setattr(layer, mlp_name, fused_layer) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
def unfuse_qkv_gemm(model: PretrainedModel) -> PretrainedModel: |
|
|
'''Split all the models' Attention layer's QKV GEMM into 3 GEMMs layer.q layer.k, layer.v and return the changed model |
|
|
''' |
|
|
from ..quantization.quantize import quantize |
|
|
|
|
|
for name, layer in model.named_modules(): |
|
|
if isinstance(layer, Attention) and not layer.cross_attention: |
|
|
assert layer.tp_size == 1, "please disable manual tp when enable auto parallel" |
|
|
if layer.qkv is None: |
|
|
continue |
|
|
qkv_params = get_init_params(layer.qkv, ColumnLinear) |
|
|
qkv_params["bias"] = qkv_params["bias"] is not None |
|
|
qkv_params["strict_dtype"] = qkv_params.get( |
|
|
"strict_dtype") is not None |
|
|
q = ColumnLinear( |
|
|
**{ |
|
|
**qkv_params, |
|
|
"out_features": |
|
|
layer.tp_size * layer.num_attention_heads * |
|
|
layer.attention_head_size, |
|
|
}) |
|
|
k = ColumnLinear( |
|
|
**{ |
|
|
**qkv_params, |
|
|
"out_features": |
|
|
layer.tp_size * layer.num_attention_kv_heads * |
|
|
layer.attention_head_size, |
|
|
}) |
|
|
v = ColumnLinear( |
|
|
**{ |
|
|
**qkv_params, |
|
|
"out_features": |
|
|
layer.tp_size * layer.num_attention_kv_heads * |
|
|
layer.attention_head_size, |
|
|
}) |
|
|
q = quantize(q, model.config.quantization) |
|
|
k = quantize(k, model.config.quantization) |
|
|
v = quantize(v, model.config.quantization) |
|
|
out_features = q.out_features + k.out_features + v.out_features |
|
|
if isinstance(layer.qkv, ( |
|
|
WeightOnlyQuantLinear, |
|
|
WeightOnlyQuantRowLinear, |
|
|
WeightOnlyGroupwiseQuantLinear, |
|
|
WeightOnlyGroupwiseQuantRowLinear, |
|
|
)): |
|
|
out_dim = 1 |
|
|
else: |
|
|
out_dim = 0 |
|
|
if layer.qkv.weight.is_inited(): |
|
|
qkv_weight = layer.qkv.weight.raw_value |
|
|
weights = np.split(qkv_weight, [ |
|
|
qkv_weight.shape[out_dim] * q.out_features // out_features, |
|
|
qkv_weight.shape[out_dim] * |
|
|
(q.out_features + k.out_features) // out_features, |
|
|
], |
|
|
axis=out_dim) |
|
|
for gemm, weight in zip([q, k, v], weights): |
|
|
gemm.weight.value = weight |
|
|
if layer.qkv.bias is not None and layer.qkv.bias.is_inited(): |
|
|
qkv_bias = layer.qkv.bias.raw_value |
|
|
biases = np.split(qkv_bias, [ |
|
|
qkv_bias.shape[out_dim] * q.out_features // out_features, |
|
|
qkv_bias.shape[out_dim] * |
|
|
(q.out_features + k.out_features) // out_features, |
|
|
], |
|
|
axis=out_dim) |
|
|
for gemm, bias in zip([q, k, v], biases): |
|
|
gemm.bias.value = bias |
|
|
for name, parameter in layer.qkv._parameters.items(): |
|
|
if name not in ["weight", "bias"]: |
|
|
for gemm in [q, k, v]: |
|
|
setattr(gemm, name, parameter) |
|
|
layer.q = q |
|
|
layer.k = k |
|
|
layer.v = v |
|
|
layer.qkv = None |
|
|
return model |
|
|
|
|
|
|
|
|
def fuse_rg_lru(model: PretrainedModel) -> PretrainedModel: |
|
|
for name, rg_lru, parent in model.named_modules_with_parent(): |
|
|
if isinstance(rg_lru, RgLru): |
|
|
fused_layer = FusedRgLru(**get_init_params(rg_lru)) |
|
|
fused_layer.gate.weight.value = np.concatenate( |
|
|
[ |
|
|
rg_lru.input_gate.weight.raw_value, |
|
|
rg_lru.recurrent_gate.weight.raw_value, |
|
|
], |
|
|
axis=-1, |
|
|
) |
|
|
fused_layer.gate.bias.value = np.concatenate( |
|
|
[ |
|
|
rg_lru.input_gate.bias.raw_value, |
|
|
rg_lru.recurrent_gate.bias.raw_value, |
|
|
], |
|
|
axis=-1, |
|
|
) |
|
|
fused_layer.recurrent_param.value = rg_lru.recurrent_param.raw_value |
|
|
rg_lru_name = name.rsplit('.', 1)[-1] |
|
|
setattr(parent, rg_lru_name, fused_layer) |
|
|
return model |
|
|
|
|
|
|
|
|
def set_prompt_tuning(model: PretrainedModel) -> PretrainedModel: |
|
|
'''Replace the given models embedding layer with a PromptTuningEmbedding layer in-place, return the changed model |
|
|
Pre-conditions: vocab_embedding exists |
|
|
Post-conditions: isinstance(vocab_embedding, PromptTuningEmbedding) |
|
|
|
|
|
''' |
|
|
for name, embedding, parent in model.named_modules_with_parent(): |
|
|
layer_name = name.rsplit('.', 1)[-1] |
|
|
if layer_name == "vocab_embedding" and isinstance(embedding, Embedding): |
|
|
ptuning_embedding = PromptTuningEmbedding( |
|
|
**get_init_params(embedding)) |
|
|
ptuning_embedding.weight.value = embedding.weight.raw_value |
|
|
parent.vocab_embedding = ptuning_embedding |
|
|
return model |
|
|
|
|
|
|
|
|
def add_lora(model: PretrainedModel, |
|
|
max_lora_rank: Optional[int]) -> PretrainedModel: |
|
|
''' Add lora layers to the Attention/BertAttention/Linear/RowLinear/FusedGatedMLP layers to the given model, return the changed model |
|
|
''' |
|
|
for name, layer in model.named_modules(): |
|
|
max_rank = max_lora_rank |
|
|
if isinstance(layer, (Attention, BertAttention)): |
|
|
if max_rank is None: |
|
|
max_rank = min( |
|
|
layer.hidden_size, |
|
|
layer.num_attention_heads * layer.attention_head_size, |
|
|
layer.num_attention_kv_heads * layer.attention_head_size) |
|
|
layer.qkv_lora = Lora( |
|
|
in_hidden_size=layer.hidden_size, |
|
|
out_hidden_sizes=[ |
|
|
layer.num_attention_heads * layer.attention_head_size, |
|
|
layer.num_attention_kv_heads * layer.attention_head_size, |
|
|
layer.num_attention_kv_heads * layer.attention_head_size |
|
|
], |
|
|
max_low_rank=max_rank, |
|
|
) |
|
|
if isinstance(layer, (Linear, RowLinear)): |
|
|
if max_rank is None: |
|
|
max_rank = min(layer.in_features, layer.out_features) |
|
|
layer.lora = Lora( |
|
|
in_hidden_size=layer.in_features, |
|
|
out_hidden_sizes=[layer.out_features], |
|
|
max_low_rank=max_rank, |
|
|
) |
|
|
if isinstance(layer, FusedGatedMLP): |
|
|
if max_rank is None: |
|
|
max_rank = min(layer.hidden_size, |
|
|
layer.ffn_hidden_size // layer.tp_size) |
|
|
layer.lora = Lora( |
|
|
in_hidden_size=layer.hidden_size, |
|
|
out_hidden_sizes=[ |
|
|
layer.ffn_hidden_size // layer.tp_size, |
|
|
layer.ffn_hidden_size // layer.tp_size |
|
|
], |
|
|
max_low_rank=max_rank, |
|
|
) |
|
|
return model |
|
|
|
|
|
|
|
|
def to_ootb_moe(model: PretrainedModel) -> PretrainedModel: |
|
|
''' Use OOTB MoE instead of MoE plugin, return the changed model |
|
|
''' |
|
|
for name, layer, parent in model.named_modules_with_parent(): |
|
|
if isinstance(layer, MOE): |
|
|
layer_name = name.rsplit('.', 1)[-1] |
|
|
ootb_layer = layer.to(MoeOOTB, model.config.quantization) |
|
|
setattr(parent, layer_name, ootb_layer) |
|
|
return model |
|
|
|
|
|
|
|
|
def parallelize_embedding(model: PretrainedModel) -> PretrainedModel: |
|
|
for name, embedding, parent in model.named_modules_with_parent(): |
|
|
layer_name = name.rsplit('.', 1)[-1] |
|
|
if isinstance(embedding, Embedding) and embedding.tp_group is None: |
|
|
init_params = get_init_params(embedding) |
|
|
init_params["tp_group"] = model.config.mapping.tp_group |
|
|
init_params["tp_size"] = model.config.mapping.tp_size |
|
|
init_params["tp_rank"] = model.config.mapping.tp_rank |
|
|
init_params["sharding_dim"] = model.config.embedding_sharding_dim |
|
|
new_embedding = embedding.__class__(**init_params) |
|
|
setattr(parent, layer_name, new_embedding) |
|
|
return model |
|
|
|
|
|
|
|
|
def share_embedding(model: PretrainedModel) -> PretrainedModel: |
|
|
lm_head = None |
|
|
vocab_embedding = None |
|
|
for name, layer in model.named_modules(): |
|
|
layer_name = name.rsplit('.', 1)[-1] |
|
|
if layer_name == "lm_head": |
|
|
lm_head = layer |
|
|
if layer_name == "vocab_embedding": |
|
|
vocab_embedding = layer |
|
|
if lm_head is not None and vocab_embedding is not None: |
|
|
break |
|
|
|
|
|
if lm_head is not None and vocab_embedding is not None: |
|
|
lm_head.weight = vocab_embedding.weight |
|
|
if (hasattr(vocab_embedding, "per_token_scale") |
|
|
and vocab_embedding.per_token_scale is not None): |
|
|
lm_head.per_channel_scale = vocab_embedding.per_token_scale |
|
|
return model |
|
|
|
|
|
|
|
|
def set_fp8_context_fhma(model: PretrainedModel) -> PretrainedModel: |
|
|
for name, layer in model.named_modules(): |
|
|
if isinstance(layer, Attention): |
|
|
scale = [1.0] / layer.dense.activation_scaling_factor.raw_value |
|
|
layer.attention_output_orig_quant_scale = Parameter( |
|
|
value=scale.astype(np.float32)) |
|
|
return model |
|
|
|
|
|
|
|
|
def optimize_model( |
|
|
model: PretrainedModel, |
|
|
use_parallel_embedding: bool = False, |
|
|
share_embedding_table: bool = False, |
|
|
use_ootb_moe: bool = False, |
|
|
use_fused_mlp: bool = False, |
|
|
gemm_swiglu_plugin_dtype: Optional[str] = None, |
|
|
use_fused_rg_lru: bool = False, |
|
|
use_unfused_qkv_gemm: bool = False, |
|
|
use_prompt_tuning: bool = False, |
|
|
use_lora: bool = False, |
|
|
max_lora_rank: Optional[int] = None, |
|
|
use_fp8_context_fmha: bool = False, |
|
|
) -> PretrainedModel: |
|
|
""" |
|
|
Run optimization passes on model. |
|
|
There are dependencies between some passes, |
|
|
so we always run passes in the order of arguments to guarantee the execution order. |
|
|
""" |
|
|
|
|
|
if use_parallel_embedding: |
|
|
model = parallelize_embedding(model) |
|
|
if share_embedding_table: |
|
|
model = share_embedding(model) |
|
|
|
|
|
|
|
|
if use_ootb_moe: |
|
|
model = to_ootb_moe(model) |
|
|
if use_fused_mlp: |
|
|
model = fuse_gate_mlp(model, gemm_swiglu_plugin_dtype) |
|
|
if use_fused_rg_lru: |
|
|
model = fuse_rg_lru(model) |
|
|
if use_unfused_qkv_gemm: |
|
|
model = unfuse_qkv_gemm(model) |
|
|
if use_prompt_tuning: |
|
|
model = set_prompt_tuning(model) |
|
|
if use_lora: |
|
|
model = add_lora(model, max_lora_rank) |
|
|
if use_fp8_context_fmha: |
|
|
model = set_fp8_context_fhma(model) |
|
|
return model |
|
|
|
|
|
|
|
|
def preprocess_weights(weights: Dict[str, torch.Tensor], |
|
|
model_config: PretrainedConfig, |
|
|
from_pruned=False) -> None: |
|
|
"""This function in-place modifies weights and model_config, making them compatible with each other. |
|
|
|
|
|
Note: Typically, it should be called before model creation and weight loading. For example, |
|
|
preprocess_weights(weights, model_config) |
|
|
model = XXXForCausalLM(model_config) |
|
|
model.load(weights) |
|
|
""" |
|
|
quant_algo = model_config.quantization.quant_algo |
|
|
kv_cache_quant_algo = model_config.quantization.kv_cache_quant_algo |
|
|
|
|
|
|
|
|
if quant_algo == QuantAlgo.W4A8_AWQ or quant_algo == QuantAlgo.W4A16_AWQ: |
|
|
preprocessor = torch.ops.trtllm.preprocess_weights_for_mixed_gemm |
|
|
if quant_algo == QuantAlgo.W4A8_AWQ: |
|
|
activation_type = torch.float8_e4m3fn |
|
|
elif quant_algo == QuantAlgo.W4A16_AWQ: |
|
|
activation_type = torch.float16 |
|
|
for name, param in weights.items(): |
|
|
if from_pruned and param.numel() == 0: |
|
|
continue |
|
|
if name.endswith('weight') and param.dtype == torch.int8: |
|
|
dtype = torch.float16 |
|
|
if model_config.dtype == "bfloat16": |
|
|
dtype = torch.bfloat16 |
|
|
weights[name] = preprocessor(param.T.contiguous(), |
|
|
torch.quint4x2, |
|
|
activation_type).view(dtype) |
|
|
if name.endswith('weights_scaling_factor'): |
|
|
weights[name] = param.T.contiguous().to( |
|
|
str_dtype_to_torch(model_config.dtype)) |
|
|
if name.endswith('prequant_scaling_factor'): |
|
|
weights[name] = param.reshape(1, -1) |
|
|
if model_config.mapping.tp_rank > 0: |
|
|
if name.endswith('attention.dense.bias') or name.endswith( |
|
|
'mlp.proj.bias'): |
|
|
weights[name] = torch.zeros_like(param) |
|
|
|
|
|
if quant_algo == QuantAlgo.W4A8_AWQ: |
|
|
for name in list(weights): |
|
|
if name.endswith('weights_scaling_factor'): |
|
|
activation_scaling_factor = weights.pop( |
|
|
name.replace('weights_scaling_factor', |
|
|
'activation_scaling_factor')) |
|
|
weights_scaling_factor_2 = weights.pop( |
|
|
name.replace('weights_scaling_factor', |
|
|
'weights_scaling_factor_2')) |
|
|
weights[name] /= weights_scaling_factor_2 |
|
|
weights[name.replace( |
|
|
'weights_scaling_factor', |
|
|
'prequant_scaling_factor')] /= activation_scaling_factor |
|
|
weights[name.replace( |
|
|
'weights_scaling_factor', 'alpha' |
|
|
)] = activation_scaling_factor * weights_scaling_factor_2 |
|
|
|
|
|
|
|
|
elif quant_algo == QuantAlgo.FP8: |
|
|
for name, param in weights.items(): |
|
|
if name.endswith('weight') and param.dtype == torch.int8: |
|
|
weights[name] = param.view(torch.float8_e4m3fn) |
|
|
|
|
|
if "lm_head.weight" in weights: |
|
|
assert weights['lm_head.weight'].dtype == str_dtype_to_torch( |
|
|
model_config.dtype) |
|
|
weights.pop('lm_head.weights_scaling_factor', None) |
|
|
weights.pop('lm_head.activation_scaling_factor', None) |
|
|
elif quant_algo == QuantAlgo.FP8_PER_CHANNEL_PER_TOKEN: |
|
|
for name, param in weights.items(): |
|
|
if name.endswith('weight') and param.dtype == torch.int8: |
|
|
weights[name] = param.view(torch.float8_e4m3fn) |
|
|
|
|
|
if "lm_head.weight" in weights: |
|
|
assert weights['lm_head.weight'].dtype == str_dtype_to_torch( |
|
|
model_config.dtype) |
|
|
weights.pop('lm_head.weights_scaling_factor', None) |
|
|
weights.pop('lm_head.activation_scaling_factor', None) |
|
|
|
|
|
elif quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16]: |
|
|
weights = weight_only_quantize_dict(weights=weights, |
|
|
quant_algo=quant_algo, |
|
|
plugin=True) |
|
|
|
|
|
|
|
|
if kv_cache_quant_algo == QuantAlgo.FP8: |
|
|
for name, param in weights.items(): |
|
|
if name.endswith('kv_cache_scaling_factor'): |
|
|
weights[name] = torch.tensor([1.0], dtype=torch.float32) |
|
|
|
|
|
|
|
|
elif model_config.architecture == 'GPTJForCausalLM': |
|
|
if model_config.mapping.tp_rank > 0: |
|
|
for name, param in weights.items(): |
|
|
if 'attention.dense.bias' in name or 'mlp.proj.bias' in name: |
|
|
weights[name] = torch.zeros_like(param) |
|
|
|
|
|
|
|
|
check_share_embedding(weights, model_config) |
|
|
|
|
|
|
|
|
def check_share_embedding(weights: Dict[str, torch.Tensor], |
|
|
model_config: PretrainedConfig): |
|
|
if model_config.share_embedding_table: |
|
|
if "lm_head.weight" in weights and "transformer.vocab_embedding.weight" in weights: |
|
|
if (weights["lm_head.weight"] - |
|
|
weights["transformer.vocab_embedding.weight"]).any(): |
|
|
logger.warning( |
|
|
"lm_head.weight and transformer.vocab_embedding.weight are not identical, " |
|
|
"share_embedding_table cannot be enabled; setting share_embedding_table=False." |
|
|
) |
|
|
model_config.share_embedding_table = False |
|
|
else: |
|
|
weights.pop("lm_head.weight") |
|
|
|