{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "95d2a9e6", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Loaded 24148 SMILES strings.\n", "After deduplication: 24148 SMILES.\n", "Encoding SMILES... (this may take a while)\n" ] }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "c2d1275efb4f44f9869321a7e568e6a8", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batches: 0%| | 0/755 [00:00" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "\n", "# === 6. Query + Visualization ===\n", "query_smiles = [\"O=C1/C=C\\\\C=C2/N1C[C@@H]3CNC[C@H]2C3\"] # Your query\n", "\n", "# Encode & normalize query\n", "query_emb = model.encode(query_smiles, convert_to_numpy=True)\n", "faiss.normalize_L2(query_emb)\n", "\n", "k = 10\n", "distances, indices = index.search(query_emb, k)\n", "\n", "# Collect molecules and labels\n", "mols = []\n", "labels = []\n", "\n", "# Add query molecule (with label \"Query\")\n", "query_mol = Chem.MolFromSmiles(query_smiles[0])\n", "if query_mol is None:\n", " print(\"⚠️ Invalid query SMILES!\")\n", "else:\n", " mols.append(query_mol)\n", " labels.append(\"Query\")\n", "\n", "# Add top-k results\n", "for i in range(k):\n", " idx = indices[0][i]\n", " sim_score = distances[0][i]\n", " smi = smiles_list[idx]\n", " \n", " mol = Chem.MolFromSmiles(smi)\n", " if mol is not None:\n", " mols.append(mol)\n", " labels.append(f\"Sim: {sim_score:.3f}\")\n", " else:\n", " print(f\"⚠️ Invalid SMILES in result #{i+1}: {smi}\")\n", "\n", "# Plot\n", "n_mols = len(mols)\n", "if n_mols == 0:\n", " print(\"No valid molecules to display.\")\n", "else:\n", " # Create grid: 1 row, n_mols columns\n", " fig, axes = plt.subplots(1, n_mols, figsize=(4 * n_mols, 4))\n", " if n_mols == 1:\n", " axes = [axes] # make iterable\n", "\n", " # Render each molecule\n", " for i, (mol, label) in enumerate(zip(mols, labels)):\n", " img = Draw.MolToImage(mol, size=(300, 300))\n", " axes[i].imshow(img)\n", " axes[i].set_title(label, fontsize=12)\n", " axes[i].axis('off')\n", "\n", " plt.tight_layout()\n", " \n", " plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.13.0" } }, "nbformat": 4, "nbformat_minor": 5 }