# coding=utf-8 # Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved. import copy import logging import os import time import torch from torch.distributed.distributed_c10d import _world from transformers import AutoTokenizer root_logger = logging.getLogger() root_logger.handlers.clear() logging.basicConfig( format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s", level=logging.INFO, ) torch.manual_seed(42) torch.npu.manual_seed_all(42) def get_init_attn_mask(mask_length, device, valid_len=None): share_mask_tril = ~torch.tril( torch.ones((mask_length, mask_length), dtype=torch.bool, device=device) ) if valid_len is not None: share_mask_tril[-valid_len:, :] = torch.zeros(valid_len, mask_length) return share_mask_tril def get_decode_mask(mask_length, device, position): decode_mask = torch.zeros((1, mask_length), device=device) decode_mask[0, :position] = 1 return decode_mask def sample(input_logits: torch.Tensor, temperature=1.0, top_p=0.0, top_k=0, top_n_sigma=-1.0, **kwargs): # shape of input_logits: [batch_size, 1, vocab_size] # greedy if temperature <= 0.0 or top_k == 1 or top_p == 0.0 or top_n_sigma == 0.0: return torch.argmax(input_logits, dim=-1) logits = input_logits / temperature filter_value = -3.4028e+38 # top_n_sigma truncation if top_n_sigma > 0.0: max_vals, _ = logits.max(dim=-1, keepdim=True) std_vals = logits.std(dim=-1, keepdim=True) threshold = max_vals - top_n_sigma * std_vals mask = logits < threshold logits = torch.where(mask, filter_value, logits) # top_k truncation if top_k > 0: indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] logits[indices_to_remove] = filter_value # top_p truncation if 0.0 < top_p < 1.0: sorted_logits, sorted_indices = torch.sort(logits, descending=False) cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) sorted_indices_to_remove = cumulative_probs <= (1 - top_p) # keep at least 1 token sorted_indices_to_remove[..., -1:] = 0 indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove) logits = logits.masked_fill(indices_to_remove, filter_value) probs = logits.softmax(dim=-1) outputs = torch.multinomial(probs.squeeze(1), num_samples=1) return outputs class ModelRunner: def __init__(self, runner_config): self.runner_config = runner_config self.model_name = runner_config.get("model_name", "default_model_name") model_path = self.runner_config.get("model_path") self.dtype = runner_config.get("model_config").get("dtype", torch.bfloat16) self.max_position_embeddings = runner_config.get("data_config").get( "max_position_embeddings", 131072 ) self.input_max_len = runner_config.get("data_config").get("input_max_len", 1024) self.max_new_tokens = runner_config.get("data_config").get("max_new_tokens", 32) self.batch_size = runner_config.get("data_config").get("batch_size", 16) self.sampling_params = runner_config.get("sampling_config", {}) self.tokenizer = None self.model = None self.device = None self.local_rank = int(os.getenv("LOCAL_RANK", "0")) self.rank_offset = int(os.getenv("RANK_OFFSET", "0")) self.global_rank = self.local_rank + self.rank_offset self.world_size = int(os.getenv("WORLD_SIZE", "1")) if self.world_size == 1: self.model_path = model_path else: self.model_path = os.path.join(model_path, f"rank_{self.global_rank}") self.res_path = os.getenv("RES_PATH", "./") self.enable_profiler = runner_config.get("model_config").get( "enable_profiler", 0 ) self.use_pretrained_model = True self.execute_mode = runner_config.get("exe_mode", "dynamo") self.tokenizer_mode = runner_config.get("model_config").get( "tokenizer_mode", "default" ) self.init_device() self.start_time = None self.end_time = None self.with_ckpt = runner_config.get("model_config").get("with_ckpt", 1) @staticmethod def repeat_batch(tensor, repeat_num): if repeat_num == 1: return tensor return tensor.repeat(repeat_num, *[1] * (tensor.dim() - 1)) def init_device(self): logging.info( "Set execution using npu index: %s, global: %s", self.local_rank, self.global_rank, ) self.device = torch.device("%s:%s" % ("npu", self.local_rank)) torch.npu.set_device(self.device) if torch.npu.is_available() and self.world_size > 1: if _world._default_pg is None: torch.distributed.init_process_group( backend="hccl", world_size=self.world_size, rank=self.global_rank ) def init_model(self, model, config=None): if self.with_ckpt: self.use_pretrained_model = True config = None else: self.use_pretrained_model = False from configuration_openpangu_moe import PanguUltraMoEConfig as config logging.info(f"use_pretrained_model: {self.use_pretrained_model}") if self.use_pretrained_model: self.load_model(model) else: self.init_model_from_config(model, config=config) self.to_device() self.compile_model() self.init_tokenizer() def init_model_from_config(self, model, config): if config is None: raise Exception("config cannot be None") config_file = f"{self.model_path}/config.json" model_config = config.from_pretrained( config_file, torch_dtype=self.dtype, max_position_embeddings=self.max_position_embeddings, ) self.model = model(model_config, runner_config=self.runner_config).to( self.dtype ) def load_model(self, model): logging.info("Try to load pretrained model in path: %s", self.model_path) self.model = model.from_pretrained( self.model_path, low_cpu_mem_usage=True, ignore_mismatched_sizes=True, torch_dtype=self.dtype, max_position_embeddings=self.max_position_embeddings, runner_config=self.runner_config, ) for name, params in self.model.named_parameters(): logging.info( "Param of %s: %s, %s, %s", self.model_name, name, params.size(), params.dtype, ) def to_device(self): self.model.to(self.device) logging.info("Model weights H2D finished.") def init_tokenizer(self): self.tokenizer = AutoTokenizer.from_pretrained( self.model_path, trust_remote_code=True, local_files_only=True, padding_side="right", truncation_side="right", ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.tokenizer.pad_token_id = self.tokenizer.eos_token_id def compile_model(self): logging.info("The final model structure is: \n %s", self.model) if self.execute_mode == "dynamo": logging.info("Try to compile model") self.graph_compile() def graph_compile(self): import torchair as tng import torchair.ge_concrete_graph.ge_converter.experimental.patch_for_hcom_allreduce from torchair.configs.compiler_config import CompilerConfig compiler_config = CompilerConfig() compiler_config.experimental_config.frozen_parameter = True compiler_config.experimental_config.tiling_schedule_optimize = True npu_backend = tng.get_npu_backend(compiler_config=compiler_config) self.model = torch.compile( self.model, dynamic=True, fullgraph=True, backend=npu_backend ) def mark_inputs(self, model_inputs): if self.execute_mode == "dynamo": input_ids = model_inputs.get("input_ids") kv_len = model_inputs.get("kv_len") attention_mask = model_inputs.get("attention_mask") position_ids = model_inputs.get("position_ids") past_key_values = model_inputs.get("past_key_values") # prefill with dynamic sequence length, decode with static sequence length torch._dynamo.mark_static(kv_len) for item in past_key_values: for sub_item in item: torch._dynamo.mark_static(sub_item) torch._dynamo.mark_static(input_ids) if attention_mask is not None: torch._dynamo.mark_static(attention_mask) torch._dynamo.mark_static(position_ids) def model_input_prepare(self, input_dict): input_ids = input_dict.get("input_ids") attention_mask = input_dict.get("attention_mask") past_key_values = input_dict.get("past_key_values") is_prefill = input_dict.get("is_prefill") kv_len = input_dict.get("kv_len") share_mask_tril = input_dict.get("share_mask_tril") model_inputs = self.model.prepare_inputs_for_generation( input_ids=input_ids, attention_mask=attention_mask, past_key_values=past_key_values, is_prefill=is_prefill, kv_len=kv_len, input_lens=input_dict.get("input_lens"), share_mask_tril=share_mask_tril, ) return model_inputs def model_inference(self, model_inputs, warm_up=False): torch.npu.synchronize() if warm_up: self.mark_inputs(model_inputs) if self.start_time is None: self.start_time = time.time() with torch.no_grad(): logits = self.model(**model_inputs) torch.npu.synchronize() self.end_time = time.time() if torch.distributed.get_rank() != 0: logging.info( f"{self.model_name} inference time cost {(self.end_time - self.start_time)*1000:.2f} ms" ) self.start_time = time.time() return logits def model_generate(self, prompts, warm_up=False, **kwargs): calling_func = { "default": self.tokenizer, "chat": self.tokenizer.apply_chat_template, } kwargs = { "return_tensors": "pt", "truncation": True, "padding": "max_length", "max_length": self.input_max_len, } if self.tokenizer_mode == "chat": chat_kwargs = {"add_generation_prompt": True, "return_dict": True} kwargs.update(chat_kwargs) tokenizer = calling_func.get(self.tokenizer_mode, self.tokenizer) inputs = tokenizer(prompts, **kwargs).to(self.device) # get init input_dict share_mask_tril = get_init_attn_mask( self.max_position_embeddings, self.device, valid_len=self.input_max_len ) share_mask_tril = share_mask_tril[None, None, ...] input_lens = copy.deepcopy(inputs.input_ids.size()[1]) logging.info("Padding max prompts lens is : %d", input_lens) input_dict = { "input_ids": inputs.input_ids, "generate_ids": inputs.input_ids, "input_lens": input_lens, "kv_len": None, "past_key_values": None, "attention_mask": inputs.attention_mask, "share_mask_tril": share_mask_tril, "is_prefill": True, } if torch.distributed.get_rank() == 0: logging.info( f"inputs.input_ids {inputs.input_ids[:,:30]} eod id {self.tokenizer.eos_token_id}" ) generate_tokens = 0 cnt = 0 all_done = [False for _ in range(input_dict["input_ids"].size(0))] done_len = [-1 for _ in range(input_dict["input_ids"].size(0))] while True: jump_flag = self.get_jump_flag(cnt, warm_up, generate_tokens) if jump_flag: break # exit until all reach eod if input_dict["input_ids"].size(1) == 1: for bi in range(input_dict["input_ids"].size(0)): if ( input_dict["input_ids"][bi, 0].item() == self.tokenizer.eos_token_id ): all_done[bi] = True done_len[bi] = generate_tokens if all(all_done): break model_inputs = self.model_input_prepare(input_dict) # fix decode mask if model_inputs["position_ids"].shape[1] == 1: model_inputs["attention_mask"].fill_(-3.4028e38) for bi in range(model_inputs["position_ids"].size(0)): max_l = model_inputs["position_ids"][bi].max().item() model_inputs["attention_mask"][bi, :, :, : max_l + 1] = 0 outputs = self.model_inference(model_inputs, warm_up=warm_up) self.model_output_process(model_inputs, outputs, input_dict) # prof.step() generate_tokens += 1 cnt += 1 generate_ids = input_dict["generate_ids"][:, input_lens:].clip( 0, self.model.config.vocab_size - 1 ) for bi in range(generate_ids.size(0)): if done_len[bi] != -1: generate_ids[bi, done_len[bi] :] = self.tokenizer.eos_token_id res = self.tokenizer.batch_decode(generate_ids, skip_special_tokens=True) if isinstance(res, list): for answer in res: logging.info("Inference decode result: \n%s", answer) else: logging.info("Inference decode result: \n%s", res) return res def get_jump_flag(self, cnt, warm_up, generate_tokens): default_decode_dump = 2 # warm up only perform for 5 times(decode) jump_flag_warm = warm_up and cnt >= default_decode_dump # do not generate after max_token jump_flag_oversize = generate_tokens >= self.max_new_tokens jump_flag = jump_flag_oversize or jump_flag_warm return jump_flag def model_output_process(self, model_inputs, outputs, input_dict): next_batch = self.batch_size attn_tp_size = self.runner_config.get("parallel_config").get("attn_tp_size", 1) if self.world_size % attn_tp_size != 0: raise Exception( f"world_size ({self.world_siz}) not divisible by attn_tp_size ({attn_tp_size})!" ) attn_dp_size = self.world_size // attn_tp_size input_dict["is_prefill"] = False input_dict["input_lens"] = input_dict["input_lens"] + 1 kv_len = torch.max(model_inputs.get("position_ids"), axis=1)[0] + 1 input_dict["kv_len"] = kv_len logits = outputs past_key_values = model_inputs.get("past_key_values") input_dict["past_key_values"] = past_key_values attention_mask = None share_mask_tril = get_decode_mask( mask_length=self.max_position_embeddings, device=self.device, position=input_dict["input_lens"], ) share_mask_tril = share_mask_tril[None, None, ...] input_dict["attention_mask"] = attention_mask input_dict["share_mask_tril"] = ModelRunner.repeat_batch( share_mask_tril, self.batch_size ) next_tokens = sample(logits, **self.sampling_params) torch.distributed.broadcast(next_tokens, src=0) input_dict["input_ids"] = next_tokens input_dict["generate_ids"] = torch.cat( [input_dict["generate_ids"], next_tokens], dim=-1 )