File size: 12,664 Bytes
8019be0 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 | """
Evaluate a finetuned molecule model checkpoint by sampling sequences
and computing metrics for the De Novo Small Molecule Generation table:
Validity (%), Uniqueness (%), QED (↑), SA (↓), Quality (%), Diversity (↑), Sampling Time (↓)
"""
import os
import sys
import argparse
import time
import torch
import numpy as np
import pandas as pd
from tdc import Oracle, Evaluator
# add repo root (A2D2/) to sys.path so top-level packages like lightning_modules resolve
REPO_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.insert(0, REPO_ROOT)
from lightning_modules.any_length_remask import AnyOrderInsertionFlowModuleFT
from lightning_modules import AnyOrderInsertionFlowModule
from inference_quality_mol import sample_mol_eval
from mol_scoring.scoring_functions import MolScoringFunctions
from finetune_mol import MolFinetuner, get_tokenizer
from mol_utils.utils import str2bool, set_seed
def load_finetuned_model(checkpoint_path, pretrained_ckpt_path, device='cuda'):
"""Load a finetuned MolFinetuner from a Lightning checkpoint."""
# We need to reconstruct the model the same way main() does, then load state
# Load from Lightning checkpoint directly
ckpt = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
hparams = ckpt.get('hyper_parameters', {})
args = hparams.get('args', None)
# Load pretrained base checkpoint to get config
base_ckpt = torch.load(pretrained_ckpt_path, map_location='cpu', weights_only=False)
if 'hyper_parameters' in base_ckpt:
config = base_ckpt['hyper_parameters']['config']
elif 'config' in base_ckpt:
config = base_ckpt['config']
else:
raise ValueError("Cannot find config in base checkpoint")
from omegaconf import OmegaConf, DictConfig
if not OmegaConf.is_config(config):
config = DictConfig(config)
OmegaConf.set_struct(config, False)
# Set adaptive schedule config from args or defaults
config.training.use_adaptive_schedule = getattr(args, 'use_adaptive_schedule', True)
config.training.schedule_hidden_dim = getattr(args, 'schedule_hidden_dim', 256)
config.training.schedule_num_layers = getattr(args, 'schedule_num_layers', 2)
config.training.schedule_loss_weight = getattr(args, 'schedule_loss_weight', 0.1)
config.training.freeze_base_model = getattr(args, 'freeze_base_model', False)
config.training.schedule_warmup_epochs = getattr(args, 'schedule_warmup_epochs', 0)
config.training.use_bracket_safe = True
OmegaConf.set_struct(config, True)
# Determine if planner should be loaded based on disable_planner flag
disable_planner = getattr(args, 'disable_planner', False)
# Initialize policy model
policy_model = AnyOrderInsertionFlowModuleFT(
config=config,
args=args,
pretrained_checkpoint=pretrained_ckpt_path,
insertion_planner=not disable_planner,
)
# Load policy model weights from the finetuned checkpoint
state_dict = ckpt['state_dict']
# Lightning wraps the model: 'policy_model.xxx' -> remove prefix for the sub-module
policy_state = {}
for k, v in state_dict.items():
if k.startswith('policy_model.'):
policy_state[k[len('policy_model.'):]] = v
policy_model.load_state_dict(policy_state, strict=False)
policy_model = policy_model.to(device)
policy_model.eval()
return policy_model, args, config
@torch.no_grad()
def evaluate_checkpoint(policy_model, tokenizer, reward_model, evaluator,
num_samples=1000, batch_size=50, max_length=256,
total_num_steps=256, quality_mode="both", num_remasking=2,
quality_threshold=0.5, unmask_quality_threshold=None, device='cuda'):
"""
Sample `num_samples` molecules and compute all table metrics.
Returns a dict with: validity, uniqueness, qed, sa, quality, diversity, sampling_time
"""
all_valid_seqs = []
all_smiles_generated = 0
total_time = 0.0
num_batches = (num_samples + batch_size - 1) // batch_size
remaining = num_samples
for b in range(num_batches):
bs = min(batch_size, remaining)
remaining -= bs
t_start = time.time()
result = sample_mol_eval(
model=policy_model,
reward_model=reward_model,
tokenizer=tokenizer,
steps=total_num_steps,
mask=policy_model.interpolant.mask_token,
pad=policy_model.interpolant.pad_token,
batch_size=bs,
max_length=max_length,
quality_mode=quality_mode,
num_remasking=num_remasking,
quality_threshold=quality_threshold,
unmask_quality_threshold=unmask_quality_threshold,
evaluator=evaluator,
dataframe=True,
)
t_end = time.time()
# Unpack: uniqueSequences, qed, sa, valid_fraction, uniqueness, diversity, quality, df
unique_seqs, qed_scores, sa_scores, valid_frac, uniq, div, qual, df = result
all_valid_seqs.extend(list(unique_seqs) if not isinstance(unique_seqs, list) else unique_seqs)
all_smiles_generated += bs
total_time += (t_end - t_start)
print(f" Batch {b+1}/{num_batches}: {len(unique_seqs)} valid unique, "
f"time={t_end - t_start:.1f}s")
# --- Aggregate metrics over all samples ---
total_generated = num_samples
# Valid sequences (keeping duplicates for validity count)
# Re-evaluate from scratch on all collected valid sequences
all_unique = list(set(all_valid_seqs))
num_valid = len(all_valid_seqs) # total valid across batches (before dedup)
num_unique = len(all_unique)
validity = num_valid / total_generated * 100.0
uniqueness = num_unique / num_valid * 100.0 if num_valid > 0 else 0.0
# Diversity on unique SMILES
diversity = evaluator(all_unique) if num_unique > 1 else 0.0
# QED and SA on unique sequences
if num_unique > 0:
oracle_qed = Oracle('qed')
oracle_sa = Oracle('sa')
qed_vals = oracle_qed(all_unique)
sa_vals = oracle_sa(all_unique)
mean_qed = np.mean(qed_vals)
mean_sa = np.mean(sa_vals)
# Quality: unique sequences with QED >= 0.6 AND SA <= 4
quality_mask = [(q >= 0.6 and s <= 4) for q, s in zip(qed_vals, sa_vals)]
quality = sum(quality_mask) / total_generated * 100.0
else:
mean_qed = 0.0
mean_sa = 0.0
quality = 0.0
sampling_time = total_time
metrics = {
'Validity (%)': validity,
'Uniqueness (%)': uniqueness,
'QED': mean_qed,
'Synthetic Accessibility': mean_sa,
'Quality (%)': quality,
'Diversity': diversity,
'Sampling Time (s)': sampling_time,
'Num Generated': total_generated,
'Num Valid': num_valid,
'Num Unique': num_unique,
}
return metrics, all_unique, qed_vals if num_unique > 0 else [], sa_vals if num_unique > 0 else []
def main():
parser = argparse.ArgumentParser(description="Evaluate a finetuned mol checkpoint")
parser.add_argument('--checkpoint_path', type=str, required=True,
help='Path to the finetuned Lightning checkpoint (e.g., last.ckpt)')
parser.add_argument('--pretrained_ckpt', type=str,
default=os.path.join(REPO_ROOT, 'pretrained', 'anylength_mol.ckpt'),
help='Path to the pretrained base model checkpoint '
'(defaults to <repo>/pretrained/anylength_mol.ckpt)')
parser.add_argument('--num_samples', type=int, default=1000,
help='Number of molecules to sample')
parser.add_argument('--batch_size', type=int, default=50,
help='Batch size for sampling')
parser.add_argument('--max_length', type=int, default=256)
parser.add_argument('--total_num_steps', type=int, default=256)
parser.add_argument('--num_remasking', type=int, default=2)
parser.add_argument('--disable_planner', action='store_true',
help='If set, disable remasking during evaluation (matches training mode)')
parser.add_argument('--disable_insertion_planner', action='store_true',
help='If set, disable insertion quality filtering during evaluation')
parser.add_argument('--disable_unmasking_planner', action='store_true',
help='If set, disable unmasking confidence planner during evaluation')
parser.add_argument('--quality_threshold', type=float, default=0.5,
help='Threshold for insertion quality filtering during sampling')
parser.add_argument('--unmask_quality_threshold', type=float, default=None,
help='If set, gate unmasking remasking on confidence: remask clean '
'tokens whose remasking_conf < threshold (overrides the '
'schedule-driven count). Default None = schedule-driven behavior.')
parser.add_argument('--output_dir', type=str, default=None,
help='Directory to save results CSV. Defaults to checkpoint directory.')
parser.add_argument('--device', type=str, default='cuda:0')
parser.add_argument('--seed', type=int, default=42)
args = parser.parse_args()
set_seed(args.seed, use_cuda=True)
device = torch.device(args.device if torch.cuda.is_available() else 'cpu')
print(f"Loading checkpoint: {args.checkpoint_path}")
print(f"Pretrained base: {args.pretrained_ckpt}")
print(f"Disable planner (no remasking): {args.disable_planner}")
print(f"Disable insertion planner: {args.disable_insertion_planner}")
print(f"Disable unmasking planner: {args.disable_unmasking_planner}")
policy_model, train_args, config = load_finetuned_model(
args.checkpoint_path, args.pretrained_ckpt, device=device
)
tokenizer = get_tokenizer()
score_func_names = ['qed', 'sa']
reward_model = MolScoringFunctions(score_func_names, device=device)
evaluator = Evaluator('diversity')
use_remasking = not args.disable_planner
disable_insertion_planner = args.disable_insertion_planner
disable_unmasking_planner = args.disable_unmasking_planner
# Map flags to quality_mode
if args.disable_planner:
quality_mode = "none"
elif args.disable_insertion_planner and args.disable_unmasking_planner:
quality_mode = "none"
elif args.disable_insertion_planner:
quality_mode = "unmasking_only"
elif args.disable_unmasking_planner:
quality_mode = "insertion_only"
else:
quality_mode = "both"
print(f"\nSampling {args.num_samples} molecules (quality_mode={quality_mode})...")
metrics, unique_smiles, qed_vals, sa_vals = evaluate_checkpoint(
policy_model=policy_model,
tokenizer=tokenizer,
reward_model=reward_model,
evaluator=evaluator,
num_samples=args.num_samples,
batch_size=args.batch_size,
max_length=args.max_length,
total_num_steps=args.total_num_steps,
quality_mode=quality_mode,
num_remasking=args.num_remasking,
quality_threshold=getattr(args, 'quality_threshold', 0.5),
unmask_quality_threshold=args.unmask_quality_threshold,
device=device,
)
# Print summary table
print("\n" + "=" * 60)
print(" De Novo Small Molecule Generation Results")
print("=" * 60)
for k, v in metrics.items():
if isinstance(v, float):
print(f" {k:<30s}: {v:.4f}")
else:
print(f" {k:<30s}: {v}")
print("=" * 60)
# Save results
output_dir = args.output_dir or os.path.dirname(args.checkpoint_path)
os.makedirs(output_dir, exist_ok=True)
if args.disable_planner:
tag = "no_planner"
elif args.disable_insertion_planner:
tag = "no_insertion_planner"
elif args.disable_unmasking_planner:
tag = "no_unmasking_planner"
else:
tag = "with_planner"
metrics_path = os.path.join(output_dir, f'eval_metrics_{tag}.csv')
pd.DataFrame([metrics]).to_csv(metrics_path, index=False)
print(f"Metrics saved to: {metrics_path}")
if unique_smiles:
smiles_path = os.path.join(output_dir, f'eval_smiles_{tag}.csv')
df = pd.DataFrame({
'SMILES': unique_smiles,
'QED': qed_vals,
'SA': sa_vals,
})
df.to_csv(smiles_path, index=False)
print(f"SMILES saved to: {smiles_path}")
if __name__ == '__main__':
main()
|