File size: 3,616 Bytes
25b1d8f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import torch
import sentencepiece as spm
import torch.nn.functional as F

from trainer import TinyDecoderModel, SPTokenizer, Config

# --- 配置 ---
CHECKPOINT_PATH = './checkpoints/ckpt_step46500.pt'
TOKENIZER_MODEL_PATH = "mymodel.model"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

# --- 生成参数 ---
PROMPT = "ئەمما بۇ ھادىسە ئېلىپ كەلگەن مەسىلە شۇكى"
MAX_NEW_TOKENS = 300
TOP_P = 0.95       # nucleus sampling
TEMPERATURE = 0.5 # 温度采样

# pad id 不在 vocab 中时自己定义
PAD_ID = -1

def generate_text(model, tokenizer, prompt, max_tokens, device, cfg):
    print(f"--- 提示 ---\n{prompt}")
    model.eval()

    prompt_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)

    # --- 【新增】重复惩罚的超参数 ---
    REPETITION_PENALTY = 1.8
    # --------------------------------

    with torch.no_grad():
        for _ in range(max_tokens):
            idx_cond = prompt_ids[:, -cfg.seq_len:] if prompt_ids.size(1) > cfg.seq_len else prompt_ids
            logits = model(idx_cond)
            logits = logits[:, -1, :tokenizer.sp.vocab_size()]

            if REPETITION_PENALTY != 1.0:
                prev_tokens = idx_cond[0]
                score = torch.gather(logits[0], 0, prev_tokens)
                score = torch.where(score > 0, score / REPETITION_PENALTY, score * REPETITION_PENALTY)
                logits[0].scatter_(0, prev_tokens, score)
            # ------------------------------------------

            logits = logits / TEMPERATURE

            sorted_logits, sorted_indices = torch.sort(logits, descending=True)
            cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
            mask = cumulative_probs > TOP_P
            mask[..., 1:] = mask[..., :-1].clone()
            mask[..., 0] = 0
            sorted_logits[mask] = -float('Inf')

            probs = F.softmax(sorted_logits, dim=-1)
            next_token_id = torch.multinomial(probs, num_samples=1)

            next_token_id = sorted_indices.gather(-1, next_token_id)

            next_token_id = torch.clamp(next_token_id, min=0, max=tokenizer.sp.vocab_size() - 1)

            prompt_ids = torch.cat((prompt_ids, next_token_id), dim=1)

    valid_ids = {PAD_ID, tokenizer.sp.unk_id(), -1}
    token_list = [tid for tid in prompt_ids[0].tolist() if tid not in valid_ids]
    try:
        generated_text = tokenizer.sp.decode(token_list)
    except Exception:
        generated_text = tokenizer.sp.decode([tid for tid in token_list if 0 <= tid < tokenizer.sp.vocab_size()])

    print(f"\n--- 模型生成结果 ---\n{generated_text}")
    print("\n" + "="*50)

if __name__ == '__main__':
    cfg = Config()
    cfg.device = torch.device(DEVICE)
    print(f"当前设备: {cfg.device}")

    print("1. 正在加载分词器...")
    tokenizer = SPTokenizer(model_file=TOKENIZER_MODEL_PATH, seq_len=cfg.seq_len)

    print("2. 正在创建模型结构...")
    model = TinyDecoderModel(cfg).to(cfg.device)

    print(f"3. 正在加载检查点: {CHECKPOINT_PATH}")
    try:
        checkpoint = torch.load(CHECKPOINT_PATH, map_location=cfg.device)
        model.load_state_dict(checkpoint['model_state'])
        print("模型权重加载成功!")
    except FileNotFoundError:
        print(f"错误: 找不到检查点文件 '{CHECKPOINT_PATH}'!")
        exit()

    print("4. 开始生成文本...")
    generate_text(model, tokenizer, PROMPT, MAX_NEW_TOKENS, cfg.device, cfg)