|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
|
|
|
def gen_suffix(rank, use_smooth_quant, quant_per_channel): |
|
|
suffix = f"{rank}.bin" |
|
|
if use_smooth_quant: |
|
|
sq_prefix = "int8." |
|
|
if quant_per_channel: |
|
|
sq_prefix += "col." |
|
|
suffix = sq_prefix + suffix |
|
|
return suffix |
|
|
|
|
|
|
|
|
def extract_layer_idx(name): |
|
|
ss = name.split('.') |
|
|
for s in ss: |
|
|
if s.isdigit(): |
|
|
return s |
|
|
return None |
|
|
|
|
|
|
|
|
def split(v: Union[np.ndarray, torch.Tensor], |
|
|
tp_size: int, |
|
|
tp_rank: int, |
|
|
dim=0): |
|
|
if tp_size == 1: |
|
|
return v |
|
|
assert len(v.shape) > 1 or dim == 0 |
|
|
if isinstance(v, np.ndarray): |
|
|
return np.ascontiguousarray( |
|
|
np.split(v, tp_size, axis=dim)[tp_rank].copy()) |
|
|
else: |
|
|
assert v.shape[dim] % tp_size == 0, \ |
|
|
'Unable to split: shape={v.shape} (dim={dim}) tp_size={tp_size}.' |
|
|
split_size = v.shape[dim] // tp_size |
|
|
return v.split(split_size, dim=dim)[tp_rank].clone().detach() |
|
|
|
|
|
|
|
|
def dup_kv_weight(v, num_head, tp_size): |
|
|
assert tp_size % num_head == 0 |
|
|
reps = tp_size // num_head |
|
|
head_size = v.shape[0] // num_head |
|
|
v = v.reshape(num_head, head_size, |
|
|
-1)[:, None, :, :].expand(num_head, reps, head_size, |
|
|
v.shape[1]) |
|
|
return v.reshape(num_head * reps * head_size, -1).clone().detach() |
|
|
|