|
|
|
|
|
|
|
|
""" |
|
|
Title : maotao_inference.py.py |
|
|
project : minimind_RiboUTR |
|
|
Created by: julse |
|
|
Created on: 2025/10/23 16:49 |
|
|
des: TODO |
|
|
""" |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import time |
|
|
import pandas as pd |
|
|
import numpy as np |
|
|
|
|
|
import sys |
|
|
import os |
|
|
import time |
|
|
import torch |
|
|
import torch.distributed as dist |
|
|
from model.tools import get_pretraining_args, find_unused_parameters |
|
|
from contextlib import nullcontext |
|
|
from train import sft_process_maotao, init_config, maotao |
|
|
|
|
|
|
|
|
ddp = int(os.environ.get("RANK", -1)) != -1 |
|
|
print('Setting running environment') |
|
|
|
|
|
def Logger(*content): |
|
|
if not ddp or dist.get_rank() == 0: |
|
|
print(*content) |
|
|
def init_distributed_mode(ddp=True): |
|
|
print("init distributed mode,ddp=",ddp) |
|
|
|
|
|
if not ddp: return |
|
|
global ddp_local_rank, DEVICE |
|
|
dist.init_process_group(backend="nccl") |
|
|
ddp_rank = int(os.environ["RANK"]) |
|
|
ddp_local_rank = int(os.environ["LOCAL_RANK"]) |
|
|
ddp_world_size = int(os.environ["WORLD_SIZE"]) |
|
|
DEVICE = f"cuda:{ddp_local_rank}" |
|
|
torch.cuda.set_device(DEVICE) |
|
|
|
|
|
print('init distributed mode, ddp_rank:', ddp_rank, 'ddp_local_rank:', ddp_local_rank, 'ddp_world_size:', ddp_world_size) |
|
|
return ddp_local_rank,DEVICE |
|
|
def inference(args): |
|
|
sft = maotao() |
|
|
|
|
|
if args.predict: |
|
|
task = args.task |
|
|
else: |
|
|
task = 'AA2CDS_data' |
|
|
device_type = "cuda" if "cuda" in args.device else "cpu" |
|
|
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast() |
|
|
|
|
|
ddp_local_rank, DEVICE = 0, "cuda:0" |
|
|
if ddp: |
|
|
print('init distributed mode') |
|
|
ddp_local_rank, DEVICE = init_distributed_mode(ddp=ddp) |
|
|
args.device = torch.device(DEVICE) |
|
|
Logger('args.device:', args.device) |
|
|
Logger('setting args', args) |
|
|
max_seq_len = 1200 |
|
|
args.seq_len = max_seq_len |
|
|
|
|
|
args.save_dir = os.path.join(args.out_dir) |
|
|
|
|
|
os.makedirs(args.save_dir, exist_ok=True) |
|
|
os.makedirs(args.out_dir, exist_ok=True) |
|
|
tokens_per_iter = args.batch_size * max_seq_len |
|
|
|
|
|
lm_config, tokenizer = init_config(args.arg_overrides['data'] + '/small_dict.txt', args.n_layers, max_seq_len) |
|
|
lm_config.use_moe = args.use_moe |
|
|
wandb_project = args.wandb_project |
|
|
|
|
|
'''3. benchmark downstream tasks''' |
|
|
prefix = 'TS' |
|
|
|
|
|
|
|
|
epochs = args.epochs |
|
|
args.out_dir = os.path.abspath(args.out_dir) |
|
|
os.makedirs(args.out_dir, exist_ok=True) |
|
|
model_dir = args.out_dir |
|
|
model_dir = os.path.abspath(model_dir) |
|
|
data_dir = args.downstream_data_path |
|
|
Logger(f'model_dir:{model_dir}') |
|
|
|
|
|
os.makedirs(model_dir, exist_ok=True) |
|
|
args.save_dir = os.path.abspath(args.save_dir) |
|
|
args.downstream_data_path = os.path.abspath(args.downstream_data_path) |
|
|
|
|
|
Logger('args.downstream_data_path:', args.downstream_data_path) |
|
|
out_ckp = args.save_dir + f'/AA2CDS.pth' |
|
|
out_ckp = os.path.abspath(out_ckp) |
|
|
|
|
|
|
|
|
|
|
|
in_ckp = args.mlm_pretrained_model_path |
|
|
ckp, final_metrics, _ = sft_process_maotao(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp, |
|
|
ddp_local_rank=ddp_local_rank, |
|
|
args=args, ckp=in_ckp, out_ckp=out_ckp, |
|
|
lm_config=lm_config, tokenizer=tokenizer, Logger=Logger,task=task, |
|
|
sft=sft,require_ckp=True) |
|
|
|
|
|
if __name__ == '__main__': |
|
|
print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) |
|
|
start = time.time() |
|
|
|
|
|
|
|
|
parser = get_pretraining_args() |
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
args.downstream_data_path = 'maotao_file/' |
|
|
args.task='AA2CDS_data' |
|
|
args.predict =True |
|
|
args.mlm_pretrained_model_path = 'checkpoint/AA2CDS.pth' |
|
|
args.out_dir = 'example/out_TR_TS' |
|
|
|
|
|
inference(args) |
|
|
print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime())) |
|
|
print('time', time.time() - start) |
|
|
|