| from model import TransformerConfig, TransformerLanguageModel |
| import torch |
| from tokenizer import load_tokenizer, SpecialToken |
| import argparse |
|
|
| |
| config = TransformerConfig( |
| vocab_size=50306, |
| block_size=2048, |
| n_embed=768, |
| n_heads=12, |
| n_layers=12, |
| dropout=0.0, |
| bias=True |
| ) |
|
|
| |
| tokenizer = load_tokenizer("tokenizer.model") |
|
|
| |
| device = 'cuda:0' if torch.cuda.is_available() else 'cpu' |
| model = TransformerLanguageModel(config) |
| model = model.to(device) |
|
|
|
|
| def build_prompt(user_text): |
| """构建标准SFT格式的prompt""" |
| return [ |
| SpecialToken("<|im_start|>"), |
| f"user\n{user_text}", |
| SpecialToken("<|im_end|>"), |
| "\n", |
| SpecialToken("<|im_start|>"), |
| "assistant\n", |
| ] |
|
|
|
|
| def decode_output(token_list): |
| """将decode结果拼接为可读字符串""" |
| text = "" |
| for item in token_list: |
| if isinstance(item, str): |
| text += item |
| elif isinstance(item, SpecialToken): |
| text += f"<{item.name}>" |
| return text |
|
|
|
|
| def generate(user_text, checkpoint_path='checkpoints/sft/sft_final.pt', |
| max_new_tokens=512, temperature=0.8, top_k=40): |
| """加载模型并进行SFT推理""" |
| model.load_state_dict(torch.load(checkpoint_path, map_location=device)) |
| model.eval() |
|
|
| |
| prompt_tokens = tokenizer.encode_all(build_prompt(user_text)) |
| context = torch.tensor(prompt_tokens, dtype=torch.int64).to(device).view(1, -1) |
|
|
| with torch.no_grad(): |
| result = model.generate( |
| context, |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_k=top_k, |
| use_cache=True |
| )[0, :] |
|
|
| decoded = tokenizer.decode(result.tolist()) |
| return decode_output(decoded) |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="SFT模型推理") |
| parser.add_argument("--checkpoint", type=str, default="model_sft.pt", |
| help="模型检查点路径") |
| parser.add_argument("--prompt", type=str, default="写一个恋爱喜剧轻小说,主角是能听到物品心声的高中生。", |
| help="用户输入prompt") |
| parser.add_argument("--max_tokens", type=int, default=512, |
| help="最大生成token数") |
| parser.add_argument("--temperature", type=float, default=0.8, |
| help="采样温度") |
| parser.add_argument("--top_k", type=int, default=40, |
| help="top_k采样") |
| args = parser.parse_args() |
|
|
| output = generate( |
| user_text=args.prompt, |
| checkpoint_path=args.checkpoint, |
| max_new_tokens=args.max_tokens, |
| temperature=args.temperature, |
| top_k=args.top_k |
| ) |
|
|
| print("=" * 60) |
| print(f"Prompt: {args.prompt}") |
| print("=" * 60) |
| print(output) |
| print("=" * 60) |
|
|