|
|
import re |
|
|
from pathlib import Path |
|
|
from typing import Dict, Optional, Union |
|
|
|
|
|
import torch |
|
|
from datasets import load_dataset |
|
|
|
|
|
from ..quantization import QuantAlgo |
|
|
|
|
|
|
|
|
def split(v, tp_size, idx, dim=0): |
|
|
if tp_size == 1: |
|
|
return v |
|
|
if len(v.shape) == 1: |
|
|
return torch.chunk(v, tp_size)[idx].contiguous() |
|
|
else: |
|
|
return torch.chunk(v, tp_size, dim=dim)[idx].clone() |
|
|
|
|
|
|
|
|
def split_qkv_tp(v, n_head, n_hidden, tensor_parallel, rank): |
|
|
""" |
|
|
Splits the QKV matrix according to tensor parallelism |
|
|
""" |
|
|
v = v.reshape(3, n_hidden, n_hidden) |
|
|
split_v = split(v, tensor_parallel, rank, dim=1) |
|
|
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel), n_hidden) |
|
|
return split_v.clone() |
|
|
|
|
|
|
|
|
def split_qkv_bias_tp(v, n_head, n_hidden, tensor_parallel, rank): |
|
|
""" |
|
|
Splits the QKV bias according to tensor parallelism |
|
|
""" |
|
|
v = v.reshape(3, n_hidden) |
|
|
split_v = split(v, tensor_parallel, rank, dim=1) |
|
|
split_v = split_v.reshape(3 * (n_hidden // tensor_parallel)) |
|
|
return split_v.clone() |
|
|
|
|
|
|
|
|
def split_matrix_tp(v, tensor_parallel, rank, dim): |
|
|
return split(v, tensor_parallel, rank, dim=dim) |
|
|
|
|
|
|
|
|
def weight_only_quantize(weight: torch.Tensor, |
|
|
quant_algo: str, |
|
|
plugin: bool = True): |
|
|
assert quant_algo in [QuantAlgo.W4A16, QuantAlgo.W8A16 |
|
|
], f'unsupported quant algo: {quant_algo}' |
|
|
if quant_algo == QuantAlgo.W4A16: |
|
|
assert plugin, 'W4A16 is only supported with plugin' |
|
|
if weight.dim() > 2: |
|
|
v = weight.transpose(-1, -2) |
|
|
else: |
|
|
v = weight.t() |
|
|
t = torch.quint4x2 if quant_algo == QuantAlgo.W4A16 else torch.int8 |
|
|
processed_torch_weights, torch_weight_scales = \ |
|
|
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( |
|
|
v.contiguous(), t) |
|
|
if plugin: |
|
|
return processed_torch_weights, torch_weight_scales |
|
|
else: |
|
|
return v, torch_weight_scales |
|
|
|
|
|
|
|
|
def weight_only_quantize_dict(weights: Dict[str, torch.Tensor], |
|
|
quant_algo: str, |
|
|
quant_weights=[ |
|
|
'qkv.weight', 'dense.weight', 'fc.weight', |
|
|
'proj.weight', 'gate.weight' |
|
|
], |
|
|
exclude_weights=['shared_expert_gate.weight'], |
|
|
plugin: bool = True): |
|
|
if quant_algo not in [QuantAlgo.W4A16, QuantAlgo.W8A16]: |
|
|
return weights |
|
|
for name in list(weights): |
|
|
if any([_name in name for _name in exclude_weights]): |
|
|
continue |
|
|
if any([_name in name for _name in quant_weights |
|
|
]) and weights[name].dtype != torch.int8: |
|
|
quant_weight, quant_scale = weight_only_quantize( |
|
|
weight=weights[name], quant_algo=quant_algo, plugin=plugin) |
|
|
weights[name] = quant_weight |
|
|
weights[name.replace('.weight', '.per_channel_scale')] = quant_scale |
|
|
return weights |
|
|
|
|
|
|
|
|
def load_state_dict( |
|
|
file_path: Union[str, Path], |
|
|
dtype: Optional[torch.dtype] = None, |
|
|
device: Optional[Union[str, torch.device]] = None, |
|
|
) -> Dict[str, torch.Tensor]: |
|
|
""" Load weights from model file. |
|
|
|
|
|
`safetensors` or `pytorch binary` is supported. |
|
|
Args: |
|
|
file_path: model file path, ends with .bin or .safetensors. |
|
|
dtype: torch.dtype, data type. |
|
|
device: torch device like, optional. If None, load to cpu. |
|
|
Returns: |
|
|
Weights as state dict. |
|
|
""" |
|
|
file_path = Path(file_path) |
|
|
if dtype is not None: |
|
|
assert isinstance(dtype, torch.dtype) |
|
|
|
|
|
if device is None: |
|
|
device = 'cpu' |
|
|
|
|
|
model_params = {} |
|
|
if file_path.suffix == '.safetensors': |
|
|
|
|
|
from safetensors import safe_open |
|
|
with safe_open(file_path, framework='pt', device=device) as f: |
|
|
for name in f.keys(): |
|
|
tensor = f.get_tensor(name) |
|
|
if dtype is not None: |
|
|
tensor = tensor.to(dtype) |
|
|
model_params[name] = tensor |
|
|
elif file_path.suffix == '.bin': |
|
|
|
|
|
state_dict = torch.load(file_path, map_location=device) |
|
|
for name in state_dict: |
|
|
tensor = state_dict[name] |
|
|
if dtype is not None: |
|
|
tensor = tensor.to(dtype) |
|
|
model_params[name] = tensor |
|
|
else: |
|
|
raise NotImplementedError( |
|
|
f'Support .safetensors or .bin files, but got {str(file_path)}') |
|
|
return model_params |
|
|
|
|
|
|
|
|
def get_model_path( |
|
|
model_dir: Union[str, Path], |
|
|
name: Optional[str] = None, |
|
|
) -> Optional[str]: |
|
|
""" Get model path from model directory. |
|
|
|
|
|
`safetensors` or `pytorch binary` is supported. |
|
|
Args: |
|
|
model_dir: model directory. |
|
|
name: model file name without suffix. |
|
|
Returns: |
|
|
Full model path. |
|
|
""" |
|
|
model_dir = Path(model_dir) |
|
|
if name is not None: |
|
|
if (model_dir / f"{name}.safetensors").exists(): |
|
|
return str(model_dir / f"{name}.safetensors") |
|
|
elif (model_dir / f"{name}.bin").exists(): |
|
|
return str(model_dir / f"{name}.bin") |
|
|
else: |
|
|
return None |
|
|
else: |
|
|
model_files = list(model_dir.glob('*.safetensors')) |
|
|
if len(model_files) > 0: |
|
|
assert len( |
|
|
model_files |
|
|
) == 1, f"find multiple safetensors files in {model_dir}, please specify one" |
|
|
return str(model_files[0]) |
|
|
model_files = list(model_dir.glob('*.bin')) |
|
|
if len(model_files) > 0: |
|
|
assert len( |
|
|
model_files |
|
|
) == 1, f"find multiple bin files in {model_dir}, please specify one" |
|
|
return str(model_files[0]) |
|
|
return None |
|
|
|
|
|
|
|
|
def retrieved_layer_index_from_name(name: str) -> Optional[int]: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
res = re.search(r'\d+', name) |
|
|
return int(res.group()) if res is not None else res |
|
|
|
|
|
|
|
|
def iterate_shard_files(model_dir: Union[Path, str], |
|
|
rank: int, |
|
|
progress_bar: bool = True): |
|
|
model_dir = Path(model_dir) |
|
|
|
|
|
|
|
|
|
|
|
shard_files = list(model_dir.glob('*.safetensors')) |
|
|
if not shard_files: |
|
|
|
|
|
shard_files = list(model_dir.glob('*.bin')) |
|
|
if not shard_files: |
|
|
raise RuntimeError( |
|
|
f"Could not find any .safetensors or .bin files in {model_dir}") |
|
|
|
|
|
try: |
|
|
import tqdm |
|
|
if progress_bar: |
|
|
|
|
|
desc = f'Rank [{rank}] Loading weights' |
|
|
shard_files = tqdm.tqdm(shard_files, desc=desc, position=rank) |
|
|
|
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
for shard_file in shard_files: |
|
|
yield shard_file |
|
|
|
|
|
|
|
|
def has_safetensors(model_dir: str): |
|
|
return len(list(Path(model_dir).glob('*.safetensors'))) > 0 |
|
|
|
|
|
|
|
|
DEFAULT_HF_DATASET_META = { |
|
|
'ccdv/cnn_dailymail': ('3.0.0', 'train', 'article'), |
|
|
'cnn_dailymail': ('3.0.0', 'train', 'article'), |
|
|
'lambada': (None, 'validation', 'text'), |
|
|
} |
|
|
|
|
|
|
|
|
def load_calib_dataset(dataset_name_or_dir: str, |
|
|
config_name: Optional[str] = None, |
|
|
split: Optional[str] = None, |
|
|
key: Optional[str] = None, |
|
|
trust_remote_code=True, |
|
|
**kwargs): |
|
|
if config_name is None: |
|
|
for name, meta in DEFAULT_HF_DATASET_META.items(): |
|
|
if name in dataset_name_or_dir: |
|
|
if config_name is None: |
|
|
config_name = meta[0] |
|
|
if split is None: |
|
|
split = meta[1] |
|
|
if key is None: |
|
|
key = meta[2] |
|
|
break |
|
|
|
|
|
dataset = load_dataset(dataset_name_or_dir, |
|
|
name=config_name, |
|
|
split=split, |
|
|
**kwargs) |
|
|
return dataset[key] |
|
|
|