# torchrun --nproc_per_node 8 train_distillation_pt.py --limit=-1 --batch_size=32 --n_layers=4 --use_wandb --ddp --local_rank=0 --epochs=100 import os from collections import defaultdict import numpy as np # Load modules # env | grep -E "PATH|LIBRARY_PATH|LD_LIBRARY_PATH" # Your code here username = os.environ['HOME'] import argparse import time import math import warnings # import pandas as pd import numpy as np # np.float = float # np.int = int #module 'numpy' has no attribute 'int' # np.object = object #module 'numpy' has no attribute 'object' # np.bool = bool #module 'numpy' has no attribute 'bool' # 获取脚本所在的目录 script_dir = os.path.dirname(os.path.abspath(__file__)) # 切换到脚本所在的目录 os.chdir(script_dir) import torch import torch.nn.functional as F import torch.distributed as dist from torch import optim, nn from torch.nn.parallel import DistributedDataParallel from torch.utils.data import DataLoader, DistributedSampler, random_split from contextlib import nullcontext from model.model_ribo import MiniMindLM from model.LMConfig import LMConfig from model.tools import EarlyStopping, get_pretraining_args,get_dataset_args from utils.ernie_rna.dictionary import Dictionary from utils.ernie_rna.dataset import RNADataset from src.utils import load_pretrained_ernierna from src.mRNA2vec.model import mRNA2vec, T5_encoder warnings.filterwarnings('ignore') def Logger(*content): if not ddp or dist.get_rank() == 0: print(*content) def get_lr(current_step, total_steps, lr): return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'): with torch.no_grad(): teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach() student_log_probs = F.log_softmax(student_logits / temperature, dim=-1) kl = F.kl_div( student_log_probs, teacher_probs, reduction=reduction ) return (temperature ** 2) * kl def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0): start_time = time.time() if teacher_model is not None: teacher_model.eval() teacher_model.requires_grad_(False) model.train() for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(train_loader): # print(f'train step:{step}/{len(train_loader)}') src_data, tgt_data, twod_data, loss_mask = src_data.to(args.device),tgt_data.to(args.device),twod_data.to(args.device),loss_mask.to(args.device) lr = get_lr(epoch * iter_per_epoch + step, args.epochs * iter_per_epoch, args.learning_rate) for param_group in optimizer.param_groups: param_group['lr'] = lr # 前向传播(学生模型) with ctx: res = model(src_data,twod_data) student_logits = res.logits[...,:-1] # drop mask # 教师模型前向传播(只在eval & no_grad) if teacher_model is not None: with torch.no_grad(): # teacher_logits = teacher_model(src_data,twod_data,is_twod=args.is_twod).logits if "ernierna" in args.mlm_pretrained_model_path: teacher_logits,attn_map_lst,out_dict = teacher_model(src_data,twod_data,is_twod=args.is_twod) # teacher_logits = out_dict['sentence_logits'] # [11,1205, 25] # teacher_logits = teacher_logits#[:,:,4:8] # [11,1205, 25] vocab_size_student = student_logits.size(-1) # N 10 teacher_logits = teacher_logits[..., :vocab_size_student] # [11*1205, 25] elif "mrna2vec" in args.mlm_pretrained_model_path: teacher_logits = teacher_model(src_data) vocab_size_student = student_logits.size(-1) teacher_logits = teacher_logits[..., :vocab_size_student] # [11*1205, 25] # ========== 计算损失 ========== # 1) Ground-Truth CE Loss(可选) loss_mask_flat = loss_mask.view(-1) ce_loss = F.cross_entropy( # 隐含了softmax的计算 student_logits.view(-1, student_logits.size(-1)), tgt_data.view(-1), ignore_index=0, reduction='none' ) ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum() if lm_config_student.use_moe: ce_loss += res.aux_loss # 2) Distillation Loss(可选) if teacher_model is not None: # 只在有效token位置做蒸馏 distill_loss = distillation_loss_fn( student_logits.reshape(-1, student_logits.size(-1))[loss_mask_flat == 1], teacher_logits.reshape(-1, teacher_logits.size(-1))[loss_mask_flat == 1], # [1394, 9] mask token temperature=temperature ) else: distill_loss = torch.tensor(0.0, device=args.device) # 3) 总损失 = alpha * CE + (1-alpha) * Distill loss = alpha * ce_loss + (1 - alpha) * distill_loss scaler.scale(loss).backward() if (step + 1) % args.accumulation_steps == 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) if (step % args.log_interval == 0 or args.debug) and (not ddp or dist.get_rank() == 0): spend_time = time.time() - start_time ans = { "tr_loss": loss.item(), "tr_ce_loss": ce_loss.item(), "tr_distill_loss": distill_loss.item() if teacher_model is not None else 0.0, "tr_lr": optimizer.param_groups[-1]['lr'], "tr_step_time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60 } Logger( 'Epoch:[{}/{}][step:{}/{}] tr_loss:{:.4f} tr_ce_loss:{:.4f} tr_distill_loss:{:.4f} tr_step_time:{} min:'.format( epoch, args.epochs - 1, step,train_loader.__len__(), ans['tr_loss'], ans['tr_ce_loss'], ans['tr_distill_loss'], ans['tr_step_time'] ) ) if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log(ans) if (step+1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): early_stopping.save_model(model, ckp.replace('epoch','step')) model.train() # print('end of a train step') def val_epoch(epoch, wandb, alpha=0.0, temperature=1.0): start_time = time.time() if teacher_model is not None: teacher_model.eval() teacher_model.requires_grad_(False) result = defaultdict(list) # for step, (X, Y, loss_mask) in enumerate(train_loader): for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(val_loader): # print(f'val step:{step}/{len(val_loader)}') src_data, tgt_data, twod_data, loss_mask = src_data.to(args.device),tgt_data.to(args.device),twod_data.to(args.device),loss_mask.to(args.device) # 前向传播(学生模型) with ctx: with torch.no_grad(): res = model(src_data,twod_data) student_logits = res.logits[...,:-1] # drop mask # 教师模型前向传播(只在eval & no_grad) if teacher_model is not None: with torch.no_grad(): # teacher_logits = teacher_model(src_data,twod_data,is_twod=args.is_twod).logits teacher_logits,attn_map_lst,out_dict = teacher_model(src_data,twod_data,is_twod=args.is_twod) # teacher_logits = out_dict['sentence_logits'] # [11,1205, 25] # teacher_logits = teacher_logits#[:,:,4:8] # [11,1205, 25] vocab_size_student = student_logits.size(-1) # N 10 teacher_logits = teacher_logits[..., :vocab_size_student] # [11*1205, 25] # ========== 计算损失 ========== # 1) Ground-Truth CE Loss(可选) loss_mask_flat = loss_mask.view(-1) ce_loss = F.cross_entropy( # 隐含了softmax的计算 student_logits.view(-1, student_logits.size(-1)), tgt_data.view(-1), ignore_index=0, reduction='none' ) ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum() if lm_config_student.use_moe: ce_loss += res.aux_loss # 2) Distillation Loss(可选) if teacher_model is not None: # 只在有效token位置做蒸馏 distill_loss = distillation_loss_fn( student_logits.reshape(-1, student_logits.size(-1))[loss_mask_flat == 1], teacher_logits.reshape(-1, teacher_logits.size(-1))[loss_mask_flat == 1], # [1394, 9] mask token temperature=temperature ) else: distill_loss = torch.tensor(0.0, device=args.device) # 3) 总损失 = alpha * CE + (1-alpha) * Distill loss = alpha * ce_loss + (1 - alpha) * distill_loss spend_time = time.time() - start_time rs = { "val_loss": loss.item(), "val_ce_loss": ce_loss.item(), "val_distill_loss": distill_loss.item() if teacher_model is not None else 0.0, "val_lr": optimizer.param_groups[-1]['lr'], "val_step_time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60 } [result[key].append(value) for key,value in rs.items()] ans = {key:np.array(result[key]).mean() for key,value in result.items()} # print('dist.get_rank()',dist.get_rank()) if not ddp or dist.get_rank() == 0: Logger( 'Epoch:[{}/{}] val_loss:{:.4f} val_ce_loss:{:.4f} val_distill_loss:{:.4f} val_step_time:{} min:'.format( epoch, args.epochs - 1, ans['val_loss'], ans['val_ce_loss'], ans['val_distill_loss'], ans['val_step_time'] ) ) if (wandb is not None) and (not ddp or dist.get_rank() == 0): wandb.log(ans) # print('end of val epoch') return ans['val_loss'] #val_loss def init_student_model(): # tokenizer = AutoTokenizer.from_pretrained('./model/minimind_tokenizer') vocab_path = args.arg_overrides['data'] + '/small_dict.txt' tokenizer = Dictionary.load(vocab_path) tokenizer.mask_index = tokenizer.add_symbol('') # ['', '', '', '', 'G', 'A', 'U', 'C', 'N', ''] if args.debug:args.n_layers = 1 lm_config = LMConfig(dim=256, n_layers=args.n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, model = MiniMindLM(lm_config) # moe_path = '_moe' if lm_config.use_moe else '' # ckp = f'./out/full_sft_{lm_config.dim}{moe_path}.pth' # state_dict = torch.load(ckp, map_location=args.device) # model.load_state_dict(state_dict, strict=False) Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万,vocab_size={len(tokenizer)}') model = model.to(args.device) print(model) return model, tokenizer,lm_config # def load_pretrained_ernierna(mlm_pretrained_model_path,arg_overrides): # rna_models, _, _ = checkpoint_utils.load_model_ensemble_and_task(mlm_pretrained_model_path.split(os.pathsep),arg_overrides=arg_overrides) # model_pretrained = rna_models[0] # return model_pretrained def init_teacher_model(args,Logger=None): # model = MiniMindLM(lm_config) # moe_path = '_moe' if lm_config.use_moe else '' # ckp = args.mlm_pretrained_model_path vocab_path = args.arg_overrides['data'] + '/dict.txt' tokenizer = Dictionary.load(vocab_path) tokenizer.add_symbol('') # tokenizer = None if 'ernierna' in args.mlm_pretrained_model_path: model_pre = load_pretrained_ernierna(args.mlm_pretrained_model_path, args.arg_overrides) model = model_pre.encoder if args.debug: print('debug mode') num_layers_to_keep = 1 # 保留12层,todo model.sentence_encoder.layers = model.sentence_encoder.layers[ :num_layers_to_keep] # torch.save(model,args.save_dir+'/pretraining0215.pt') # print('save small ERNIE-RNA model in',args.save_dir+'/pretraining0215.pt') # state_dict = torch.load(ckp, map_location=args.device) # model.load_state_dict(state_dict, strict=False) elif "mrna2vec" in args.mlm_pretrained_model_path: encoder = T5_encoder( hidden_size=256, num_attention_heads=4, num_hidden_layers=4, ) model = mRNA2vec(encoder=encoder) ckpt = torch.load(args.mlm_pretrained_model_path, map_location=args.device) model.encoder.load_state_dict(ckpt['encoder'], strict=True) Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万, vocab_size={len(tokenizer)}') model = model.to(args.device) print(model) return model,tokenizer def init_distributed_mode(): print("init distributed mode,ddp=",ddp) if not ddp: return global ddp_local_rank, DEVICE dist.init_process_group(backend="nccl") ddp_rank = int(os.environ["RANK"]) ddp_local_rank = int(os.environ["LOCAL_RANK"]) ddp_world_size = int(os.environ["WORLD_SIZE"]) DEVICE = f"cuda:{ddp_local_rank}" torch.cuda.set_device(DEVICE) if __name__ == '__main__': # 获取 pretraining 和 dataset 的 args os.chdir(os.path.dirname(os.path.abspath(__file__))) pretraining_parser = get_pretraining_args() dataset_parser = get_dataset_args() # 合并 args parser = argparse.ArgumentParser(parents=[pretraining_parser, dataset_parser], add_help=False,conflict_handler='resolve') args = parser.parse_args() torch.manual_seed(1337) device_type = "cuda" if "cuda" in args.device else "cpu" ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run? ddp_local_rank, DEVICE = 0, "cuda:0" if ddp: init_distributed_mode() args.device = torch.device(DEVICE) Logger(f'loading model') # 定义学生模型和教师模型 max_seq_len = args.region*4+5 args.save_dir = os.path.join(args.out_dir) os.makedirs(args.save_dir, exist_ok=True) os.makedirs(args.out_dir, exist_ok=True) tokens_per_iter = args.batch_size * max_seq_len # 初始化学生模型和教师模型 # teacher_model = init_teacher_model(lm_config_teacher) teacher_model,tokenizer = init_teacher_model(args,Logger) # teacher_model, tokenizer = init_student_model(lm_config_student) model, tokenizer,lm_config_student = init_student_model() if args.debug:args.limit=320 # train_ds = SFTDataset(args.data_path, tokenizer, max_length=max_seq_len) dataset = RNADataset(args.ffasta, region=args.region,tokenizer=tokenizer,limit=args.limit) train_size = int(0.99 * len(dataset)) # 99% 用于训练 val_size = len(dataset) - train_size # 1% 用于验证 train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) train_sampler = DistributedSampler(train_dataset) if ddp else None Logger(f'loading {train_size} training samples from',args.ffasta) Logger(f'loading {val_size} validation samples from',args.ffasta) train_loader = DataLoader( train_dataset, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=False, num_workers=args.num_workers, sampler=train_sampler ) val_loader = DataLoader( val_dataset, batch_size=args.batch_size, pin_memory=True, drop_last=False, shuffle=False, num_workers=args.num_workers ) args.wandb_run_name = f"{args.wandb_project}-EP-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}-{len(train_loader.dataset)/ 1e6:.0f}" if args.use_wandb and (not ddp or ddp_local_rank == 0): import wandb wandb.init(project=args.wandb_project, name=args.wandb_run_name,mode="offline") else: wandb = None scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) if ddp: model._ddp_params_and_buffers_to_ignore = {"pos_cis"} model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) Logger('local_rank',args.local_rank,'ddp',ddp,'ddp_local_rank',ddp_local_rank) iter_per_epoch = len(train_loader) epoch = 0 # 定义 EarlyStopping 实例 moe_path = '_moe' if lm_config_student.use_moe else '' ckp = f'{args.save_dir}/full_dist_{lm_config_student.dim}{moe_path}_epoch.pth' Logger(f'save model to', os.path.abspath(ckp)) early_stopping = EarlyStopping(patience=5, verbose=True, path=ckp) for epoch in range(args.epochs): if ddp: train_loader.sampler.set_epoch(epoch) # print(f'start training epoch: {epoch}/{args.epochs}') train_epoch(epoch, wandb,alpha=args.celoss_alpha) current_loss = val_epoch(epoch, wandb, alpha=args.celoss_alpha) if ddp: # 分布式训练逻辑 if ddp_local_rank == 0: early_stopping(current_loss, model) # 如果监控的是SPR,直接传入-SPR即可 # 广播 should_stop 的值到其他进程 to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.float32, device=args.device) dist.broadcast(to_broadcast, 0) else: # 非主进程等待主进程广播 to_broadcast = torch.tensor([False], dtype=torch.float32, device=args.device) dist.broadcast(to_broadcast, 0) early_stopping.early_stop = bool(to_broadcast.item()) else: # 单机单卡训练逻辑 early_stopping(current_loss, model) # 如果监控的是SPR,直接传入-SPR即可 if early_stopping.early_stop:break print('the end')