diff --git "a/model_eval.ipynb" "b/model_eval.ipynb" new file mode 100644--- /dev/null +++ "b/model_eval.ipynb" @@ -0,0 +1,634 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "id": "a2508432", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Loading tokenizer from ./ChemQ...\n", + "āœ… Special tokens bound: 0 1 2 3 4\n", + "Loading ChemQ3MTP model from ./ChemQ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Generated SELFIES:\n", + "[C] [=C] [C] [=C] [NH1] [C] [Branch2] [Ring2] [#Branch1] [C] [=C] [C] [=C] [Branch2] [Ring1] [O] [C] [=N] [N] [=C] [N] [Ring1] [Branch1] [C] [=C] [C] [=C] [N] [=C] [Ring1] [=Branch1] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [Ring1] [#C] [C] [=C] [Ring2] [Ring1] [Branch2] [=N] [C] [Ring2] [Ring1] [=N] [=C] [Ring2] [Ring1] [P]\n" + ] + } + ], + "source": [ + "# ==============================\n", + "# Generate SELFIES from ChemQ3MTP checkpoint\n", + "# LOADING THE MODEL & TOKENIZER\n", + "# ================================\n", + "\n", + "import sys\n", + "import os\n", + "import torch\n", + "\n", + "# --- Replicate local module loading exactly as in training ---\n", + "notebook_dir = os.getcwd()\n", + "chemq3mtp_path = os.path.join(notebook_dir, \"ChemQ3MTP\")\n", + "\n", + "if chemq3mtp_path not in sys.path:\n", + " sys.path.insert(0, chemq3mtp_path)\n", + "\n", + "# Optional: clean up duplicate paths (as in your training script)\n", + "existing_paths = [p for p in sys.path if p.endswith(\"ChemQ3MTP\")]\n", + "for path in existing_paths[:-1]: # keep only the most recently added\n", + " sys.path.remove(path)\n", + "\n", + "# Now import from local ChemQ3MTP folder\n", + "from FastChemTokenizerHF import FastChemTokenizerSelfies\n", + "from ChemQ3MTP import ChemQ3MTPForCausalLM # <-- your custom model\n", + "\n", + "# --- Load from checkpoint (same as saved in training) ---\n", + "checkpoint_dir = \"./ChemQ\" # or your actual checkpoint path\n", + "\n", + "print(f\"Loading tokenizer from {checkpoint_dir}...\")\n", + "tokenizer = FastChemTokenizerSelfies.from_pretrained('./selftok_core/')\n", + "\n", + "print(f\"Loading ChemQ3MTP model from {checkpoint_dir}...\")\n", + "model = ChemQ3MTPForCausalLM.from_pretrained(checkpoint_dir)\n", + "\n", + "# --- Prepare for generation ---\n", + "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", + "model.to(device)\n", + "model.eval()\n", + "\n", + "# Disable MTP mode for standard autoregressive generation\n", + "if hasattr(model, 'set_mtp_training'):\n", + " model.set_mtp_training(False)\n", + "\n", + "try:\n", + " # Tokenize start token\n", + " input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", + " \n", + " with torch.no_grad():\n", + " gen = model.generate(\n", + " input_ids=input_ids,\n", + " max_length=256,\n", + " top_k=50,\n", + " temperature=1.0,\n", + " do_sample=True,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + " eos_token_id=tokenizer.eos_token_id,\n", + " early_stopping=True\n", + " )\n", + " \n", + " result = tokenizer.decode(gen[0], skip_special_tokens=True)\n", + " print(\"Generated SELFIES:\")\n", + " print(result)\n", + "\n", + "except Exception as e:\n", + " print(f\"Generation failed: {e}\")\n", + " import traceback\n", + " traceback.print_exc()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "5d2f7934", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "šŸš€ Starting molecular generation pipeline...\n", + "\n", + "1. Generating 1000 molecules in batches of 5...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d74da4c0ea8c4ae9909843c02da92795", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Generating batches: 0%| | 0/200 [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "============================================================\n", + "šŸ“Š FINAL EVALUATION SUMMARY\n", + "============================================================\n", + "Total generated: 1000\n", + "Valid SMILES: 976 (97.6%)\n", + "Lipinski-compliant: 687 (70.4% of valid)\n", + "Internal diversity: 0.6387\n", + "MACCS clusters (≄0.7): 448\n", + "Average cluster size: 2.18\n", + "Largest cluster size: 15\n", + "============================================================\n", + "\n", + "āœ… Results dictionary created\n", + "āœ… Valid SMILES saved to 'generated_valid_2500.smi'\n", + "āœ… Cluster visualization saved as 'molecular_clusters.png'\n" + ] + } + ], + "source": [ + "# ================================\n", + "# MOLECULAR GENERATION & EVALUATION PIPELINE\n", + "# ================================\n", + "\n", + "# Core imports\n", + "import numpy as np\n", + "import pandas as pd\n", + "from rdkit import Chem, RDLogger, DataStructs\n", + "from rdkit.Chem import Descriptors, MACCSkeys\n", + "from rdkit.ML.Cluster import Butina\n", + "from tqdm.notebook import tqdm\n", + "import matplotlib.pyplot as plt\n", + "from sklearn.decomposition import PCA\n", + "from sklearn.manifold import TSNE\n", + "import random\n", + "import selfies as sf\n", + "import os\n", + "os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\n", + "os.environ['OMP_NUM_THREADS'] = '1'\n", + "# Suppress RDKit warnings\n", + "RDLogger.DisableLog('rdApp.*')\n", + "\n", + "# ================================\n", + "# 1. GENERATE 2.5K MOLECULES IN BATCHES\n", + "# ================================\n", + "print(\"šŸš€ Starting molecular generation pipeline...\\n\")\n", + "\n", + "num_samples = 1000\n", + "batch_size = 5\n", + "num_batches = (num_samples + batch_size - 1) // batch_size\n", + "\n", + "print(f\"1. Generating {num_samples} molecules in batches of {batch_size}...\")\n", + "\n", + "gen = []\n", + "for batch_idx in tqdm(range(num_batches), desc=\"Generating batches\", unit=\"batch\"):\n", + " # Calculate how many to generate in this batch\n", + " remaining = num_samples - len(gen)\n", + " current_batch_size = min(batch_size, remaining)\n", + " \n", + " batch_gen = model.generate(\n", + " max_length=512,\n", + " num_return_sequences=current_batch_size,\n", + " do_sample=True,\n", + " top_k=50,\n", + " top_p=0.95,\n", + " temperature=1.0,\n", + " pad_token_id=tokenizer.pad_token_id,\n", + " eos_token_id=tokenizer.eos_token_id\n", + " )\n", + " gen.extend(batch_gen)\n", + "\n", + "print(f\"āœ… Generated {len(gen)} sequences total\")\n", + "\n", + "# ================================\n", + "# 2. DECODE TO SMILES & VALIDATE\n", + "# ================================\n", + "print(\"\\n2. Decoding SELFIES → SMILES and validating...\")\n", + "valid_smiles = []\n", + "\n", + "for i in tqdm(range(num_samples), desc=\"Decoding & validating\", unit=\"mol\"):\n", + " try:\n", + " selfies_str = tokenizer.decode(gen[i], skip_special_tokens=True)\n", + " selfies_str = selfies_str.replace(' ', '')\n", + " smiles = sf.decoder(selfies_str)\n", + " mol = Chem.MolFromSmiles(smiles)\n", + " if mol is not None and smiles.strip() != '' and '.' not in smiles:\n", + " valid_smiles.append(smiles)\n", + " except:\n", + " continue\n", + "\n", + "print(f\"āœ… Valid SMILES: {len(valid_smiles)} ({100 * len(valid_smiles)/num_samples:.2f}%)\")\n", + "\n", + "# ================================\n", + "# 3. LIPINSKI EVALUATION\n", + "# ================================\n", + "def passes_lipinski(mol):\n", + " mw = Descriptors.MolWt(mol)\n", + " logp = Descriptors.MolLogP(mol)\n", + " hbd = Descriptors.NumHDonors(mol)\n", + " hba = Descriptors.NumHAcceptors(mol)\n", + " violations = 0\n", + " if mw > 500: violations += 1\n", + " if logp > 5: violations += 1\n", + " if hbd > 5: violations += 1\n", + " if hba > 10: violations += 1\n", + " return violations == 0\n", + "\n", + "print(\"\\n3. Evaluating Lipinski's Rule of 5...\")\n", + "lipinski_pass = 0\n", + "for smiles in tqdm(valid_smiles, desc=\"Lipinski evaluation\", unit=\"mol\"):\n", + " mol = Chem.MolFromSmiles(smiles)\n", + " if passes_lipinski(mol):\n", + " lipinski_pass += 1\n", + "\n", + "print(f\"āœ… Lipinski-compliant: {lipinski_pass} ({100 * lipinski_pass/len(valid_smiles):.2f}% of valid)\")\n", + "\n", + "# ================================\n", + "# 4. MACCS FINGERPRINTS\n", + "# ================================\n", + "def calculate_maccs_fingerprints(smiles_list):\n", + " \"\"\"Calculate MACCS fingerprints for a list of SMILES\"\"\"\n", + " fingerprints = []\n", + " mols = []\n", + " \n", + " for smiles in tqdm(smiles_list, desc=\"Computing MACCS fingerprints\", unit=\"mol\"):\n", + " mol = Chem.MolFromSmiles(smiles)\n", + " if mol is not None:\n", + " fp = MACCSkeys.GenMACCSKeys(mol)\n", + " fingerprints.append(fp)\n", + " mols.append(mol)\n", + " \n", + " return fingerprints, mols\n", + "\n", + "print(\"\\n4. Computing molecular fingerprints...\")\n", + "generated_fingerprints, generated_mols = calculate_maccs_fingerprints(valid_smiles)\n", + "\n", + "# ================================\n", + "# 5. INTERNAL DIVERSITY\n", + "# ================================\n", + "def calculate_internal_diversity(fingerprints, sample_size=1000):\n", + " \"\"\"Calculate internal diversity of generated molecules\"\"\"\n", + " if len(fingerprints) < 2:\n", + " return 0.0\n", + " \n", + " if len(fingerprints) > sample_size:\n", + " random.seed(42)\n", + " sampled_fps = random.sample(fingerprints, sample_size)\n", + " else:\n", + " sampled_fps = fingerprints\n", + " \n", + " distances = []\n", + " n = len(sampled_fps)\n", + " total_pairs = n * (n - 1) // 2\n", + " \n", + " with tqdm(total=total_pairs, desc=\"Computing diversity\", unit=\"pair\") as pbar:\n", + " for i in range(n):\n", + " for j in range(i+1, n):\n", + " dist = 1 - DataStructs.TanimotoSimilarity(sampled_fps[i], sampled_fps[j])\n", + " distances.append(dist)\n", + " pbar.update(1)\n", + " \n", + " return np.mean(distances) if distances else 0.0\n", + "\n", + "print(\"\\n5. Calculating internal diversity...\")\n", + "internal_div = calculate_internal_diversity(generated_fingerprints, sample_size=1000)\n", + "print(f\"āœ… Internal diversity: {internal_div:.4f}\")\n", + "\n", + "# ================================\n", + "# 6. CLUSTERING\n", + "# ================================\n", + "def cluster_molecules(fingerprints, cutoff=0.7):\n", + " \"\"\"Cluster molecules using Butina clustering\"\"\"\n", + " sample_size = min(2000, len(fingerprints))\n", + " if len(fingerprints) > sample_size:\n", + " random.seed(42)\n", + " sampled_fps = random.sample(fingerprints, sample_size)\n", + " else:\n", + " sampled_fps = fingerprints\n", + " \n", + " dists = []\n", + " n = len(sampled_fps)\n", + " total_pairs = n * (n - 1) // 2\n", + " \n", + " with tqdm(total=total_pairs, desc=\"Computing distance matrix\", unit=\"pair\") as pbar:\n", + " for i in range(n):\n", + " for j in range(i+1, n):\n", + " dist = 1 - DataStructs.TanimotoSimilarity(sampled_fps[i], sampled_fps[j])\n", + " dists.append(dist)\n", + " pbar.update(1)\n", + " \n", + " print(\"Performing Butina clustering...\")\n", + " cluster_indices = Butina.ClusterData(dists, n, 1 - cutoff, isDistData=True)\n", + " \n", + " return cluster_indices\n", + "\n", + "print(\"\\n6. Performing molecular clustering...\")\n", + "cluster_indices = cluster_molecules(generated_fingerprints, cutoff=0.7)\n", + "n_clusters = len(cluster_indices)\n", + "cluster_sizes = [len(cluster) for cluster in cluster_indices]\n", + "largest_cluster = max(cluster_sizes) if cluster_sizes else 0\n", + "\n", + "print(f\"āœ… Number of clusters: {n_clusters}\")\n", + "print(f\"āœ… Average cluster size: {np.mean(cluster_sizes):.2f}\")\n", + "print(f\"āœ… Largest cluster: {largest_cluster} molecules\")\n", + "\n", + "# ================================\n", + "# 7. VISUALIZATION\n", + "# ================================\n", + "def visualize_clusters(fingerprints, smiles_list):\n", + " \"\"\"Visualize clusters using PCA and t-SNE\"\"\"\n", + " print(\"Preparing fingerprint array...\")\n", + " fp_array = []\n", + " for fp in tqdm(fingerprints, desc=\"Converting fingerprints\", unit=\"fp\", leave=False):\n", + " arr = np.zeros((1,))\n", + " DataStructs.ConvertToNumpyArray(fp, arr)\n", + " fp_array.append(arr)\n", + " \n", + " fp_array = np.array(fp_array)\n", + " \n", + " if len(fp_array) > 1000:\n", + " random.seed(42)\n", + " indices = random.sample(range(len(fp_array)), 1000)\n", + " fp_array = fp_array[indices]\n", + " sampled_smiles = [smiles_list[i] for i in indices]\n", + " else:\n", + " sampled_smiles = smiles_list\n", + " \n", + " print(\"Computing PCA...\")\n", + " pca = PCA(n_components=2)\n", + " pca_coords = pca.fit_transform(fp_array)\n", + " \n", + " print(\"Computing t-SNE...\")\n", + " tsne = TSNE(n_components=2, random_state=42, \n", + " perplexity=min(30, len(fp_array)-1), max_iter=300)\n", + " tsne_coords = tsne.fit_transform(fp_array)\n", + " \n", + " fig, axes = plt.subplots(1, 2, figsize=(15, 6))\n", + " \n", + " axes[0].scatter(pca_coords[:, 0], pca_coords[:, 1], alpha=0.6)\n", + " axes[0].set_title(f'PCA of MACCS Fingerprints\\n({len(fp_array)} molecules)')\n", + " axes[0].set_xlabel('PC1')\n", + " axes[0].set_ylabel('PC2')\n", + " \n", + " axes[1].scatter(tsne_coords[:, 0], tsne_coords[:, 1], alpha=0.6)\n", + " axes[1].set_title(f't-SNE of MACCS Fingerprints\\n({len(fp_array)} molecules)')\n", + " axes[1].set_xlabel('t-SNE 1')\n", + " axes[1].set_ylabel('t-SNE 2')\n", + " \n", + " plt.tight_layout()\n", + " plt.savefig('molecular_clusters.png', dpi=300, bbox_inches='tight')\n", + " plt.show()\n", + "\n", + "print(\"\\n7. Creating visualizations...\")\n", + "if len(generated_fingerprints) > 10:\n", + " visualize_clusters(generated_fingerprints, valid_smiles)\n", + "\n", + "# ================================\n", + "# 8. FINAL SUMMARY\n", + "# ================================\n", + "print(\"\\n\" + \"=\"*60)\n", + "print(\"šŸ“Š FINAL EVALUATION SUMMARY\")\n", + "print(\"=\"*60)\n", + "print(f\"Total generated: {num_samples}\")\n", + "print(f\"Valid SMILES: {len(valid_smiles)} ({100 * len(valid_smiles)/num_samples:.1f}%)\")\n", + "print(f\"Lipinski-compliant: {lipinski_pass} ({100 * lipinski_pass/len(valid_smiles):.1f}% of valid)\")\n", + "print(f\"Internal diversity: {internal_div:.4f}\")\n", + "print(f\"MACCS clusters (≄0.7): {n_clusters}\")\n", + "print(f\"Average cluster size: {np.mean(cluster_sizes):.2f}\")\n", + "print(f\"Largest cluster size: {largest_cluster}\")\n", + "print(\"=\"*60)\n", + "\n", + "results = {\n", + " 'total_generated': num_samples,\n", + " 'valid_smiles': len(valid_smiles),\n", + " 'validity_rate': len(valid_smiles)/num_samples,\n", + " 'lipinski_compliant': lipinski_pass,\n", + " 'lipinski_rate_valid': lipinski_pass/len(valid_smiles) if valid_smiles else 0,\n", + " 'internal_diversity': internal_div,\n", + " 'num_clusters': n_clusters,\n", + " 'avg_cluster_size': np.mean(cluster_sizes) if cluster_sizes else 0,\n", + " 'largest_cluster_size': largest_cluster\n", + "}\n", + "\n", + "with open(\"generated_valid_2500.smi\", \"w\") as f:\n", + " for smi in valid_smiles:\n", + " f.write(smi + \"\\n\")\n", + "\n", + "print(f\"\\nāœ… Results dictionary created\")\n", + "print(f\"āœ… Valid SMILES saved to 'generated_valid_2500.smi'\")\n", + "print(f\"āœ… Cluster visualization saved as 'molecular_clusters.png'\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c1ae9a02", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.13.0" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +}