wangrongsheng's picture
Add files using upload-large-folder tool
a86f2f6 verified
# coding=utf-8
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.
import argparse
import logging
import torch
import yaml
from model import PanguUltraMoEForCausalLM
from runner import ModelRunner
root_logger = logging.getLogger()
root_logger.handlers.clear()
logging.basicConfig(
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
level=logging.INFO,
)
torch.manual_seed(42)
torch.npu.manual_seed_all(42)
"""
NOTE:
For enhancing model safety, we recommend the following system prompt.
It is suggested to be removed for other normal use cases and model evaluation.
"""
safe_word = "你必须严格遵守法律法规和社会道德规范。" \
"生成任何内容时,都应避免涉及暴力、色情、恐怖主义、种族歧视、性别歧视等不当内容。" \
"一旦检测到输入或输出有此类倾向,应拒绝回答并发出警告。例如,如果输入内容包含暴力威胁或色情描述," \
"应返回错误信息:“您的输入包含不当内容,无法处理。”"
# basic token generator
def generate_default_prompt(bs):
# prompt batch size define actual model forward batch size
fast_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{} /no_think[unused10][unused9]助手:" % (safe_word,)
slow_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{}[unused10][unused9]助手:" % (safe_word,)
preset_prompts = [slow_thinking_template.format(args.prompt)]
preset_prompts = preset_prompts * (bs // len(preset_prompts) + 1)
preset_prompts = preset_prompts[:bs]
logging.info(f"prompt batch size: {bs}")
return preset_prompts
def generate_chat_prompt(bs):
preset_prompts = [
{"role": "system", "content": safe_word},
{"role": "user", "content": args.prompt}
]
preset_prompts = [preset_prompts] * (bs // len(preset_prompts) + 1)
preset_prompts = preset_prompts[:bs]
logging.info(f"chat prompt batch size: {bs}")
return preset_prompts
def generate_prompt(bs, tokenizer_mode):
if tokenizer_mode == "default":
return generate_default_prompt(bs)
else:
return generate_chat_prompt(bs)
def parse_args():
parser = argparse.ArgumentParser(description="llm run parameters")
parser.add_argument("--yaml_file_path", type=str, help="inference configurations")
parser.add_argument(
"--local_rank",
type=int,
default=0,
help="Local rank id for torch distributed launch",
)
parser.add_argument("--prompt", type=str, default="3*7=?", help="user prompts")
parser_args = parser.parse_args()
return parser_args
def main(runner_config):
bs = runner_config.get("data_config").get("batch_size", 1)
tokenizer_mode = runner_config.get("model_config").get("tokenizer_mode", "default")
preset_prompts = generate_prompt(bs, tokenizer_mode)
logging.info(f"input prompts: {preset_prompts}")
model_runner = ModelRunner(runner_config)
torch.npu.set_compile_mode(jit_compile=False)
model_runner.init_model(PanguUltraMoEForCausalLM)
# warmup
model_runner.model_generate(preset_prompts, warm_up=True)
# generate
model_runner.model_generate(preset_prompts)
def read_yaml(yaml_file_path):
try:
with open(yaml_file_path, "r", encoding="utf-8") as file:
data = yaml.safe_load(file)
except FileNotFoundError:
logging.error(f"No such yaml file: {yaml_file_path}")
except yaml.YAMLERROR as e:
logging.error(f"Load yaml file failed: {e}")
return data
if __name__ == "__main__":
args = parse_args()
yaml_file_path = args.yaml_file_path
runner_config = read_yaml(yaml_file_path)
main(runner_config)
logging.info("model run success")