miyuki2026's picture
update
114324e
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/jingyaogong/minimind/blob/master/eval_llm.py
"""
import argparse
import time
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer
from project_settings import project_path
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--pretrained_model_name_or_path",
# default=(project_path / "trained_models/gpt2-sst2-generation"),
# default=(project_path / "trained_models/gpt2-sst2-ppo/checkpoint-150"),
default=(project_path / "trained_models/gpt2_sst2_ppo"),
# default="qgyd2021/gpt2_sst2_ppo",
type=str
)
parser.add_argument(
"--max_new_tokens",
default=1024, # 8192, 128
type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
)
parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)")
parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")
args = parser.parse_args()
return args
def main():
args = get_args()
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
# device = "mps"
device = "cpu"
else:
device = "cpu"
print(f"device: {device}")
tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path)
model = model.eval().to(device)
tokenized = tokenizer(
# "this",
"this is ",
# "who needs mind-bending",
# "eldom has a movie",
# "thanks to scott 's charismatic",
return_tensors="pt"
)
streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
generated_ids = model.generate(
inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"],
max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
top_p=args.top_p, temperature=args.temperature, repetition_penalty=3.0,
early_stopping=True,
)
# response = tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True)
response = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
print(response)
print(generated_ids)
return
if __name__ == "__main__":
main()