Spaces:
Sleeping
Sleeping
| #!/usr/bin/python3 | |
| # -*- coding: utf-8 -*- | |
| """ | |
| https://github.com/jingyaogong/minimind/blob/master/eval_llm.py | |
| """ | |
| import argparse | |
| import time | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer | |
| from project_settings import project_path | |
| def get_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--pretrained_model_name_or_path", | |
| # default=(project_path / "trained_models/gpt2-sst2-generation"), | |
| # default=(project_path / "trained_models/gpt2-sst2-ppo/checkpoint-150"), | |
| default=(project_path / "trained_models/gpt2_sst2_ppo"), | |
| # default="qgyd2021/gpt2_sst2_ppo", | |
| type=str | |
| ) | |
| parser.add_argument( | |
| "--max_new_tokens", | |
| default=1024, # 8192, 128 | |
| type=int, help="最大生成长度(注意:并非模型实际长文本能力)" | |
| ) | |
| parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)") | |
| parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)") | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = get_args() | |
| if torch.cuda.is_available(): | |
| device = "cuda" | |
| elif torch.backends.mps.is_available(): | |
| # device = "mps" | |
| device = "cpu" | |
| else: | |
| device = "cpu" | |
| print(f"device: {device}") | |
| tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path) | |
| model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path) | |
| model = model.eval().to(device) | |
| tokenized = tokenizer( | |
| # "this", | |
| "this is ", | |
| # "who needs mind-bending", | |
| # "eldom has a movie", | |
| # "thanks to scott 's charismatic", | |
| return_tensors="pt" | |
| ) | |
| streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) | |
| generated_ids = model.generate( | |
| inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"], | |
| max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer, | |
| pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, | |
| top_p=args.top_p, temperature=args.temperature, repetition_penalty=3.0, | |
| early_stopping=True, | |
| ) | |
| # response = tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True) | |
| response = tokenizer.decode(generated_ids[0], skip_special_tokens=False) | |
| print(response) | |
| print(generated_ids) | |
| return | |
| if __name__ == "__main__": | |
| main() | |