| | |
| | |
| | import os |
| | import sys |
| | import json |
| | import argparse |
| | import random |
| | from typing import List, Optional |
| | from tqdm import tqdm |
| |
|
| | import torch |
| | from rdkit import Chem |
| | from rdkit.Chem import AllChem |
| | from rdkit import RDLogger |
| | import selfies as sf |
| | import pandas as pd |
| |
|
| | |
| | RDLogger.DisableLog('rdApp.*') |
| |
|
| | |
| | sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) |
| |
|
| | from FastChemTokenizerHF import FastChemTokenizerSelfies |
| | from ChemQ3MTP import ChemQ3MTPForCausalLM |
| |
|
| | |
| | |
| | |
| |
|
| | def selfies_to_smiles(selfies_str: str) -> Optional[str]: |
| | """Convert SELFIES string to SMILES, handling tokenizer artifacts.""" |
| | try: |
| | clean_selfies = selfies_str.replace(" ", "") |
| | return sf.decoder(clean_selfies) |
| | except Exception: |
| | return None |
| |
|
| |
|
| | def is_valid_smiles(smiles: str) -> bool: |
| | """ |
| | Check if a SMILES string represents a valid molecule. |
| | FIXED: Now properly checks for heavy atoms (non-hydrogens) >= 3 |
| | and rejects disconnected/separated molecules |
| | """ |
| | if not isinstance(smiles, str) or len(smiles.strip()) == 0: |
| | return False |
| | |
| | smiles = smiles.strip() |
| | |
| | |
| | if '.' in smiles: |
| | return False |
| | |
| | try: |
| | mol = Chem.MolFromSmiles(smiles) |
| | if mol is None: |
| | return False |
| | |
| | |
| | heavy_atoms = mol.GetNumHeavyAtoms() |
| | if heavy_atoms < 3: |
| | return False |
| | |
| | return True |
| | except Exception: |
| | return False |
| |
|
| | def passes_durrant_lab_filter(smiles: str) -> bool: |
| | """ |
| | Apply Durrant's lab filter to remove improbable substructures. |
| | FIXED: More robust error handling, pattern checking, and disconnected molecule rejection. |
| | Returns True if molecule passes the filter (is acceptable), False otherwise. |
| | """ |
| | if not smiles or not isinstance(smiles, str) or len(smiles.strip()) == 0: |
| | return False |
| | |
| | try: |
| | mol = Chem.MolFromSmiles(smiles.strip()) |
| | if mol is None: |
| | return False |
| | |
| | |
| | if mol.GetNumHeavyAtoms() < 3: |
| | return False |
| | |
| | |
| | fragments = Chem.rdmolops.GetMolFrags(mol, asMols=False) |
| | if len(fragments) > 1: |
| | return False |
| | |
| | |
| | problematic_patterns = [ |
| | "C=[N-]", |
| | "[N-]C=[N+]", |
| | "[nH+]c[n-]", |
| | "[#7+]~[#7+]", |
| | "[#7-]~[#7-]", |
| | "[!#7]~[#7+]~[#7-]~[!#7]", |
| | "[#5]", |
| | "O=[PH](=O)([#8])([#8])", |
| | "N=c1cc[#7]c[#7]1", |
| | "[$([NX2H1]),$([NX3H2])]=C[$([OH]),$([O-])]", |
| | ] |
| | |
| | |
| | metal_exclusions = {11, 12, 19, 20} |
| | for atom in mol.GetAtoms(): |
| | atomic_num = atom.GetAtomicNum() |
| | |
| | if atomic_num > 20 and atomic_num not in metal_exclusions: |
| | return False |
| | |
| | |
| | for pattern in problematic_patterns: |
| | try: |
| | patt_mol = Chem.MolFromSmarts(pattern) |
| | if patt_mol is not None: |
| | matches = mol.GetSubstructMatches(patt_mol) |
| | if matches: |
| | return False |
| | except Exception: |
| | |
| | continue |
| | |
| | return True |
| | |
| | except Exception: |
| | return False |
| |
|
| |
|
| | def get_sa_label_and_confidence(selfies_str: str) -> tuple[str, float]: |
| | """Get SA label (Easy/Hard) and confidence from the model's SA classifier.""" |
| | try: |
| | from ChemQ3MTP.rl_utils import get_sa_classifier |
| | classifier = get_sa_classifier() |
| | if classifier is None: |
| | return "Unknown", 0.0 |
| | |
| | |
| | result = classifier(selfies_str, truncation=True, max_length=128)[0] |
| | return result["label"], result["score"] |
| | except Exception as e: |
| | return "Unknown", 0.0 |
| |
|
| | def get_morgan_fingerprint_from_smiles(smiles: str, radius=2, n_bits=2048): |
| | mol = Chem.MolFromSmiles(smiles) |
| | if mol is None: |
| | return None |
| | return AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=n_bits) |
| |
|
| | def tanimoto_sim(fp1, fp2): |
| | from rdkit.DataStructs import TanimotoSimilarity |
| | return TanimotoSimilarity(fp1, fp2) |
| |
|
| | |
| | |
| | |
| |
|
| | def evaluate_model( |
| | model_path: str, |
| | train_data_path: str = "../data/chunk_5.csv", |
| | n_samples: int = 1000, |
| | seed: int = 42, |
| | max_gen_len: int = 32 |
| | ): |
| | torch.manual_seed(seed) |
| | random.seed(seed) |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | print(f"🚀 Evaluating model at: {model_path}") |
| | print(f" Device: {device} | Samples: {n_samples} | Seed: {seed}\n") |
| |
|
| | |
| | tokenizer = FastChemTokenizerSelfies.from_pretrained("../selftok_core") |
| | model = ChemQ3MTPForCausalLM.from_pretrained(model_path) |
| | model.to(device) |
| | model.eval() |
| |
|
| | |
| | print("📂 Loading and normalizing training set for novelty...") |
| | train_df = pd.read_csv(train_data_path) |
| | train_selfies_clean = set() |
| | for s in train_df["SELFIES"].dropna().astype(str): |
| | clean_s = s.replace(" ", "") |
| | train_selfies_clean.add(clean_s) |
| | print(f" Training set size: {len(train_selfies_clean)} unique (space-free) SELFIES\n") |
| |
|
| | |
| | print("GenerationStrategy: Using MTP-aware generation...") |
| | all_selfies_raw = [] |
| | batch_size = 32 |
| | num_batches = (n_samples + batch_size - 1) // batch_size |
| |
|
| | with torch.no_grad(): |
| | for _ in tqdm(range(num_batches), desc="Generating"): |
| | current_batch_size = min(batch_size, n_samples - len(all_selfies_raw)) |
| | if current_batch_size <= 0: |
| | break |
| |
|
| | input_ids = torch.full( |
| | (current_batch_size, 1), |
| | tokenizer.bos_token_id, |
| | dtype=torch.long, |
| | device=device |
| | ) |
| |
|
| | if hasattr(model, 'generate_with_logprobs'): |
| | try: |
| | outputs = model.generate_with_logprobs( |
| | input_ids=input_ids, |
| | max_new_tokens=25, |
| | temperature=1.0, |
| | top_k=50, |
| | top_p=0.95, |
| | do_sample=True, |
| | return_probs=True, |
| | tokenizer=tokenizer |
| | ) |
| | batch_selfies = outputs[0] |
| | except Exception as e: |
| | print(f"⚠️ MTP generation failed: {e}. Falling back.") |
| | gen_tokens = model.generate( |
| | input_ids, |
| | max_length=max_gen_len, |
| | do_sample=True, |
| | top_k=50, |
| | top_p=0.95, |
| | temperature=1.0, |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id |
| | ) |
| | batch_selfies = [ |
| | tokenizer.decode(seq, skip_special_tokens=True) |
| | for seq in gen_tokens |
| | ] |
| | else: |
| | gen_tokens = model.generate( |
| | input_ids, |
| | max_length=max_gen_len, |
| | do_sample=True, |
| | top_k=50, |
| | top_p=0.95, |
| | temperature=1.0, |
| | pad_token_id=tokenizer.pad_token_id, |
| | eos_token_id=tokenizer.eos_token_id |
| | ) |
| | batch_selfies = [ |
| | tokenizer.decode(seq, skip_special_tokens=True) |
| | for seq in gen_tokens |
| | ] |
| |
|
| | all_selfies_raw.extend(batch_selfies) |
| | if len(all_selfies_raw) >= n_samples: |
| | break |
| |
|
| | all_selfies_raw = all_selfies_raw[:n_samples] |
| | print(f"\n✅ Generated {len(all_selfies_raw)} raw SELFIES strings.\n") |
| |
|
| | |
| | valid_records = [] |
| | print("🧪 Processing SELFIES and converting to SMILES...") |
| | for i, raw_selfies in enumerate(tqdm(all_selfies_raw, desc="Converting")): |
| | |
| | clean_selfies = raw_selfies.replace(" ", "") |
| | |
| | |
| | smiles = selfies_to_smiles(clean_selfies) |
| | |
| | if smiles is not None and is_valid_smiles(smiles) and passes_durrant_lab_filter(smiles): |
| | valid_records.append({ |
| | "raw_selfies": raw_selfies, |
| | "selfies_clean": clean_selfies, |
| | "selfies": clean_selfies, |
| | "smiles": smiles.strip() |
| | }) |
| |
|
| | |
| | if valid_records: |
| | print("\n🔍 DEBUG: Sample generated molecules") |
| | print("-" * 70) |
| | for i in range(min(5, len(valid_records))): |
| | example = valid_records[i] |
| | print(f"Example {i+1}:") |
| | print(f" Raw SELFIES : {example['raw_selfies'][:80]}{'...' if len(example['raw_selfies']) > 80 else ''}") |
| | print(f" SMILES : {example['smiles']}") |
| | |
| | |
| | label, confidence = get_sa_label_and_confidence(example['raw_selfies']) |
| | print(f" SA Label : {label} (confidence: {confidence:.3f})") |
| | |
| | if i == 0: |
| | |
| | simple_label, simple_conf = get_sa_label_and_confidence('[C]') |
| | benzene_label, benzene_conf = get_sa_label_and_confidence('[c] [c] [c] [c] [c] [c] [Ring1] [=Branch1]') |
| | print(f" 🧪 SA Test - Simple molecule: {simple_label} ({simple_conf:.3f})") |
| | print(f" 🧪 SA Test - Benzene: {benzene_label} ({benzene_conf:.3f})") |
| | |
| | |
| | mol = Chem.MolFromSmiles(example['smiles']) |
| | if mol: |
| | print(f" Atoms : {mol.GetNumAtoms()}") |
| | print(f" Bonds : {mol.GetNumBonds()}") |
| | print() |
| | print("-" * 70) |
| | |
| | |
| | sa_labels = [] |
| | for r in valid_records[:100]: |
| | label, _ = get_sa_label_and_confidence(r["raw_selfies"]) |
| | sa_labels.append(label) |
| | |
| | easy_count = sa_labels.count("Easy") |
| | hard_count = sa_labels.count("Hard") |
| | unknown_count = sa_labels.count("Unknown") |
| | |
| | print(f"🔍 SA Label Analysis (first 100 molecules):") |
| | print(f" Easy to synthesize: {easy_count}/100 ({easy_count}%)") |
| | print(f" Hard to synthesize: {hard_count}/100 ({hard_count}%)") |
| | if unknown_count > 0: |
| | print(f" Unknown/Failed: {unknown_count}/100 ({unknown_count}%)") |
| | else: |
| | print("\n⚠️ WARNING: No valid molecules generated in sample!") |
| | |
| |
|
| | |
| | validity = len(valid_records) / n_samples |
| | |
| | unique_valid = list({r["selfies_clean"]: r for r in valid_records}.values()) |
| | uniqueness = len(unique_valid) / len(valid_records) if valid_records else 0.0 |
| |
|
| | novel_count = sum(1 for r in unique_valid if r["selfies_clean"] not in train_selfies_clean) |
| | novelty = novel_count / len(unique_valid) if unique_valid else 0.0 |
| |
|
| | |
| | sa_labels_all = [] |
| | for r in unique_valid: |
| | label, _ = get_sa_label_and_confidence(r["raw_selfies"]) |
| | sa_labels_all.append(label) |
| | |
| | easy_total = sa_labels_all.count("Easy") |
| | hard_total = sa_labels_all.count("Hard") |
| | unknown_total = sa_labels_all.count("Unknown") |
| | total_labeled = len(sa_labels_all) |
| |
|
| | |
| | if len(unique_valid) >= 2: |
| | fps = [] |
| | for r in unique_valid: |
| | fp = get_morgan_fingerprint_from_smiles(r["smiles"]) |
| | if fp is not None: |
| | fps.append(fp) |
| | if len(fps) >= 2: |
| | total_sim, count = 0.0, 0 |
| | for i in range(len(fps)): |
| | for j in range(i + 1, len(fps)): |
| | total_sim += tanimoto_sim(fps[i], fps[j]) |
| | count += 1 |
| | internal_diversity = 1.0 - (total_sim / count) |
| | else: |
| | internal_diversity = 0.0 |
| | else: |
| | internal_diversity = 0.0 |
| |
|
| | |
| | |
| | |
| | print("\n" + "="*55) |
| | print("📊 MOLECULAR GENERATION EVALUATION SUMMARY") |
| | print("="*55) |
| | print(f"Model Path : {model_path}") |
| | print(f"Generation Mode : {'MTP-aware' if hasattr(model, 'generate_with_logprobs') else 'Standard'}") |
| | print(f"Samples Generated: {n_samples}") |
| | print("-"*55) |
| | print(f"Validity : {validity:.4f} ({len(valid_records)}/{n_samples})") |
| | print(f"Uniqueness : {uniqueness:.4f} (unique valid)") |
| | print(f"Novelty (vs train): {novelty:.4f} (space-free SELFIES)") |
| | print(f"Synthesis Labels : Easy: {easy_total}/{total_labeled} ({easy_total/max(1,total_labeled)*100:.1f}%) | Hard: {hard_total}/{total_labeled} ({hard_total/max(1,total_labeled)*100:.1f}%)") |
| | if unknown_total > 0: |
| | print(f" Unknown: {unknown_total}/{total_labeled} ({unknown_total/max(1,total_labeled)*100:.1f}%)") |
| | print(f"Internal Diversity: {internal_diversity:.4f} (1 - avg Tanimoto)") |
| | print("="*55) |
| |
|
| | results = { |
| | "model_path": model_path, |
| | "generation_mode": "MTP-aware" if hasattr(model, 'generate_with_logprobs') else "standard", |
| | "n_samples": n_samples, |
| | "validity": validity, |
| | "uniqueness": uniqueness, |
| | "novelty": novelty, |
| | "sa_easy_count": easy_total, |
| | "sa_hard_count": hard_total, |
| | "sa_easy_percentage": easy_total/max(1,total_labeled)*100, |
| | "sa_hard_percentage": hard_total/max(1,total_labeled)*100, |
| | "internal_diversity": internal_diversity, |
| | "valid_molecules_count": len(valid_records) |
| | } |
| | |
| | if unknown_total > 0: |
| | results["sa_unknown_count"] = unknown_total |
| | results["sa_unknown_percentage"] = unknown_total/max(1,total_labeled)*100 |
| |
|
| | output_json = os.path.join(model_path, "evaluation_summary.json") |
| | with open(output_json, "w") as f: |
| | json.dump(results, f, indent=2) |
| | print(f"\n💾 Results saved to: {output_json}") |
| |
|
| | return results |
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser(description="Evaluate molecular generative model with MTP-aware generation") |
| | parser.add_argument("--model_path", type=str, required=True, help="Path to model checkpoint") |
| | parser.add_argument("--n_samples", type=int, default=1000, help="Number of molecules to generate") |
| | parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| | parser.add_argument("--train_data", type=str, default="../data/chunk_5.csv", help="Training data CSV") |
| |
|
| | args = parser.parse_args() |
| | evaluate_model( |
| | model_path=args.model_path, |
| | train_data_path=args.train_data, |
| | n_samples=args.n_samples, |
| | seed=args.seed |
| | ) |