|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
import logging |
|
|
|
|
|
import torch |
|
|
import yaml |
|
|
from model import PanguUltraMoEForCausalLM |
|
|
from runner import ModelRunner |
|
|
|
|
|
root_logger = logging.getLogger() |
|
|
root_logger.handlers.clear() |
|
|
logging.basicConfig( |
|
|
format="%(asctime)s - %(levelname)s - [LLM](%(filename)s:%(lineno)d): %(message)s", |
|
|
level=logging.INFO, |
|
|
) |
|
|
torch.manual_seed(42) |
|
|
torch.npu.manual_seed_all(42) |
|
|
|
|
|
""" |
|
|
NOTE: |
|
|
For enhancing model safety, we recommend the following system prompt. |
|
|
It is suggested to be removed for other normal use cases and model evaluation. |
|
|
""" |
|
|
safe_word = "你必须严格遵守法律法规和社会道德规范。" \ |
|
|
"生成任何内容时,都应避免涉及暴力、色情、恐怖主义、种族歧视、性别歧视等不当内容。" \ |
|
|
"一旦检测到输入或输出有此类倾向,应拒绝回答并发出警告。例如,如果输入内容包含暴力威胁或色情描述," \ |
|
|
"应返回错误信息:“您的输入包含不当内容,无法处理。”" |
|
|
|
|
|
|
|
|
|
|
|
def generate_default_prompt(bs): |
|
|
|
|
|
fast_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{} /no_think[unused10][unused9]助手:" % (safe_word,) |
|
|
slow_thinking_template = "[unused9]系统:%s[unused10][unused9]用户:{}[unused10][unused9]助手:" % (safe_word,) |
|
|
preset_prompts = [slow_thinking_template.format(args.prompt)] |
|
|
preset_prompts = preset_prompts * (bs // len(preset_prompts) + 1) |
|
|
preset_prompts = preset_prompts[:bs] |
|
|
logging.info(f"prompt batch size: {bs}") |
|
|
return preset_prompts |
|
|
|
|
|
|
|
|
def generate_chat_prompt(bs): |
|
|
preset_prompts = [ |
|
|
{"role": "system", "content": safe_word}, |
|
|
{"role": "user", "content": args.prompt} |
|
|
] |
|
|
preset_prompts = [preset_prompts] * (bs // len(preset_prompts) + 1) |
|
|
preset_prompts = preset_prompts[:bs] |
|
|
logging.info(f"chat prompt batch size: {bs}") |
|
|
return preset_prompts |
|
|
|
|
|
|
|
|
def generate_prompt(bs, tokenizer_mode): |
|
|
if tokenizer_mode == "default": |
|
|
return generate_default_prompt(bs) |
|
|
else: |
|
|
return generate_chat_prompt(bs) |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser(description="llm run parameters") |
|
|
parser.add_argument("--yaml_file_path", type=str, help="inference configurations") |
|
|
parser.add_argument( |
|
|
"--local_rank", |
|
|
type=int, |
|
|
default=0, |
|
|
help="Local rank id for torch distributed launch", |
|
|
) |
|
|
parser.add_argument("--prompt", type=str, default="3*7=?", help="user prompts") |
|
|
parser_args = parser.parse_args() |
|
|
return parser_args |
|
|
|
|
|
|
|
|
def main(runner_config): |
|
|
bs = runner_config.get("data_config").get("batch_size", 1) |
|
|
tokenizer_mode = runner_config.get("model_config").get("tokenizer_mode", "default") |
|
|
preset_prompts = generate_prompt(bs, tokenizer_mode) |
|
|
logging.info(f"input prompts: {preset_prompts}") |
|
|
model_runner = ModelRunner(runner_config) |
|
|
torch.npu.set_compile_mode(jit_compile=False) |
|
|
model_runner.init_model(PanguUltraMoEForCausalLM) |
|
|
|
|
|
model_runner.model_generate(preset_prompts, warm_up=True) |
|
|
|
|
|
model_runner.model_generate(preset_prompts) |
|
|
|
|
|
|
|
|
def read_yaml(yaml_file_path): |
|
|
try: |
|
|
with open(yaml_file_path, "r", encoding="utf-8") as file: |
|
|
data = yaml.safe_load(file) |
|
|
except FileNotFoundError: |
|
|
logging.error(f"No such yaml file: {yaml_file_path}") |
|
|
except yaml.YAMLERROR as e: |
|
|
logging.error(f"Load yaml file failed: {e}") |
|
|
return data |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
args = parse_args() |
|
|
yaml_file_path = args.yaml_file_path |
|
|
runner_config = read_yaml(yaml_file_path) |
|
|
main(runner_config) |
|
|
logging.info("model run success") |
|
|
|