maotao / src /mRNA2vec /train_distillation_pt.py
julse's picture
upload AA2CDS
4707555 verified
# torchrun --nproc_per_node 8 train_distillation_pt.py --limit=-1 --batch_size=32 --n_layers=4 --use_wandb --ddp --local_rank=0 --epochs=100
import os
from collections import defaultdict
import numpy as np
# Load modules
# env | grep -E "PATH|LIBRARY_PATH|LD_LIBRARY_PATH"
# Your code here
username = os.environ['HOME']
import argparse
import time
import math
import warnings
# import pandas as pd
import numpy as np
# np.float = float
# np.int = int #module 'numpy' has no attribute 'int'
# np.object = object #module 'numpy' has no attribute 'object'
# np.bool = bool #module 'numpy' has no attribute 'bool'
# 获取脚本所在的目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 切换到脚本所在的目录
os.chdir(script_dir)
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler, random_split
from contextlib import nullcontext
from model.model_ribo import MiniMindLM
from model.LMConfig import LMConfig
from model.tools import EarlyStopping, get_pretraining_args,get_dataset_args
from utils.ernie_rna.dictionary import Dictionary
from utils.ernie_rna.dataset import RNADataset
from src.utils import load_pretrained_ernierna
from src.mRNA2vec.model import mRNA2vec, T5_encoder
warnings.filterwarnings('ignore')
def Logger(*content):
if not ddp or dist.get_rank() == 0:
print(*content)
def get_lr(current_step, total_steps, lr):
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps))
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'):
with torch.no_grad():
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach()
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1)
kl = F.kl_div(
student_log_probs,
teacher_probs,
reduction=reduction
)
return (temperature ** 2) * kl
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
start_time = time.time()
if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)
model.train()
for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(train_loader):
# print(f'train step:{step}/{len(train_loader)}')
src_data, tgt_data, twod_data, loss_mask = src_data.to(args.device),tgt_data.to(args.device),twod_data.to(args.device),loss_mask.to(args.device)
lr = get_lr(epoch * iter_per_epoch + step,
args.epochs * iter_per_epoch,
args.learning_rate)
for param_group in optimizer.param_groups:
param_group['lr'] = lr
# 前向传播(学生模型)
with ctx:
res = model(src_data,twod_data)
student_logits = res.logits[...,:-1] # drop mask
# 教师模型前向传播(只在eval & no_grad)
if teacher_model is not None:
with torch.no_grad():
# teacher_logits = teacher_model(src_data,twod_data,is_twod=args.is_twod).logits
if "ernierna" in args.mlm_pretrained_model_path:
teacher_logits,attn_map_lst,out_dict = teacher_model(src_data,twod_data,is_twod=args.is_twod)
# teacher_logits = out_dict['sentence_logits'] # [11,1205, 25]
# teacher_logits = teacher_logits#[:,:,4:8] # [11,1205, 25]
vocab_size_student = student_logits.size(-1) # N 10
teacher_logits = teacher_logits[..., :vocab_size_student] # [11*1205, 25]
elif "mrna2vec" in args.mlm_pretrained_model_path:
teacher_logits = teacher_model(src_data)
vocab_size_student = student_logits.size(-1)
teacher_logits = teacher_logits[..., :vocab_size_student] # [11*1205, 25]
# ========== 计算损失 ==========
# 1) Ground-Truth CE Loss(可选)
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy( # 隐含了softmax的计算
student_logits.view(-1, student_logits.size(-1)),
tgt_data.view(-1),
ignore_index=0,
reduction='none'
)
ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
if lm_config_student.use_moe:
ce_loss += res.aux_loss
# 2) Distillation Loss(可选)
if teacher_model is not None:
# 只在有效token位置做蒸馏
distill_loss = distillation_loss_fn(
student_logits.reshape(-1, student_logits.size(-1))[loss_mask_flat == 1],
teacher_logits.reshape(-1, teacher_logits.size(-1))[loss_mask_flat == 1], # [1394, 9] mask token
temperature=temperature
)
else:
distill_loss = torch.tensor(0.0, device=args.device)
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
loss = alpha * ce_loss + (1 - alpha) * distill_loss
scaler.scale(loss).backward()
if (step + 1) % args.accumulation_steps == 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if (step % args.log_interval == 0 or args.debug) and (not ddp or dist.get_rank() == 0):
spend_time = time.time() - start_time
ans = {
"tr_loss": loss.item(),
"tr_ce_loss": ce_loss.item(),
"tr_distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
"tr_lr": optimizer.param_groups[-1]['lr'],
"tr_step_time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
}
Logger(
'Epoch:[{}/{}][step:{}/{}] tr_loss:{:.4f} tr_ce_loss:{:.4f} tr_distill_loss:{:.4f} tr_step_time:{} min:'.format(
epoch,
args.epochs - 1,
step,train_loader.__len__(),
ans['tr_loss'],
ans['tr_ce_loss'],
ans['tr_distill_loss'],
ans['tr_step_time']
)
)
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log(ans)
if (step+1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0):
early_stopping.save_model(model, ckp.replace('epoch','step'))
model.train()
# print('end of a train step')
def val_epoch(epoch, wandb, alpha=0.0, temperature=1.0):
start_time = time.time()
if teacher_model is not None:
teacher_model.eval()
teacher_model.requires_grad_(False)
result = defaultdict(list)
# for step, (X, Y, loss_mask) in enumerate(train_loader):
for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(val_loader):
# print(f'val step:{step}/{len(val_loader)}')
src_data, tgt_data, twod_data, loss_mask = src_data.to(args.device),tgt_data.to(args.device),twod_data.to(args.device),loss_mask.to(args.device)
# 前向传播(学生模型)
with ctx:
with torch.no_grad():
res = model(src_data,twod_data)
student_logits = res.logits[...,:-1] # drop mask
# 教师模型前向传播(只在eval & no_grad)
if teacher_model is not None:
with torch.no_grad():
# teacher_logits = teacher_model(src_data,twod_data,is_twod=args.is_twod).logits
teacher_logits,attn_map_lst,out_dict = teacher_model(src_data,twod_data,is_twod=args.is_twod)
# teacher_logits = out_dict['sentence_logits'] # [11,1205, 25]
# teacher_logits = teacher_logits#[:,:,4:8] # [11,1205, 25]
vocab_size_student = student_logits.size(-1) # N 10
teacher_logits = teacher_logits[..., :vocab_size_student] # [11*1205, 25]
# ========== 计算损失 ==========
# 1) Ground-Truth CE Loss(可选)
loss_mask_flat = loss_mask.view(-1)
ce_loss = F.cross_entropy( # 隐含了softmax的计算
student_logits.view(-1, student_logits.size(-1)),
tgt_data.view(-1),
ignore_index=0,
reduction='none'
)
ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum()
if lm_config_student.use_moe:
ce_loss += res.aux_loss
# 2) Distillation Loss(可选)
if teacher_model is not None:
# 只在有效token位置做蒸馏
distill_loss = distillation_loss_fn(
student_logits.reshape(-1, student_logits.size(-1))[loss_mask_flat == 1],
teacher_logits.reshape(-1, teacher_logits.size(-1))[loss_mask_flat == 1], # [1394, 9] mask token
temperature=temperature
)
else:
distill_loss = torch.tensor(0.0, device=args.device)
# 3) 总损失 = alpha * CE + (1-alpha) * Distill
loss = alpha * ce_loss + (1 - alpha) * distill_loss
spend_time = time.time() - start_time
rs = {
"val_loss": loss.item(),
"val_ce_loss": ce_loss.item(),
"val_distill_loss": distill_loss.item() if teacher_model is not None else 0.0,
"val_lr": optimizer.param_groups[-1]['lr'],
"val_step_time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60
}
[result[key].append(value) for key,value in rs.items()]
ans = {key:np.array(result[key]).mean() for key,value in result.items()}
# print('dist.get_rank()',dist.get_rank())
if not ddp or dist.get_rank() == 0:
Logger(
'Epoch:[{}/{}] val_loss:{:.4f} val_ce_loss:{:.4f} val_distill_loss:{:.4f} val_step_time:{} min:'.format(
epoch,
args.epochs - 1,
ans['val_loss'],
ans['val_ce_loss'],
ans['val_distill_loss'],
ans['val_step_time']
)
)
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log(ans)
# print('end of val epoch')
return ans['val_loss'] #val_loss
def init_student_model():
# tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer')
vocab_path = args.arg_overrides['data'] + '/small_dict.txt'
tokenizer = Dictionary.load(vocab_path)
tokenizer.mask_index = tokenizer.add_symbol('<mask>')
# ['<s>', '<pad>', '</s>', '<unk>', 'G', 'A', 'U', 'C', 'N', '<mask>']
if args.debug:args.n_layers = 1
lm_config = LMConfig(dim=256, n_layers=args.n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, <s> <unk><unk><unk> </s>
model = MiniMindLM(lm_config)
# moe_path = '_moe' if lm_config.use_moe else ''
# ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth'
# state_dict = torch.load(ckp, map_location=args.device)
# model.load_state_dict(state_dict, strict=False)
Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万,vocab_size={len(tokenizer)}')
model = model.to(args.device)
print(model)
return model, tokenizer,lm_config
# def load_pretrained_ernierna(mlm_pretrained_model_path,arg_overrides):
# rna_models, _, _ = checkpoint_utils.load_model_ensemble_and_task(mlm_pretrained_model_path.split(os.pathsep),arg_overrides=arg_overrides)
# model_pretrained = rna_models[0]
# return model_pretrained
def init_teacher_model(args,Logger=None):
# model = MiniMindLM(lm_config)
# moe_path = '_moe' if lm_config.use_moe else ''
# ckp = args.mlm_pretrained_model_path
vocab_path = args.arg_overrides['data'] + '/dict.txt'
tokenizer = Dictionary.load(vocab_path)
tokenizer.add_symbol('<mask>')
# tokenizer = None
if 'ernierna' in args.mlm_pretrained_model_path:
model_pre = load_pretrained_ernierna(args.mlm_pretrained_model_path, args.arg_overrides)
model = model_pre.encoder
if args.debug:
print('debug mode')
num_layers_to_keep = 1 # 保留12层,todo
model.sentence_encoder.layers = model.sentence_encoder.layers[
:num_layers_to_keep]
# torch.save(model,args.save_dir+'/pretraining0215.pt')
# print('save small ERNIE-RNA model in',args.save_dir+'/pretraining0215.pt')
# state_dict = torch.load(ckp, map_location=args.device)
# model.load_state_dict(state_dict, strict=False)
elif "mrna2vec" in args.mlm_pretrained_model_path:
encoder = T5_encoder(
hidden_size=256,
num_attention_heads=4,
num_hidden_layers=4,
)
model = mRNA2vec(encoder=encoder)
ckpt = torch.load(args.mlm_pretrained_model_path, map_location=args.device)
model.encoder.load_state_dict(ckpt['encoder'], strict=True)
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万, vocab_size={len(tokenizer)}')
model = model.to(args.device)
print(model)
return model,tokenizer
def init_distributed_mode():
print("init distributed mode,ddp=",ddp)
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
if __name__ == '__main__':
# 获取 pretraining 和 dataset 的 args
os.chdir(os.path.dirname(os.path.abspath(__file__)))
pretraining_parser = get_pretraining_args()
dataset_parser = get_dataset_args()
# 合并 args
parser = argparse.ArgumentParser(parents=[pretraining_parser, dataset_parser], add_help=False,conflict_handler='resolve')
args = parser.parse_args()
torch.manual_seed(1337)
device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp:
init_distributed_mode()
args.device = torch.device(DEVICE)
Logger(f'loading model')
# 定义学生模型和教师模型
max_seq_len = args.region*4+5
args.save_dir = os.path.join(args.out_dir)
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * max_seq_len
# 初始化学生模型和教师模型
# teacher_model = init_teacher_model(lm_config_teacher)
teacher_model,tokenizer = init_teacher_model(args,Logger)
# teacher_model, tokenizer = init_student_model(lm_config_student)
model, tokenizer,lm_config_student = init_student_model()
if args.debug:args.limit=320
# train_ds = SFTDataset(args.data_path, tokenizer, max_length=max_seq_len)
dataset = RNADataset(args.ffasta, region=args.region,tokenizer=tokenizer,limit=args.limit)
train_size = int(0.99 * len(dataset)) # 99% 用于训练
val_size = len(dataset) - train_size # 1% 用于验证
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_sampler = DistributedSampler(train_dataset) if ddp else None
Logger(f'loading {train_size} training samples from',args.ffasta)
Logger(f'loading {val_size} validation samples from',args.ffasta)
train_loader = DataLoader(
train_dataset,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler
)
val_loader = DataLoader(
val_dataset,
batch_size=args.batch_size,
pin_memory=True,
drop_last=False,
shuffle=False,
num_workers=args.num_workers
)
args.wandb_run_name = f"{args.wandb_project}-EP-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}-{len(train_loader.dataset)/ 1e6:.0f}"
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name,mode="offline")
else:
wandb = None
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank])
Logger('local_rank',args.local_rank,'ddp',ddp,'ddp_local_rank',ddp_local_rank)
iter_per_epoch = len(train_loader)
epoch = 0
# 定义 EarlyStopping 实例
moe_path = '_moe' if lm_config_student.use_moe else ''
ckp = f'{args.save_dir}/full_dist_{lm_config_student.dim}{moe_path}_epoch.pth'
Logger(f'save model to', os.path.abspath(ckp))
early_stopping = EarlyStopping(patience=5, verbose=True, path=ckp)
for epoch in range(args.epochs):
if ddp: train_loader.sampler.set_epoch(epoch)
# print(f'start training epoch: {epoch}/{args.epochs}')
train_epoch(epoch, wandb,alpha=args.celoss_alpha)
current_loss = val_epoch(epoch, wandb, alpha=args.celoss_alpha)
if ddp:
# 分布式训练逻辑
if ddp_local_rank == 0:
early_stopping(current_loss, model) # 如果监控的是SPR,直接传入-SPR即可
# 广播 should_stop 的值到其他进程
to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.float32, device=args.device)
dist.broadcast(to_broadcast, 0)
else:
# 非主进程等待主进程广播
to_broadcast = torch.tensor([False], dtype=torch.float32, device=args.device)
dist.broadcast(to_broadcast, 0)
early_stopping.early_stop = bool(to_broadcast.item())
else:
# 单机单卡训练逻辑
early_stopping(current_loss, model) # 如果监控的是SPR,直接传入-SPR即可
if early_stopping.early_stop:break
print('the end')