|
|
import torch
|
|
|
import sentencepiece as spm
|
|
|
import torch.nn.functional as F
|
|
|
|
|
|
from trainer import TinyDecoderModel, SPTokenizer, Config
|
|
|
|
|
|
|
|
|
CHECKPOINT_PATH = './checkpoints/ckpt_step46500.pt'
|
|
|
TOKENIZER_MODEL_PATH = "mymodel.model"
|
|
|
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
|
|
|
|
|
PROMPT = "ئەمما بۇ ھادىسە ئېلىپ كەلگەن مەسىلە شۇكى"
|
|
|
MAX_NEW_TOKENS = 300
|
|
|
TOP_P = 0.95
|
|
|
TEMPERATURE = 0.5
|
|
|
|
|
|
|
|
|
PAD_ID = -1
|
|
|
|
|
|
def generate_text(model, tokenizer, prompt, max_tokens, device, cfg):
|
|
|
print(f"--- 提示 ---\n{prompt}")
|
|
|
model.eval()
|
|
|
|
|
|
prompt_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device)
|
|
|
|
|
|
|
|
|
REPETITION_PENALTY = 1.8
|
|
|
|
|
|
|
|
|
with torch.no_grad():
|
|
|
for _ in range(max_tokens):
|
|
|
idx_cond = prompt_ids[:, -cfg.seq_len:] if prompt_ids.size(1) > cfg.seq_len else prompt_ids
|
|
|
logits = model(idx_cond)
|
|
|
logits = logits[:, -1, :tokenizer.sp.vocab_size()]
|
|
|
|
|
|
if REPETITION_PENALTY != 1.0:
|
|
|
prev_tokens = idx_cond[0]
|
|
|
score = torch.gather(logits[0], 0, prev_tokens)
|
|
|
score = torch.where(score > 0, score / REPETITION_PENALTY, score * REPETITION_PENALTY)
|
|
|
logits[0].scatter_(0, prev_tokens, score)
|
|
|
|
|
|
|
|
|
logits = logits / TEMPERATURE
|
|
|
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
|
|
cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
|
|
|
mask = cumulative_probs > TOP_P
|
|
|
mask[..., 1:] = mask[..., :-1].clone()
|
|
|
mask[..., 0] = 0
|
|
|
sorted_logits[mask] = -float('Inf')
|
|
|
|
|
|
probs = F.softmax(sorted_logits, dim=-1)
|
|
|
next_token_id = torch.multinomial(probs, num_samples=1)
|
|
|
|
|
|
next_token_id = sorted_indices.gather(-1, next_token_id)
|
|
|
|
|
|
next_token_id = torch.clamp(next_token_id, min=0, max=tokenizer.sp.vocab_size() - 1)
|
|
|
|
|
|
prompt_ids = torch.cat((prompt_ids, next_token_id), dim=1)
|
|
|
|
|
|
valid_ids = {PAD_ID, tokenizer.sp.unk_id(), -1}
|
|
|
token_list = [tid for tid in prompt_ids[0].tolist() if tid not in valid_ids]
|
|
|
try:
|
|
|
generated_text = tokenizer.sp.decode(token_list)
|
|
|
except Exception:
|
|
|
generated_text = tokenizer.sp.decode([tid for tid in token_list if 0 <= tid < tokenizer.sp.vocab_size()])
|
|
|
|
|
|
print(f"\n--- 模型生成结果 ---\n{generated_text}")
|
|
|
print("\n" + "="*50)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
cfg = Config()
|
|
|
cfg.device = torch.device(DEVICE)
|
|
|
print(f"当前设备: {cfg.device}")
|
|
|
|
|
|
print("1. 正在加载分词器...")
|
|
|
tokenizer = SPTokenizer(model_file=TOKENIZER_MODEL_PATH, seq_len=cfg.seq_len)
|
|
|
|
|
|
print("2. 正在创建模型结构...")
|
|
|
model = TinyDecoderModel(cfg).to(cfg.device)
|
|
|
|
|
|
print(f"3. 正在加载检查点: {CHECKPOINT_PATH}")
|
|
|
try:
|
|
|
checkpoint = torch.load(CHECKPOINT_PATH, map_location=cfg.device)
|
|
|
model.load_state_dict(checkpoint['model_state'])
|
|
|
print("模型权重加载成功!")
|
|
|
except FileNotFoundError:
|
|
|
print(f"错误: 找不到检查点文件 '{CHECKPOINT_PATH}'!")
|
|
|
exit()
|
|
|
|
|
|
print("4. 开始生成文本...")
|
|
|
generate_text(model, tokenizer, PROMPT, MAX_NEW_TOKENS, cfg.device, cfg)
|
|
|
|
|
|
|