wangrongsheng's picture
Add files using upload-large-folder tool
a86f2f6 verified
# coding=utf-8
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
import copy
import logging
import os
import time
import torch
from torch.distributed.distributed_c10d import _world
from transformers import AutoTokenizer
root_logger = logging.getLogger()
root_logger.handlers.clear()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
level=logging.INFO,
)
torch.manual_seed(42)
torch.npu.manual_seed_all(42)
def get_init_attn_mask(mask_length, device, valid_len=None):
share_mask_tril = ~torch.tril(
torch.ones((mask_length, mask_length), dtype=torch.bool, device=device)
)
if valid_len is not None:
share_mask_tril[-valid_len:, :] = torch.zeros(valid_len, mask_length)
return share_mask_tril
def get_decode_mask(mask_length, device, position):
decode_mask = torch.zeros((1, mask_length), device=device)
decode_mask[0, :position] = 1
return decode_mask
def sample(input_logits: torch.Tensor, temperature=1.0, top_p=0.0, top_k=0, top_n_sigma=-1.0, **kwargs):
# shape of input_logits: [batch_size, 1, vocab_size]
# greedy
if temperature <= 0.0 or top_k == 1 or top_p == 0.0 or top_n_sigma == 0.0:
return torch.argmax(input_logits, dim=-1)
logits = input_logits / temperature
filter_value = -3.4028e+38
# top_n_sigma truncation
if top_n_sigma > 0.0:
max_vals, _ = logits.max(dim=-1, keepdim=True)
std_vals = logits.std(dim=-1, keepdim=True)
threshold = max_vals - top_n_sigma * std_vals
mask = logits < threshold
logits = torch.where(mask, filter_value, logits)
# top_k truncation
if top_k > 0:
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
logits[indices_to_remove] = filter_value
# top_p truncation
if 0.0 < top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=False)
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
sorted_indices_to_remove = cumulative_probs <= (1 - top_p)
# keep at least 1 token
sorted_indices_to_remove[..., -1:] = 0
indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
logits = logits.masked_fill(indices_to_remove, filter_value)
probs = logits.softmax(dim=-1)
outputs = torch.multinomial(probs.squeeze(1), num_samples=1)
return outputs
class ModelRunner:
def __init__(self, runner_config):
self.runner_config = runner_config
self.model_name = runner_config.get("model_name", "default_model_name")
model_path = self.runner_config.get("model_path")
self.dtype = runner_config.get("model_config").get("dtype", torch.bfloat16)
self.max_position_embeddings = runner_config.get("data_config").get(
"max_position_embeddings", 131072
)
self.input_max_len = runner_config.get("data_config").get("input_max_len", 1024)
self.max_new_tokens = runner_config.get("data_config").get("max_new_tokens", 32)
self.batch_size = runner_config.get("data_config").get("batch_size", 16)
self.sampling_params = runner_config.get("sampling_config", {})
self.tokenizer = None
self.model = None
self.device = None
self.local_rank = int(os.getenv("LOCAL_RANK", "0"))
self.rank_offset = int(os.getenv("RANK_OFFSET", "0"))
self.global_rank = self.local_rank + self.rank_offset
self.world_size = int(os.getenv("WORLD_SIZE", "1"))
if self.world_size == 1:
self.model_path = model_path
else:
self.model_path = os.path.join(model_path, f"rank_{self.global_rank}")
self.res_path = os.getenv("RES_PATH", "./")
self.enable_profiler = runner_config.get("model_config").get(
"enable_profiler", 0
)
self.use_pretrained_model = True
self.execute_mode = runner_config.get("exe_mode", "dynamo")
self.tokenizer_mode = runner_config.get("model_config").get(
"tokenizer_mode", "default"
)
self.init_device()
self.start_time = None
self.end_time = None
self.with_ckpt = runner_config.get("model_config").get("with_ckpt", 1)
@staticmethod
def repeat_batch(tensor, repeat_num):
if repeat_num == 1:
return tensor
return tensor.repeat(repeat_num, *[1] * (tensor.dim() - 1))
def init_device(self):
logging.info(
"Set execution using npu index: %s, global: %s",
self.local_rank,
self.global_rank,
)
self.device = torch.device("%s:%s" % ("npu", self.local_rank))
torch.npu.set_device(self.device)
if torch.npu.is_available() and self.world_size > 1:
if _world._default_pg is None:
torch.distributed.init_process_group(
backend="hccl", world_size=self.world_size, rank=self.global_rank
)
def init_model(self, model, config=None):
if self.with_ckpt:
self.use_pretrained_model = True
config = None
else:
self.use_pretrained_model = False
from configuration_openpangu_moe import PanguUltraMoEConfig as config
logging.info(f"use_pretrained_model: {self.use_pretrained_model}")
if self.use_pretrained_model:
self.load_model(model)
else:
self.init_model_from_config(model, config=config)
self.to_device()
self.compile_model()
self.init_tokenizer()
def init_model_from_config(self, model, config):
if config is None:
raise Exception("config cannot be None")
config_file = f"{self.model_path}/config.json"
model_config = config.from_pretrained(
config_file,
torch_dtype=self.dtype,
max_position_embeddings=self.max_position_embeddings,
)
self.model = model(model_config, runner_config=self.runner_config).to(
self.dtype
)
def load_model(self, model):
logging.info("Try to load pretrained model in path: %s", self.model_path)
self.model = model.from_pretrained(
self.model_path,
low_cpu_mem_usage=True,
ignore_mismatched_sizes=True,
torch_dtype=self.dtype,
max_position_embeddings=self.max_position_embeddings,
runner_config=self.runner_config,
)
for name, params in self.model.named_parameters():
logging.info(
"Param of %s: %s, %s, %s",
self.model_name,
name,
params.size(),
params.dtype,
)
def to_device(self):
self.model.to(self.device)
logging.info("Model weights H2D finished.")
def init_tokenizer(self):
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_path,
trust_remote_code=True,
local_files_only=True,
padding_side="right",
truncation_side="right",
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
def compile_model(self):
logging.info("The final model structure is: \n %s", self.model)
if self.execute_mode == "dynamo":
logging.info("Try to compile model")
self.graph_compile()
def graph_compile(self):
import torchair as tng
import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce
from torchair.configs.compiler_config import CompilerConfig
compiler_config = CompilerConfig()
compiler_config.experimental_config.frozen_parameter = True
compiler_config.experimental_config.tiling_schedule_optimize = True
npu_backend = tng.get_npu_backend(compiler_config=compiler_config)
self.model = torch.compile(
self.model, dynamic=True, fullgraph=True, backend=npu_backend
)
def mark_inputs(self, model_inputs):
if self.execute_mode == "dynamo":
input_ids = model_inputs.get("input_ids")
kv_len = model_inputs.get("kv_len")
attention_mask = model_inputs.get("attention_mask")
position_ids = model_inputs.get("position_ids")
past_key_values = model_inputs.get("past_key_values")
# prefill with dynamic sequence length, decode with static sequence length
torch._dynamo.mark_static(kv_len)
for item in past_key_values:
for sub_item in item:
torch._dynamo.mark_static(sub_item)
torch._dynamo.mark_static(input_ids)
if attention_mask is not None:
torch._dynamo.mark_static(attention_mask)
torch._dynamo.mark_static(position_ids)
def model_input_prepare(self, input_dict):
input_ids = input_dict.get("input_ids")
attention_mask = input_dict.get("attention_mask")
past_key_values = input_dict.get("past_key_values")
is_prefill = input_dict.get("is_prefill")
kv_len = input_dict.get("kv_len")
share_mask_tril = input_dict.get("share_mask_tril")
model_inputs = self.model.prepare_inputs_for_generation(
input_ids=input_ids,
attention_mask=attention_mask,
past_key_values=past_key_values,
is_prefill=is_prefill,
kv_len=kv_len,
input_lens=input_dict.get("input_lens"),
share_mask_tril=share_mask_tril,
)
return model_inputs
def model_inference(self, model_inputs, warm_up=False):
torch.npu.synchronize()
if warm_up:
self.mark_inputs(model_inputs)
if self.start_time is None:
self.start_time = time.time()
with torch.no_grad():
logits = self.model(**model_inputs)
torch.npu.synchronize()
self.end_time = time.time()
if torch.distributed.get_rank() != 0:
logging.info(
f"{self.model_name} inference time cost {(self.end_time - self.start_time)*1000:.2f} ms"
)
self.start_time = time.time()
return logits
def model_generate(self, prompts, warm_up=False, **kwargs):
calling_func = {
"default": self.tokenizer,
"chat": self.tokenizer.apply_chat_template,
}
kwargs = {
"return_tensors": "pt",
"truncation": True,
"padding": "max_length",
"max_length": self.input_max_len,
}
if self.tokenizer_mode == "chat":
chat_kwargs = {"add_generation_prompt": True, "return_dict": True}
kwargs.update(chat_kwargs)
tokenizer = calling_func.get(self.tokenizer_mode, self.tokenizer)
inputs = tokenizer(prompts, **kwargs).to(self.device)
# get init input_dict
share_mask_tril = get_init_attn_mask(
self.max_position_embeddings, self.device, valid_len=self.input_max_len
)
share_mask_tril = share_mask_tril[None, None, ...]
input_lens = copy.deepcopy(inputs.input_ids.size()[1])
logging.info("Padding max prompts lens is : %d", input_lens)
input_dict = {
"input_ids": inputs.input_ids,
"generate_ids": inputs.input_ids,
"input_lens": input_lens,
"kv_len": None,
"past_key_values": None,
"attention_mask": inputs.attention_mask,
"share_mask_tril": share_mask_tril,
"is_prefill": True,
}
if torch.distributed.get_rank() == 0:
logging.info(
f"inputs.input_ids {inputs.input_ids[:,:30]} eod id {self.tokenizer.eos_token_id}"
)
generate_tokens = 0
cnt = 0
all_done = [False for _ in range(input_dict["input_ids"].size(0))]
done_len = [-1 for _ in range(input_dict["input_ids"].size(0))]
while True:
jump_flag = self.get_jump_flag(cnt, warm_up, generate_tokens)
if jump_flag:
break
# exit until all reach eod
if input_dict["input_ids"].size(1) == 1:
for bi in range(input_dict["input_ids"].size(0)):
if (
input_dict["input_ids"][bi, 0].item()
== self.tokenizer.eos_token_id
):
all_done[bi] = True
done_len[bi] = generate_tokens
if all(all_done):
break
model_inputs = self.model_input_prepare(input_dict)
# fix decode mask
if model_inputs["position_ids"].shape[1] == 1:
model_inputs["attention_mask"].fill_(-3.4028e38)
for bi in range(model_inputs["position_ids"].size(0)):
max_l = model_inputs["position_ids"][bi].max().item()
model_inputs["attention_mask"][bi, :, :, : max_l + 1] = 0
outputs = self.model_inference(model_inputs, warm_up=warm_up)
self.model_output_process(model_inputs, outputs, input_dict)
# prof.step()
generate_tokens += 1
cnt += 1
generate_ids = input_dict["generate_ids"][:, input_lens:].clip(
0, self.model.config.vocab_size - 1
)
for bi in range(generate_ids.size(0)):
if done_len[bi] != -1:
generate_ids[bi, done_len[bi] :] = self.tokenizer.eos_token_id
res = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True)
if isinstance(res, list):
for answer in res:
logging.info("Inference decode result: \n%s", answer)
else:
logging.info("Inference decode result: \n%s", res)
return res
def get_jump_flag(self, cnt, warm_up, generate_tokens):
default_decode_dump = 2
# warm up only perform for 5 times(decode)
jump_flag_warm = warm_up and cnt >= default_decode_dump
# do not generate after max_token
jump_flag_oversize = generate_tokens >= self.max_new_tokens
jump_flag = jump_flag_oversize or jump_flag_warm
return jump_flag
def model_output_process(self, model_inputs, outputs, input_dict):
next_batch = self.batch_size
attn_tp_size = self.runner_config.get("parallel_config").get("attn_tp_size", 1)
if self.world_size % attn_tp_size != 0:
raise Exception(
f"world_size ({self.world_siz}) not divisible by attn_tp_size ({attn_tp_size})!"
)
attn_dp_size = self.world_size // attn_tp_size
input_dict["is_prefill"] = False
input_dict["input_lens"] = input_dict["input_lens"] + 1
kv_len = torch.max(model_inputs.get("position_ids"), axis=1)[0] + 1
input_dict["kv_len"] = kv_len
logits = outputs
past_key_values = model_inputs.get("past_key_values")
input_dict["past_key_values"] = past_key_values
attention_mask = None
share_mask_tril = get_decode_mask(
mask_length=self.max_position_embeddings,
device=self.device,
position=input_dict["input_lens"],
)
share_mask_tril = share_mask_tril[None, None, ...]
input_dict["attention_mask"] = attention_mask
input_dict["share_mask_tril"] = ModelRunner.repeat_batch(
share_mask_tril, self.batch_size
)
next_tokens = sample(logits, **self.sampling_params)
torch.distributed.broadcast(next_tokens, src=0)
input_dict["input_ids"] = next_tokens
input_dict["generate_ids"] = torch.cat(
[input_dict["generate_ids"], next_tokens], dim=-1
)