File size: 2,583 Bytes
3bd251d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114324e
8f7ca17
114324e
 
3bd251d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f7ca17
3bd251d
8f7ca17
3bd251d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
#!/usr/bin/python3
# -*- coding: utf-8 -*-
"""
https://github.com/jingyaogong/minimind/blob/master/eval_llm.py
"""
import argparse
import time

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer

from project_settings import project_path


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--pretrained_model_name_or_path",
        # default=(project_path / "trained_models/gpt2-sst2-generation"),
        # default=(project_path / "trained_models/gpt2-sst2-ppo/checkpoint-150"),
        default=(project_path / "trained_models/gpt2_sst2_ppo"),
        # default="qgyd2021/gpt2_sst2_ppo",
        type=str
    )
    parser.add_argument(
        "--max_new_tokens",
        default=1024, # 8192, 128
        type=int, help="最大生成长度(注意:并非模型实际长文本能力)"
    )
    parser.add_argument("--top_p", default=0.85, type=float, help="nucleus采样阈值(0-1)")
    parser.add_argument("--temperature", default=0.85, type=float, help="生成温度,控制随机性(0-1,越大越随机)")

    args = parser.parse_args()
    return args


def main():
    args = get_args()

    if torch.cuda.is_available():
        device = "cuda"
    elif torch.backends.mps.is_available():
        # device = "mps"
        device = "cpu"
    else:
        device = "cpu"
    print(f"device: {device}")

    tokenizer = AutoTokenizer.from_pretrained(args.pretrained_model_name_or_path)
    model = AutoModelForCausalLM.from_pretrained(args.pretrained_model_name_or_path)
    model = model.eval().to(device)

    tokenized = tokenizer(
        # "this",
        "this is ",
        # "who needs mind-bending",
        # "eldom has a movie",
        # "thanks to scott 's charismatic",
        return_tensors="pt"
    )

    streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)

    generated_ids = model.generate(
        inputs=tokenized["input_ids"], attention_mask=tokenized["attention_mask"],
        max_new_tokens=args.max_new_tokens, do_sample=True, streamer=streamer,
        pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id,
        top_p=args.top_p, temperature=args.temperature, repetition_penalty=3.0,
        early_stopping=True,
    )
    # response = tokenizer.decode(generated_ids[0][len(tokenized["input_ids"][0]):], skip_special_tokens=True)
    response = tokenizer.decode(generated_ids[0], skip_special_tokens=False)
    print(response)
    print(generated_ids)

    return


if __name__ == "__main__":
    main()