maotao / predict.py
julse's picture
Upload 4 files
d50ee88 verified
#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
Title : maotao_predict.py.py
project : minimind_RiboUTR
Created by: julse
Created on: 2025/10/23 16:02
des: TODO
"""
import os
import time
import pandas as pd
from inference import inference
from model.assemble_fragment import assemble_fragments
from model.codon_attr import Codon
from model.sliding_windows import process_nucleotide_sequences
from model.tools import get_pretraining_args
def check_path(dirout,file=False):
if file:dirout = dirout.rsplit('/',1)[0]
try:
if not os.path.exists(dirout):
print('make dir -p '+dirout)
os.makedirs(dirout)
except:
print(f'{dirout} have been made by other process')
def translate(nucleotide_seq):
seq = nucleotide_seq.replace('T','U')
amino_acid_seq = ''.join([Codon.CODON_TO_AA.get(seq[x:x+3],'_') for x in range(0,len(seq),3)])
return amino_acid_seq
def process_inputs(fin=None,dirout= None,codon_table=None):
# codon_table = '/Users/gz_julse/code/minimind_RiboUTR/maotao_file/codon_table/codon_usage_{species}.csv'
# fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx'
# dirout = f'{WDIR}/predict_web/'
check_path(dirout)
df = pd.read_excel(fin)
data = df[['id', 'RefSeq_aa']]
df = data
species_list = """mouse,Ec,Sac,Pic,Human""".split(',')
print(f'loading {len(data)} AA from {fin}\nprepare inputs for generating to CDS for expression in {species_list}')
codon_instance = {species: Codon(codon_table.format(species=species), rna=False) for species in species_list}
for i, species in enumerate(species_list):
df['species'] = species
df['cai_best_nn'] = df.apply(lambda x: codon_instance[x['species']].cai_opt_codon(x['RefSeq_aa']), axis=1)
if i == 0:
df.to_csv(dirout + '/TS.csv', mode='w', index=False, header=True)
else:
df.to_csv(dirout + '/TS.csv', mode='a', index=False, header=False)
data = pd.read_csv(dirout + '/TS.csv')
data['RefSeq_nn'] = data['cai_best_nn']
fragments_list = data.apply(
lambda x: process_nucleotide_sequences(x['RefSeq_nn'], max_nn_length=1200, step=300, pad_char='_',
meta_dict={'_id': x['id'], 'species': x['species']}), axis=1)
expanded_data = pd.DataFrame([item for sublist in fragments_list for item in sublist])
expanded_data['truncated_aa'] = expanded_data['truncated_nn'].apply(lambda x: translate(x))
expanded_data = expanded_data.rename(columns={'truncated_nn': 'cai_best_nn'})
expanded_data.to_csv(dirout + '/TS.csv', mode='w', index=False, header=True)
print(f'process {len(expanded_data)} data and saving to {dirout}/TS.csv')
def process_result(fin=None,fpred=None,ftruncated=None,fout=None):
# fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx'
# fpred = '/Users/gz_julse/Data/maotao/2025_bio-os_data/predict_web/TS_pred.csv'
# ftruncated = '/Users/gz_julse/Data/maotao/2025_bio-os_data/predict_web/TS.csv'
# fout = '/Users/gz_julse/Data/maotao/2025_bio-os_data/predict_web/TS_assemble.xlsx'
df = pd.read_excel(fin)
df_pred = pd.read_csv(fpred)
df_trun = pd.read_csv(ftruncated)
df_info = df_pred.merge(df_trun)
tmps = []
for (_id, species), data in df_info.groupby(by=['_id', 'species']):
# print(_id)
# if len(data) <40: continue
seq = assemble_fragments(data)
# seq = seq.replace('T','U')
# aa = ''.join([Codon.CODON_TO_AA[seq[x:x+3]] for x in range(0,len(seq),3)])
# print('seq',seq)
tmps.append([_id, species, seq])
df_tmp = pd.DataFrame(tmps, columns=['_id', 'species', 'seq'])
df_tmp['species'] = df_tmp['species'].replace({
'Ec': 'Escherichia coli',
'Human': 'Homo sapiens (Human)',
'Pic': 'Pichia angusta',
'Sac': 'Saccharomyces cerevisiae',
'mouse': 'Mus musculus (Mouse)'
})
full_name = ['Homo sapiens (Human)','Mus musculus (Mouse)','Escherichia coli','Saccharomyces cerevisiae','Pichia angusta']
df_wide = df_tmp.pivot(index=['_id'], columns='species', values='seq')
df_wide = df_wide.reset_index() # 将索引转回列
df_wide['RefSeq_aa_translate'] = df_wide['Homo sapiens (Human)'].apply(
lambda x: ''.join([Codon.CODON_TO_AA[x.replace('T', 'U')[i:i + 3]] for i in range(0, len(x), 3)]))
df_wide = df_wide.rename(columns={'_id': 'id'})
df_wide = df[['id', 'RefSeq_aa']].merge(df_wide, on=['id'])[['id', 'RefSeq_aa'] + full_name]
# if len(df_wide[df_wide['RefSeq_aa']!=df_wide['RefSeq_aa_translate']]):print('wrongly translated')
df_wide.to_excel(fout, index=False, engine='openpyxl')
def predict(fin,dirout):
'''prepare data'''
# codon_table = '/Users/gz_julse/code/minimind_RiboUTR/maotao_file/codon_table/codon_usage_{species}.csv'
# fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx'
# dirout = f'{WDIR}/predict_web/'
parser = get_pretraining_args()
args = parser.parse_args()
# config parameters
# args.downstream_data_path = 'maotao_file/'
# args.predict =True
# args.out_dir = 'maotao_exp/test'
# args.task = 'AA2CDS_data'
# args.mlm_pretrained_model_path = args.out_dir + '/AA2CDS.pth'
tmp_dir = dirout+'/tmp/'
# os.system(f'rm -rf {tmp_dir}')
check_path(tmp_dir)
args.downstream_data_path = tmp_dir
args.predict =True
args.out_dir = tmp_dir
args.task = 'AA2CDS_data/'
args.mlm_pretrained_model_path = 'checkpoint/AA2CDS.pth'
WDIR = os.path.join(args.downstream_data_path,args.task)
check_path(WDIR)
# fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx'
fpred = f'{WDIR}/TS_pred.csv'
ftruncated = f'{WDIR}/TS.csv'
fout = f'{dirout}/Tests.xlsx'
'''process inputs'''
process_inputs(fin=fin, dirout=os.path.dirname(fpred), codon_table=args.codon_table_path)
'''predict'''
inference(args)
# '''assemble'''
process_result(fin=fin,fpred=fpred,
ftruncated=ftruncated,
fout=fout)
if __name__ == '__main__':
print('start', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
start = time.time()
# fin = '/Users/gz_julse/Data/maotao/2025_bio-os_data/Tests.xlsx'
'''round 1'''
# fin = 'example/Tests.xlsx'
# # dirout = 'maotao_exp/test/'
# dirout = os.path.abspath('example/out')
# os.system(f'rm -rf {dirout}')
# # --limit=320 --batch_size=12 --epoch=2 --out_dir=maotao_exp/test --learning_rate=0.000001 --predict --debug
# predict(fin,dirout)
# os.system(f'cp {dirout}/Tests.xlsx Tests.xlsx')
'''round2 for experiment'''
fin = 'example/Tests_round3.xlsx'
# dirout = 'maotao_exp/test/'
dirout = os.path.abspath('example/out_round3')
os.system(f'rm -rf {dirout}')
predict(fin,dirout)
print('stop', time.strftime('%Y-%m-%d %H:%M:%S', time.localtime()))
print('time', time.time() - start)