#!/usr/bin/python3 # -*- coding: utf-8 -*- """ https://github.com/jingyaogong/minimind/blob/master/eval_llm.py """ import argparse import time import torch from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer from project_settings import project_path def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "--pretrained_model_name_or_path", # default=(project_path / "trained_models/gpt2-sst2-generation"), # default=(project_path / "trained_models/gpt2-sst2-ppo/checkpoint-150"), default=(project_path / "trained_models/gpt2_sst2_ppo"), # default="qgyd2021/gpt2_sst2_ppo", type=str ) parser.add_argument( "--max_new_tokens", default=1024, # 8192, 128 type=int, help="最大生成长度(注意:并非模型实际长文本能力)" ) parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)") parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)") args = parser.parse_args() return args def main(): args = get_args() if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): # device = "mps" device = "cpu" else: device = "cpu" print(f"device: {device}") tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path) model = model.eval().to(device) tokenized = tokenizer( # "this", "this is ", # "who needs mind-bending", # "eldom has a movie", # "thanks to scott 's charismatic", return_tensors="pt" ) streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) generated_ids = model.generate( inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"], max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, top_p=args.top_p, temperature=args.temperature, repetition_penalty=3.0, early_stopping=True, ) # response = tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True) response = tokenizer.decode(generated_ids[0], skip_special_tokens=False) print(response) print(generated_ids) return if __name__ == "__main__": main()