LightNovelModel-Alpha / generate.py
hugfaceguy0001's picture
upload model and train/infer codes
e10f35b verified
Raw
History Blame Contribute Delete
2.97 kB
from model import TransformerConfig, TransformerLanguageModel
import torch
from tokenizer import load_tokenizer, SpecialToken
import argparse
# 模型参数(SFT版:词表扩展至50306,上下文2048)
config = TransformerConfig(
vocab_size=50306,
block_size=2048,
n_embed=768,
n_heads=12,
n_layers=12,
dropout=0.0,
bias=True
)
# 编码和解码
tokenizer = load_tokenizer("tokenizer.model")
# 建立模型
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
model = TransformerLanguageModel(config)
model = model.to(device)
def build_prompt(user_text):
"""构建标准SFT格式的prompt"""
return [
SpecialToken("<|im_start|>"),
f"user\n{user_text}",
SpecialToken("<|im_end|>"),
"\n",
SpecialToken("<|im_start|>"),
"assistant\n",
]
def decode_output(token_list):
"""将decode结果拼接为可读字符串"""
text = ""
for item in token_list:
if isinstance(item, str):
text += item
elif isinstance(item, SpecialToken):
text += f"<{item.name}>"
return text
def generate(user_text, checkpoint_path='checkpoints/sft/sft_final.pt',
max_new_tokens=512, temperature=0.8, top_k=40):
"""加载模型并进行SFT推理"""
model.load_state_dict(torch.load(checkpoint_path, map_location=device))
model.eval()
# 构建输入token
prompt_tokens = tokenizer.encode_all(build_prompt(user_text))
context = torch.tensor(prompt_tokens, dtype=torch.int64).to(device).view(1, -1)
with torch.no_grad():
result = model.generate(
context,
max_new_tokens=max_new_tokens,
temperature=temperature,
top_k=top_k,
use_cache=True
)[0, :]
decoded = tokenizer.decode(result.tolist())
return decode_output(decoded)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SFT模型推理")
parser.add_argument("--checkpoint", type=str, default="model_sft.pt",
help="模型检查点路径")
parser.add_argument("--prompt", type=str, default="写一个恋爱喜剧轻小说,主角是能听到物品心声的高中生。",
help="用户输入prompt")
parser.add_argument("--max_tokens", type=int, default=512,
help="最大生成token数")
parser.add_argument("--temperature", type=float, default=0.8,
help="采样温度")
parser.add_argument("--top_k", type=int, default=40,
help="top_k采样")
args = parser.parse_args()
output = generate(
user_text=args.prompt,
checkpoint_path=args.checkpoint,
max_new_tokens=args.max_tokens,
temperature=args.temperature,
top_k=args.top_k
)
print("=" * 60)
print(f"Prompt: {args.prompt}")
print("=" * 60)
print(output)
print("=" * 60)