| |
| import torch |
| import torch.nn.functional as F |
| import math |
| import random |
| import sys |
| from diffusion import Diffusion |
| import hydra |
| from tqdm import tqdm |
| import matplotlib.pyplot as plt |
| import os |
| import seaborn as sns |
| import pandas as pd |
| import numpy as np |
| import argparse |
| |
| from diffusion import Diffusion |
| from hydra import initialize, compose |
| from hydra.core.global_hydra import GlobalHydra |
| import numpy as np |
| import torch |
| import argparse |
| import os |
| import datetime |
| from utils.utils import str2bool, set_seed |
|
|
| |
| from utils.app import PeptideAnalyzer |
| from peptide_mcts import MCTS |
|
|
| @torch.no_grad() |
| def generate_mcts(args, cfg, policy_model, pretrained, prot=None, prot_name=None, filename=None): |
| |
| score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability'] |
| |
| mcts = MCTS(args, cfg, policy_model, pretrained, score_func_names, prot_seqs=[prot]) |
| |
| final_x, log_rnd, final_rewards, score_vectors, sequences = mcts.forward() |
| |
| return final_x, log_rnd, final_rewards, score_vectors, sequences |
|
|
| def save_logs_to_file(reward_log, logrnd_log, |
| valid_fraction_log, affinity1_log, |
| sol_log, hemo_log, nf_log, |
| permeability_log, output_path): |
| """ |
| Saves the logs (valid_fraction_log, affinity1_log, and permeability_log) to a CSV file. |
| |
| Parameters: |
| valid_fraction_log (list): Log of valid fractions over iterations. |
| affinity1_log (list): Log of binding affinity over iterations. |
| permeability_log (list): Log of membrane permeability over iterations. |
| output_path (str): Path to save the log CSV file. |
| """ |
| os.makedirs(os.path.dirname(output_path), exist_ok=True) |
| |
| |
| log_data = { |
| "Iteration": list(range(1, len(valid_fraction_log) + 1)), |
| "Reward": reward_log, |
| "Log RND": logrnd_log, |
| "Valid Fraction": valid_fraction_log, |
| "Binding Affinity": affinity1_log, |
| "Solubility": sol_log, |
| "Hemolysis": hemo_log, |
| "Nonfouling": nf_log, |
| "Permeability": permeability_log |
| } |
| |
| df = pd.DataFrame(log_data) |
| |
| |
| df.to_csv(output_path, index=False) |
|
|
| argparser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| argparser.add_argument('--base_path', type=str, default='') |
| argparser.add_argument('--learning_rate', type=float, default=1e-4) |
| argparser.add_argument('--num_epochs', type=int, default=1000) |
| argparser.add_argument('--num_accum_steps', type=int, default=4) |
| argparser.add_argument('--truncate_steps', type=int, default=50) |
| argparser.add_argument("--truncate_kl", type=str2bool, default=False) |
| argparser.add_argument('--gumbel_temp', type=float, default=1.0) |
| argparser.add_argument('--gradnorm_clip', type=float, default=1.0) |
| argparser.add_argument('--batch_size', type=int, default=32) |
| argparser.add_argument('--name', type=str, default='debug') |
| argparser.add_argument('--total_num_steps', type=int, default=128) |
| argparser.add_argument('--copy_flag_temp', type=float, default=None) |
| argparser.add_argument('--save_every_n_epochs', type=int, default=50) |
| argparser.add_argument('--alpha_schedule_warmup', type=int, default=0) |
| argparser.add_argument("--seed", type=int, default=0) |
| |
| argparser.add_argument('--run_name', type=str, default='drakes') |
| argparser.add_argument("--device", default="cuda", type=str) |
|
|
| |
| argparser.add_argument('--num_sequences', type=int, default=100) |
| argparser.add_argument('--num_children', type=int, default=20) |
| argparser.add_argument('--num_iter', type=int, default=100) |
| argparser.add_argument('--seq_length', type=int, default=200) |
| argparser.add_argument('--time_conditioning', action='store_true', default=False) |
| argparser.add_argument('--mcts_sampling', type=int, default=0) |
| argparser.add_argument('--buffer_size', type=int, default=100) |
| argparser.add_argument('--wdce_num_replicates', type=int, default=16) |
| argparser.add_argument('--noise_removal', action='store_true', default=False) |
| argparser.add_argument('--exploration', type=float, default=0.1) |
| argparser.add_argument('--reset_every_n_step', type=int, default=100) |
| argparser.add_argument('--alpha', type=float, default=0.01) |
| argparser.add_argument('--scalarization', type=str, default='sum') |
| argparser.add_argument('--no_mcts', action='store_true', default=False) |
| argparser.add_argument("--centering", action='store_true', default=False) |
| argparser.add_argument('--num_obj', type=int, default=5) |
|
|
| argparser.add_argument('--prot_seq', type=str, default=None) |
| argparser.add_argument('--prot_name', type=str, default=None) |
|
|
| args = argparser.parse_args() |
| print(args) |
| |
| |
| ckpt_path = f'{args.base_path}/TR2-D2/tr2d2-pep/pretrained/peptune-pretrained.ckpt' |
|
|
| |
| GlobalHydra.instance().clear() |
|
|
| |
| initialize(config_path="configs", job_name="load_model") |
| cfg = compose(config_name="config.yaml") |
| curr_time = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| set_seed(args.seed, use_cuda=True) |
|
|
| |
| amhr = 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV' |
| tfr = 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF' |
| gfap = 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM' |
| glp1 = 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS' |
| glast = 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM' |
| ncam = 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF' |
| cereblon = 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL' |
| ligase = 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS' |
| skp2 = 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL' |
|
|
| if args.prot_seq is not None: |
| prot = args.prot_seq |
| prot_name = args.prot_name |
| filename = args.prot_name |
| else: |
| prot = tfr |
| prot_name = "tfr" |
| filename = "tfr" |
|
|
| |
| new_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device) |
| old_model = Diffusion.load_from_checkpoint(ckpt_path, config=cfg, strict=False, map_location=args.device) |
|
|
| with torch.no_grad(): |
| final_x, log_rnd, final_rewards, score_vectors, sequences = generate_mcts(args, cfg, new_model, old_model, |
| prot=prot, prot_name=prot_name) |
| |
| final_x = final_x.detach().to('cpu') |
| log_rnd = log_rnd.detach().to('cpu').float().view(-1) |
| |
|
|
| print("loaded models...") |
| analyzer = PeptideAnalyzer() |
|
|
| generation_results = [] |
|
|
| for i in range(final_x.shape[0]): |
| sequence = sequences[i] |
| log_rnd_single = log_rnd[i] |
| final_reward = final_rewards[i] |
| |
| aa_seq, seq_length = analyzer.analyze_structure(sequence) |
| |
| scores = score_vectors[i] |
| |
| binding1 = scores[0] |
| solubility = scores[1] |
| hemo = scores[2] |
| nonfouling = scores[3] |
| permeability = scores[4] |
| |
| generation_results.append([sequence, aa_seq, final_reward, log_rnd_single, binding1, solubility, hemo, nonfouling, permeability]) |
| print(f"length: {seq_length} | smiles sequence: {sequence} | amino acid sequence: {aa_seq} | Binding Affinity: {binding1} | Solubility: {solubility} | Hemolysis: {hemo} | Nonfouling: {nonfouling} | Permeability: {permeability}") |
|
|
| sys.stdout.flush() |
|
|
| df = pd.DataFrame(generation_results, columns=['Generated SMILES', 'Peptide Sequence', 'Final Reward', 'Log RND', 'Binding Affinity', 'Solubility', 'Hemolysis', 'Nonfouling', 'Permeability']) |
|
|
|
|
| df.to_csv(f'{args.base_path}/TR2-D2/tr2d2-pep/plots/{prot_name}-peptune-baseline/generation_results.csv', index=False) |
|
|