from model import TransformerConfig, TransformerLanguageModel import torch from tokenizer import load_tokenizer, SpecialToken import argparse # 模型参数(SFT版:词表扩展至50306,上下文2048) config = TransformerConfig( vocab_size=50306, block_size=2048, n_embed=768, n_heads=12, n_layers=12, dropout=0.0, bias=True ) # 编码和解码 tokenizer = load_tokenizer("tokenizer.model") # 建立模型 device = 'cuda:0' if torch.cuda.is_available() else 'cpu' model = TransformerLanguageModel(config) model = model.to(device) def build_prompt(user_text): """构建标准SFT格式的prompt""" return [ SpecialToken("<|im_start|>"), f"user\n{user_text}", SpecialToken("<|im_end|>"), "\n", SpecialToken("<|im_start|>"), "assistant\n", ] def decode_output(token_list): """将decode结果拼接为可读字符串""" text = "" for item in token_list: if isinstance(item, str): text += item elif isinstance(item, SpecialToken): text += f"<{item.name}>" return text def generate(user_text, checkpoint_path='checkpoints/sft/sft_final.pt', max_new_tokens=512, temperature=0.8, top_k=40): """加载模型并进行SFT推理""" model.load_state_dict(torch.load(checkpoint_path, map_location=device)) model.eval() # 构建输入token prompt_tokens = tokenizer.encode_all(build_prompt(user_text)) context = torch.tensor(prompt_tokens, dtype=torch.int64).to(device).view(1, -1) with torch.no_grad(): result = model.generate( context, max_new_tokens=max_new_tokens, temperature=temperature, top_k=top_k, use_cache=True )[0, :] decoded = tokenizer.decode(result.tolist()) return decode_output(decoded) if __name__ == "__main__": parser = argparse.ArgumentParser(description="SFT模型推理") parser.add_argument("--checkpoint", type=str, default="model_sft.pt", help="模型检查点路径") parser.add_argument("--prompt", type=str, default="写一个恋爱喜剧轻小说,主角是能听到物品心声的高中生。", help="用户输入prompt") parser.add_argument("--max_tokens", type=int, default=512, help="最大生成token数") parser.add_argument("--temperature", type=float, default=0.8, help="采样温度") parser.add_argument("--top_k", type=int, default=40, help="top_k采样") args = parser.parse_args() output = generate( user_text=args.prompt, checkpoint_path=args.checkpoint, max_new_tokens=args.max_tokens, temperature=args.temperature, top_k=args.top_k ) print("=" * 60) print(f"Prompt: {args.prompt}") print("=" * 60) print(output) print("=" * 60)