{ "cells": [ { "cell_type": "markdown", "id": "5fa7908c-b785-49ce-9e5d-7c6ad6b4378b", "metadata": {}, "source": [ "## Imports and setup" ] }, { "cell_type": "code", "execution_count": 3, "id": "d0c96204-ea08-4330-b1bb-784b259ec32e", "metadata": {}, "outputs": [], "source": [ "import os\n", "import huggingface_hub" ] }, { "cell_type": "code", "execution_count": 4, "id": "6813e76b", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "The token has not been saved to the git credentials helper. Pass `add_to_git_credential=True` in this function directly or `--add-to-git-credential` if using via `huggingface-cli` if you want to set the git credential as well.\n", "Token is valid (permission: write).\n", "Your token has been saved to /share/kuleshov/yzs2/discrete-guidance/.hf_cache/token\n", "Login successful\n" ] } ], "source": [ "if os.path.exists(os.path.join(os.environ['HF_HOME'], 'token')):\n", " with open(os.path.join(os.environ['HF_HOME'], 'token'), 'r') as f:\n", " token = f.read().strip()\n", "else:\n", " token = None\n", "huggingface_hub.login(token=token)" ] }, { "cell_type": "code", "execution_count": 5, "id": "61cb2ac4", "metadata": {}, "outputs": [], "source": [ "import json\n", "import typing\n", "\n", "import datasets\n", "import numpy as np\n", "import pandas as pd\n", "import rdkit\n", "import transformers\n", "from rdkit import Chem as rdChem\n", "from rdkit.Chem import Crippen, QED\n", "from rdkit.Contrib.NP_Score import npscorer\n", "from rdkit.Contrib.SA_Score import sascorer\n", "from tqdm.auto import tqdm" ] }, { "cell_type": "code", "execution_count": 6, "id": "24444c85", "metadata": {}, "outputs": [], "source": [ "# TODO: Update to 2024.03.6 release when available instead of suppressing warning!\n", "# See: https://github.com/rdkit/rdkit/issues/7625#\n", "rdkit.rdBase.DisableLog('rdApp.warning')" ] }, { "cell_type": "markdown", "id": "902de4c5-dda5-4e4c-a4dd-f3b88015464e", "metadata": {}, "source": [ "## Create dataset" ] }, { "cell_type": "code", "execution_count": null, "id": "7b7a8986", "metadata": {}, "outputs": [], "source": [ "def parse_float(\n", " s: str\n", ") -> float:\n", " \"\"\"Parses floats potentially written as exponentiated values.\n", " \n", " Copied from https://www.kaggle.com/code/tawe141/extracting-data-from-qm9-xyz-files/code\n", " \"\"\"\n", " try:\n", " return float(s)\n", " except ValueError:\n", " base, power = s.split('*^')\n", " return float(base) * 10**float(power)\n", "\n", "\n", "def count_rings_and_bonds(\n", " mol: rdChem.Mol, max_ring_size: int = -1\n", ") -> typing.Dict[str, int]:\n", " \"\"\"Counts bond and ring (by type).\"\"\"\n", " \n", " # Counting rings\n", " ssr = rdChem.GetSymmSSSR(mol)\n", " ring_count = len(ssr)\n", " \n", " ring_sizes = {} if max_ring_size < 0 else {i: 0 for i in range(3, max_ring_size+1)}\n", " for ring in ssr:\n", " ring_size = len(ring)\n", " if ring_size not in ring_sizes:\n", " ring_sizes[ring_size] = 0\n", " ring_sizes[ring_size] += 1\n", " \n", " # Counting bond types\n", " bond_counts = {\n", " 'single': 0,\n", " 'double': 0,\n", " 'triple': 0,\n", " 'aromatic': 0\n", " }\n", " \n", " for bond in mol.GetBonds():\n", " if bond.GetIsAromatic():\n", " bond_counts['aromatic'] += 1\n", " elif bond.GetBondType() == rdChem.BondType.SINGLE:\n", " bond_counts['single'] += 1\n", " elif bond.GetBondType() == rdChem.BondType.DOUBLE:\n", " bond_counts['double'] += 1\n", " elif bond.GetBondType() == rdChem.BondType.TRIPLE:\n", " bond_counts['triple'] += 1\n", " result = {\n", " 'ring_count': ring_count,\n", " }\n", " for k, v in ring_sizes.items():\n", " result[f\"R{k}\"] = v\n", "\n", " for k, v in bond_counts.items():\n", " result[f\"{k}_bond\"] = v\n", " return result\n", "\n", "\n", "def parse_xyz(\n", " filename: str,\n", " max_ring_size: int = -1,\n", " npscorer_model: typing.Optional[dict] = None,\n", " array_format: str = 'np'\n", ") -> typing.Dict[str, typing.Any]:\n", " \"\"\"Parses QM9 specific xyz files. \n", " \n", " See https://www.nature.com/articles/sdata201422/tables/2 for reference.\n", " Adapted from https://www.kaggle.com/code/tawe141/extracting-data-from-qm9-xyz-files/code\n", " \"\"\"\n", " assert array_format in ['np', 'pt'], \\\n", " f\"Invalid array_format: `{array_format}` provided. Must be one of `np` (numpy.array), `pt` (torch.tensor).\"\n", " \n", " num_atoms = 0\n", " scalar_properties = []\n", " atomic_symbols = []\n", " xyz = []\n", " charges = []\n", " harmonic_vibrational_frequencies = []\n", " smiles = ''\n", " inchi = ''\n", " with open(filename, 'r') as f:\n", " for line_num, line in enumerate(f):\n", " if line_num == 0:\n", " num_atoms = int(line)\n", " elif line_num == 1:\n", " scalar_properties = [float(i) for i in line.split()[2:]]\n", " elif 2 <= line_num <= 1 + num_atoms:\n", " atom_symbol, x, y, z, charge = line.split()\n", " atomic_symbols.append(atom_symbol)\n", " xyz.append([parse_float(x), parse_float(y), parse_float(z)])\n", " charges.append(parse_float(charge))\n", " elif line_num == num_atoms + 2:\n", " harmonic_vibrational_frequencies = [float(i) for i in line.split()]\n", " elif line_num == num_atoms + 3:\n", " smiles = line.split()[0]\n", " elif line_num == num_atoms + 4:\n", " inchi = line.split()[0]\n", "\n", " array_wrap = np.array if array_format == 'np' else torch.tensor\n", " result = {\n", " 'num_atoms': num_atoms,\n", " 'atomic_symbols': atomic_symbols,\n", " 'pos': array_wrap(xyz),\n", " 'charges': array_wrap(charges),\n", " 'harmonic_oscillator_frequencies': array_wrap(harmonic_vibrational_frequencies),\n", " 'smiles': smiles,\n", " 'inchi': inchi\n", " }\n", " scalar_property_labels = [\n", " 'A', 'B', 'C', 'mu', 'alpha', 'homo', 'lumo', 'gap', 'r2', 'zpve', 'u0', 'u', 'h', 'g', 'cv'\n", " ] \n", " scalar_properties = dict(zip(scalar_property_labels, scalar_properties))\n", " result.update(scalar_properties)\n", "\n", " # RdKit\n", " result['canonical_smiles'] = rdChem.CanonSmiles(result['smiles'])\n", " m = rdChem.MolFromSmiles(result['canonical_smiles'])\n", " result['logP'] = Crippen.MolLogP(m)\n", " result['qed'] = QED.qed(m)\n", " if npscorer_model is not None:\n", " result['np_score'] = npscorer.scoreMol(m, npscorer_model)\n", " result['sa_score'] = sascorer.calculateScore(m)\n", " result.update(count_rings_and_bonds(m, max_ring_size=max_ring_size))\n", " \n", " return result" ] }, { "cell_type": "code", "execution_count": null, "id": "72254d85", "metadata": {}, "outputs": [], "source": [ "\"\"\"\n", " Download xyz files from:\n", " https://figshare.com/collections/Quantum_chemistry_structures_and_properties_of_134_kilo_molecules/978904\n", " > wget https://figshare.com/ndownloader/files/3195389/dsgdb9nsd.xyz.tar.bz2\n", " > mkdir dsgdb9nsd.xyz\n", " > tar -xvjf dsgdb9nsd.xyz.tar.bz2 -C dsgdb9nsd.xyz\n", "\"\"\"\n", "MAX_RING_SIZE = 9\n", "fscore = npscorer.readNPModel()\n", "xyz_dir_path = '/Users/yairschiff/Downloads/dsgdb9nsd.xyz'\n", "parsed_xyz = []\n", "for file in tqdm(sorted(os.listdir(xyz_dir_path)), desc='Parsing'):\n", " parsed = parse_xyz(os.path.join(xyz_dir_path, file),\n", " max_ring_size=MAX_RING_SIZE,\n", " npscorer_model=fscore,\n", " array_format='np')\n", " parsed_xyz.append(parsed)" ] }, { "cell_type": "code", "execution_count": null, "id": "12969dd2", "metadata": {}, "outputs": [], "source": [ "qm9_df = pd.DataFrame(data=parsed_xyz)" ] }, { "cell_type": "code", "execution_count": null, "id": "eed4f163", "metadata": {}, "outputs": [], "source": [ "# Conversion below is needed to avoid:\n", "# `ArrowInvalid: ('Can only convert 1-dimensional array values',\n", "# 'Conversion failed for column pos with type object')`\n", "qm9_df['pos'] = qm9_df['pos'].apply(lambda x: [xi for xi in x])" ] }, { "cell_type": "code", "execution_count": null, "id": "c912d23a", "metadata": {}, "outputs": [], "source": [ "dataset = datasets.Dataset.from_pandas(qm9_df)" ] }, { "cell_type": "code", "execution_count": null, "id": "7a7df506", "metadata": {}, "outputs": [], "source": [ "dataset.push_to_hub('yairschiff/qm9')" ] }, { "cell_type": "code", "execution_count": null, "id": "86c4e1ae", "metadata": {}, "outputs": [], "source": [ "# # Random train/test splits as recommended by:\n", "# # https://moleculenet.org/datasets-1\n", "# test_size = 0.1\n", "# seed = 1\n", "# dataset.train_test_split(test_size=test_size, seed=seed)" ] }, { "cell_type": "markdown", "id": "e982da1b-05ab-493b-bb82-8bf1225dcb2b", "metadata": {}, "source": [ "## Create tokenizer" ] }, { "cell_type": "code", "execution_count": 7, "id": "b0504e77", "metadata": {}, "outputs": [], "source": [ "def smi_tokenizer(smi):\n", " \"\"\"Tokenize a SMILES molecule or reaction.\n", "\n", " Copied from https://github.com/pschwllr/MolecularTransformer.\n", " \"\"\"\n", " import re\n", " pattern = \"(\\[[^\\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\\(|\\)|\\.|=|#|-|\\+|\\\\\\\\|\\/|:|~|@|\\?|>|\\*|\\$|\\%[0-9]{2}|[0-9])\"\n", " regex = re.compile(pattern)\n", " tokens = [token for token in regex.findall(smi)]\n", " assert smi == ''.join(tokens)\n", " return tokens" ] }, { "cell_type": "code", "execution_count": 8, "id": "b89a4def-ea08-466a-8779-24acf75a2bd0", "metadata": {}, "outputs": [], "source": [ "dataset = datasets.load_dataset('yairschiff/qm9', split='train')" ] }, { "cell_type": "code", "execution_count": 9, "id": "6ef61481-9384-4c1c-8361-ab858cb157ba", "metadata": {}, "outputs": [], "source": [ "# # If vocab file not created yet, uncomment and run this cell\n", "\n", "# tokens = []\n", "# for smi in dataset['canonical_smiles']:\n", "# tokens.extend(smi_tokenizer(smi))\n", "\n", "# with open('qm9_vocab.json', 'w', encoding='utf-8') as f:\n", "# f.write(\n", "# json.dumps(\n", "# {t: i for i, t in enumerate(sorted(set(tokens)))},\n", "# indent=2,\n", "# sort_keys=True,\n", "# ensure_ascii=False\n", "# ) + '\\n')" ] }, { "cell_type": "code", "execution_count": 9, "id": "6af7fccb-08ee-4dc6-99dc-cfa4fc38074c", "metadata": {}, "outputs": [], "source": [ "# # If HF tokenizer not yet published, uncomment and run this cell\n", "# import tokenizer\n", "\n", "# tokenizer.QM9Tokenizer.register_for_auto_class()\n", "# qm9_tokenizer = tokenizer.QM9Tokenizer(vocab_file='qm9_vocab.json')\n", "# qm9_tokenizer.push_to_hub('yairschiff/qm9-tokenizer')" ] }, { "cell_type": "code", "execution_count": 23, "id": "4cc39f16-b53c-481a-a35e-a42fb1b08378", "metadata": {}, "outputs": [], "source": [ "# Test tokenizer\n", "qm9_tokenizer = transformers.AutoTokenizer.from_pretrained(\n", " 'yairschiff/qm9-tokenizer', trust_remote_code=True, resume_download=None)\n", "print(dataset[1000]['canonical_smiles'])\n", "print(qm9_tokenizer.encode(dataset[1000]['canonical_smiles']))\n", "print(qm9_tokenizer.decode(qm9_tokenizer.encode(dataset[1000]['canonical_smiles'])))" ] }, { "cell_type": "code", "execution_count": null, "id": "41752e94-175e-4f40-b9d2-496241eab0c0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }