File size: 3,822 Bytes
a86f2f6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# coding=utf-8
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All rights reserved.

import argparse
import logging

import torch
import yaml
from model import PanguUltraMoEForCausalLM
from runner import ModelRunner

root_logger = logging.getLogger()
root_logger.handlers.clear()
logging.basicConfig(
    format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s",
    level=logging.INFO,
)
torch.manual_seed(42)
torch.npu.manual_seed_all(42)

"""
NOTE:
For enhancing model safety, we recommend the following system prompt. 
It is suggested to be removed for other normal use cases and model evaluation.
"""
safe_word = "你必须严格遵守法律法规和社会道德规范。" \
    "生成任何内容时,都应避免涉及暴力、色情、恐怖主义、种族歧视、性别歧视等不当内容。" \
    "一旦检测到输入或输出有此类倾向,应拒绝回答并发出警告。例如,如果输入内容包含暴力威胁或色情描述," \
    "应返回错误信息:“您的输入包含不当内容,无法处理。”"


# basic token generator
def generate_default_prompt(bs):
    # prompt batch size define actual model forward batch size
    fast_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{} /no_think[unused10][unused9]助手:" % (safe_word,)
    slow_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{}[unused10][unused9]助手:" % (safe_word,)
    preset_prompts = [slow_thinking_template.format(args.prompt)]
    preset_prompts = preset_prompts * (bs // len(preset_prompts) + 1)
    preset_prompts = preset_prompts[:bs]
    logging.info(f"prompt batch size: {bs}")
    return preset_prompts


def generate_chat_prompt(bs):
    preset_prompts = [
        {"role": "system", "content": safe_word},
        {"role": "user", "content": args.prompt}
    ]
    preset_prompts = [preset_prompts] * (bs // len(preset_prompts) + 1)
    preset_prompts = preset_prompts[:bs]
    logging.info(f"chat prompt batch size: {bs}")
    return preset_prompts


def generate_prompt(bs, tokenizer_mode):
    if tokenizer_mode == "default":
        return generate_default_prompt(bs)
    else:
        return generate_chat_prompt(bs)


def parse_args():
    parser = argparse.ArgumentParser(description="llm run parameters")
    parser.add_argument("--yaml_file_path", type=str, help="inference configurations")
    parser.add_argument(
        "--local_rank",
        type=int,
        default=0,
        help="Local rank id for torch distributed launch",
    )
    parser.add_argument("--prompt", type=str, default="3*7=?", help="user prompts")
    parser_args = parser.parse_args()
    return parser_args


def main(runner_config):
    bs = runner_config.get("data_config").get("batch_size", 1)
    tokenizer_mode = runner_config.get("model_config").get("tokenizer_mode", "default")
    preset_prompts = generate_prompt(bs, tokenizer_mode)
    logging.info(f"input prompts: {preset_prompts}")
    model_runner = ModelRunner(runner_config)
    torch.npu.set_compile_mode(jit_compile=False)
    model_runner.init_model(PanguUltraMoEForCausalLM)
    # warmup
    model_runner.model_generate(preset_prompts, warm_up=True)
    # generate
    model_runner.model_generate(preset_prompts)


def read_yaml(yaml_file_path):
    try:
        with open(yaml_file_path, "r", encoding="utf-8") as file:
            data = yaml.safe_load(file)
    except FileNotFoundError:
        logging.error(f"No such yaml file: {yaml_file_path}")
    except yaml.YAMLERROR as e:
        logging.error(f"Load yaml file failed: {e}")
    return data


if __name__ == "__main__":
    args = parse_args()
    yaml_file_path = args.yaml_file_path
    runner_config = read_yaml(yaml_file_path)
    main(runner_config)
    logging.info("model run success")