|
|
|
|
|
import os |
|
|
from collections import defaultdict |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
username = os.environ['HOME'] |
|
|
import argparse |
|
|
import time |
|
|
import math |
|
|
import warnings |
|
|
|
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
|
|
|
os.chdir(script_dir) |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
|
|
|
from torch import optim, nn |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
from torch.utils.data import DataLoader, DistributedSampler, random_split |
|
|
from contextlib import nullcontext |
|
|
from model.model_ribo import MiniMindLM |
|
|
from model.LMConfig import LMConfig |
|
|
from model.tools import EarlyStopping, get_pretraining_args,get_dataset_args |
|
|
from utils.ernie_rna.dictionary import Dictionary |
|
|
from utils.ernie_rna.dataset import RNADataset |
|
|
|
|
|
from src.utils import load_pretrained_ernierna |
|
|
from src.mRNA2vec.model import mRNA2vec, T5_encoder |
|
|
warnings.filterwarnings('ignore') |
|
|
def Logger(*content): |
|
|
if not ddp or dist.get_rank() == 0: |
|
|
print(*content) |
|
|
|
|
|
def get_lr(current_step, total_steps, lr): |
|
|
return lr / 10 + 0.5 * lr * (1 + math.cos(math.pi * current_step / total_steps)) |
|
|
|
|
|
|
|
|
def distillation_loss_fn(student_logits, teacher_logits, temperature=1.0, reduction='batchmean'): |
|
|
with torch.no_grad(): |
|
|
teacher_probs = F.softmax(teacher_logits / temperature, dim=-1).detach() |
|
|
|
|
|
student_log_probs = F.log_softmax(student_logits / temperature, dim=-1) |
|
|
|
|
|
kl = F.kl_div( |
|
|
student_log_probs, |
|
|
teacher_probs, |
|
|
reduction=reduction |
|
|
) |
|
|
return (temperature ** 2) * kl |
|
|
|
|
|
|
|
|
def train_epoch(epoch, wandb, alpha=0.0, temperature=1.0): |
|
|
start_time = time.time() |
|
|
if teacher_model is not None: |
|
|
teacher_model.eval() |
|
|
teacher_model.requires_grad_(False) |
|
|
model.train() |
|
|
for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(train_loader): |
|
|
|
|
|
src_data, tgt_data, twod_data, loss_mask = src_data.to(args.device),tgt_data.to(args.device),twod_data.to(args.device),loss_mask.to(args.device) |
|
|
|
|
|
lr = get_lr(epoch * iter_per_epoch + step, |
|
|
args.epochs * iter_per_epoch, |
|
|
args.learning_rate) |
|
|
for param_group in optimizer.param_groups: |
|
|
param_group['lr'] = lr |
|
|
|
|
|
|
|
|
with ctx: |
|
|
res = model(src_data,twod_data) |
|
|
student_logits = res.logits[...,:-1] |
|
|
|
|
|
|
|
|
if teacher_model is not None: |
|
|
with torch.no_grad(): |
|
|
|
|
|
if "ernierna" in args.mlm_pretrained_model_path: |
|
|
teacher_logits,attn_map_lst,out_dict = teacher_model(src_data,twod_data,is_twod=args.is_twod) |
|
|
|
|
|
|
|
|
vocab_size_student = student_logits.size(-1) |
|
|
teacher_logits = teacher_logits[..., :vocab_size_student] |
|
|
elif "mrna2vec" in args.mlm_pretrained_model_path: |
|
|
teacher_logits = teacher_model(src_data) |
|
|
vocab_size_student = student_logits.size(-1) |
|
|
teacher_logits = teacher_logits[..., :vocab_size_student] |
|
|
|
|
|
|
|
|
|
|
|
loss_mask_flat = loss_mask.view(-1) |
|
|
ce_loss = F.cross_entropy( |
|
|
student_logits.view(-1, student_logits.size(-1)), |
|
|
tgt_data.view(-1), |
|
|
ignore_index=0, |
|
|
reduction='none' |
|
|
) |
|
|
ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum() |
|
|
if lm_config_student.use_moe: |
|
|
ce_loss += res.aux_loss |
|
|
|
|
|
|
|
|
if teacher_model is not None: |
|
|
|
|
|
distill_loss = distillation_loss_fn( |
|
|
student_logits.reshape(-1, student_logits.size(-1))[loss_mask_flat == 1], |
|
|
teacher_logits.reshape(-1, teacher_logits.size(-1))[loss_mask_flat == 1], |
|
|
temperature=temperature |
|
|
) |
|
|
else: |
|
|
distill_loss = torch.tensor(0.0, device=args.device) |
|
|
|
|
|
|
|
|
loss = alpha * ce_loss + (1 - alpha) * distill_loss |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
if (step + 1) % args.accumulation_steps == 0: |
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
optimizer.zero_grad(set_to_none=True) |
|
|
|
|
|
if (step % args.log_interval == 0 or args.debug) and (not ddp or dist.get_rank() == 0): |
|
|
spend_time = time.time() - start_time |
|
|
ans = { |
|
|
"tr_loss": loss.item(), |
|
|
"tr_ce_loss": ce_loss.item(), |
|
|
"tr_distill_loss": distill_loss.item() if teacher_model is not None else 0.0, |
|
|
"tr_lr": optimizer.param_groups[-1]['lr'], |
|
|
"tr_step_time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60 |
|
|
} |
|
|
Logger( |
|
|
'Epoch:[{}/{}][step:{}/{}] tr_loss:{:.4f} tr_ce_loss:{:.4f} tr_distill_loss:{:.4f} tr_step_time:{} min:'.format( |
|
|
epoch, |
|
|
args.epochs - 1, |
|
|
step,train_loader.__len__(), |
|
|
ans['tr_loss'], |
|
|
ans['tr_ce_loss'], |
|
|
ans['tr_distill_loss'], |
|
|
ans['tr_step_time'] |
|
|
) |
|
|
) |
|
|
if (wandb is not None) and (not ddp or dist.get_rank() == 0): |
|
|
wandb.log(ans) |
|
|
|
|
|
if (step+1) % args.save_interval == 0 and (not ddp or dist.get_rank() == 0): |
|
|
early_stopping.save_model(model, ckp.replace('epoch','step')) |
|
|
model.train() |
|
|
|
|
|
def val_epoch(epoch, wandb, alpha=0.0, temperature=1.0): |
|
|
start_time = time.time() |
|
|
if teacher_model is not None: |
|
|
teacher_model.eval() |
|
|
teacher_model.requires_grad_(False) |
|
|
result = defaultdict(list) |
|
|
|
|
|
for step, (src_data,tgt_data,twod_data,loss_mask) in enumerate(val_loader): |
|
|
|
|
|
src_data, tgt_data, twod_data, loss_mask = src_data.to(args.device),tgt_data.to(args.device),twod_data.to(args.device),loss_mask.to(args.device) |
|
|
|
|
|
|
|
|
with ctx: |
|
|
with torch.no_grad(): |
|
|
res = model(src_data,twod_data) |
|
|
student_logits = res.logits[...,:-1] |
|
|
|
|
|
|
|
|
if teacher_model is not None: |
|
|
with torch.no_grad(): |
|
|
|
|
|
teacher_logits,attn_map_lst,out_dict = teacher_model(src_data,twod_data,is_twod=args.is_twod) |
|
|
|
|
|
|
|
|
vocab_size_student = student_logits.size(-1) |
|
|
teacher_logits = teacher_logits[..., :vocab_size_student] |
|
|
|
|
|
|
|
|
|
|
|
loss_mask_flat = loss_mask.view(-1) |
|
|
ce_loss = F.cross_entropy( |
|
|
student_logits.view(-1, student_logits.size(-1)), |
|
|
tgt_data.view(-1), |
|
|
ignore_index=0, |
|
|
reduction='none' |
|
|
) |
|
|
ce_loss = torch.sum(ce_loss * loss_mask_flat) / loss_mask_flat.sum() |
|
|
if lm_config_student.use_moe: |
|
|
ce_loss += res.aux_loss |
|
|
|
|
|
|
|
|
if teacher_model is not None: |
|
|
|
|
|
distill_loss = distillation_loss_fn( |
|
|
student_logits.reshape(-1, student_logits.size(-1))[loss_mask_flat == 1], |
|
|
teacher_logits.reshape(-1, teacher_logits.size(-1))[loss_mask_flat == 1], |
|
|
temperature=temperature |
|
|
) |
|
|
else: |
|
|
distill_loss = torch.tensor(0.0, device=args.device) |
|
|
|
|
|
|
|
|
loss = alpha * ce_loss + (1 - alpha) * distill_loss |
|
|
spend_time = time.time() - start_time |
|
|
rs = { |
|
|
"val_loss": loss.item(), |
|
|
"val_ce_loss": ce_loss.item(), |
|
|
"val_distill_loss": distill_loss.item() if teacher_model is not None else 0.0, |
|
|
"val_lr": optimizer.param_groups[-1]['lr'], |
|
|
"val_step_time": spend_time / (step + 1) * iter_per_epoch // 60 - spend_time // 60 |
|
|
} |
|
|
[result[key].append(value) for key,value in rs.items()] |
|
|
ans = {key:np.array(result[key]).mean() for key,value in result.items()} |
|
|
|
|
|
if not ddp or dist.get_rank() == 0: |
|
|
Logger( |
|
|
'Epoch:[{}/{}] val_loss:{:.4f} val_ce_loss:{:.4f} val_distill_loss:{:.4f} val_step_time:{} min:'.format( |
|
|
epoch, |
|
|
args.epochs - 1, |
|
|
ans['val_loss'], |
|
|
ans['val_ce_loss'], |
|
|
ans['val_distill_loss'], |
|
|
ans['val_step_time'] |
|
|
) |
|
|
) |
|
|
|
|
|
if (wandb is not None) and (not ddp or dist.get_rank() == 0): |
|
|
wandb.log(ans) |
|
|
|
|
|
return ans['val_loss'] |
|
|
|
|
|
|
|
|
def init_student_model(): |
|
|
|
|
|
vocab_path = args.arg_overrides['data'] + '/small_dict.txt' |
|
|
tokenizer = Dictionary.load(vocab_path) |
|
|
tokenizer.mask_index = tokenizer.add_symbol('<mask>') |
|
|
|
|
|
if args.debug:args.n_layers = 1 |
|
|
lm_config = LMConfig(dim=256, n_layers=args.n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) |
|
|
model = MiniMindLM(lm_config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Logger(f'学生模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万,vocab_size={len(tokenizer)}') |
|
|
model = model.to(args.device) |
|
|
print(model) |
|
|
return model, tokenizer,lm_config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_teacher_model(args,Logger=None): |
|
|
|
|
|
|
|
|
|
|
|
vocab_path = args.arg_overrides['data'] + '/dict.txt' |
|
|
tokenizer = Dictionary.load(vocab_path) |
|
|
tokenizer.add_symbol('<mask>') |
|
|
|
|
|
if 'ernierna' in args.mlm_pretrained_model_path: |
|
|
model_pre = load_pretrained_ernierna(args.mlm_pretrained_model_path, args.arg_overrides) |
|
|
model = model_pre.encoder |
|
|
|
|
|
if args.debug: |
|
|
print('debug mode') |
|
|
num_layers_to_keep = 1 |
|
|
model.sentence_encoder.layers = model.sentence_encoder.layers[ |
|
|
:num_layers_to_keep] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
elif "mrna2vec" in args.mlm_pretrained_model_path: |
|
|
encoder = T5_encoder( |
|
|
hidden_size=256, |
|
|
num_attention_heads=4, |
|
|
num_hidden_layers=4, |
|
|
) |
|
|
model = mRNA2vec(encoder=encoder) |
|
|
ckpt = torch.load(args.mlm_pretrained_model_path, map_location=args.device) |
|
|
model.encoder.load_state_dict(ckpt['encoder'], strict=True) |
|
|
|
|
|
Logger(f'教师模型(LLM)总参数量:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万, vocab_size={len(tokenizer)}') |
|
|
model = model.to(args.device) |
|
|
print(model) |
|
|
return model,tokenizer |
|
|
|
|
|
|
|
|
def init_distributed_mode(): |
|
|
print("init distributed mode,ddp=",ddp) |
|
|
|
|
|
if not ddp: return |
|
|
global ddp_local_rank, DEVICE |
|
|
dist.init_process_group(backend="nccl") |
|
|
ddp_rank = int(os.environ["RANK"]) |
|
|
ddp_local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
ddp_world_size = int(os.environ["WORLD_SIZE"]) |
|
|
DEVICE = f"cuda:{ddp_local_rank}" |
|
|
torch.cuda.set_device(DEVICE) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
os.chdir(os.path.dirname(os.path.abspath(__file__))) |
|
|
pretraining_parser = get_pretraining_args() |
|
|
dataset_parser = get_dataset_args() |
|
|
|
|
|
|
|
|
parser = argparse.ArgumentParser(parents=[pretraining_parser, dataset_parser], add_help=False,conflict_handler='resolve') |
|
|
args = parser.parse_args() |
|
|
|
|
|
torch.manual_seed(1337) |
|
|
device_type = "cuda" if "cuda" in args.device else "cpu" |
|
|
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() |
|
|
ddp = int(os.environ.get("RANK", -1)) != -1 |
|
|
ddp_local_rank, DEVICE = 0, "cuda:0" |
|
|
if ddp: |
|
|
init_distributed_mode() |
|
|
args.device = torch.device(DEVICE) |
|
|
Logger(f'loading model') |
|
|
|
|
|
|
|
|
max_seq_len = args.region*4+5 |
|
|
|
|
|
args.save_dir = os.path.join(args.out_dir) |
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
|
os.makedirs(args.out_dir, exist_ok=True) |
|
|
tokens_per_iter = args.batch_size * max_seq_len |
|
|
|
|
|
|
|
|
|
|
|
teacher_model,tokenizer = init_teacher_model(args,Logger) |
|
|
|
|
|
model, tokenizer,lm_config_student = init_student_model() |
|
|
|
|
|
if args.debug:args.limit=320 |
|
|
|
|
|
dataset = RNADataset(args.ffasta, region=args.region,tokenizer=tokenizer,limit=args.limit) |
|
|
|
|
|
train_size = int(0.99 * len(dataset)) |
|
|
val_size = len(dataset) - train_size |
|
|
train_dataset, val_dataset = random_split(dataset, [train_size, val_size]) |
|
|
|
|
|
train_sampler = DistributedSampler(train_dataset) if ddp else None |
|
|
Logger(f'loading {train_size} training samples from',args.ffasta) |
|
|
Logger(f'loading {val_size} validation samples from',args.ffasta) |
|
|
|
|
|
train_loader = DataLoader( |
|
|
train_dataset, |
|
|
batch_size=args.batch_size, |
|
|
pin_memory=True, |
|
|
drop_last=False, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
sampler=train_sampler |
|
|
) |
|
|
|
|
|
val_loader = DataLoader( |
|
|
val_dataset, |
|
|
batch_size=args.batch_size, |
|
|
pin_memory=True, |
|
|
drop_last=False, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers |
|
|
) |
|
|
|
|
|
args.wandb_run_name = f"{args.wandb_project}-EP-{args.epochs}-BS-{args.batch_size}-LR-{args.learning_rate}-{len(train_loader.dataset)/ 1e6:.0f}" |
|
|
if args.use_wandb and (not ddp or ddp_local_rank == 0): |
|
|
import wandb |
|
|
wandb.init(project=args.wandb_project, name=args.wandb_run_name,mode="offline") |
|
|
else: |
|
|
wandb = None |
|
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) |
|
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) |
|
|
|
|
|
if ddp: |
|
|
model._ddp_params_and_buffers_to_ignore = {"pos_cis"} |
|
|
model = DistributedDataParallel(model, device_ids=[ddp_local_rank]) |
|
|
Logger('local_rank',args.local_rank,'ddp',ddp,'ddp_local_rank',ddp_local_rank) |
|
|
|
|
|
iter_per_epoch = len(train_loader) |
|
|
epoch = 0 |
|
|
|
|
|
|
|
|
moe_path = '_moe' if lm_config_student.use_moe else '' |
|
|
ckp = f'{args.save_dir}/full_dist_{lm_config_student.dim}{moe_path}_epoch.pth' |
|
|
Logger(f'save model to', os.path.abspath(ckp)) |
|
|
early_stopping = EarlyStopping(patience=5, verbose=True, path=ckp) |
|
|
for epoch in range(args.epochs): |
|
|
if ddp: train_loader.sampler.set_epoch(epoch) |
|
|
|
|
|
train_epoch(epoch, wandb,alpha=args.celoss_alpha) |
|
|
current_loss = val_epoch(epoch, wandb, alpha=args.celoss_alpha) |
|
|
if ddp: |
|
|
|
|
|
if ddp_local_rank == 0: |
|
|
early_stopping(current_loss, model) |
|
|
|
|
|
to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.float32, device=args.device) |
|
|
dist.broadcast(to_broadcast, 0) |
|
|
else: |
|
|
|
|
|
to_broadcast = torch.tensor([False], dtype=torch.float32, device=args.device) |
|
|
dist.broadcast(to_broadcast, 0) |
|
|
early_stopping.early_stop = bool(to_broadcast.item()) |
|
|
else: |
|
|
|
|
|
early_stopping(current_loss, model) |
|
|
if early_stopping.early_stop:break |
|
|
|
|
|
print('the end') |
|
|
|