{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "a2508432", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loading tokenizer from ./ChemQ...\n", "āœ… Special tokens bound: 0 1 2 3 4\n", "Loading ChemQ3MTP model from ./ChemQ...\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Generated SELFIES:\n", "[C] [=C] [C] [=C] [NH1] [C] [Branch2] [Ring2] [#Branch1] [C] [=C] [C] [=C] [Branch2] [Ring1] [O] [C] [=N] [N] [=C] [N] [Ring1] [Branch1] [C] [=C] [C] [=C] [N] [=C] [Ring1] [=Branch1] [N] [C] [=C] [C] [=C] [C] [=C] [Ring1] [=Branch1] [Ring1] [#C] [C] [=C] [Ring2] [Ring1] [Branch2] [=N] [C] [Ring2] [Ring1] [=N] [=C] [Ring2] [Ring1] [P]\n" ] } ], "source": [ "# ==============================\n", "# Generate SELFIES from ChemQ3MTP checkpoint\n", "# LOADING THE MODEL & TOKENIZER\n", "# ================================\n", "\n", "import sys\n", "import os\n", "import torch\n", "\n", "# --- Replicate local module loading exactly as in training ---\n", "notebook_dir = os.getcwd()\n", "chemq3mtp_path = os.path.join(notebook_dir, \"ChemQ3MTP\")\n", "\n", "if chemq3mtp_path not in sys.path:\n", " sys.path.insert(0, chemq3mtp_path)\n", "\n", "# Optional: clean up duplicate paths (as in your training script)\n", "existing_paths = [p for p in sys.path if p.endswith(\"ChemQ3MTP\")]\n", "for path in existing_paths[:-1]: # keep only the most recently added\n", " sys.path.remove(path)\n", "\n", "# Now import from local ChemQ3MTP folder\n", "from FastChemTokenizerHF import FastChemTokenizerSelfies\n", "from ChemQ3MTP import ChemQ3MTPForCausalLM # <-- your custom model\n", "\n", "# --- Load from checkpoint (same as saved in training) ---\n", "checkpoint_dir = \"./ChemQ\" # or your actual checkpoint path\n", "\n", "print(f\"Loading tokenizer from {checkpoint_dir}...\")\n", "tokenizer = FastChemTokenizerSelfies.from_pretrained('./selftok_core/')\n", "\n", "print(f\"Loading ChemQ3MTP model from {checkpoint_dir}...\")\n", "model = ChemQ3MTPForCausalLM.from_pretrained(checkpoint_dir)\n", "\n", "# --- Prepare for generation ---\n", "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "model.to(device)\n", "model.eval()\n", "\n", "# Disable MTP mode for standard autoregressive generation\n", "if hasattr(model, 'set_mtp_training'):\n", " model.set_mtp_training(False)\n", "\n", "try:\n", " # Tokenize start token\n", " input_ids = tokenizer(\"\", return_tensors=\"pt\").input_ids.to(device)\n", " \n", " with torch.no_grad():\n", " gen = model.generate(\n", " input_ids=input_ids,\n", " max_length=256,\n", " top_k=50,\n", " temperature=1.0,\n", " do_sample=True,\n", " pad_token_id=tokenizer.pad_token_id,\n", " eos_token_id=tokenizer.eos_token_id,\n", " early_stopping=True\n", " )\n", " \n", " result = tokenizer.decode(gen[0], skip_special_tokens=True)\n", " print(\"Generated SELFIES:\")\n", " print(result)\n", "\n", "except Exception as e:\n", " print(f\"Generation failed: {e}\")\n", " import traceback\n", " traceback.print_exc()" ] }, { "cell_type": "code", "execution_count": 3, "id": "5d2f7934", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "šŸš€ Starting molecular generation pipeline...\n", "\n", "1. Generating 1000 molecules in batches of 5...\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d74da4c0ea8c4ae9909843c02da92795", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Generating batches: 0%| | 0/200 [00:00" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "\n", "============================================================\n", "šŸ“Š FINAL EVALUATION SUMMARY\n", "============================================================\n", "Total generated: 1000\n", "Valid SMILES: 976 (97.6%)\n", "Lipinski-compliant: 687 (70.4% of valid)\n", "Internal diversity: 0.6387\n", "MACCS clusters (≄0.7): 448\n", "Average cluster size: 2.18\n", "Largest cluster size: 15\n", "============================================================\n", "\n", "āœ… Results dictionary created\n", "āœ… Valid SMILES saved to 'generated_valid_2500.smi'\n", "āœ… Cluster visualization saved as 'molecular_clusters.png'\n" ] } ], "source": [ "# ================================\n", "# MOLECULAR GENERATION & EVALUATION PIPELINE\n", "# ================================\n", "\n", "# Core imports\n", "import numpy as np\n", "import pandas as pd\n", "from rdkit import Chem, RDLogger, DataStructs\n", "from rdkit.Chem import Descriptors, MACCSkeys\n", "from rdkit.ML.Cluster import Butina\n", "from tqdm.notebook import tqdm\n", "import matplotlib.pyplot as plt\n", "from sklearn.decomposition import PCA\n", "from sklearn.manifold import TSNE\n", "import random\n", "import selfies as sf\n", "import os\n", "os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'\n", "os.environ['OMP_NUM_THREADS'] = '1'\n", "# Suppress RDKit warnings\n", "RDLogger.DisableLog('rdApp.*')\n", "\n", "# ================================\n", "# 1. GENERATE 2.5K MOLECULES IN BATCHES\n", "# ================================\n", "print(\"šŸš€ Starting molecular generation pipeline...\\n\")\n", "\n", "num_samples = 1000\n", "batch_size = 5\n", "num_batches = (num_samples + batch_size - 1) // batch_size\n", "\n", "print(f\"1. Generating {num_samples} molecules in batches of {batch_size}...\")\n", "\n", "gen = []\n", "for batch_idx in tqdm(range(num_batches), desc=\"Generating batches\", unit=\"batch\"):\n", " # Calculate how many to generate in this batch\n", " remaining = num_samples - len(gen)\n", " current_batch_size = min(batch_size, remaining)\n", " \n", " batch_gen = model.generate(\n", " max_length=512,\n", " num_return_sequences=current_batch_size,\n", " do_sample=True,\n", " top_k=50,\n", " top_p=0.95,\n", " temperature=1.0,\n", " pad_token_id=tokenizer.pad_token_id,\n", " eos_token_id=tokenizer.eos_token_id\n", " )\n", " gen.extend(batch_gen)\n", "\n", "print(f\"āœ… Generated {len(gen)} sequences total\")\n", "\n", "# ================================\n", "# 2. DECODE TO SMILES & VALIDATE\n", "# ================================\n", "print(\"\\n2. Decoding SELFIES → SMILES and validating...\")\n", "valid_smiles = []\n", "\n", "for i in tqdm(range(num_samples), desc=\"Decoding & validating\", unit=\"mol\"):\n", " try:\n", " selfies_str = tokenizer.decode(gen[i], skip_special_tokens=True)\n", " selfies_str = selfies_str.replace(' ', '')\n", " smiles = sf.decoder(selfies_str)\n", " mol = Chem.MolFromSmiles(smiles)\n", " if mol is not None and smiles.strip() != '' and '.' not in smiles:\n", " valid_smiles.append(smiles)\n", " except:\n", " continue\n", "\n", "print(f\"āœ… Valid SMILES: {len(valid_smiles)} ({100 * len(valid_smiles)/num_samples:.2f}%)\")\n", "\n", "# ================================\n", "# 3. LIPINSKI EVALUATION\n", "# ================================\n", "def passes_lipinski(mol):\n", " mw = Descriptors.MolWt(mol)\n", " logp = Descriptors.MolLogP(mol)\n", " hbd = Descriptors.NumHDonors(mol)\n", " hba = Descriptors.NumHAcceptors(mol)\n", " violations = 0\n", " if mw > 500: violations += 1\n", " if logp > 5: violations += 1\n", " if hbd > 5: violations += 1\n", " if hba > 10: violations += 1\n", " return violations == 0\n", "\n", "print(\"\\n3. Evaluating Lipinski's Rule of 5...\")\n", "lipinski_pass = 0\n", "for smiles in tqdm(valid_smiles, desc=\"Lipinski evaluation\", unit=\"mol\"):\n", " mol = Chem.MolFromSmiles(smiles)\n", " if passes_lipinski(mol):\n", " lipinski_pass += 1\n", "\n", "print(f\"āœ… Lipinski-compliant: {lipinski_pass} ({100 * lipinski_pass/len(valid_smiles):.2f}% of valid)\")\n", "\n", "# ================================\n", "# 4. MACCS FINGERPRINTS\n", "# ================================\n", "def calculate_maccs_fingerprints(smiles_list):\n", " \"\"\"Calculate MACCS fingerprints for a list of SMILES\"\"\"\n", " fingerprints = []\n", " mols = []\n", " \n", " for smiles in tqdm(smiles_list, desc=\"Computing MACCS fingerprints\", unit=\"mol\"):\n", " mol = Chem.MolFromSmiles(smiles)\n", " if mol is not None:\n", " fp = MACCSkeys.GenMACCSKeys(mol)\n", " fingerprints.append(fp)\n", " mols.append(mol)\n", " \n", " return fingerprints, mols\n", "\n", "print(\"\\n4. Computing molecular fingerprints...\")\n", "generated_fingerprints, generated_mols = calculate_maccs_fingerprints(valid_smiles)\n", "\n", "# ================================\n", "# 5. INTERNAL DIVERSITY\n", "# ================================\n", "def calculate_internal_diversity(fingerprints, sample_size=1000):\n", " \"\"\"Calculate internal diversity of generated molecules\"\"\"\n", " if len(fingerprints) < 2:\n", " return 0.0\n", " \n", " if len(fingerprints) > sample_size:\n", " random.seed(42)\n", " sampled_fps = random.sample(fingerprints, sample_size)\n", " else:\n", " sampled_fps = fingerprints\n", " \n", " distances = []\n", " n = len(sampled_fps)\n", " total_pairs = n * (n - 1) // 2\n", " \n", " with tqdm(total=total_pairs, desc=\"Computing diversity\", unit=\"pair\") as pbar:\n", " for i in range(n):\n", " for j in range(i+1, n):\n", " dist = 1 - DataStructs.TanimotoSimilarity(sampled_fps[i], sampled_fps[j])\n", " distances.append(dist)\n", " pbar.update(1)\n", " \n", " return np.mean(distances) if distances else 0.0\n", "\n", "print(\"\\n5. Calculating internal diversity...\")\n", "internal_div = calculate_internal_diversity(generated_fingerprints, sample_size=1000)\n", "print(f\"āœ… Internal diversity: {internal_div:.4f}\")\n", "\n", "# ================================\n", "# 6. CLUSTERING\n", "# ================================\n", "def cluster_molecules(fingerprints, cutoff=0.7):\n", " \"\"\"Cluster molecules using Butina clustering\"\"\"\n", " sample_size = min(2000, len(fingerprints))\n", " if len(fingerprints) > sample_size:\n", " random.seed(42)\n", " sampled_fps = random.sample(fingerprints, sample_size)\n", " else:\n", " sampled_fps = fingerprints\n", " \n", " dists = []\n", " n = len(sampled_fps)\n", " total_pairs = n * (n - 1) // 2\n", " \n", " with tqdm(total=total_pairs, desc=\"Computing distance matrix\", unit=\"pair\") as pbar:\n", " for i in range(n):\n", " for j in range(i+1, n):\n", " dist = 1 - DataStructs.TanimotoSimilarity(sampled_fps[i], sampled_fps[j])\n", " dists.append(dist)\n", " pbar.update(1)\n", " \n", " print(\"Performing Butina clustering...\")\n", " cluster_indices = Butina.ClusterData(dists, n, 1 - cutoff, isDistData=True)\n", " \n", " return cluster_indices\n", "\n", "print(\"\\n6. Performing molecular clustering...\")\n", "cluster_indices = cluster_molecules(generated_fingerprints, cutoff=0.7)\n", "n_clusters = len(cluster_indices)\n", "cluster_sizes = [len(cluster) for cluster in cluster_indices]\n", "largest_cluster = max(cluster_sizes) if cluster_sizes else 0\n", "\n", "print(f\"āœ… Number of clusters: {n_clusters}\")\n", "print(f\"āœ… Average cluster size: {np.mean(cluster_sizes):.2f}\")\n", "print(f\"āœ… Largest cluster: {largest_cluster} molecules\")\n", "\n", "# ================================\n", "# 7. VISUALIZATION\n", "# ================================\n", "def visualize_clusters(fingerprints, smiles_list):\n", " \"\"\"Visualize clusters using PCA and t-SNE\"\"\"\n", " print(\"Preparing fingerprint array...\")\n", " fp_array = []\n", " for fp in tqdm(fingerprints, desc=\"Converting fingerprints\", unit=\"fp\", leave=False):\n", " arr = np.zeros((1,))\n", " DataStructs.ConvertToNumpyArray(fp, arr)\n", " fp_array.append(arr)\n", " \n", " fp_array = np.array(fp_array)\n", " \n", " if len(fp_array) > 1000:\n", " random.seed(42)\n", " indices = random.sample(range(len(fp_array)), 1000)\n", " fp_array = fp_array[indices]\n", " sampled_smiles = [smiles_list[i] for i in indices]\n", " else:\n", " sampled_smiles = smiles_list\n", " \n", " print(\"Computing PCA...\")\n", " pca = PCA(n_components=2)\n", " pca_coords = pca.fit_transform(fp_array)\n", " \n", " print(\"Computing t-SNE...\")\n", " tsne = TSNE(n_components=2, random_state=42, \n", " perplexity=min(30, len(fp_array)-1), max_iter=300)\n", " tsne_coords = tsne.fit_transform(fp_array)\n", " \n", " fig, axes = plt.subplots(1, 2, figsize=(15, 6))\n", " \n", " axes[0].scatter(pca_coords[:, 0], pca_coords[:, 1], alpha=0.6)\n", " axes[0].set_title(f'PCA of MACCS Fingerprints\\n({len(fp_array)} molecules)')\n", " axes[0].set_xlabel('PC1')\n", " axes[0].set_ylabel('PC2')\n", " \n", " axes[1].scatter(tsne_coords[:, 0], tsne_coords[:, 1], alpha=0.6)\n", " axes[1].set_title(f't-SNE of MACCS Fingerprints\\n({len(fp_array)} molecules)')\n", " axes[1].set_xlabel('t-SNE 1')\n", " axes[1].set_ylabel('t-SNE 2')\n", " \n", " plt.tight_layout()\n", " plt.savefig('molecular_clusters.png', dpi=300, bbox_inches='tight')\n", " plt.show()\n", "\n", "print(\"\\n7. Creating visualizations...\")\n", "if len(generated_fingerprints) > 10:\n", " visualize_clusters(generated_fingerprints, valid_smiles)\n", "\n", "# ================================\n", "# 8. FINAL SUMMARY\n", "# ================================\n", "print(\"\\n\" + \"=\"*60)\n", "print(\"šŸ“Š FINAL EVALUATION SUMMARY\")\n", "print(\"=\"*60)\n", "print(f\"Total generated: {num_samples}\")\n", "print(f\"Valid SMILES: {len(valid_smiles)} ({100 * len(valid_smiles)/num_samples:.1f}%)\")\n", "print(f\"Lipinski-compliant: {lipinski_pass} ({100 * lipinski_pass/len(valid_smiles):.1f}% of valid)\")\n", "print(f\"Internal diversity: {internal_div:.4f}\")\n", "print(f\"MACCS clusters (≄0.7): {n_clusters}\")\n", "print(f\"Average cluster size: {np.mean(cluster_sizes):.2f}\")\n", "print(f\"Largest cluster size: {largest_cluster}\")\n", "print(\"=\"*60)\n", "\n", "results = {\n", " 'total_generated': num_samples,\n", " 'valid_smiles': len(valid_smiles),\n", " 'validity_rate': len(valid_smiles)/num_samples,\n", " 'lipinski_compliant': lipinski_pass,\n", " 'lipinski_rate_valid': lipinski_pass/len(valid_smiles) if valid_smiles else 0,\n", " 'internal_diversity': internal_div,\n", " 'num_clusters': n_clusters,\n", " 'avg_cluster_size': np.mean(cluster_sizes) if cluster_sizes else 0,\n", " 'largest_cluster_size': largest_cluster\n", "}\n", "\n", "with open(\"generated_valid_2500.smi\", \"w\") as f:\n", " for smi in valid_smiles:\n", " f.write(smi + \"\\n\")\n", "\n", "print(f\"\\nāœ… Results dictionary created\")\n", "print(f\"āœ… Valid SMILES saved to 'generated_valid_2500.smi'\")\n", "print(f\"āœ… Cluster visualization saved as 'molecular_clusters.png'\")" ] }, { "cell_type": "code", "execution_count": null, "id": "c1ae9a02", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" } }, "nbformat": 4, "nbformat_minor": 5 }