{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "a1af2321-8860-4a3e-8406-a9ae587b97bf", "metadata": {}, "outputs": [], "source": [ "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "import selfies as sf\n", "from rdkit import Chem\n", "from typing import Optional\n", "import numpy as np\n", "import py3Dmol\n", "from rdkit import Chem, DataStructs\n", "from rdkit.Chem import AllChem\n", "import torch\n", "\n", "def smiles_to_3d(smiles_list, width=400, height=300):\n", " # Visualize the 3D structure using py3Dmol\n", " view = py3Dmol.view(width=width, height=height)\n", " for smiles in smiles_list:\n", " # Generate the RDKit molecule object\n", " mol = Chem.MolFromSmiles(smiles)\n", " if mol is None:\n", " raise ValueError(\"Invalid SMILES string\")\n", "\n", " # Add hydrogens to the molecule\n", " mol = Chem.AddHs(mol)\n", "\n", " # Generate 3D coordinates\n", " AllChem.EmbedMolecule(mol, randomSeed=42)\n", " AllChem.UFFOptimizeMolecule(mol)\n", "\n", " # Generate the 3D structure in the form of a pdb string\n", " pdb = Chem.MolToPDBBlock(mol)\n", " view.addModel(pdb, 'pdb')\n", " view.setStyle({'stick': {}})\n", " view.zoomTo()\n", " return view\n", "\n", " \n", "# Load the checkpoint and the tokenizer\n", "checkpoint_path = \"lamthuy/SelfiesGen\"\n", "model = AutoModelForCausalLM.from_pretrained(checkpoint_path)\n", "tokenizer = AutoTokenizer.from_pretrained(checkpoint_path)" ] }, { "cell_type": "code", "execution_count": 9, "id": "9b7066a4-6637-4d45-a0d9-3cc5e2ca0409", "metadata": {}, "outputs": [ { "data": { "application/3dmoljs_load.v0": "
\n

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n
\n", "text/html": [ "
\n", "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", "
\n", "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Given a SMILES, get its fingerpint\n", "smiles = \"CC(=O)OC1=CC=CC=C1C(=O)O\"\n", "smiles_to_3d([smiles])" ] }, { "cell_type": "code", "execution_count": 12, "id": "05f9bf21-c998-4d63-870e-c1033ff91b31", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[C][C][=Branch1][C][=O][O][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=Branch1][C][=S][O][SEP]\n", "[C][C][=Branch1][C][=O][O][C][=C][C][=C][C][=C][Ring1][=Branch1][C][=Branch1][C][=NH2+1][O]\n", "CC(=O)OC1=CC=CC=C1C(=[NH2+1])O\n" ] }, { "data": { "application/3dmoljs_load.v0": "
\n

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n
\n", "text/html": [ "
\n", "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", "
\n", "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "s = sf.encoder(smiles)\n", "s = s + \"[SEP]\"\n", "print(s)\n", "input_ids = tokenizer.encode(s, return_tensors=\"pt\")\n", "n = input_ids.size(1)\n", "# Generate output sequence\n", "output_ids = model.generate(input_ids, max_length=128, num_beams=5, num_return_sequences=5,\n", " early_stopping=True)\n", "output = tokenizer.decode(output_ids[1][n:], skip_special_tokens=True)\n", "print(output)\n", "smiles = sf.decoder(output)\n", "print(smiles)\n", "smiles_to_3d([smiles])" ] }, { "cell_type": "code", "execution_count": 13, "id": "dbe6cebd-c7d5-4da9-aac0-114f232cf147", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.\n", "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[C][C][=Branch1][C][=O][N][C][=C][C][=N][C][=C][Ring1][=Branch1][C][=Branch1][C][=S][O-1]\n", "CC(=O)NC1=CC=NC=C1C(=S)[O-1]\n" ] }, { "data": { "application/3dmoljs_load.v0": "
\n

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n
\n", "text/html": [ "
\n", "

3Dmol.js failed to load for some reason. Please check your browser console for error messages.

\n", "
\n", "" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_ids[0][5] = tokenizer.mask_token_id\n", "input_ids[0][9] = tokenizer.mask_token_id\n", "input_ids[0][18] = tokenizer.mask_token_id\n", "input_ids[0][11] = tokenizer.mask_token_id\n", "# Generate output sequence\n", "output_ids = model.generate(input_ids, max_length=128, num_beams=5, num_return_sequences=5,\n", " early_stopping=True)\n", "output = tokenizer.decode(output_ids[1][n:], skip_special_tokens=True)\n", "print(output)\n", "smiles = sf.decoder(output)\n", "print(smiles)\n", "smiles_to_3d([smiles])" ] }, { "cell_type": "code", "execution_count": null, "id": "f696bb9c-2870-4b0b-9b62-1623411e5df6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.10" } }, "nbformat": 4, "nbformat_minor": 5 }