maotao / train.py
julse's picture
Upload train.py
0f6eb07 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Title : ray_search.py
project : minimind_RiboUTR
Created by: julse
Created on: 2025/9/15 16:58
des: https://docs.ray.io/en/latest/tune/index.html
"""
import random
from collections import defaultdict
import sys
import os
import time
import pandas as pd
import numpy as np
from model.codon_tables import AA_str
from model.model_exp import MiniMindLM_Maotao
from utils.ernie_rna.dataset import MaotaoDataset
import sys
import os
import pandas as pd
from fairseq.data import Dictionary
username = os.environ['HOME']
import argparse
import time
import math
import warnings
# 获取脚本所在的目录
script_dir = os.path.dirname(os.path.abspath(__file__))
# 切换到脚本所在的目录
os.chdir(script_dir)
import torch
import torch.nn.functional as F
import torch.distributed as dist
from torch import optim, nn
from torch.nn.parallel import DistributedDataParallel
from torch.utils.data import DataLoader, DistributedSampler, random_split
from contextlib import nullcontext
import torch.nn.functional as F
from model.model_ribo import MiniMindLM
from model.LMConfig import LMConfig, LMaoTaoConfig
from model.tools import EarlyStopping, get_pretraining_args, init_config, ddp_broadcast_early_stopping,compute_metrics_dict
from utils.ernie_rna.dataset_dst import RNADataset
from src.utils import load_pretrained_ernierna
warnings.filterwarnings('ignore')
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
print('Setting running environment')
def copy_current_code(path):
if not ddp or ddp_local_rank == 0:
os.makedirs(path, exist_ok=True)
os.system(f'cp -r {script_dir}/*.py {path}')
os.system(f'cp -r {script_dir}/model {path}')
os.system(f'cp -r {script_dir}/utils {path}')
with open(os.path.join(path, 'run.sh'), 'w') as f:
gpu_count = torch.cuda.device_count()
f.write('#!/bin/bash\n')
cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', 'NOT_SET')
f.write(f'# CUDA_VISIBLE_DEVICES={cuda_visible}\n')
f.write(f'# RANK={os.environ.get("RANK", -1)}\n')
f.write(f'# LOCAL_RANK={os.environ.get("LOCAL_RANK", -1)}\n')
f.write(f'# WORLD_SIZE={os.environ.get("WORLD_SIZE", -1)}\n')
f.write('cd '+os.path.abspath(os.path.dirname(os.path.abspath(__file__)))+'\n')
f.write(' \\\n'.join([str(sys.executable)]+sys.argv))
def Logger(*content):
if not ddp or dist.get_rank() == 0:
print(*content)
def init_distributed_mode(ddp=True):
print("init distributed mode,ddp=",ddp)
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
print('init distributed mode, ddp_rank:', ddp_rank, 'ddp_local_rank:', ddp_local_rank, 'ddp_world_size:', ddp_world_size)
return ddp_local_rank,DEVICE
def all_gather(tensor, world_size):
tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)]
torch.distributed.all_gather(tensor_list, tensor)
return torch.cat(tensor_list)
class DotDefaultDict(defaultdict):
def __getattr__(self, name):
if name in self:
return self[name]
# 保留 defaultdict 的默认值生成特性
return super().__getattribute__(name)
__setattr__ = defaultdict.__setitem__
__delattr__ = defaultdict.__delitem__
def metric_monitor(all_preds, all_labels, epoch, prefix,start_time,ddp=None,wandb=None,loss_fct=None,Logger=None,cls='regression'):
# 合并当前进程的结果
if isinstance(all_preds, list):
all_preds = torch.cat(all_preds)
all_labels = torch.cat(all_labels)
'''数据在dataloader上就处理好了,TR分布到不同卡上,VL和TS在每张卡都跑'''
# # 跨进程收集数据
# Logger('before 跨进程收集数据,', len(all_labels),prefix,ddp)
# if ddp: #
# world_size = dist.get_world_size()
# all_preds = all_gather(all_preds, world_size)
# all_labels = all_gather(all_labels, world_size)
# Logger('after 跨进程收集数据,',len(all_labels))
loss_mask = all_labels != 1
all_preds = all_preds[loss_mask]
all_labels = all_labels[loss_mask]
epoch_loss = loss_fct(all_preds, all_labels)
# print(epoch_loss)
all_preds = all_preds.cpu().numpy()
all_labels = all_labels.cpu().numpy()
epoch_loss = epoch_loss.cpu().item()
ans = compute_metrics_dict(all_preds, all_labels,cls=cls)
# ans = compute_metrics_dict(np.array(all_preds), np.array(all_labels),cls='binary')
wandb_ans = dict(zip([f"{prefix}_{k}" for k in ans.keys()], ans.values()))
wandb_ans[f"{prefix}_epoch_loss"] = epoch_loss
if wandb:
wandb.log(wandb_ans)
# Logger(f"Epoch {epoch} - {prefix} Loss: {epoch_loss:.4f} "+', '.join([f"{k}: {v:.4f}" for k,v in ans.items()]))
# Logger(f"{prefix} time: {time.time() - start_time:.2f}s")
return wandb_ans
class maotao():
def step_loss(self,target_nn=None,tgt_te=None,
masked_logits_list=None,nn_prob=None,
res=None,loss_fct=None,loss_mse=None,args=None):
loss_mask = target_nn != 1
tgt_te = tgt_te.view(-1)
te_loss = loss_mse(res.te, tgt_te)
# res.logits = res.logits + nn_prob + masked_logits # # frame 1,2 is the best cai
loss = torch.tensor(0, dtype=torch.float32, device=args.device)
for i in range(masked_logits_list.size(1)):
masked_logits = masked_logits_list[:, i, ...]
loss += self.calculate_loss(res.logits + nn_prob + masked_logits, loss_mask, target_nn,
loss_fct)
loss += self.calculate_loss(res.logits, loss_mask, target_nn,
loss_fct)
# res.logits = F.softmax(res.logits, dim=-1) # 数值稳定的,因为它内部使用了数学等价但数值更稳定的实现方式
# cds_start_region_loss, te_start_region_loss, cds_end_region_loss, te_end_region_loss = 0, 0, 0, 0
# loss = sum_loss_model(ans)
loss += loss + te_loss + res.aux_loss
return loss
def forward_step(self,model=None,src_data=None,twod_data=None,aa_idx=None,continuous_features=None,species_features=None,truncated_features=None,target_nn=None,tgt_te=None,masked_logits_list=None,nn_prob=None,loss_fct=None,loss_mse=None,args=None):
res = model(input_ids=src_data,
twod_tokens=twod_data,
aa_idx=aa_idx,
continuous_features=continuous_features,
species_features=species_features,
truncated_features=truncated_features,
# targets_nn=target_nn,targets_te=tgt_te
)
# find_unused_parameters(model, res.te)
res.te = res.te.view(-1)
nn_prob = torch.masked_fill(nn_prob, mask=nn_prob == 0, value=float('-inf'))
masked_logits = masked_logits_list[:,-1,...]
res.logits = res.logits + nn_prob
ans = dict()
step_loss = self.step_loss(target_nn=target_nn,tgt_te=tgt_te,
masked_logits_list=masked_logits_list,nn_prob=nn_prob,
res=res,loss_fct=loss_fct,loss_mse=loss_mse,args=args)
ans.update({'loss':step_loss})
res.logits = res.logits + masked_logits
ans.update({'res':res})
return ans
def train_epoch(self,model=None,wandb=None,ddp=None,train_loader=None,optimizer=None,
epoch=None,prefix="TR",loss_fct=None,loss_mse=None,args=None,scaler=None,Logger=None,lr_scheduler=None):
model.train()
epoch_loss = 0
# all_preds, all_labels = [], []
all_preds, all_labels = defaultdict(list), defaultdict(list)
start_time = time.time()
ans = {}
for step,(src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, \
target_nn, target, masked_logits_list,nn_prob,maotao_id) in enumerate(train_loader):
src_data = src_data.to(args.device)
twod_data = twod_data.to(args.device)
aa_idx = aa_idx.to(args.device)
continuous_features = continuous_features.to(args.device)
species_features = species_features.to(args.device)
truncated_features = truncated_features.to(args.device)
masked_logits_list = masked_logits_list.to(args.device) # [12, 1200, 32]
nn_prob = nn_prob.to(args.device) # [12, 1200, 32]
target_nn = target_nn.to(args.device)
tgt_te = target.to(args.device)
with torch.cuda.amp.autocast(enabled=scaler.is_enabled()): # torch.cuda.amp 只能在 CUDA 设备上使用[
results = self.forward_step(model=model,src_data=src_data,
twod_data=twod_data,aa_idx=aa_idx,
continuous_features=continuous_features,
species_features=species_features,
truncated_features=truncated_features,
target_nn=target_nn,tgt_te=tgt_te,
masked_logits_list=masked_logits_list,
nn_prob=nn_prob,loss_fct=loss_fct,
loss_mse=loss_mse,args=args)
res, loss = results['res'], results['loss']
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
# Logger('Epoch:[{}/{}][step:{}/{}] loss:{:.4f} lr:{:.6f}'.format(epoch, args.epochs, step, len(train_loader), loss.item(), lr))
Logger('Epoch:[{}/{}][batch:{}/{}] loss:{:.4f}'.format(epoch, args.epochs, step, len(train_loader), loss.item()))
epoch_loss += loss.item()
# masked_logits = masked_logits_list[:, 2, ...]
# res.logits = res.logits +nn_prob+masked_logits
all_preds['logits'].append(res.logits.detach())
all_preds['te'].append(res.te.detach())
all_preds['aux_loss'].append(torch.tensor(res.aux_loss, dtype=torch.float32).reshape(1, 1).to(args.device))
all_labels['logits'].append(target_nn.detach())
all_labels['te'].append(tgt_te.detach())
if step%100==1:
ans = {}
# ans = metric_monitor(all_preds['logits'], all_labels['logits'], epoch, prefix+'_logits',start_time,ddp=ddp,wandb=wandb,loss_fct=loss_fct,Logger=Logger,cls='binary')
ans.update(metric_monitor(all_preds['logits'], all_labels['logits'], epoch, prefix+'_codon',start_time,ddp=ddp,wandb=wandb,loss_fct=loss_fct,Logger=Logger,cls='identity'))
# ans.update(metric_monitor(all_preds['te'], all_labels['te'], epoch, prefix+'_cai',start_time,ddp=ddp,wandb=wandb,loss_fct=loss_mse,Logger=Logger,cls='regression'))
ans.update({f'{prefix}_loss':epoch_loss})
all_preds, all_labels = defaultdict(list), defaultdict(list)
epoch_loss = epoch_loss / len(train_loader) if len(train_loader) > 0 else 0
Logger('\n'+'#'*10+f' {prefix}_loss: {epoch_loss:.4f} '+'#'*10+'\n')
return ans
def Maotao_DataLoader(self,file_path, args, tokenizer, data_tag='TR', ddp=None, Logger=None, returnid=None):
if args.debug: args.limit = 320
train_ds = MaotaoDataset(file_path, tokenizer, args=args, limit=args.limit, seq_len=args.seq_len,returnid=returnid,codon_table_path=args.codon_table_path)
train_sampler = DistributedSampler(train_ds) if ddp else None
drop_last = True if ddp and data_tag == 'TR' else False
if data_tag == 'TR':
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=drop_last, # 以避免各卡处理的批次数量不同。
shuffle=False,
num_workers=args.num_workers,
sampler=train_sampler # 验证集不需要
)
else:
train_loader = DataLoader(
train_ds,
batch_size=args.batch_size,
pin_memory=True,
drop_last=drop_last, # 以避免各卡处理的批次数量不同。
shuffle=False,
num_workers=args.num_workers
)
Logger(f'loading data from ', file_path, args.seq_len, args.column, args.label, len(train_loader.dataset))
# 验证集(VL)没有使用 DistributedSampler,导致所有GPU都处理完整的验证集,而不是分布到不同卡上。
return train_loader
def load_data(self,args=None,task='sft',tokenizer=None,ddp=None,Logger=None):
# csv_path = os.path.join(os.path.dirname(__file__),args.downstream_data_path, task, "{}.csv")
csv_path = os.path.join(args.downstream_data_path, task, "{}.csv")
check_file_flag = [os.access(csv_path.format(tag),os.F_OK) for tag in ['TR', 'VL', 'TS']]
loaders = []
for flag,tag in zip(check_file_flag,['TR', 'VL', 'TS']):
if flag:
train_loader = self.Maotao_DataLoader(os.path.join(csv_path.format(tag)), args,
tokenizer, data_tag=tag, ddp=ddp,
Logger=Logger) # --wandb_project=IRES_circle
else:
print(f'Warning: {tag} data not found, skip it.',csv_path.format(tag))
train_loader = None
loaders.append(train_loader)
assert check_file_flag[-1],f'请检查数据集路径是否正确,至少需要一个TS数据集,{os.path.abspath(csv_path)},{csv_path.format("TS")}'
return loaders
def calculate_loss(self,logits, loss_mask, tgt_data, loss_ce):
# print('tgt_data',tgt_data.shape) # [5, 30]
# print('res.logits',res.logits.shape) # [5, 30, 10]
# print('loss_mask',loss_mask.shape) # [5, 30]
loss_mask = loss_mask == 1
# print('torch.masked_select(res.logits, loss_mask==1).view(-1, res.logits.size(-1))',torch.masked_select(res.logits, loss_mask==1).view(-1, res.logits.size(-1)).shape)
# print('torch.masked_select(tgt_data, loss_mask==1).view(-1)',torch.masked_select(tgt_data, loss_mask==1).view(-1).shape.shape)
# logits = torch.softmax(logits,dim=-1)
# logits = F.softmax(logits, dim=-1) # 数值稳定的,因为它内部使用了数学等价但数值更稳定的实现方式
loss = loss_ce(
torch.masked_select(logits, loss_mask.unsqueeze(-1).repeat(1, 1, logits.size(-1))).view(-1,logits.size(
-1)),
# [150, 10]
torch.masked_select(tgt_data, loss_mask).view(-1)
).mean() # 2.5530
return loss
def evaluate(self,model, data_loader,stage='epoch', prefix="VL",args=None,loss_fct=None,
loss_mse=None,wandb=None,ddp=None,ctx=None,Logger=None,fpred_out=None,tokenizer=None):
start_time = time.time()
epoch_loss = 0
model.eval()
# all_preds, all_labels = [], []
all_preds, all_labels = defaultdict(list), defaultdict(list)
all_num = args.batch_size * len(data_loader)
with torch.no_grad():
for idx,(src_data, twod_data, aa_idx, continuous_features, species_features, truncated_features, \
target_nn, target, masked_logits_list, nn_prob, maotao_id) in enumerate(data_loader):
src_data = src_data.to(args.device)
twod_data = twod_data.to(args.device)
aa_idx = aa_idx.to(args.device)
continuous_features = continuous_features.to(args.device)
species_features = species_features.to(args.device)
truncated_features = truncated_features.to(args.device)
masked_logits_list = masked_logits_list.to(args.device) # [12, 1200, 32]
# masked_logits = masked_logits.to(args.device) # [12, 1200, 32]
nn_prob = nn_prob.to(args.device) # [12, 1200, 32]
target_nn = target_nn.to(args.device)
tgt_te = target.to(args.device)
results = self.forward_step(model=model,src_data=src_data,
twod_data=twod_data,aa_idx=aa_idx,
continuous_features=continuous_features,
species_features=species_features,
truncated_features=truncated_features,
target_nn=target_nn,tgt_te=tgt_te,
masked_logits_list=masked_logits_list,
nn_prob=nn_prob,loss_fct=loss_fct,
loss_mse=loss_mse,args=args)
res, loss = results['res'], results['loss']
epoch_loss += loss.item()
all_preds['logits'].append(res.logits.detach())
all_preds['te'].append(res.te.reshape(-1).detach())
all_preds['aux_loss'].append(torch.tensor(res.aux_loss, dtype=torch.float32).reshape(1, 1).to(args.device))
all_labels['logits'].append(target_nn)
all_labels['te'].append(tgt_te.reshape(-1))
if args.predict and (not ddp or dist.get_rank() == 0):
pred_logis = res.logits.detach()
pred_te = res.te.reshape(-1).detach()
for idj,(_id, logits, nn, te) in enumerate(zip(maotao_id, pred_logis, target_nn, pred_te)):
pred_cai = te.item()
# logits.argmax(1) # 只有一种结果,改成不同seed不同结果
temperature = 0.8
probs = torch.softmax(logits / temperature, dim=-1)
probs = torch.nan_to_num(probs, nan=1e-9)
tokens = torch.multinomial(probs, num_samples=1)
tokens = tokens.squeeze(-1)
tokens_hard = logits.argmax(1)
pred_nn = ''.join([tokenizer.symbols[x] for x,y in zip(tokens.cpu().numpy(), tokens_hard.cpu().numpy()) if y]).replace(
'U', 'T')
ans = metric_monitor(logits, nn, stage, '' + '',
start_time, ddp=ddp,
wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity')
df = pd.DataFrame([ans])
df['maotao_id'] = _id
df['pred_cai'] = pred_cai
df['pred_nn'] = pred_nn
if os.access(fpred_out, os.F_OK):
df.to_csv(fpred_out, mode='a', header=False, index=False)
else:
df.to_csv(fpred_out, mode='w', header=True, index=False)
Logger(f'{args.batch_size*idx+idj}/{data_loader.dataset.__len__()},fpred_out:{fpred_out}')
epoch_loss /=len(data_loader)
if (wandb is not None) and (not ddp or dist.get_rank() == 0):
wandb.log({f'{prefix}_epoch_loss': epoch_loss})
ans = metric_monitor(all_preds['logits'], all_labels['logits'], stage, prefix + '_logits', start_time, ddp=ddp,
wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='binary')
ans.update(
metric_monitor(all_preds['logits'], all_labels['logits'], stage, prefix + '_codon', start_time, ddp=ddp,
wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity'))
if not args.predict:ans.update(
metric_monitor(all_preds['te'], all_labels['te'], stage, prefix + '_cai', start_time, ddp=ddp, wandb=wandb,
loss_fct=loss_mse, Logger=Logger, cls='regression'))
ans.update({f'{prefix}_loss': epoch_loss})
wandb_ans = ans
if not ddp or dist.get_rank() == 0:
if wandb:
wandb.log(wandb_ans)
# Logger(
# f"Metrics - {prefix} " + ', '.join([f"{k}: {v:.4f}" for k, v in wandb_ans.items()]))
return wandb_ans
# def predict(self,model, data_loader,stage='epoch', prefix="TS",args=None,loss_fct=None,loss_mse=None,wandb=None,ddp=None,ctx=None,Logger=None,fpred_out=None):
# start_time = time.time()
# model.eval()
# with torch.no_grad():
# for src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, \
# target_nn, target, masked_logits_list,nn_prob,maotao_id in data_loader:
# src_data = src_data.to(args.device)
# twod_data = twod_data.to(args.device)
# aa_idx = aa_idx.to(args.device)
# continuous_features = continuous_features.to(args.device)
# species_features = species_features.to(args.device)
# truncated_features = truncated_features.to(args.device)
#
# masked_logits_list = masked_logits_list.to(args.device) # [12, 1200, 32]
# # masked_logits = masked_logits.to(args.device) # [12, 1200, 32]
# nn_prob = nn_prob.to(args.device) # [12, 1200, 32]
#
# results = self.forward_step(model=model,src_data=src_data,
# twod_data=twod_data,aa_idx=aa_idx,
# continuous_features=continuous_features,
# species_features=species_features,
# truncated_features=truncated_features,
# target_nn=None,tgt_te=None,
# masked_logits_list=masked_logits_list,
# nn_prob=nn_prob,loss_fct=loss_fct,
# loss_mse=loss_mse,args=args)
#
# res = results['res']
# # fpred_out = args.out_dir + f'/{prefix}_pred.csv'
# if not ddp or dist.get_rank() == 0:
# pred_logis = res.logits.detach()
# pred_te = res.te.reshape(-1).detach()
# for _id,logits,nn,te in zip(maotao_id,pred_logis,target_nn,pred_te):
# pred_cai = te.item()
# pred_nn = ''.join([tokenizer.symbols[x] for x in logits.argmax(1).cpu().numpy() if x]).replace('U','T')
# ans = metric_monitor(logits, nn, stage, '' + '',
# start_time, ddp=ddp,
# wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity')
# df = pd.DataFrame([ans])
# df['maotao_id'] = _id
# df['pred_cai'] = pred_cai
# df['pred_nn'] = pred_nn
# if os.access(fpred_out, os.F_OK):
# df.to_csv(fpred_out, mode='a', header=False, index=False)
# else:
# df.to_csv(fpred_out, mode='w', header=True, index=False)
#
# Logger(fpred_out)
# return pred_results
def init_model(args,ckp = f'./out/full_dist_256_epoch.pth',lm_config=None,tokenizer=None,Logger=None,require_ckp=False):
if args.debug:lm_config.n_layers = 1
model = MiniMindLM_Maotao(lm_config)
print(model)
print(lm_config)
print('ckp=',ckp)
if ckp is not None and os.access(ckp, os.F_OK) and os.path.getsize(ckp) > 0:
print('loading model from', ckp)
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
print(f'finetune ({args.finetune}) from, {os.path.abspath(ckp)}')
else:
if require_ckp:
print('ckp')
exit('not found model'+ckp)
# exit('not found model,'+ckp)
# print('learning from scratch')
if args.finetune:
for name, value in model.named_parameters():
if 'layers' in name:
value.requires_grad = False
print(name, value.numel(), value.requires_grad)
else:
for name, value in model.named_parameters():
print(name, value.numel(), value.requires_grad)
print(
f'LLM参数量 训练/总计:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万 / {sum(p.numel() for p in model.parameters()) / 1e6:.3f} 百万')
print(
f'LLM参数量 训练/总计:{sum(p.numel() for p in model.parameters() if p.requires_grad)} / {sum(p.numel() for p in model.parameters())}')
model = model.to(args.device)
return model, tokenizer,lm_config
def save_metrics(ans, epoch, out_dir, filename='history_metrics.csv'):
if not ddp or ddp_local_rank == 0:
df = pd.DataFrame([ans])
df['epoch'] = epoch
df['path'] = out_dir
if epoch == 0:
df.to_csv(os.path.join(out_dir, filename), mode='w')
elif epoch == 'TS':
df = df.T
df.to_csv(os.path.join(out_dir, filename), mode='w')
else:
df.to_csv(os.path.join(out_dir, filename), mode='a', header=False)
def sft_process_maotao(max_seq_len=-1,ctx=None,ddp=False,ddp_local_rank=0,args=None,ckp=None,out_ckp=None,lm_config=None,
tokenizer=None,Logger=None,task=None,seq_pkl_path=None,sft=None,require_ckp=False):
print('3. fine-tune on downstream tasks...')
print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
start = time.time()
# if args.predict:ckp = out_ckp
# model, tokenizer = sft.init_model(args=args,ddp=ddp,ddp_local_rank=ddp_local_rank,Logger=Logger,ckp=ckp,lm_config=lm_config,tokenizer=tokenizer)
model, tokenizer, lm_config_student = init_model(args,ckp=ckp,lm_config=lm_config,tokenizer=tokenizer,Logger=Logger,require_ckp=require_ckp)
if ddp:
model._ddp_params_and_buffers_to_ignore = {"pos_cis"}
model = DistributedDataParallel(model, device_ids=[ddp_local_rank],find_unused_parameters=True)
train_loader, val_loader, test_loader = sft.load_data(args=args,tokenizer=tokenizer,
task=task,ddp=ddp,Logger=Logger)
if train_loader:
args.wandb_run_name = f"{args.wandb_project}_EP_{args.epochs}_BS_{args.batch_size}_LR_{args.learning_rate}_FT_{args.finetune}_TR_{len(train_loader.dataset)}"
else:
args.wandb_run_name = f"{args.wandb_project}_EP_{args.epochs}_BS_{args.batch_size}_LR_{args.learning_rate}_FT_{args.finetune}_TR_{0}"
if args.use_wandb and (not ddp or ddp_local_rank == 0):
import wandb
wandb.init(project=args.wandb_project, name=args.wandb_run_name, mode="offline",config=args,anonymous="allow")
Logger(f'init wandb with id {wandb.run.id}')
else:
wandb = None
loss_ce = nn.CrossEntropyLoss()
loss_mse = nn.MSELoss()
if args.predict:
fpred_out = args.out_dir + f'/{args.task}/TS_pred.csv'
os.system(f'rm -f {fpred_out}')
os.makedirs(os.path.dirname(fpred_out), exist_ok=True)
Logger(f'predicting {fpred_out} with {os.path.abspath(ckp)}')
final_metrics = sft.evaluate(model, test_loader, stage='predict', prefix='TS', args=args, loss_fct=loss_ce,loss_mse=loss_mse, wandb=wandb, ddp=ddp,
ctx=ctx, Logger=Logger,fpred_out=fpred_out,tokenizer=tokenizer)
return ckp, final_metrics, -1
final_metrics = None
if not ddp or ddp_local_rank == 0:
Logger('predict and saving to file...',os.path.join(args.out_dir, f"zeroshot_metrics.csv"))
for tag,loader in zip(['TR','VL','TS'],[train_loader,val_loader,test_loader]):
if tag=='TR':continue
if loader is not None:
fpred_out = args.out_dir + f'/{prefix}_pred.csv'
os.system(f'rm -f {fpred_out}')
final_metrics = sft.evaluate(model, loader, stage='predict',prefix=tag,args=args,loss_fct=loss_fct,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,fpred_out=fpred_out,tokenizer=tokenizer)
save_metrics(final_metrics,tag,args.out_dir,filename=f"zeroshot_metrics.csv")
# final_metrics = [f'{k}_{tag}: {v:.4f}' for k,v in final_metrics.items()]
# with open(os.path.join(args.out_dir, f"{args.wandb_run_name}_zeroshot_metrics.csv"), 'a') as f:
# f.write(','.join([f'{e:.4f}' for e in final_metrics])+'\n')
return ckp, final_metrics, -1
# 最终测试
# Logger('first evaluation on TS')
# if not ddp or ddp_local_rank == 0:
# df = pd.DataFrame(final_metrics).T
# df.to_csv(os.path.join(args.out_dir, f"{args.wandb_run_name}_zeroshot_metrics.csv"),index=False)
with open(os.path.join(args.out_dir, f"{args.wandb_run_name}.params"), 'w') as f:
f.write(f'#args={args}' + '\n')
f.write(f'#model={model}' + '\n')
f.write(f'#lm_config={lm_config}' + '\n')
f.write(f'#tokenizer={tokenizer.indices}' + '\n')
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16']))
# lr_scheduler = LearningRateScheduler(base_lr=args.learning_rate, total_epochs=args.epochs, total_steps_per_epoch=len(train_loader))
lr_scheduler = None
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate)
if out_ckp is not None:
ckp = out_ckp
else:
ckp = f'{args.save_dir}/{args.wandb_run_name}.pth'
# 新增:初始化Early Stopper(示例:监控验证集MSE,越小越好)
early_stopping = EarlyStopping(patience=3, verbose=True, path=ckp)
early_stopping.save_model(model, ckp)
# 主训练循环
epoch = 0
for epoch in range(args.epochs):
if ddp: train_loader.sampler.set_epoch(epoch)
if epoch ==0:
Logger(f'iterating over {len(train_loader)} batches per epoch, {len(train_loader.dataset)} samples per epoch, batch_size={train_loader.batch_size}, gpu_num={torch.cuda.device_count()}')
# 训练一个epoch
# current_epoch: int, current_step: int, warmup_epochs: int = 2
ans = sft.train_epoch(model=model,wandb=wandb,
train_loader=train_loader, prefix="TR",
optimizer=optimizer, ddp=ddp,epoch=epoch,
loss_fct=loss_ce,loss_mse=loss_mse,
scaler=scaler,args=args,Logger=Logger,lr_scheduler=lr_scheduler)
ans_vl = sft.evaluate(model, val_loader, stage=f'{epoch}',prefix="VL",args=args,loss_fct=loss_ce,loss_mse=loss_mse,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,tokenizer=tokenizer)
ans.update(ans_vl)
val_mse = ans['VL_logits_epoch_loss']
save_metrics(ans,epoch,args.out_dir,filename=f"{args.wandb_run_name}_history_metrics.csv")
#val_spr, val_pr, val_mse, val_rmse,val_r2
if ddp:
# 分布式训练逻辑
if ddp_local_rank == 0:
early_stopping(val_mse, model) # 如果监控的是SPR,直接传入-SPR即可
# 广播 should_stop 的值到其他进程
to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.bool, device=args.device)
dist.broadcast(to_broadcast, 0)
else:
# 非主进程等待主进程广播
# print('非主进程等待主进程广播')
to_broadcast = torch.tensor([False], dtype=torch.bool, device=args.device)
dist.broadcast(to_broadcast, 0)
early_stopping.early_stop = bool(to_broadcast.item())
else:
# 单机单卡训练逻辑
early_stopping(val_mse, model) # 如果监控的是SPR,直接传入-SPR即可
if early_stopping.early_stop:break
# 恢复最佳模型(可选)
if os.access(ckp, os.F_OK) and os.path.getsize(ckp) > 0: # epoch ==0 的时候不会保存模型
state_dict = torch.load(ckp, map_location=args.device)
model.load_state_dict(state_dict, strict=False)
Logger("Loaded best model for final evaluation.")
else:early_stopping.save_model(model,ckp)
final_metrics = sft.evaluate(model, test_loader,stage='final', prefix="TS",args=args,loss_fct=loss_ce,loss_mse=loss_mse,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,tokenizer=tokenizer)
save_metrics(final_metrics, 'TS', args.out_dir, filename=f"{args.wandb_run_name}_final_metrics.csv")
# 最终测试
Logger('final evaluation on TS')
# if ddp and ddp_local_rank == 0: dist.destroy_process_group()
if args.use_wandb and (not ddp or ddp_local_rank == 0):wandb.finish()
Logger(f'the end of experiment stf process: time = {(time.time() - start) // 60:.0f} min')
return ckp,final_metrics,epoch
def trainable(config):
# 从config中获取超参数
sft = maotao()
# 假设其他参数固定不变或从args中获取
# ctx, ddp, ddp_local_rank, args, in_ckp, out_ckp, lm_config, tokenizer, Logger, task, seq_pkl_path = get_fixed_params()
args.batch_size = config["batch_size"]
args.learning_rate = config["learning_rate"]
max_seq_len = args.max_seq_len
# 调用sft_process函数
# ckp, (spr, pr, mse, rmse, r2), _ = sft_process(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp,
# ddp_local_rank=ddp_local_rank,
# args=args, ckp=in_ckp, out_ckp=out_ckp,
# lm_config=lm_config, tokenizer=tokenizer, Logger=Logger,
# task=task, seq_pkl_path=seq_pkl_path)
ckp, final_metrics, _ = sft_process_maotao(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp,
ddp_local_rank=ddp_local_rank,
args=args, ckp=in_ckp, out_ckp=out_ckp,
lm_config=lm_config, tokenizer=tokenizer, Logger=Logger,task=task,
sft=sft)
# # 报告结果给Ray Tune
# tune.report(accuracy=spr)
return {"identity": final_metrics['TS_codon_identity_codon']}
# 这里需要你根据实际情况定义get_fixed_params函数,该函数用于获取不通过Ray Tune调优的参数
def test_ray_tune():
Logger('hello word')
from ray import tune
def objective(config): # ①
score = config["a"] ** 2 + config["b"]
return {"score": score}
search_space = { # ②
"a": tune.grid_search([0.001, 0.01, 0.1, 1.0,0.4]),
"b": tune.choice([1, 2, 3]),
}
tuner = tune.Tuner(objective, param_space=search_space,
# run_config=ray.air.RunConfig(
# storage_path="./ray_results/tune_results", # 存储路径
# name="my_experiment" # 实验名称
# )
) # ③
results = tuner.fit()
Logger(results.get_best_result(metric="score", mode="max").config)
def init_config(vocab_path,n_layers,max_seq_len):
tokenizer = Dictionary.load(vocab_path)
tokenizer.mask_index = tokenizer.add_symbol('<mask>') # ['<s>', '<pad>', '</s>', '<unk>', 'G', 'A', 'U', 'C', 'N', '<mask>']
tokenizer.indices['T'] = tokenizer.indices['U']
tokenizer.indices['_'] = tokenizer.pad_index
lm_config = LMaoTaoConfig(dim=256, logit_dim=len(tokenizer),n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) # n_layers 8, <s> <unk><unk><unk> </s>
[tokenizer.add_symbol(word) for word in AA_str] # 10-31
return lm_config,tokenizer
def seed_everything(seed=2022):
print('seed_everything to ',seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed) # 程序每次运行结果一致,但是程序中多次生成随机数每次不一致 # https://blog.csdn.net/qq_42951560/article/details/112174334
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False # minbatch的长度一直在变化,这个优化比较浪费时间
if __name__ == '__main__':
print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
start = time.time()
'''
test
'''
# test_ray_tune()
# import train_full_sft as sft
sft = maotao()
parser = get_pretraining_args()
args = parser.parse_args()
args.downstream_data_path = 'maotao_file/'
args.seed = int(args.seed)
# torch.manual_seed(args.seed)
# torch.manual_seed(1337)
seed_everything(seed=args.seed)
if args.predict:
task= args.task
else:
task = 'AA2CDS_data'
device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp:
print('init distributed mode')
ddp_local_rank, DEVICE = init_distributed_mode(ddp=ddp)
args.device = torch.device(DEVICE)
Logger('args.device:',args.device)
Logger('setting args',args)
max_seq_len = 1200
args.seq_len = max_seq_len
args.save_dir = os.path.join(args.out_dir)
# os.system(f"rm -rf {args.save_dir}") # todo
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * max_seq_len
lm_config,tokenizer = init_config(args.arg_overrides['data'] + '/small_dict.txt', args.n_layers, max_seq_len)
lm_config.use_moe = args.use_moe
wandb_project = args.wandb_project
'''3. benchmark downstream tasks'''
prefix = 'TS'
# with open(args.save_dir+'/benchmark_result.tsv','w') as f:
# f.write('Project\tModel\tTask\tSPR\tPR\tMSE\tRMSE\tR2\tckp\tepoch\n')
epochs = args.epochs
args.out_dir = os.path.abspath(args.out_dir)
os.makedirs(args.out_dir, exist_ok=True)
model_dir = args.out_dir # 'exp_log/out_demo4/'
model_dir = os.path.abspath(model_dir)
# model_dir = 'exp_log/out_demo250810/'
data_dir = args.downstream_data_path # 'dataset/downstreamV4/'
data_dir = os.path.abspath(data_dir)
Logger(f'model_dir:{model_dir}')
os.makedirs(model_dir, exist_ok=True)
args.save_dir = os.path.abspath(args.save_dir)
args.downstream_data_path = os.path.abspath(args.downstream_data_path)
# args.codon_table_path = 'maotao_file/codon_table/codon_usage_{species}.csv'
Logger('args.downstream_data_path:', args.downstream_data_path)
out_ckp = args.save_dir + f'/AA2CDS.pth'
out_ckp = os.path.abspath(out_ckp)
# os.system(f"rm -rf {out_ckp}")
in_ckp = model_dir+'/AA2CDS.pth'
# copy_current_code(args.out_dir+'/code/')
search_space = {
# "max_seq_len": 1200,
"batch_size": args.batch_size,
"learning_rate": args.learning_rate,
}
trainable(search_space)
print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
print('time', time.time() - start)