|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import logging |
|
|
import os |
|
|
import time |
|
|
|
|
|
import torch |
|
|
from torch.distributed.distributed_c10d import _world |
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
root_logger = logging.getLogger() |
|
|
root_logger.handlers.clear() |
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s", |
|
|
level=logging.INFO, |
|
|
) |
|
|
|
|
|
torch.manual_seed(42) |
|
|
torch.npu.manual_seed_all(42) |
|
|
|
|
|
|
|
|
def get_init_attn_mask(mask_length, device, valid_len=None): |
|
|
share_mask_tril = ~torch.tril( |
|
|
torch.ones((mask_length, mask_length), dtype=torch.bool, device=device) |
|
|
) |
|
|
if valid_len is not None: |
|
|
share_mask_tril[-valid_len:, :] = torch.zeros(valid_len, mask_length) |
|
|
return share_mask_tril |
|
|
|
|
|
|
|
|
def get_decode_mask(mask_length, device, position): |
|
|
decode_mask = torch.zeros((1, mask_length), device=device) |
|
|
decode_mask[0, :position] = 1 |
|
|
return decode_mask |
|
|
|
|
|
|
|
|
def sample(input_logits: torch.Tensor, temperature=1.0, top_p=0.0, top_k=0, top_n_sigma=-1.0, **kwargs): |
|
|
|
|
|
|
|
|
if temperature <= 0.0 or top_k == 1 or top_p == 0.0 or top_n_sigma == 0.0: |
|
|
return torch.argmax(input_logits, dim=-1) |
|
|
|
|
|
logits = input_logits / temperature |
|
|
|
|
|
filter_value = -3.4028e+38 |
|
|
|
|
|
|
|
|
if top_n_sigma > 0.0: |
|
|
max_vals, _ = logits.max(dim=-1, keepdim=True) |
|
|
std_vals = logits.std(dim=-1, keepdim=True) |
|
|
threshold = max_vals - top_n_sigma * std_vals |
|
|
mask = logits < threshold |
|
|
logits = torch.where(mask, filter_value, logits) |
|
|
|
|
|
|
|
|
if top_k > 0: |
|
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
|
logits[indices_to_remove] = filter_value |
|
|
|
|
|
|
|
|
if 0.0 < top_p < 1.0: |
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False) |
|
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
|
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
|
|
|
|
|
sorted_indices_to_remove[..., -1:] = 0 |
|
|
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) |
|
|
logits = logits.masked_fill(indices_to_remove, filter_value) |
|
|
|
|
|
probs = logits.softmax(dim=-1) |
|
|
outputs = torch.multinomial(probs.squeeze(1), num_samples=1) |
|
|
return outputs |
|
|
|
|
|
|
|
|
class ModelRunner: |
|
|
def __init__(self, runner_config): |
|
|
self.runner_config = runner_config |
|
|
self.model_name = runner_config.get("model_name", "default_model_name") |
|
|
model_path = self.runner_config.get("model_path") |
|
|
self.dtype = runner_config.get("model_config").get("dtype", torch.bfloat16) |
|
|
self.max_position_embeddings = runner_config.get("data_config").get( |
|
|
"max_position_embeddings", 131072 |
|
|
) |
|
|
self.input_max_len = runner_config.get("data_config").get("input_max_len", 1024) |
|
|
self.max_new_tokens = runner_config.get("data_config").get("max_new_tokens", 32) |
|
|
self.batch_size = runner_config.get("data_config").get("batch_size", 16) |
|
|
self.sampling_params = runner_config.get("sampling_config", {}) |
|
|
self.tokenizer = None |
|
|
self.model = None |
|
|
self.device = None |
|
|
self.local_rank = int(os.getenv("LOCAL_RANK", "0")) |
|
|
self.rank_offset = int(os.getenv("RANK_OFFSET", "0")) |
|
|
self.global_rank = self.local_rank + self.rank_offset |
|
|
self.world_size = int(os.getenv("WORLD_SIZE", "1")) |
|
|
if self.world_size == 1: |
|
|
self.model_path = model_path |
|
|
else: |
|
|
self.model_path = os.path.join(model_path, f"rank_{self.global_rank}") |
|
|
|
|
|
self.res_path = os.getenv("RES_PATH", "./") |
|
|
self.enable_profiler = runner_config.get("model_config").get( |
|
|
"enable_profiler", 0 |
|
|
) |
|
|
self.use_pretrained_model = True |
|
|
self.execute_mode = runner_config.get("exe_mode", "dynamo") |
|
|
self.tokenizer_mode = runner_config.get("model_config").get( |
|
|
"tokenizer_mode", "default" |
|
|
) |
|
|
self.init_device() |
|
|
self.start_time = None |
|
|
self.end_time = None |
|
|
self.with_ckpt = runner_config.get("model_config").get("with_ckpt", 1) |
|
|
|
|
|
@staticmethod |
|
|
def repeat_batch(tensor, repeat_num): |
|
|
if repeat_num == 1: |
|
|
return tensor |
|
|
return tensor.repeat(repeat_num, *[1] * (tensor.dim() - 1)) |
|
|
|
|
|
def init_device(self): |
|
|
logging.info( |
|
|
"Set execution using npu index: %s, global: %s", |
|
|
self.local_rank, |
|
|
self.global_rank, |
|
|
) |
|
|
self.device = torch.device("%s:%s" % ("npu", self.local_rank)) |
|
|
torch.npu.set_device(self.device) |
|
|
if torch.npu.is_available() and self.world_size > 1: |
|
|
if _world._default_pg is None: |
|
|
torch.distributed.init_process_group( |
|
|
backend="hccl", world_size=self.world_size, rank=self.global_rank |
|
|
) |
|
|
|
|
|
def init_model(self, model, config=None): |
|
|
if self.with_ckpt: |
|
|
self.use_pretrained_model = True |
|
|
config = None |
|
|
else: |
|
|
self.use_pretrained_model = False |
|
|
from configuration_openpangu_moe import PanguUltraMoEConfig as config |
|
|
logging.info(f"use_pretrained_model: {self.use_pretrained_model}") |
|
|
|
|
|
if self.use_pretrained_model: |
|
|
self.load_model(model) |
|
|
else: |
|
|
self.init_model_from_config(model, config=config) |
|
|
self.to_device() |
|
|
self.compile_model() |
|
|
self.init_tokenizer() |
|
|
|
|
|
def init_model_from_config(self, model, config): |
|
|
if config is None: |
|
|
raise Exception("config cannot be None") |
|
|
config_file = f"{self.model_path}/config.json" |
|
|
model_config = config.from_pretrained( |
|
|
config_file, |
|
|
torch_dtype=self.dtype, |
|
|
max_position_embeddings=self.max_position_embeddings, |
|
|
) |
|
|
self.model = model(model_config, runner_config=self.runner_config).to( |
|
|
self.dtype |
|
|
) |
|
|
|
|
|
def load_model(self, model): |
|
|
logging.info("Try to load pretrained model in path: %s", self.model_path) |
|
|
self.model = model.from_pretrained( |
|
|
self.model_path, |
|
|
low_cpu_mem_usage=True, |
|
|
ignore_mismatched_sizes=True, |
|
|
torch_dtype=self.dtype, |
|
|
max_position_embeddings=self.max_position_embeddings, |
|
|
runner_config=self.runner_config, |
|
|
) |
|
|
for name, params in self.model.named_parameters(): |
|
|
logging.info( |
|
|
"Param of %s: %s, %s, %s", |
|
|
self.model_name, |
|
|
name, |
|
|
params.size(), |
|
|
params.dtype, |
|
|
) |
|
|
|
|
|
def to_device(self): |
|
|
self.model.to(self.device) |
|
|
logging.info("Model weights H2D finished.") |
|
|
|
|
|
def init_tokenizer(self): |
|
|
self.tokenizer = AutoTokenizer.from_pretrained( |
|
|
self.model_path, |
|
|
trust_remote_code=True, |
|
|
local_files_only=True, |
|
|
padding_side="right", |
|
|
truncation_side="right", |
|
|
) |
|
|
if self.tokenizer.pad_token is None: |
|
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id |
|
|
|
|
|
def compile_model(self): |
|
|
logging.info("The final model structure is: \n %s", self.model) |
|
|
if self.execute_mode == "dynamo": |
|
|
logging.info("Try to compile model") |
|
|
self.graph_compile() |
|
|
|
|
|
def graph_compile(self): |
|
|
import torchair as tng |
|
|
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce |
|
|
from torchair.configs.compiler_config import CompilerConfig |
|
|
|
|
|
compiler_config = CompilerConfig() |
|
|
compiler_config.experimental_config.frozen_parameter = True |
|
|
compiler_config.experimental_config.tiling_schedule_optimize = True |
|
|
npu_backend = tng.get_npu_backend(compiler_config=compiler_config) |
|
|
self.model = torch.compile( |
|
|
self.model, dynamic=True, fullgraph=True, backend=npu_backend |
|
|
) |
|
|
|
|
|
def mark_inputs(self, model_inputs): |
|
|
if self.execute_mode == "dynamo": |
|
|
input_ids = model_inputs.get("input_ids") |
|
|
kv_len = model_inputs.get("kv_len") |
|
|
attention_mask = model_inputs.get("attention_mask") |
|
|
position_ids = model_inputs.get("position_ids") |
|
|
past_key_values = model_inputs.get("past_key_values") |
|
|
|
|
|
|
|
|
torch._dynamo.mark_static(kv_len) |
|
|
for item in past_key_values: |
|
|
for sub_item in item: |
|
|
torch._dynamo.mark_static(sub_item) |
|
|
|
|
|
torch._dynamo.mark_static(input_ids) |
|
|
if attention_mask is not None: |
|
|
torch._dynamo.mark_static(attention_mask) |
|
|
torch._dynamo.mark_static(position_ids) |
|
|
|
|
|
def model_input_prepare(self, input_dict): |
|
|
input_ids = input_dict.get("input_ids") |
|
|
attention_mask = input_dict.get("attention_mask") |
|
|
past_key_values = input_dict.get("past_key_values") |
|
|
is_prefill = input_dict.get("is_prefill") |
|
|
kv_len = input_dict.get("kv_len") |
|
|
share_mask_tril = input_dict.get("share_mask_tril") |
|
|
model_inputs = self.model.prepare_inputs_for_generation( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
past_key_values=past_key_values, |
|
|
is_prefill=is_prefill, |
|
|
kv_len=kv_len, |
|
|
input_lens=input_dict.get("input_lens"), |
|
|
share_mask_tril=share_mask_tril, |
|
|
) |
|
|
return model_inputs |
|
|
|
|
|
def model_inference(self, model_inputs, warm_up=False): |
|
|
torch.npu.synchronize() |
|
|
if warm_up: |
|
|
self.mark_inputs(model_inputs) |
|
|
if self.start_time is None: |
|
|
self.start_time = time.time() |
|
|
with torch.no_grad(): |
|
|
logits = self.model(**model_inputs) |
|
|
torch.npu.synchronize() |
|
|
self.end_time = time.time() |
|
|
if torch.distributed.get_rank() != 0: |
|
|
logging.info( |
|
|
f"{self.model_name} inference time cost {(self.end_time - self.start_time)*1000:.2f} ms" |
|
|
) |
|
|
self.start_time = time.time() |
|
|
return logits |
|
|
|
|
|
def model_generate(self, prompts, warm_up=False, **kwargs): |
|
|
calling_func = { |
|
|
"default": self.tokenizer, |
|
|
"chat": self.tokenizer.apply_chat_template, |
|
|
} |
|
|
kwargs = { |
|
|
"return_tensors": "pt", |
|
|
"truncation": True, |
|
|
"padding": "max_length", |
|
|
"max_length": self.input_max_len, |
|
|
} |
|
|
if self.tokenizer_mode == "chat": |
|
|
chat_kwargs = {"add_generation_prompt": True, "return_dict": True} |
|
|
kwargs.update(chat_kwargs) |
|
|
tokenizer = calling_func.get(self.tokenizer_mode, self.tokenizer) |
|
|
inputs = tokenizer(prompts, **kwargs).to(self.device) |
|
|
|
|
|
|
|
|
share_mask_tril = get_init_attn_mask( |
|
|
self.max_position_embeddings, self.device, valid_len=self.input_max_len |
|
|
) |
|
|
share_mask_tril = share_mask_tril[None, None, ...] |
|
|
|
|
|
input_lens = copy.deepcopy(inputs.input_ids.size()[1]) |
|
|
logging.info("Padding max prompts lens is : %d", input_lens) |
|
|
input_dict = { |
|
|
"input_ids": inputs.input_ids, |
|
|
"generate_ids": inputs.input_ids, |
|
|
"input_lens": input_lens, |
|
|
"kv_len": None, |
|
|
"past_key_values": None, |
|
|
"attention_mask": inputs.attention_mask, |
|
|
"share_mask_tril": share_mask_tril, |
|
|
"is_prefill": True, |
|
|
} |
|
|
|
|
|
if torch.distributed.get_rank() == 0: |
|
|
logging.info( |
|
|
f"inputs.input_ids {inputs.input_ids[:,:30]} eod id {self.tokenizer.eos_token_id}" |
|
|
) |
|
|
|
|
|
generate_tokens = 0 |
|
|
cnt = 0 |
|
|
all_done = [False for _ in range(input_dict["input_ids"].size(0))] |
|
|
done_len = [-1 for _ in range(input_dict["input_ids"].size(0))] |
|
|
while True: |
|
|
jump_flag = self.get_jump_flag(cnt, warm_up, generate_tokens) |
|
|
if jump_flag: |
|
|
break |
|
|
|
|
|
|
|
|
if input_dict["input_ids"].size(1) == 1: |
|
|
for bi in range(input_dict["input_ids"].size(0)): |
|
|
if ( |
|
|
input_dict["input_ids"][bi, 0].item() |
|
|
== self.tokenizer.eos_token_id |
|
|
): |
|
|
all_done[bi] = True |
|
|
done_len[bi] = generate_tokens |
|
|
if all(all_done): |
|
|
break |
|
|
model_inputs = self.model_input_prepare(input_dict) |
|
|
|
|
|
if model_inputs["position_ids"].shape[1] == 1: |
|
|
model_inputs["attention_mask"].fill_(-3.4028e38) |
|
|
for bi in range(model_inputs["position_ids"].size(0)): |
|
|
max_l = model_inputs["position_ids"][bi].max().item() |
|
|
model_inputs["attention_mask"][bi, :, :, : max_l + 1] = 0 |
|
|
outputs = self.model_inference(model_inputs, warm_up=warm_up) |
|
|
self.model_output_process(model_inputs, outputs, input_dict) |
|
|
|
|
|
generate_tokens += 1 |
|
|
cnt += 1 |
|
|
|
|
|
generate_ids = input_dict["generate_ids"][:, input_lens:].clip( |
|
|
0, self.model.config.vocab_size - 1 |
|
|
) |
|
|
for bi in range(generate_ids.size(0)): |
|
|
if done_len[bi] != -1: |
|
|
generate_ids[bi, done_len[bi] :] = self.tokenizer.eos_token_id |
|
|
res = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True) |
|
|
|
|
|
if isinstance(res, list): |
|
|
for answer in res: |
|
|
logging.info("Inference decode result: \n%s", answer) |
|
|
else: |
|
|
logging.info("Inference decode result: \n%s", res) |
|
|
return res |
|
|
|
|
|
def get_jump_flag(self, cnt, warm_up, generate_tokens): |
|
|
default_decode_dump = 2 |
|
|
|
|
|
jump_flag_warm = warm_up and cnt >= default_decode_dump |
|
|
|
|
|
jump_flag_oversize = generate_tokens >= self.max_new_tokens |
|
|
jump_flag = jump_flag_oversize or jump_flag_warm |
|
|
return jump_flag |
|
|
|
|
|
def model_output_process(self, model_inputs, outputs, input_dict): |
|
|
next_batch = self.batch_size |
|
|
attn_tp_size = self.runner_config.get("parallel_config").get("attn_tp_size", 1) |
|
|
if self.world_size % attn_tp_size != 0: |
|
|
raise Exception( |
|
|
f"world_size ({self.world_siz}) not divisible by attn_tp_size ({attn_tp_size})!" |
|
|
) |
|
|
attn_dp_size = self.world_size // attn_tp_size |
|
|
input_dict["is_prefill"] = False |
|
|
input_dict["input_lens"] = input_dict["input_lens"] + 1 |
|
|
|
|
|
kv_len = torch.max(model_inputs.get("position_ids"), axis=1)[0] + 1 |
|
|
input_dict["kv_len"] = kv_len |
|
|
|
|
|
logits = outputs |
|
|
past_key_values = model_inputs.get("past_key_values") |
|
|
input_dict["past_key_values"] = past_key_values |
|
|
|
|
|
attention_mask = None |
|
|
|
|
|
share_mask_tril = get_decode_mask( |
|
|
mask_length=self.max_position_embeddings, |
|
|
device=self.device, |
|
|
position=input_dict["input_lens"], |
|
|
) |
|
|
share_mask_tril = share_mask_tril[None, None, ...] |
|
|
|
|
|
input_dict["attention_mask"] = attention_mask |
|
|
input_dict["share_mask_tril"] = ModelRunner.repeat_batch( |
|
|
share_mask_tril, self.batch_size |
|
|
) |
|
|
|
|
|
next_tokens = sample(logits, **self.sampling_params) |
|
|
torch.distributed.broadcast(next_tokens, src=0) |
|
|
input_dict["input_ids"] = next_tokens |
|
|
input_dict["generate_ids"] = torch.cat( |
|
|
[input_dict["generate_ids"], next_tokens], dim=-1 |
|
|
) |
|
|
|