kemuray6300a's picture
Upload 3 files
25b1d8f verified
import torch
import sentencepiece as spm
import torch.nn.functional as F
from trainer import TinyDecoderModel, SPTokenizer, Config
# --- 配置 ---
CHECKPOINT_PATH = './checkpoints/ckpt_step46500.pt'
TOKENIZER_MODEL_PATH = "mymodel.model"
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
# --- 生成参数 ---
PROMPT = "ئەمما بۇ ھادىسە ئېلىپ كەلگەن مەسىلە شۇكى"
MAX_NEW_TOKENS = 300
TOP_P = 0.95 # nucleus sampling
TEMPERATURE = 0.5 # 温度采样
# pad id 不在 vocab 中时自己定义
PAD_ID = -1
def generate_text(model, tokenizer, prompt, max_tokens, device, cfg):
print(f"--- 提示 ---\n{prompt}")
model.eval()
prompt_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
# --- 【新增】重复惩罚的超参数 ---
REPETITION_PENALTY = 1.8
# --------------------------------
with torch.no_grad():
for _ in range(max_tokens):
idx_cond = prompt_ids[:, -cfg.seq_len:] if prompt_ids.size(1) > cfg.seq_len else prompt_ids
logits = model(idx_cond)
logits = logits[:, -1, :tokenizer.sp.vocab_size()]
if REPETITION_PENALTY != 1.0:
prev_tokens = idx_cond[0]
score = torch.gather(logits[0], 0, prev_tokens)
score = torch.where(score > 0, score / REPETITION_PENALTY, score * REPETITION_PENALTY)
logits[0].scatter_(0, prev_tokens, score)
# ------------------------------------------
logits = logits / TEMPERATURE
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
mask = cumulative_probs > TOP_P
mask[..., 1:] = mask[..., :-1].clone()
mask[..., 0] = 0
sorted_logits[mask] = -float('Inf')
probs = F.softmax(sorted_logits, dim=-1)
next_token_id = torch.multinomial(probs, num_samples=1)
next_token_id = sorted_indices.gather(-1, next_token_id)
next_token_id = torch.clamp(next_token_id, min=0, max=tokenizer.sp.vocab_size() - 1)
prompt_ids = torch.cat((prompt_ids, next_token_id), dim=1)
valid_ids = {PAD_ID, tokenizer.sp.unk_id(), -1}
token_list = [tid for tid in prompt_ids[0].tolist() if tid not in valid_ids]
try:
generated_text = tokenizer.sp.decode(token_list)
except Exception:
generated_text = tokenizer.sp.decode([tid for tid in token_list if 0 <= tid < tokenizer.sp.vocab_size()])
print(f"\n--- 模型生成结果 ---\n{generated_text}")
print("\n" + "="*50)
if __name__ == '__main__':
cfg = Config()
cfg.device = torch.device(DEVICE)
print(f"当前设备: {cfg.device}")
print("1. 正在加载分词器...")
tokenizer = SPTokenizer(model_file=TOKENIZER_MODEL_PATH, seq_len=cfg.seq_len)
print("2. 正在创建模型结构...")
model = TinyDecoderModel(cfg).to(cfg.device)
print(f"3. 正在加载检查点: {CHECKPOINT_PATH}")
try:
checkpoint = torch.load(CHECKPOINT_PATH, map_location=cfg.device)
model.load_state_dict(checkpoint['model_state'])
print("模型权重加载成功!")
except FileNotFoundError:
print(f"错误: 找不到检查点文件 '{CHECKPOINT_PATH}'!")
exit()
print("4. 开始生成文本...")
generate_text(model, tokenizer, PROMPT, MAX_NEW_TOKENS, cfg.device, cfg)