File size: 17,977 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 | """
Evaluate a finetuned peptide model checkpoint by sampling sequences
and computing metrics for the De Novo Peptide Generation table:
Validity (%), Affinity (↑), Solubility (↑), Hemolysis (↑),
Nonfouling (↑), Permeability (↑), Sampling Time (↓)
"""
import os
import sys
import argparse
import time
import torch
import numpy as np
import pandas as pd
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, REPO_ROOT)
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
from lightning_modules import AnyOrderInsertionFlowModule
from inference_quality import sample_peptides_eval
from pep_scoring.scoring_functions import ScoringFunctions
from pep_utils.analyzer import PeptideAnalyzer
from pep_scoring.tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
from finetune_quality import PeptideFinetuner
from pep_utils.utils import str2bool, set_seed
from tdc import Evaluator
# Protein sequences
PROTEINS = {
'amhr': 'MLGSLGLWALLPTAVEAPPNRRTCVFFEAPGVRGSTKTLGELLDTGTELPRAIRCLYSRCCFGIWNLTQDRAQVEMQGCRDSDEPGCESLHCDPSPRAHPSPGSTLFTCSCGTDFCNANYSHLPPPGSPGTPGSQGPQAAPGESIWMALVLLGLFLLLLLLLGSIILALLQRKNYRVRGEPVPEPRPDSGRDWSVELQELPELCFSQVIREGGHAVVWAGQLQGKLVAIKAFPPRSVAQFQAERALYELPGLQHDHIVRFITASRGGPGRLLSGPLLVLELHPKGSLCHYLTQYTSDWGSSLRMALSLAQGLAFLHEERWQNGQYKPGIAHRDLSSQNVLIREDGSCAIGDLGLALVLPGLTQPPAWTPTQPQGPAAIMEAGTQRYMAPELLDKTLDLQDWGMALRRADIYSLALLLWEILSRCPDLRPDSSPPPFQLAYEAELGNTPTSDELWALAVQERRRPYIPSTWRCFATDPDGLRELLEDCWDADPEARLTAECVQQRLAALAHPQESHPFPESCPRGCPPLCPEDCTSIPAPTILPCRPQRSACHFSVQQGPCSRNPQPACTLSPV',
'tfr': 'MMDQARSAFSNLFGGEPLSYTRFSLARQVDGDNSHVEMKLAVDEEENADNNTKANVTKPKRCSGSICYGTIAVIVFFLIGFMIGYLGYCKGVEPKTECERLAGTESPVREEPGEDFPAARRLYWDDLKRKLSEKLDSTDFTGTIKLLNENSYVPREAGSQKDENLALYVENQFREFKLSKVWRDQHFVKIQVKDSAQNSVIIVDKNGRLVYLVENPGGYVAYSKAATVTGKLVHANFGTKKDFEDLYTPVNGSIVIVRAGKITFAEKVANAESLNAIGVLIYMDQTKFPIVNAELSFFGHAHLGTGDPYTPGFPSFNHTQFPPSRSSGLPNIPVQTISRAAAEKLFGNMEGDCPSDWKTDSTCRMVTSESKNVKLTVSNVLKEIKILNIFGVIKGFVEPDHYVVVGAQRDAWGPGAAKSGVGTALLLKLAQMFSDMVLKDGFQPSRSIIFASWSAGDFGSVGATEWLEGYLSSLHLKAFTYINLDKAVLGTSNFKVSASPLLYTLIEKTMQNVKHPVTGQFLYQDSNWASKVEKLTLDNAAFPFLAYSGIPAVSFCFCEDTDYPYLGTTMDTYKELIERIPELNKVARAAAEVAGQFVIKLTHDVELNLDYERYNSQLLSFVRDLNQYRADIKEMGLSLQWLYSARGDFFRATSRLTTDFGNAEKTDRFVMKKLNDRVMRVEYHFLSPYVSPKESPFRHVFWGSGSHTLPALLENLKLRKQNNGAFNETLFRNQLALATWTIQGAANALSGDVWDIDNEF',
'gfap': 'MERRRITSAARRSYVSSGEMMVGGLAPGRRLGPGTRLSLARMPPPLPTRVDFSLAGALNAGFKETRASERAEMMELNDRFASYIEKVRFLEQQNKALAAELNQLRAKEPTKLADVYQAELRELRLRLDQLTANSARLEVERDNLAQDLATVRQKLQDETNLRLEAENNLAAYRQEADEATLARLDLERKIESLEEEIRFLRKIHEEEVRELQEQLARQQVHVELDVAKPDLTAALKEIRTQYEAMASSNMHEAEEWYRSKFADLTDAAARNAELLRQAKHEANDYRRQLQSLTCDLESLRGTNESLERQMREQEERHVREAASYQEALARLEEEGQSLKDEMARHLQEYQDLLNVKLALDIEIATYRKLLEGEENRITIPVQTFSNLQIRETSLDTKSVSEGHLKRNIVVKTVEMRDGEVIKESKQEHKDVM',
'glp1': 'MAGAPGPLRLALLLLGMVGRAGPRPQGATVSLWETVQKWREYRRQCQRSLTEDPPPATDLFCNRTFDEYACWPDGEPGSFVNVSCPWYLPWASSVPQGHVYRFCTAEGLWLQKDNSSLPWRDLSECEESKRGERSSPEEQLLFLYIIYTVGYALSFSALVIASAILLGFRHLHCTRNYIHLNLFASFILRALSVFIKDAALKWMYSTAAQQHQWDGLLSYQDSLSCRLVFLLMQYCVAANYYWLLVEGVYLYTLLAFSVLSEQWIFRLYVSIGWGVPLLFVVPWGIVKYLYEDEGCWTRNSNMNYWLIIRLPILFAIGVNFLIFVRVICIVVSKLKANLMCKTDIKCRLAKSTLTLIPLLGTHEVIFAFVMDEHARGTLRFIKLFTELSFTSFQGLMVAILYCFVNNEVQLEFRKSWERWRLEHLHIQRDSSMKPLKCPTSSLSSGATAGSSMYTATCQASCS',
'glast': 'MTKSNGEEPKMGGRMERFQQGVRKRTLLAKKKVQNITKEDVKSYLFRNAFVLLTVTAVIVGTILGFTLRPYRMSYREVKYFSFPGELLMRMLQMLVLPLIISSLVTGMAALDSKASGKMGMRAVVYYMTTTIIAVVIGIIIVIIIHPGKGTKENMHREGKIVRVTAADAFLDLIRNMFPPNLVEACFKQFKTNYEKRSFKVPIQANETLVGAVINNVSEAMETLTRITEELVPVPGSVNGVNALGLVVFSMCFGFVIGNMKEQGQALREFFDSLNEAIMRLVAVIMWYAPVGILFLIAGKIVEMEDMGVIGGQLAMYTVTVIVGLLIHAVIVLPLLYFLVTRKNPWVFIGGLLQALITALGTSSSSATLPITFKCLEENNGVDKRVTRFVLPVGATINMDGTALYEALAAIFIAQVNNFELNFGQIITISITATAASIGAAGIPQAGLVTMVIVLTSVGLPTDDITLIIAVDWFLDRLRTTTNVLGDSLGAGIVEHLSRHELKNRDVEMGNSVIEENEMKKPYQLIAQDNETEKPIDSETKM',
'ncam': 'LQTKDLIWTLFFLGTAVSLQVDIVPSQGEISVGESKFFLCQVAGDAKDKDISWFSPNGEKLTPNQQRISVVWNDDSSSTLTIYNANIDDAGIYKCVVTGEDGSESEATVNVKIFQKLMFKNAPTPQEFREGEDAVIVCDVVSSLPPTIIWKHKGRDVILKKDVRFIVLSNNYLQIRGIKKTDEGTYRCEGRILARGEINFKDIQVIVNVPPTIQARQNIVNATANLGQSVTLVCDAEGFPEPTMSWTKDGEQIEQEEDDEKYIFSDDSSQLTIKKVDKNDEAEYICIAENKAGEQDATIHLKVFAKPKITYVENQTAMELEEQVTLTCEASGDPIPSITWRTSTRNISSEEKASWTRPEKQETLDGHMVVRSHARVSSLTLKSIQYTDAGEYICTASNTIGQDSQSMYLEVQYAPKLQGPVAVYTWEGNQVNITCEVFAYPSATISWFRDGQLLPSSNYSNIKIYNTPSASYLEVTPDSENDFGNYNCTAVNRIGQESLEFILVQADTPSSPSIDQVEPYSSTAQVQFDEPEATGGVPILKYKAEWRAVGEEVWHSKWYDAKEASMEGIVTIVGLKPETTYAVRLAALNGKGLGEISAASEF',
'cereblon': 'MAGEGDQQDAAHNMGNHLPLLPAESEEEDEMEVEDQDSKEAKKPNIINFDTSLPTSHTYLGADMEEFHGRTLHDDDSCQVIPVLPQVMMILIPGQTLPLQLFHPQEVSMVRNLIQKDRTFAVLAYSNVQEREAQFGTTAEIYAYREEQDFGIEIVKVKAIGRQRFKVLELRTQSDGIQQAKVQILPECVLPSTMSAVQLESLNKCQIFPSKPVSREDQCSYKWWQKYQKRKFHCANLTSWPRWLYSLYDAETLMDRIKKQLREWDENLKDDSLPSNPIDFSYRVAACLPIDDVLRIQLLKIGSAIQRLRCELDIMNKCTSLCCKQCQETEITTKNEIFSLSLCGPMAAYVNPHGYVHETLTVYKACNLNLIGRPSTEHSWFPGYAWTVAQCKICASHIGWKFTATKKDMSPQKFWGLTRSALLPTIPDTEDEISPDKVILCL',
'ligase': 'MASQPPEDTAESQASDELECKICYNRYNLKQRKPKVLECCHRVCAKCLYKIIDFGDSPQGVIVCPFCRFETCLPDDEVSSLPDDNNILVNLTCGGKGKKCLPENPTELLLTPKRLASLVSPSHTSSNCLVITIMEVQRESSPSLSSTPVVEFYRPASFDSVTTVSHNWTVWNCTSLLFQTSIRVLVWLLGLLYFSSLPLGIYLLVSKKVTLGVVFVSLVPSSLVILMVYGFCQCVCHEFLDCMAPPS',
'skp2': 'MHRKHLQEIPDLSSNVATSFTWGWDSSKTSELLSGMGVSALEKEEPDSENIPQELLSNLGHPESPPRKRLKSKGSDKDFVIVRRPKLNRENFPGVSWDSLPDELLLGIFSCLCLPELLKVSGVCKRWYRLASDESLWQTLDLTGKNLHPDVTGRLLSQGVIAFRCPRSFMDQPLAEHFSPFRVQHMDLSNSVIEVSTLHGILSQCSKLQNLSLEGLRLSDPIVNTLAKNSNLVRLNLSGCSGFSEFALQTLLSSCSRLDELNLSWCFDFTEKHVQVAVAHVSETITQLNLSGYRKNLQKSDLSTLVRRCPNLVHLDLSDSVMLKNDCFQEFFQLNYLQHLSLSRCYDIIPETLLELGEIPTLKTLQVFGIVPDGTLQLLKEALPHLQINCSHFTTIARPTIGNKKNQEIWGIKCRLTLQKPSCL',
}
def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
"""Load a finetuned PeptideFinetuner from a Lightning checkpoint."""
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', {})
args = hparams.get('args', None)
# Load pretrained base checkpoint to get config
base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
if 'hyper_parameters' in base_ckpt:
config = base_ckpt['hyper_parameters']['config']
elif 'config' in base_ckpt:
config = base_ckpt['config']
else:
raise ValueError("Cannot find config in base checkpoint")
from omegaconf import OmegaConf, DictConfig
if not OmegaConf.is_config(config):
config = DictConfig(config)
OmegaConf.set_struct(config, False)
config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
OmegaConf.set_struct(config, True)
disable_planner = getattr(args, 'disable_planner', False)
policy_model = AnyOrderInsertionFlowModuleFT(
config=config,
args=args,
pretrained_checkpoint=pretrained_ckpt_path,
insertion_planner=not disable_planner,
)
# Load finetuned weights
state_dict = ckpt['state_dict']
policy_state = {}
for k, v in state_dict.items():
if k.startswith('policy_model.'):
policy_state[k[len('policy_model.'):]] = v
policy_model.load_state_dict(policy_state, strict=False)
policy_model = policy_model.to(device)
policy_model.eval()
return policy_model, args, config
@torch.no_grad()
def evaluate_checkpoint(policy_model, tokenizer, reward_model, analyzer,
num_samples=1000, batch_size=50, max_length=512,
total_num_steps=256, quality_mode="both", num_remasking=3,
quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
"""
Sample `num_samples` peptides and compute all table metrics.
Returns a dict with: validity, affinity, sol, hemo, nf, permeability, sampling_time
"""
all_affinity = []
all_sol = []
all_hemo = []
all_nf = []
all_permeability = []
all_valid_seqs = []
total_valid = 0
total_generated = 0
total_time = 0.0
num_batches = (num_samples + batch_size - 1) // batch_size
remaining = num_samples
for b in range(num_batches):
bs = min(batch_size, remaining)
remaining -= bs
t_start = time.time()
result = sample_peptides_eval(
model=policy_model,
reward_model=reward_model,
analyzer=analyzer,
tokenizer=tokenizer,
steps=total_num_steps,
mask=policy_model.interpolant.mask_token,
pad=policy_model.interpolant.pad_token,
batch_size=bs,
max_length=max_length,
quality_mode=quality_mode,
num_remasking=num_remasking,
quality_threshold=quality_threshold,
unmask_quality_threshold=unmask_quality_threshold,
return_valid=True,
)
t_end = time.time()
# Unpack: validSequences, affinity, sol, hemo, nf, permeability, valid_fraction
valid_seqs, affinity, sol, hemo, nf, permeability, valid_fraction = result
batch_valid = len(valid_seqs)
total_valid += batch_valid
total_generated += bs
total_time += (t_end - t_start)
all_valid_seqs.extend(valid_seqs)
if isinstance(affinity, (list, np.ndarray)) and len(affinity) > 0:
all_affinity.extend(affinity if isinstance(affinity, list) else affinity.tolist())
all_sol.extend(sol if isinstance(sol, list) else sol.tolist())
all_hemo.extend(hemo if isinstance(hemo, list) else hemo.tolist())
all_nf.extend(nf if isinstance(nf, list) else nf.tolist())
all_permeability.extend(permeability if isinstance(permeability, list) else permeability.tolist())
print(f" Batch {b+1}/{num_batches}: {batch_valid}/{bs} valid, "
f"time={t_end - t_start:.1f}s")
validity = total_valid / total_generated * 100.0 if total_generated > 0 else 0.0
# Uniqueness (% of valid sequences that are unique) and
# Diversity (1 - mean pairwise Tanimoto on Morgan FPs of unique sequences).
# Matches the convention used in evaluate_mol_table.py.
all_unique = list(set(all_valid_seqs))
num_unique = len(all_unique)
uniqueness = num_unique / total_valid * 100.0 if total_valid > 0 else 0.0
if num_unique > 1:
diversity = Evaluator('diversity')(all_unique)
else:
diversity = 0.0
metrics = {
'Validity (%)': validity,
'Uniqueness (%)': uniqueness,
'Diversity': diversity,
'Affinity': np.mean(all_affinity) if all_affinity else 0.0,
'Affinity Std': np.std(all_affinity) if all_affinity else 0.0,
'Solubility': np.mean(all_sol) if all_sol else 0.0,
'Solubility Std': np.std(all_sol) if all_sol else 0.0,
'Hemolysis': np.mean(all_hemo) if all_hemo else 0.0,
'Hemolysis Std': np.std(all_hemo) if all_hemo else 0.0,
'Nonfouling': np.mean(all_nf) if all_nf else 0.0,
'Nonfouling Std': np.std(all_nf) if all_nf else 0.0,
'Permeability': np.mean(all_permeability) if all_permeability else 0.0,
'Permeability Std': np.std(all_permeability) if all_permeability else 0.0,
'Sampling Time (s)': total_time,
'Num Generated': total_generated,
'Num Valid': total_valid,
'Num Unique': num_unique,
}
return metrics
def main():
parser = argparse.ArgumentParser(description="Evaluate a finetuned peptide checkpoint")
parser.add_argument('--checkpoint_path', type=str, required=True,
help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
parser.add_argument('--pretrained_ckpt', type=str,
default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_pep.ckpt'),
help='Path to the pretrained base model checkpoint')
parser.add_argument('--num_samples', type=int, default=500,
help='Number of peptides to sample')
parser.add_argument('--batch_size', type=int, default=50,
help='Batch size for sampling')
parser.add_argument('--max_length', type=int, default=512)
parser.add_argument('--total_num_steps', type=int, default=256)
parser.add_argument('--num_remasking', type=int, default=3)
parser.add_argument('--quality_threshold', type=float, default=0.5,
help='Threshold for insertion quality filtering during sampling')
parser.add_argument('--unmask_quality_threshold', type=float, default=None,
help='If set, gate unmasking/remasking by confidence: remask '
'ALL clean tokens whose unmasking confidence is below this '
'threshold, regardless of the schedule budget. If unset '
'(default), remasking is purely schedule-driven (count-based).')
parser.add_argument('--prot_name', type=str, default='glast',
help='Target protein name (must be one of: ' + ', '.join(PROTEINS.keys()) + ')')
parser.add_argument('--prot_seq', type=str, default=None,
help='Custom protein sequence (overrides --prot_name)')
parser.add_argument('--disable_planner', action='store_true',
help='If set, disable remasking during evaluation')
parser.add_argument('--disable_insertion_planner', action='store_true',
help='If set, disable insertion quality filtering during evaluation')
parser.add_argument('--disable_unmasking_planner', action='store_true',
help='If set, disable unmasking confidence planner during evaluation')
parser.add_argument('--output_dir', type=str, default=None,
help='Directory to save results CSV. Defaults to checkpoint directory.')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
set_seed(args.seed, use_cuda=True)
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
# Map flags to quality_mode
if args.disable_planner:
quality_mode = "none"
elif args.disable_insertion_planner and args.disable_unmasking_planner:
quality_mode = "none"
elif args.disable_insertion_planner:
quality_mode = "unmasking_only"
elif args.disable_unmasking_planner:
quality_mode = "insertion_only"
else:
quality_mode = "both"
print(f"Loading checkpoint: {args.checkpoint_path}")
print(f"Pretrained base: {args.pretrained_ckpt}")
print(f"Quality mode: {quality_mode}")
policy_model, train_args, config = load_finetuned_model(
args.checkpoint_path, args.pretrained_ckpt, device=device
)
# Setup tokenizer, reward model, analyzer
tokenizer = SMILES_SPE_Tokenizer(
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_vocab.txt'),
os.path.join(REPO_ROOT, 'a2d2_pep', 'pep_scoring', 'tokenizer', 'new_splits.txt')
)
if args.prot_seq is not None:
prot = args.prot_seq
prot_name = args.prot_name
else:
prot_name = args.prot_name
if prot_name not in PROTEINS:
raise ValueError(f"Unknown protein: {prot_name}. Choose from: {list(PROTEINS.keys())}")
prot = PROTEINS[prot_name]
score_func_names = ['binding_affinity1', 'solubility', 'hemolysis', 'nonfouling', 'permeability']
reward_model = ScoringFunctions(score_func_names, prot_seqs=[prot], device=device)
analyzer = PeptideAnalyzer()
print(f"\nSampling {args.num_samples} peptides (quality_mode={quality_mode}, target={prot_name})...")
metrics = evaluate_checkpoint(
policy_model=policy_model,
tokenizer=tokenizer,
reward_model=reward_model,
analyzer=analyzer,
num_samples=args.num_samples,
batch_size=args.batch_size,
max_length=args.max_length,
total_num_steps=args.total_num_steps,
quality_mode=quality_mode,
num_remasking=args.num_remasking,
quality_threshold=args.quality_threshold,
unmask_quality_threshold=args.unmask_quality_threshold,
device=device,
)
# Print summary table
print("\n" + "=" * 60)
print(" De Novo Peptide Generation Results")
print("=" * 60)
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k:<30s}: {v:.4f}")
else:
print(f" {k:<30s}: {v}")
print("=" * 60)
# Save results
output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
os.makedirs(output_dir, exist_ok=True)
if args.disable_planner:
tag = "no_planner"
elif args.disable_insertion_planner:
tag = "no_insertion_planner"
elif args.disable_unmasking_planner:
tag = "no_unmasking_planner"
else:
tag = "with_planner"
if args.unmask_quality_threshold is not None:
tag += f"_ut{args.unmask_quality_threshold:g}"
# Record the sweep parameter in the saved row for traceability.
metrics['unmask_quality_threshold'] = args.unmask_quality_threshold
metrics['quality_threshold'] = args.quality_threshold
metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}_{prot_name}.csv')
pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
print(f"Metrics saved to: {metrics_path}")
if __name__ == '__main__':
main()
|