import torch import torch.nn.functional as F from tokenizers import Tokenizer import intel_extension_for_pytorch as ipex from novel_model import NovelTransformer, NovelLM # 配置 VOCAB_SIZE = 8000 MAX_LEN = 4096 MODEL_PATH = "d:/图像/novel_model_ft/best_model_ft.pt" TOKENIZER_PATH = "d:/图像/novel_tokenizer.json" def generate_text(model, tokenizer, prompt, max_length=100, temperature=0.7, top_k=50, top_p=0.9, device="cpu"): """生成文本""" model.eval() # 编码提示 input_ids = torch.tensor(tokenizer.encode(prompt).ids, dtype=torch.long).unsqueeze(0).to(device) # 生成文本 with torch.no_grad(): for _ in range(max_length): # 如果序列太长,截断 if input_ids.size(1) > MAX_LEN: input_ids = input_ids[:, -MAX_LEN:] # 获取模型输出 outputs = model(input_ids) next_token_logits = outputs[:, -1, :] / temperature # 应用top-k过滤 if top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') # 应用top-p过滤 if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1) # 移除概率累积超过阈值的token sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices[sorted_indices_to_remove] next_token_logits[0, indices_to_remove] = float('-inf') # 采样下一个token probs = F.softmax(next_token_logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) # 添加到输入序列 input_ids = torch.cat([input_ids, next_token], dim=1) # 如果生成了结束标记,停止生成 if next_token.item() == tokenizer.token_to_id(""): break # 解码生成的ID output = tokenizer.decode(input_ids[0].tolist()) return output def main(): # 设置设备 device = torch.device("xpu" if torch.xpu.is_available() else "cpu") print(f"使用设备: {device}") # 加载分词器 tokenizer = Tokenizer.from_file(TOKENIZER_PATH) # 加载模型 checkpoint = torch.load(MODEL_PATH, map_location=device) base_model = NovelTransformer( vocab_size=VOCAB_SIZE, d_model=256, nhead=8, num_layers=6, dim_feedforward=1024, dropout=0.1, max_len=MAX_LEN ) model = NovelLM(base_model) model.load_state_dict(checkpoint['model_state_dict']) model = model.to(device) model = ipex.optimize(model) # 交互式生成 print("小说语言模型已加载。输入提示进行生成,输入'exit'退出。") while True: prompt = input("\n请输入提示 (或输入'exit'退出): ") if prompt.lower() == 'exit': break # 构建指令格式 if not prompt.startswith("指令:"): full_prompt = f"指令: 继续写下去\n输入: {prompt}\n输出: " else: full_prompt = prompt + "\n输出: " # 生成文本 output = generate_text(model, tokenizer, full_prompt, max_length=200, device=device) # 提取生成的部分 try: generated_text = output.split("输出: ")[1] print("\n生成的文本:") print(generated_text) except IndexError: print("\n生成的文本:") print(output) if __name__ == "__main__": main()