import torch import sentencepiece as spm import torch.nn.functional as F from trainer import TinyDecoderModel, SPTokenizer, Config # --- 配置 --- CHECKPOINT_PATH = './checkpoints/ckpt_step46500.pt' TOKENIZER_MODEL_PATH = "mymodel.model" DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' # --- 生成参数 --- PROMPT = "ئەمما بۇ ھادىسە ئېلىپ كەلگەن مەسىلە شۇكى" MAX_NEW_TOKENS = 300 TOP_P = 0.95 # nucleus sampling TEMPERATURE = 0.5 # 温度采样 # pad id 不在 vocab 中时自己定义 PAD_ID = -1 def generate_text(model, tokenizer, prompt, max_tokens, device, cfg): print(f"--- 提示 ---\n{prompt}") model.eval() prompt_ids = torch.tensor(tokenizer.encode(prompt), dtype=torch.long).unsqueeze(0).to(device) # --- 【新增】重复惩罚的超参数 --- REPETITION_PENALTY = 1.8 # -------------------------------- with torch.no_grad(): for _ in range(max_tokens): idx_cond = prompt_ids[:, -cfg.seq_len:] if prompt_ids.size(1) > cfg.seq_len else prompt_ids logits = model(idx_cond) logits = logits[:, -1, :tokenizer.sp.vocab_size()] if REPETITION_PENALTY != 1.0: prev_tokens = idx_cond[0] score = torch.gather(logits[0], 0, prev_tokens) score = torch.where(score > 0, score / REPETITION_PENALTY, score * REPETITION_PENALTY) logits[0].scatter_(0, prev_tokens, score) # ------------------------------------------ logits = logits / TEMPERATURE sorted_logits, sorted_indices = torch.sort(logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) mask = cumulative_probs > TOP_P mask[..., 1:] = mask[..., :-1].clone() mask[..., 0] = 0 sorted_logits[mask] = -float('Inf') probs = F.softmax(sorted_logits, dim=-1) next_token_id = torch.multinomial(probs, num_samples=1) next_token_id = sorted_indices.gather(-1, next_token_id) next_token_id = torch.clamp(next_token_id, min=0, max=tokenizer.sp.vocab_size() - 1) prompt_ids = torch.cat((prompt_ids, next_token_id), dim=1) valid_ids = {PAD_ID, tokenizer.sp.unk_id(), -1} token_list = [tid for tid in prompt_ids[0].tolist() if tid not in valid_ids] try: generated_text = tokenizer.sp.decode(token_list) except Exception: generated_text = tokenizer.sp.decode([tid for tid in token_list if 0 <= tid < tokenizer.sp.vocab_size()]) print(f"\n--- 模型生成结果 ---\n{generated_text}") print("\n" + "="*50) if __name__ == '__main__': cfg = Config() cfg.device = torch.device(DEVICE) print(f"当前设备: {cfg.device}") print("1. 正在加载分词器...") tokenizer = SPTokenizer(model_file=TOKENIZER_MODEL_PATH, seq_len=cfg.seq_len) print("2. 正在创建模型结构...") model = TinyDecoderModel(cfg).to(cfg.device) print(f"3. 正在加载检查点: {CHECKPOINT_PATH}") try: checkpoint = torch.load(CHECKPOINT_PATH, map_location=cfg.device) model.load_state_dict(checkpoint['model_state']) print("模型权重加载成功!") except FileNotFoundError: print(f"错误: 找不到检查点文件 '{CHECKPOINT_PATH}'!") exit() print("4. 开始生成文本...") generate_text(model, tokenizer, PROMPT, MAX_NEW_TOKENS, cfg.device, cfg)