|
|
|
|
|
|
|
|
""" |
|
|
Title : ray_search.py |
|
|
project : minimind_RiboUTR |
|
|
Created by: julse |
|
|
Created on: 2025/9/15 16:58 |
|
|
des: https://docs.ray.io/en/latest/tune/index.html |
|
|
""" |
|
|
import random |
|
|
from collections import defaultdict |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import time |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
from model.codon_tables import AA_str |
|
|
from model.model_exp import MiniMindLM_Maotao |
|
|
|
|
|
from utils.ernie_rna.dataset import MaotaoDataset |
|
|
|
|
|
|
|
|
import sys |
|
|
import os |
|
|
import pandas as pd |
|
|
from fairseq.data import Dictionary |
|
|
|
|
|
|
|
|
username = os.environ['HOME'] |
|
|
import argparse |
|
|
import time |
|
|
import math |
|
|
import warnings |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
script_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
|
|
|
os.chdir(script_dir) |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import torch.distributed as dist |
|
|
from torch import optim, nn |
|
|
from torch.nn.parallel import DistributedDataParallel |
|
|
from torch.utils.data import DataLoader, DistributedSampler, random_split |
|
|
from contextlib import nullcontext |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
from model.model_ribo import MiniMindLM |
|
|
from model.LMConfig import LMConfig, LMaoTaoConfig |
|
|
from model.tools import EarlyStopping, get_pretraining_args, init_config, ddp_broadcast_early_stopping,compute_metrics_dict |
|
|
from utils.ernie_rna.dataset_dst import RNADataset |
|
|
|
|
|
from src.utils import load_pretrained_ernierna |
|
|
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
ddp = int(os.environ.get("RANK", -1)) != -1 |
|
|
print('Setting running environment') |
|
|
|
|
|
def copy_current_code(path): |
|
|
if not ddp or ddp_local_rank == 0: |
|
|
os.makedirs(path, exist_ok=True) |
|
|
os.system(f'cp -r {script_dir}/*.py {path}') |
|
|
os.system(f'cp -r {script_dir}/model {path}') |
|
|
os.system(f'cp -r {script_dir}/utils {path}') |
|
|
with open(os.path.join(path, 'run.sh'), 'w') as f: |
|
|
gpu_count = torch.cuda.device_count() |
|
|
f.write('#!/bin/bash\n') |
|
|
cuda_visible = os.environ.get('CUDA_VISIBLE_DEVICES', 'NOT_SET') |
|
|
f.write(f'# CUDA_VISIBLE_DEVICES={cuda_visible}\n') |
|
|
f.write(f'# RANK={os.environ.get("RANK", -1)}\n') |
|
|
f.write(f'# LOCAL_RANK={os.environ.get("LOCAL_RANK", -1)}\n') |
|
|
f.write(f'# WORLD_SIZE={os.environ.get("WORLD_SIZE", -1)}\n') |
|
|
f.write('cd '+os.path.abspath(os.path.dirname(os.path.abspath(__file__)))+'\n') |
|
|
f.write(' \\\n'.join([str(sys.executable)]+sys.argv)) |
|
|
def Logger(*content): |
|
|
if not ddp or dist.get_rank() == 0: |
|
|
print(*content) |
|
|
def init_distributed_mode(ddp=True): |
|
|
print("init distributed mode,ddp=",ddp) |
|
|
|
|
|
if not ddp: return |
|
|
global ddp_local_rank, DEVICE |
|
|
dist.init_process_group(backend="nccl") |
|
|
ddp_rank = int(os.environ["RANK"]) |
|
|
ddp_local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
ddp_world_size = int(os.environ["WORLD_SIZE"]) |
|
|
DEVICE = f"cuda:{ddp_local_rank}" |
|
|
torch.cuda.set_device(DEVICE) |
|
|
|
|
|
print('init distributed mode, ddp_rank:', ddp_rank, 'ddp_local_rank:', ddp_local_rank, 'ddp_world_size:', ddp_world_size) |
|
|
return ddp_local_rank,DEVICE |
|
|
def all_gather(tensor, world_size): |
|
|
tensor_list = [torch.zeros_like(tensor) for _ in range(world_size)] |
|
|
torch.distributed.all_gather(tensor_list, tensor) |
|
|
return torch.cat(tensor_list) |
|
|
|
|
|
class DotDefaultDict(defaultdict): |
|
|
def __getattr__(self, name): |
|
|
if name in self: |
|
|
return self[name] |
|
|
|
|
|
return super().__getattribute__(name) |
|
|
|
|
|
__setattr__ = defaultdict.__setitem__ |
|
|
__delattr__ = defaultdict.__delitem__ |
|
|
def metric_monitor(all_preds, all_labels, epoch, prefix,start_time,ddp=None,wandb=None,loss_fct=None,Logger=None,cls='regression'): |
|
|
|
|
|
if isinstance(all_preds, list): |
|
|
all_preds = torch.cat(all_preds) |
|
|
all_labels = torch.cat(all_labels) |
|
|
|
|
|
'''数据在dataloader上就处理好了,TR分布到不同卡上,VL和TS在每张卡都跑''' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss_mask = all_labels != 1 |
|
|
all_preds = all_preds[loss_mask] |
|
|
all_labels = all_labels[loss_mask] |
|
|
epoch_loss = loss_fct(all_preds, all_labels) |
|
|
|
|
|
all_preds = all_preds.cpu().numpy() |
|
|
all_labels = all_labels.cpu().numpy() |
|
|
epoch_loss = epoch_loss.cpu().item() |
|
|
ans = compute_metrics_dict(all_preds, all_labels,cls=cls) |
|
|
|
|
|
wandb_ans = dict(zip([f"{prefix}_{k}" for k in ans.keys()], ans.values())) |
|
|
wandb_ans[f"{prefix}_epoch_loss"] = epoch_loss |
|
|
if wandb: |
|
|
wandb.log(wandb_ans) |
|
|
|
|
|
|
|
|
return wandb_ans |
|
|
|
|
|
|
|
|
class maotao(): |
|
|
def step_loss(self,target_nn=None,tgt_te=None, |
|
|
masked_logits_list=None,nn_prob=None, |
|
|
res=None,loss_fct=None,loss_mse=None,args=None): |
|
|
loss_mask = target_nn != 1 |
|
|
tgt_te = tgt_te.view(-1) |
|
|
te_loss = loss_mse(res.te, tgt_te) |
|
|
|
|
|
loss = torch.tensor(0, dtype=torch.float32, device=args.device) |
|
|
for i in range(masked_logits_list.size(1)): |
|
|
masked_logits = masked_logits_list[:, i, ...] |
|
|
loss += self.calculate_loss(res.logits + nn_prob + masked_logits, loss_mask, target_nn, |
|
|
loss_fct) |
|
|
loss += self.calculate_loss(res.logits, loss_mask, target_nn, |
|
|
loss_fct) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss += loss + te_loss + res.aux_loss |
|
|
return loss |
|
|
def forward_step(self,model=None,src_data=None,twod_data=None,aa_idx=None,continuous_features=None,species_features=None,truncated_features=None,target_nn=None,tgt_te=None,masked_logits_list=None,nn_prob=None,loss_fct=None,loss_mse=None,args=None): |
|
|
res = model(input_ids=src_data, |
|
|
twod_tokens=twod_data, |
|
|
aa_idx=aa_idx, |
|
|
continuous_features=continuous_features, |
|
|
species_features=species_features, |
|
|
truncated_features=truncated_features, |
|
|
|
|
|
) |
|
|
|
|
|
res.te = res.te.view(-1) |
|
|
nn_prob = torch.masked_fill(nn_prob, mask=nn_prob == 0, value=float('-inf')) |
|
|
masked_logits = masked_logits_list[:,-1,...] |
|
|
res.logits = res.logits + nn_prob |
|
|
ans = dict() |
|
|
step_loss = self.step_loss(target_nn=target_nn,tgt_te=tgt_te, |
|
|
masked_logits_list=masked_logits_list,nn_prob=nn_prob, |
|
|
res=res,loss_fct=loss_fct,loss_mse=loss_mse,args=args) |
|
|
ans.update({'loss':step_loss}) |
|
|
res.logits = res.logits + masked_logits |
|
|
ans.update({'res':res}) |
|
|
return ans |
|
|
def train_epoch(self,model=None,wandb=None,ddp=None,train_loader=None,optimizer=None, |
|
|
epoch=None,prefix="TR",loss_fct=None,loss_mse=None,args=None,scaler=None,Logger=None,lr_scheduler=None): |
|
|
model.train() |
|
|
epoch_loss = 0 |
|
|
|
|
|
all_preds, all_labels = defaultdict(list), defaultdict(list) |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
ans = {} |
|
|
|
|
|
for step,(src_data, twod_data, aa_idx,continuous_features, species_features, truncated_features, \ |
|
|
target_nn, target, masked_logits_list,nn_prob,maotao_id) in enumerate(train_loader): |
|
|
src_data = src_data.to(args.device) |
|
|
twod_data = twod_data.to(args.device) |
|
|
aa_idx = aa_idx.to(args.device) |
|
|
continuous_features = continuous_features.to(args.device) |
|
|
species_features = species_features.to(args.device) |
|
|
truncated_features = truncated_features.to(args.device) |
|
|
|
|
|
masked_logits_list = masked_logits_list.to(args.device) |
|
|
nn_prob = nn_prob.to(args.device) |
|
|
target_nn = target_nn.to(args.device) |
|
|
tgt_te = target.to(args.device) |
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=scaler.is_enabled()): |
|
|
results = self.forward_step(model=model,src_data=src_data, |
|
|
twod_data=twod_data,aa_idx=aa_idx, |
|
|
continuous_features=continuous_features, |
|
|
species_features=species_features, |
|
|
truncated_features=truncated_features, |
|
|
target_nn=target_nn,tgt_te=tgt_te, |
|
|
masked_logits_list=masked_logits_list, |
|
|
nn_prob=nn_prob,loss_fct=loss_fct, |
|
|
loss_mse=loss_mse,args=args) |
|
|
res, loss = results['res'], results['loss'] |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
Logger('Epoch:[{}/{}][batch:{}/{}] loss:{:.4f}'.format(epoch, args.epochs, step, len(train_loader), loss.item())) |
|
|
epoch_loss += loss.item() |
|
|
|
|
|
|
|
|
|
|
|
all_preds['logits'].append(res.logits.detach()) |
|
|
all_preds['te'].append(res.te.detach()) |
|
|
all_preds['aux_loss'].append(torch.tensor(res.aux_loss, dtype=torch.float32).reshape(1, 1).to(args.device)) |
|
|
|
|
|
all_labels['logits'].append(target_nn.detach()) |
|
|
all_labels['te'].append(tgt_te.detach()) |
|
|
|
|
|
if step%100==1: |
|
|
ans = {} |
|
|
|
|
|
ans.update(metric_monitor(all_preds['logits'], all_labels['logits'], epoch, prefix+'_codon',start_time,ddp=ddp,wandb=wandb,loss_fct=loss_fct,Logger=Logger,cls='identity')) |
|
|
|
|
|
ans.update({f'{prefix}_loss':epoch_loss}) |
|
|
all_preds, all_labels = defaultdict(list), defaultdict(list) |
|
|
|
|
|
epoch_loss = epoch_loss / len(train_loader) if len(train_loader) > 0 else 0 |
|
|
Logger('\n'+'#'*10+f' {prefix}_loss: {epoch_loss:.4f} '+'#'*10+'\n') |
|
|
|
|
|
return ans |
|
|
def Maotao_DataLoader(self,file_path, args, tokenizer, data_tag='TR', ddp=None, Logger=None, returnid=None): |
|
|
if args.debug: args.limit = 320 |
|
|
train_ds = MaotaoDataset(file_path, tokenizer, args=args, limit=args.limit, seq_len=args.seq_len,returnid=returnid,codon_table_path=args.codon_table_path) |
|
|
train_sampler = DistributedSampler(train_ds) if ddp else None |
|
|
drop_last = True if ddp and data_tag == 'TR' else False |
|
|
if data_tag == 'TR': |
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=args.batch_size, |
|
|
pin_memory=True, |
|
|
drop_last=drop_last, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers, |
|
|
sampler=train_sampler |
|
|
) |
|
|
else: |
|
|
train_loader = DataLoader( |
|
|
train_ds, |
|
|
batch_size=args.batch_size, |
|
|
pin_memory=True, |
|
|
drop_last=drop_last, |
|
|
shuffle=False, |
|
|
num_workers=args.num_workers |
|
|
) |
|
|
Logger(f'loading data from ', file_path, args.seq_len, args.column, args.label, len(train_loader.dataset)) |
|
|
|
|
|
|
|
|
return train_loader |
|
|
|
|
|
def load_data(self,args=None,task='sft',tokenizer=None,ddp=None,Logger=None): |
|
|
|
|
|
csv_path = os.path.join(args.downstream_data_path, task, "{}.csv") |
|
|
check_file_flag = [os.access(csv_path.format(tag),os.F_OK) for tag in ['TR', 'VL', 'TS']] |
|
|
loaders = [] |
|
|
for flag,tag in zip(check_file_flag,['TR', 'VL', 'TS']): |
|
|
if flag: |
|
|
train_loader = self.Maotao_DataLoader(os.path.join(csv_path.format(tag)), args, |
|
|
tokenizer, data_tag=tag, ddp=ddp, |
|
|
Logger=Logger) |
|
|
else: |
|
|
print(f'Warning: {tag} data not found, skip it.',csv_path.format(tag)) |
|
|
train_loader = None |
|
|
loaders.append(train_loader) |
|
|
|
|
|
assert check_file_flag[-1],f'请检查数据集路径是否正确,至少需要一个TS数据集,{os.path.abspath(csv_path)},{csv_path.format("TS")}' |
|
|
return loaders |
|
|
|
|
|
def calculate_loss(self,logits, loss_mask, tgt_data, loss_ce): |
|
|
|
|
|
|
|
|
|
|
|
loss_mask = loss_mask == 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
loss = loss_ce( |
|
|
torch.masked_select(logits, loss_mask.unsqueeze(-1).repeat(1, 1, logits.size(-1))).view(-1,logits.size( |
|
|
-1)), |
|
|
|
|
|
torch.masked_select(tgt_data, loss_mask).view(-1) |
|
|
).mean() |
|
|
return loss |
|
|
def evaluate(self,model, data_loader,stage='epoch', prefix="VL",args=None,loss_fct=None, |
|
|
loss_mse=None,wandb=None,ddp=None,ctx=None,Logger=None,fpred_out=None,tokenizer=None): |
|
|
start_time = time.time() |
|
|
epoch_loss = 0 |
|
|
model.eval() |
|
|
|
|
|
all_preds, all_labels = defaultdict(list), defaultdict(list) |
|
|
all_num = args.batch_size * len(data_loader) |
|
|
with torch.no_grad(): |
|
|
for idx,(src_data, twod_data, aa_idx, continuous_features, species_features, truncated_features, \ |
|
|
target_nn, target, masked_logits_list, nn_prob, maotao_id) in enumerate(data_loader): |
|
|
src_data = src_data.to(args.device) |
|
|
twod_data = twod_data.to(args.device) |
|
|
aa_idx = aa_idx.to(args.device) |
|
|
continuous_features = continuous_features.to(args.device) |
|
|
species_features = species_features.to(args.device) |
|
|
truncated_features = truncated_features.to(args.device) |
|
|
|
|
|
masked_logits_list = masked_logits_list.to(args.device) |
|
|
|
|
|
nn_prob = nn_prob.to(args.device) |
|
|
target_nn = target_nn.to(args.device) |
|
|
tgt_te = target.to(args.device) |
|
|
|
|
|
results = self.forward_step(model=model,src_data=src_data, |
|
|
twod_data=twod_data,aa_idx=aa_idx, |
|
|
continuous_features=continuous_features, |
|
|
species_features=species_features, |
|
|
truncated_features=truncated_features, |
|
|
target_nn=target_nn,tgt_te=tgt_te, |
|
|
masked_logits_list=masked_logits_list, |
|
|
nn_prob=nn_prob,loss_fct=loss_fct, |
|
|
loss_mse=loss_mse,args=args) |
|
|
res, loss = results['res'], results['loss'] |
|
|
epoch_loss += loss.item() |
|
|
|
|
|
all_preds['logits'].append(res.logits.detach()) |
|
|
all_preds['te'].append(res.te.reshape(-1).detach()) |
|
|
all_preds['aux_loss'].append(torch.tensor(res.aux_loss, dtype=torch.float32).reshape(1, 1).to(args.device)) |
|
|
|
|
|
all_labels['logits'].append(target_nn) |
|
|
all_labels['te'].append(tgt_te.reshape(-1)) |
|
|
|
|
|
if args.predict and (not ddp or dist.get_rank() == 0): |
|
|
pred_logis = res.logits.detach() |
|
|
pred_te = res.te.reshape(-1).detach() |
|
|
for idj,(_id, logits, nn, te) in enumerate(zip(maotao_id, pred_logis, target_nn, pred_te)): |
|
|
pred_cai = te.item() |
|
|
|
|
|
temperature = 0.8 |
|
|
probs = torch.softmax(logits / temperature, dim=-1) |
|
|
probs = torch.nan_to_num(probs, nan=1e-9) |
|
|
tokens = torch.multinomial(probs, num_samples=1) |
|
|
tokens = tokens.squeeze(-1) |
|
|
tokens_hard = logits.argmax(1) |
|
|
pred_nn = ''.join([tokenizer.symbols[x] for x,y in zip(tokens.cpu().numpy(), tokens_hard.cpu().numpy()) if y]).replace( |
|
|
'U', 'T') |
|
|
ans = metric_monitor(logits, nn, stage, '' + '', |
|
|
start_time, ddp=ddp, |
|
|
wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity') |
|
|
df = pd.DataFrame([ans]) |
|
|
df['maotao_id'] = _id |
|
|
df['pred_cai'] = pred_cai |
|
|
df['pred_nn'] = pred_nn |
|
|
if os.access(fpred_out, os.F_OK): |
|
|
df.to_csv(fpred_out, mode='a', header=False, index=False) |
|
|
else: |
|
|
df.to_csv(fpred_out, mode='w', header=True, index=False) |
|
|
Logger(f'{args.batch_size*idx+idj}/{data_loader.dataset.__len__()},fpred_out:{fpred_out}') |
|
|
|
|
|
epoch_loss /=len(data_loader) |
|
|
if (wandb is not None) and (not ddp or dist.get_rank() == 0): |
|
|
wandb.log({f'{prefix}_epoch_loss': epoch_loss}) |
|
|
|
|
|
ans = metric_monitor(all_preds['logits'], all_labels['logits'], stage, prefix + '_logits', start_time, ddp=ddp, |
|
|
wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='binary') |
|
|
ans.update( |
|
|
metric_monitor(all_preds['logits'], all_labels['logits'], stage, prefix + '_codon', start_time, ddp=ddp, |
|
|
wandb=wandb, loss_fct=loss_fct, Logger=Logger, cls='identity')) |
|
|
if not args.predict:ans.update( |
|
|
metric_monitor(all_preds['te'], all_labels['te'], stage, prefix + '_cai', start_time, ddp=ddp, wandb=wandb, |
|
|
loss_fct=loss_mse, Logger=Logger, cls='regression')) |
|
|
ans.update({f'{prefix}_loss': epoch_loss}) |
|
|
|
|
|
wandb_ans = ans |
|
|
|
|
|
if not ddp or dist.get_rank() == 0: |
|
|
if wandb: |
|
|
wandb.log(wandb_ans) |
|
|
|
|
|
|
|
|
return wandb_ans |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_model(args,ckp = f'./out/full_dist_256_epoch.pth',lm_config=None,tokenizer=None,Logger=None,require_ckp=False): |
|
|
if args.debug:lm_config.n_layers = 1 |
|
|
model = MiniMindLM_Maotao(lm_config) |
|
|
print(model) |
|
|
print(lm_config) |
|
|
print('ckp=',ckp) |
|
|
if ckp is not None and os.access(ckp, os.F_OK) and os.path.getsize(ckp) > 0: |
|
|
print('loading model from', ckp) |
|
|
state_dict = torch.load(ckp, map_location=args.device) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
print(f'finetune ({args.finetune}) from, {os.path.abspath(ckp)}') |
|
|
else: |
|
|
if require_ckp: |
|
|
print('ckp') |
|
|
exit('not found model'+ckp) |
|
|
|
|
|
|
|
|
if args.finetune: |
|
|
for name, value in model.named_parameters(): |
|
|
if 'layers' in name: |
|
|
value.requires_grad = False |
|
|
print(name, value.numel(), value.requires_grad) |
|
|
else: |
|
|
for name, value in model.named_parameters(): |
|
|
print(name, value.numel(), value.requires_grad) |
|
|
print( |
|
|
f'LLM参数量 训练/总计:{sum(p.numel() for p in model.parameters() if p.requires_grad) / 1e6:.3f} 百万 / {sum(p.numel() for p in model.parameters()) / 1e6:.3f} 百万') |
|
|
print( |
|
|
f'LLM参数量 训练/总计:{sum(p.numel() for p in model.parameters() if p.requires_grad)} / {sum(p.numel() for p in model.parameters())}') |
|
|
|
|
|
model = model.to(args.device) |
|
|
|
|
|
|
|
|
return model, tokenizer,lm_config |
|
|
|
|
|
|
|
|
def save_metrics(ans, epoch, out_dir, filename='history_metrics.csv'): |
|
|
if not ddp or ddp_local_rank == 0: |
|
|
df = pd.DataFrame([ans]) |
|
|
df['epoch'] = epoch |
|
|
df['path'] = out_dir |
|
|
if epoch == 0: |
|
|
df.to_csv(os.path.join(out_dir, filename), mode='w') |
|
|
elif epoch == 'TS': |
|
|
df = df.T |
|
|
df.to_csv(os.path.join(out_dir, filename), mode='w') |
|
|
else: |
|
|
df.to_csv(os.path.join(out_dir, filename), mode='a', header=False) |
|
|
|
|
|
def sft_process_maotao(max_seq_len=-1,ctx=None,ddp=False,ddp_local_rank=0,args=None,ckp=None,out_ckp=None,lm_config=None, |
|
|
tokenizer=None,Logger=None,task=None,seq_pkl_path=None,sft=None,require_ckp=False): |
|
|
print('3. fine-tune on downstream tasks...') |
|
|
print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
model, tokenizer, lm_config_student = init_model(args,ckp=ckp,lm_config=lm_config,tokenizer=tokenizer,Logger=Logger,require_ckp=require_ckp) |
|
|
|
|
|
if ddp: |
|
|
model._ddp_params_and_buffers_to_ignore = {"pos_cis"} |
|
|
model = DistributedDataParallel(model, device_ids=[ddp_local_rank],find_unused_parameters=True) |
|
|
|
|
|
|
|
|
train_loader, val_loader, test_loader = sft.load_data(args=args,tokenizer=tokenizer, |
|
|
task=task,ddp=ddp,Logger=Logger) |
|
|
|
|
|
if train_loader: |
|
|
args.wandb_run_name = f"{args.wandb_project}_EP_{args.epochs}_BS_{args.batch_size}_LR_{args.learning_rate}_FT_{args.finetune}_TR_{len(train_loader.dataset)}" |
|
|
else: |
|
|
args.wandb_run_name = f"{args.wandb_project}_EP_{args.epochs}_BS_{args.batch_size}_LR_{args.learning_rate}_FT_{args.finetune}_TR_{0}" |
|
|
|
|
|
|
|
|
if args.use_wandb and (not ddp or ddp_local_rank == 0): |
|
|
import wandb |
|
|
wandb.init(project=args.wandb_project, name=args.wandb_run_name, mode="offline",config=args,anonymous="allow") |
|
|
|
|
|
Logger(f'init wandb with id {wandb.run.id}') |
|
|
else: |
|
|
wandb = None |
|
|
|
|
|
loss_ce = nn.CrossEntropyLoss() |
|
|
loss_mse = nn.MSELoss() |
|
|
|
|
|
if args.predict: |
|
|
fpred_out = args.out_dir + f'/{args.task}/TS_pred.csv' |
|
|
os.system(f'rm -f {fpred_out}') |
|
|
os.makedirs(os.path.dirname(fpred_out), exist_ok=True) |
|
|
Logger(f'predicting {fpred_out} with {os.path.abspath(ckp)}') |
|
|
final_metrics = sft.evaluate(model, test_loader, stage='predict', prefix='TS', args=args, loss_fct=loss_ce,loss_mse=loss_mse, wandb=wandb, ddp=ddp, |
|
|
ctx=ctx, Logger=Logger,fpred_out=fpred_out,tokenizer=tokenizer) |
|
|
return ckp, final_metrics, -1 |
|
|
|
|
|
final_metrics = None |
|
|
if not ddp or ddp_local_rank == 0: |
|
|
Logger('predict and saving to file...',os.path.join(args.out_dir, f"zeroshot_metrics.csv")) |
|
|
for tag,loader in zip(['TR','VL','TS'],[train_loader,val_loader,test_loader]): |
|
|
if tag=='TR':continue |
|
|
if loader is not None: |
|
|
fpred_out = args.out_dir + f'/{prefix}_pred.csv' |
|
|
os.system(f'rm -f {fpred_out}') |
|
|
final_metrics = sft.evaluate(model, loader, stage='predict',prefix=tag,args=args,loss_fct=loss_fct,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,fpred_out=fpred_out,tokenizer=tokenizer) |
|
|
save_metrics(final_metrics,tag,args.out_dir,filename=f"zeroshot_metrics.csv") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return ckp, final_metrics, -1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with open(os.path.join(args.out_dir, f"{args.wandb_run_name}.params"), 'w') as f: |
|
|
f.write(f'#args={args}' + '\n') |
|
|
f.write(f'#model={model}' + '\n') |
|
|
f.write(f'#lm_config={lm_config}' + '\n') |
|
|
f.write(f'#tokenizer={tokenizer.indices}' + '\n') |
|
|
|
|
|
scaler = torch.cuda.amp.GradScaler(enabled=(args.dtype in ['float16', 'bfloat16'])) |
|
|
|
|
|
lr_scheduler = None |
|
|
optimizer = optim.AdamW(model.parameters(), lr=args.learning_rate) |
|
|
|
|
|
if out_ckp is not None: |
|
|
ckp = out_ckp |
|
|
else: |
|
|
ckp = f'{args.save_dir}/{args.wandb_run_name}.pth' |
|
|
|
|
|
|
|
|
early_stopping = EarlyStopping(patience=3, verbose=True, path=ckp) |
|
|
early_stopping.save_model(model, ckp) |
|
|
|
|
|
|
|
|
epoch = 0 |
|
|
for epoch in range(args.epochs): |
|
|
if ddp: train_loader.sampler.set_epoch(epoch) |
|
|
if epoch ==0: |
|
|
Logger(f'iterating over {len(train_loader)} batches per epoch, {len(train_loader.dataset)} samples per epoch, batch_size={train_loader.batch_size}, gpu_num={torch.cuda.device_count()}') |
|
|
|
|
|
|
|
|
ans = sft.train_epoch(model=model,wandb=wandb, |
|
|
train_loader=train_loader, prefix="TR", |
|
|
optimizer=optimizer, ddp=ddp,epoch=epoch, |
|
|
loss_fct=loss_ce,loss_mse=loss_mse, |
|
|
scaler=scaler,args=args,Logger=Logger,lr_scheduler=lr_scheduler) |
|
|
ans_vl = sft.evaluate(model, val_loader, stage=f'{epoch}',prefix="VL",args=args,loss_fct=loss_ce,loss_mse=loss_mse,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,tokenizer=tokenizer) |
|
|
ans.update(ans_vl) |
|
|
|
|
|
val_mse = ans['VL_logits_epoch_loss'] |
|
|
save_metrics(ans,epoch,args.out_dir,filename=f"{args.wandb_run_name}_history_metrics.csv") |
|
|
|
|
|
if ddp: |
|
|
|
|
|
if ddp_local_rank == 0: |
|
|
early_stopping(val_mse, model) |
|
|
|
|
|
to_broadcast = torch.tensor([early_stopping.early_stop], dtype=torch.bool, device=args.device) |
|
|
dist.broadcast(to_broadcast, 0) |
|
|
else: |
|
|
|
|
|
|
|
|
to_broadcast = torch.tensor([False], dtype=torch.bool, device=args.device) |
|
|
dist.broadcast(to_broadcast, 0) |
|
|
early_stopping.early_stop = bool(to_broadcast.item()) |
|
|
else: |
|
|
|
|
|
early_stopping(val_mse, model) |
|
|
if early_stopping.early_stop:break |
|
|
|
|
|
if os.access(ckp, os.F_OK) and os.path.getsize(ckp) > 0: |
|
|
state_dict = torch.load(ckp, map_location=args.device) |
|
|
model.load_state_dict(state_dict, strict=False) |
|
|
Logger("Loaded best model for final evaluation.") |
|
|
else:early_stopping.save_model(model,ckp) |
|
|
final_metrics = sft.evaluate(model, test_loader,stage='final', prefix="TS",args=args,loss_fct=loss_ce,loss_mse=loss_mse,wandb=wandb,ddp=ddp,ctx=ctx,Logger=Logger,tokenizer=tokenizer) |
|
|
save_metrics(final_metrics, 'TS', args.out_dir, filename=f"{args.wandb_run_name}_final_metrics.csv") |
|
|
|
|
|
|
|
|
Logger('final evaluation on TS') |
|
|
|
|
|
if args.use_wandb and (not ddp or ddp_local_rank == 0):wandb.finish() |
|
|
Logger(f'the end of experiment stf process: time = {(time.time() - start) // 60:.0f} min') |
|
|
return ckp,final_metrics,epoch |
|
|
|
|
|
def trainable(config): |
|
|
|
|
|
sft = maotao() |
|
|
|
|
|
|
|
|
args.batch_size = config["batch_size"] |
|
|
args.learning_rate = config["learning_rate"] |
|
|
max_seq_len = args.max_seq_len |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ckp, final_metrics, _ = sft_process_maotao(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp, |
|
|
ddp_local_rank=ddp_local_rank, |
|
|
args=args, ckp=in_ckp, out_ckp=out_ckp, |
|
|
lm_config=lm_config, tokenizer=tokenizer, Logger=Logger,task=task, |
|
|
sft=sft) |
|
|
|
|
|
|
|
|
return {"identity": final_metrics['TS_codon_identity_codon']} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_ray_tune(): |
|
|
Logger('hello word') |
|
|
|
|
|
from ray import tune |
|
|
|
|
|
def objective(config): |
|
|
score = config["a"] ** 2 + config["b"] |
|
|
return {"score": score} |
|
|
|
|
|
|
|
|
search_space = { |
|
|
"a": tune.grid_search([0.001, 0.01, 0.1, 1.0,0.4]), |
|
|
|
|
|
"b": tune.choice([1, 2, 3]), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
tuner = tune.Tuner(objective, param_space=search_space, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
) |
|
|
|
|
|
results = tuner.fit() |
|
|
Logger(results.get_best_result(metric="score", mode="max").config) |
|
|
|
|
|
def init_config(vocab_path,n_layers,max_seq_len): |
|
|
tokenizer = Dictionary.load(vocab_path) |
|
|
tokenizer.mask_index = tokenizer.add_symbol('<mask>') |
|
|
tokenizer.indices['T'] = tokenizer.indices['U'] |
|
|
tokenizer.indices['_'] = tokenizer.pad_index |
|
|
lm_config = LMaoTaoConfig(dim=256, logit_dim=len(tokenizer),n_layers=n_layers, max_seq_len=max_seq_len, vocab_size=len(tokenizer),padding_idx=tokenizer.pad_index) |
|
|
|
|
|
[tokenizer.add_symbol(word) for word in AA_str] |
|
|
return lm_config,tokenizer |
|
|
|
|
|
def seed_everything(seed=2022): |
|
|
print('seed_everything to ',seed) |
|
|
random.seed(seed) |
|
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
torch.cuda.manual_seed(seed) |
|
|
torch.backends.cudnn.deterministic = True |
|
|
torch.backends.cudnn.benchmark = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) |
|
|
start = time.time() |
|
|
''' |
|
|
test |
|
|
''' |
|
|
|
|
|
|
|
|
sft = maotao() |
|
|
parser = get_pretraining_args() |
|
|
args = parser.parse_args() |
|
|
args.downstream_data_path = 'maotao_file/' |
|
|
args.seed = int(args.seed) |
|
|
|
|
|
|
|
|
seed_everything(seed=args.seed) |
|
|
if args.predict: |
|
|
task= args.task |
|
|
else: |
|
|
task = 'AA2CDS_data' |
|
|
device_type = "cuda" if "cuda" in args.device else "cpu" |
|
|
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() |
|
|
|
|
|
|
|
|
ddp_local_rank, DEVICE = 0, "cuda:0" |
|
|
if ddp: |
|
|
print('init distributed mode') |
|
|
ddp_local_rank, DEVICE = init_distributed_mode(ddp=ddp) |
|
|
args.device = torch.device(DEVICE) |
|
|
Logger('args.device:',args.device) |
|
|
Logger('setting args',args) |
|
|
max_seq_len = 1200 |
|
|
args.seq_len = max_seq_len |
|
|
|
|
|
args.save_dir = os.path.join(args.out_dir) |
|
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
|
os.makedirs(args.out_dir, exist_ok=True) |
|
|
tokens_per_iter = args.batch_size * max_seq_len |
|
|
|
|
|
lm_config,tokenizer = init_config(args.arg_overrides['data'] + '/small_dict.txt', args.n_layers, max_seq_len) |
|
|
lm_config.use_moe = args.use_moe |
|
|
wandb_project = args.wandb_project |
|
|
|
|
|
'''3. benchmark downstream tasks''' |
|
|
prefix = 'TS' |
|
|
|
|
|
|
|
|
epochs = args.epochs |
|
|
args.out_dir = os.path.abspath(args.out_dir) |
|
|
os.makedirs(args.out_dir, exist_ok=True) |
|
|
model_dir = args.out_dir |
|
|
model_dir = os.path.abspath(model_dir) |
|
|
|
|
|
data_dir = args.downstream_data_path |
|
|
data_dir = os.path.abspath(data_dir) |
|
|
|
|
|
Logger(f'model_dir:{model_dir}') |
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
args.save_dir = os.path.abspath(args.save_dir) |
|
|
args.downstream_data_path = os.path.abspath(args.downstream_data_path) |
|
|
|
|
|
Logger('args.downstream_data_path:', args.downstream_data_path) |
|
|
out_ckp = args.save_dir + f'/AA2CDS.pth' |
|
|
out_ckp = os.path.abspath(out_ckp) |
|
|
|
|
|
in_ckp = model_dir+'/AA2CDS.pth' |
|
|
|
|
|
|
|
|
|
|
|
search_space = { |
|
|
|
|
|
"batch_size": args.batch_size, |
|
|
"learning_rate": args.learning_rate, |
|
|
} |
|
|
trainable(search_space) |
|
|
|
|
|
print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) |
|
|
print('time', time.time() - start) |
|
|
|