LightNovelModel-Alpha / train_sft.py
hugfaceguy0001's picture
upload model and train/infer codes
e10f35b verified
Raw
History Blame Contribute Delete
10.6 kB
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, IterableDataset
from torch.amp import autocast, GradScaler
from model import TransformerConfig, TransformerLanguageModel
from tokenizer import load_tokenizer, SpecialToken
import json
import random
import os
from tqdm import tqdm
# ============== 1. 准备Tokenizer ==============
def prepare_tokenizer():
tok = load_tokenizer("tokenizer.model")
# 添加新token
new_tokens = {
SpecialToken("<|im_start|>"): 50304,
SpecialToken("<|im_end|>"): 50305,
}
tok.special_tokens.update(new_tokens)
tok.special_tokens_inv = {v: k for k, v in tok.special_tokens.items()}
# vocab_size应该是max id + 1
tok.vocab_size = max(tok.vocab.keys()) + 1
# 确保新id在vocab中有占位
for st, sid in new_tokens.items():
if sid not in tok.vocab:
tok.vocab[sid] = f"<{st.name}>".encode("utf-8", errors="replace")
tok.save("tokenizer_sft")
print(f"Saved tokenizer_sft.model with vocab_size={tok.vocab_size}")
return tok
# ============== 2. 扩展模型词表 ==============
def expand_model_for_new_tokens(old_ckpt_path, new_vocab_size, config):
old_config = TransformerConfig(
vocab_size=50304,
block_size=1024, # 旧模型使用1024上下文
n_embed=config.n_embed,
n_heads=config.n_heads,
n_layers=config.n_layers,
dropout=config.dropout,
bias=config.bias,
)
old_model = TransformerLanguageModel(old_config)
old_model.load_state_dict(torch.load(old_ckpt_path, map_location="cpu"))
new_config = TransformerConfig(
vocab_size=new_vocab_size,
block_size=config.block_size,
n_embed=config.n_embed,
n_heads=config.n_heads,
n_layers=config.n_layers,
dropout=config.dropout,
bias=config.bias,
)
new_model = TransformerLanguageModel(new_config)
new_state = new_model.state_dict()
old_state = old_model.state_dict()
for key in new_state:
if key in old_state:
if new_state[key].shape == old_state[key].shape:
new_state[key].copy_(old_state[key])
else:
print(f"Expanding {key}: {old_state[key].shape} -> {new_state[key].shape}")
if "token_embedding_table" in key:
new_state[key][: old_state[key].size(0)].copy_(old_state[key])
elif "lm_head" in key:
new_state[key][: old_state[key].size(0)].copy_(old_state[key])
elif "position_embedding_table" in key:
# 复制旧的位置编码,新的用随机初始化
new_state[key][: old_state[key].size(0)].copy_(old_state[key])
elif "mask" in key:
# mask是buffer,新模型已经初始化为正确大小
pass
else:
print(f"Warning: unexpected shape mismatch for {key}")
else:
print(f"Key {key} not in old model, initialized randomly.")
new_model.load_state_dict(new_state)
return new_model
# ============== 3. SFT数据集 ==============
class SFTDataset(IterableDataset):
def __init__(self, data_file, tokenizer, block_size=2048, mask_prob=0.8):
self.tokenizer = tokenizer
self.block_size = block_size
self.mask_prob = mask_prob # 80%概率只计算assistant loss
self.eos_id = tokenizer.special_tokens[SpecialToken("<|endoftext|>")]
# 预加载所有数据并编码
self.samples = []
with open(data_file, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
item = json.loads(line)
tokens, mask = self._encode_messages(item["messages"])
# 过滤掉没有任何assistant内容的样本
if len(tokens) > 0 and sum(mask) > 0:
self.samples.append((tokens, mask))
print(f"Loaded {len(self.samples)} valid SFT samples.")
def _encode_messages(self, messages):
token_ids = []
loss_mask = []
for msg in messages:
role = msg["role"]
content = msg["content"]
prefix = self.tokenizer.encode_all([
SpecialToken("<|im_start|>"),
f"{role}\n",
])
content_ids = self.tokenizer.encode(content)
suffix = self.tokenizer.encode_all([
SpecialToken("<|im_end|>"),
"\n",
])
msg_tokens = prefix + content_ids + suffix
msg_mask = [1 if role == "assistant" else 0] * len(msg_tokens)
token_ids.extend(msg_tokens)
loss_mask.extend(msg_mask)
# 添加eos
token_ids.append(self.eos_id)
loss_mask.append(1)
return token_ids, loss_mask
def __iter__(self):
while True:
idx = random.randint(0, len(self.samples) - 1)
tokens, assistant_mask = self.samples[idx]
# 截断到 block_size+1(为x,y留出空间)
max_len = self.block_size + 1
if len(tokens) > max_len:
tokens = tokens[:max_len]
assistant_mask = assistant_mask[:max_len]
x = tokens[:-1]
y = tokens[1:]
mask = assistant_mask[:-1]
# pad到block_size
pad_len = self.block_size - len(x)
if pad_len > 0:
x = x + [self.eos_id] * pad_len
y = y + [self.eos_id] * pad_len
mask = mask + [0] * pad_len
# 80% / 20% 策略
if random.random() < self.mask_prob:
final_mask = mask
else:
final_mask = [1] * self.block_size
yield (
torch.tensor(x, dtype=torch.int64),
torch.tensor(y, dtype=torch.int64),
torch.tensor(final_mask, dtype=torch.float32),
)
# ============== 4. 文本生成测试 ==============
@torch.no_grad()
def gen_text(model, tokenizer, text, device="cuda:0", max_new_tokens=200):
model.eval()
ids = torch.tensor(tokenizer.encode_all([
SpecialToken("<|im_start|>"),
"user\n",
text,
SpecialToken("<|im_end|>"),
"\n",
SpecialToken("<|im_start|>"),
"assistant\n",
]), dtype=torch.int64).to(device).view(1, -1)
output_ids = model.generate(ids, max_new_tokens=max_new_tokens)[0, :]
decoded = tokenizer.decode(output_ids.tolist())
model.train()
return decoded
# ============== 5. 训练 ==============
def train():
device = "cuda:0"
block_size = 2048
new_vocab_size = 50306
batch_size = 4
gradient_accumulation_steps = 4
learning_rate = 1e-5
max_iters = 2000
save_interval = 200
eval_interval = 50
# 准备tokenizer
if not os.path.exists("tokenizer_sft.model"):
tokenizer = prepare_tokenizer()
else:
tokenizer = load_tokenizer("tokenizer_sft.model")
print(f"Loaded tokenizer_sft.model with vocab_size={tokenizer.vocab_size}")
# 模型配置
config = TransformerConfig(
vocab_size=new_vocab_size,
block_size=block_size,
n_embed=768,
n_heads=12,
n_layers=12,
dropout=0.0,
bias=True,
)
# 扩展并加载模型
print("Expanding model vocab and loading checkpoint 150000.pt...")
model = expand_model_for_new_tokens("checkpoints/new/150000.pt", new_vocab_size, config)
model = model.to(device)
total_params = sum(p.numel() for p in model.parameters())
print(f"Model loaded. Total parameters: {total_params / 1e6:.2f}M")
# 数据集
dataset = SFTDataset(
"data/novels_sft_dataset.jsonl",
tokenizer,
block_size=block_size,
mask_prob=0.8,
)
loader = DataLoader(dataset, batch_size=batch_size)
data_iter = iter(loader)
# 优化器
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# AMP混合精度
scaler = GradScaler("cuda")
autocast_ctx = lambda: autocast("cuda", dtype=torch.float16)
os.makedirs("checkpoints/sft", exist_ok=True)
model.train()
pbar = tqdm(total=max_iters, desc="SFT Training")
all_loss = 0.0
for iter_num in range(max_iters + 1):
optimizer.zero_grad(set_to_none=True)
accum_loss = 0.0
for _ in range(gradient_accumulation_steps):
x, y, mask = next(data_iter)
x = x.to(device)
y = y.to(device)
mask = mask.to(device)
with autocast_ctx():
logits, _ = model(x, device=device)
logits = logits.view(-1, config.vocab_size)
y_flat = y.view(-1)
mask_flat = mask.view(-1)
loss = F.cross_entropy(logits, y_flat, reduction="none")
loss = (loss * mask_flat).sum() / (mask_flat.sum() + 1e-8)
loss = loss / gradient_accumulation_steps
scaler.scale(loss).backward()
accum_loss += loss.item()
# 梯度裁剪
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
all_loss += accum_loss
pbar.update(1)
pbar.set_postfix(loss=f"{accum_loss:.4f}")
if iter_num % eval_interval == 0:
print(f"\n[Step {iter_num}] Loss: {accum_loss:.4f}")
try:
decoded = gen_text(model, tokenizer, "写一个恋爱喜剧轻小说,主角是能听到物品心声的高中生。", device=device)
# 找到assistant回复部分打印
text_out = ""
for tok in decoded:
if isinstance(tok, str):
text_out += tok
print(f"Sample output: {text_out[:200]}...")
except Exception as e:
print(f"Generation error: {e}")
if iter_num > 0 and (iter_num % save_interval == 0 or iter_num == max_iters):
ckpt_path = f"checkpoints/sft/sft_{iter_num}.pt"
torch.save(model.state_dict(), ckpt_path)
print(f"\nSaved checkpoint: {ckpt_path}")
pbar.close()
final_path = "checkpoints/sft/sft_final.pt"
torch.save(model.state_dict(), final_path)
print(f"Training complete. Final model saved to {final_path}")
if __name__ == "__main__":
train()