maotao / inference.py
julse's picture
Update inference.py
6dad468 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Title : maotao_inference.py.py
project : minimind_RiboUTR
Created by: julse
Created on: 2025/10/23 16:49
des: TODO
"""
import sys
import os
import time
import pandas as pd
import numpy as np
import sys
import os
import time
import torch
import torch.distributed as dist
from model.tools import get_pretraining_args, find_unused_parameters
from contextlib import nullcontext
from train import sft_process_maotao, init_config, maotao
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
print('Setting running environment')
def Logger(*content):
if not ddp or dist.get_rank() == 0:
print(*content)
def init_distributed_mode(ddp=True):
print("init distributed mode,ddp=",ddp)
if not ddp: return
global ddp_local_rank, DEVICE
dist.init_process_group(backend="nccl")
ddp_rank = int(os.environ["RANK"])
ddp_local_rank = int(os.environ["LOCAL_RANK"])
ddp_world_size = int(os.environ["WORLD_SIZE"])
DEVICE = f"cuda:{ddp_local_rank}"
torch.cuda.set_device(DEVICE)
print('init distributed mode, ddp_rank:', ddp_rank, 'ddp_local_rank:', ddp_local_rank, 'ddp_world_size:', ddp_world_size)
return ddp_local_rank,DEVICE
def inference(args):
sft = maotao()
if args.predict:
task = args.task
else:
task = 'AA2CDS_data'
device_type = "cuda" if "cuda" in args.device else "cpu"
ctx = nullcontext() if device_type == "cpu" else torch.cuda.amp.autocast()
ddp_local_rank, DEVICE = 0, "cuda:0"
if ddp:
print('init distributed mode')
ddp_local_rank, DEVICE = init_distributed_mode(ddp=ddp)
args.device = torch.device(DEVICE)
Logger('args.device:', args.device)
Logger('setting args', args)
max_seq_len = 1200
args.seq_len = max_seq_len
args.save_dir = os.path.join(args.out_dir)
# os.system(f"rm -rf {args.save_dir}") # todo
os.makedirs(args.save_dir, exist_ok=True)
os.makedirs(args.out_dir, exist_ok=True)
tokens_per_iter = args.batch_size * max_seq_len
lm_config, tokenizer = init_config(args.arg_overrides['data'] + '/small_dict.txt', args.n_layers, max_seq_len)
lm_config.use_moe = args.use_moe
wandb_project = args.wandb_project
'''3. benchmark downstream tasks'''
prefix = 'TS'
# with open(args.save_dir+'/benchmark_result.tsv','w') as f:
# f.write('Project\tModel\tTask\tSPR\tPR\tMSE\tRMSE\tR2\tckp\tepoch\n')
epochs = args.epochs
args.out_dir = os.path.abspath(args.out_dir)
os.makedirs(args.out_dir, exist_ok=True)
model_dir = args.out_dir # 'exp_log/out_demo4/'
model_dir = os.path.abspath(model_dir)
data_dir = args.downstream_data_path # 'dataset/downstreamV4/'
Logger(f'model_dir:{model_dir}')
os.makedirs(model_dir, exist_ok=True)
args.save_dir = os.path.abspath(args.save_dir)
args.downstream_data_path = os.path.abspath(args.downstream_data_path)
# args.codon_table_path = 'maotao_file/codon_table/codon_usage_{species}.csv'
Logger('args.downstream_data_path:', args.downstream_data_path)
out_ckp = args.save_dir + f'/AA2CDS.pth'
out_ckp = os.path.abspath(out_ckp)
# os.system(f"rm -rf {out_ckp}")
# in_ckp = model_dir + '/AA2CDS.pth'
in_ckp = args.mlm_pretrained_model_path
ckp, final_metrics, _ = sft_process_maotao(max_seq_len=max_seq_len, ctx=ctx, ddp=ddp,
ddp_local_rank=ddp_local_rank,
args=args, ckp=in_ckp, out_ckp=out_ckp,
lm_config=lm_config, tokenizer=tokenizer, Logger=Logger,task=task,
sft=sft,require_ckp=True)
if __name__ == '__main__':
print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
start = time.time()
parser = get_pretraining_args()
args = parser.parse_args()
# args.downstream_data_path = 'example/out/tmp/AA2CDS_data/' # TS.csv #
args.downstream_data_path = 'maotao_file/' # TS.csv
args.task='AA2CDS_data'
args.predict =True
args.mlm_pretrained_model_path = 'checkpoint/AA2CDS.pth'
args.out_dir = 'example/out_TR_TS'
inference(args)
print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
print('time', time.time() - start)