diff --git "a/2_train_GM12878_DNase.ipynb" "b/2_train_GM12878_DNase.ipynb" new file mode 100644--- /dev/null +++ "b/2_train_GM12878_DNase.ipynb" @@ -0,0 +1,1073 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9b5fb36a-bc93-4f66-b91f-8cc586c56e87", + "metadata": {}, + "source": [ + "# Train a single-task regression model from scratch" + ] + }, + { + "cell_type": "markdown", + "id": "f945eb76-8120-48cc-8983-073f70905d28", + "metadata": {}, + "source": [ + "In this tutorial, we train a single-task convolutional regression model to predict total coverage over DNase-seq peaks, starting from ENCODE DNase-seq read coverage file." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "db57f1fc-672a-4639-bed7-97b88d042762", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Using wandb-core as the SDK backend. Please refer to https://wandb.me/wandb-core for more information.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: You can find your API key in your browser here: https://wandb.ai/authorize\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Paste an API key from your profile and hit enter:" + ] + }, + { + "name": "stdin", + "output_type": "stream", + "text": [ + " ········\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m If you're specifying your api key in code, ensure this code is not shared publicly.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m Consider setting the WANDB_API_KEY environment variable, or running `wandb login` from the command line.\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mavantikalal\u001b[0m (\u001b[33mgrelu\u001b[0m) to \u001b[32mhttps://api.wandb.ai\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + } + ], + "source": [ + "import os\n", + "import numpy as np\n", + "import pandas as pd\n", + "import torch\n", + "\n", + "from plotnine import *\n", + "%matplotlib inline\n", + "\n", + "from grelu.io.bed import read_bed\n", + "from grelu.data.preprocess import split\n", + "from grelu.lightning import LightningModel\n", + "from grelu.data.dataset import BigWigSeqDataset\n", + "\n", + "import wandb\n", + "wandb.login(host='https://api.wandb.ai', relogin=True)\n", + "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"/code/github/gReLU-applications/VEP_benchmark/2_train_GM12878_DNase.ipynb\"\n", + "project_name='GM12878_dnase'" + ] + }, + { + "cell_type": "markdown", + "id": "f7c1abdb-00e7-4ebe-8d27-40e44ee4e69d", + "metadata": {}, + "source": [ + "## Model parameters" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "eb8ae740-cfce-4832-a2c4-b6c62eaec31f", + "metadata": {}, + "outputs": [], + "source": [ + "model_params = {\n", + " 'model_type':'DilatedConvModel',\n", + " 'crop_len':(2114-1000)//2,\n", + " 'n_tasks':1,\n", + " 'channels':512,\n", + " 'n_conv':9,\n", + "}\n", + "\n", + "train_params = {\n", + " 'task':'regression',\n", + " 'loss': 'mse', \n", + " 'logger':'wandb',\n", + " 'lr':1e-4,\n", + " 'batch_size':512,\n", + " 'max_epochs':15,\n", + " 'devices':0,\n", + " 'num_workers':16,\n", + " 'save_dir':'.',\n", + " 'checkpoint': {\"monitor\":\"val_pearson\", \"mode\":\"max\", \"save_last\":True}\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "432c3ff9-1fb5-4aee-943b-4f040f0cf4d9", + "metadata": {}, + "source": [ + "## Set up sweep" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "53c2f2cf-72f6-4b52-bcfd-cd4177f9eb2c", + "metadata": {}, + "outputs": [], + "source": [ + "sweep_params = {\n", + " 'method': 'grid',\n", + " 'name': project_name,\n", + " \n", + " 'metric': {\n", + " 'goal': 'minimize', \n", + " 'name': 'validation_loss',\n", + " },\n", + "\n", + " 'parameters': {\n", + " 'rc':{\n", + " 'distribution': 'categorical',\n", + " 'values':[False, True],\n", + " },\n", + " 'max_seq_shift':{\n", + " 'distribution': 'categorical',\n", + " 'values':[0, 1, 3],\n", + " },\n", + " 'max_pair_shift':{\n", + " 'distribution': 'categorical',\n", + " 'values':[0, 10, 50, 100],\n", + " },\n", + " }\n", + "}" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "1ed82124-b7ab-42ed-bafa-303cf4e4ada1", + "metadata": {}, + "outputs": [], + "source": [ + "def build_and_train_model(model_params=model_params, train_params=train_params):\n", + " \n", + " run = wandb.init(dir='.')\n", + "\n", + " artifact = run.use_artifact('GM12878_dnase/dataset:latest')\n", + " dir = artifact.download()\n", + " intervals = read_bed(os.path.join(dir, \"intervals.bed\"))\n", + "\n", + " train, val, test = split(intervals, val_chroms=[\"chr10\"], test_chroms=[\"chr11\"])\n", + " train_dataset = BigWigSeqDataset(\n", + " intervals = train, \n", + " bw_files=[\"ENCFF093VXI.bigWig\"],\n", + " label_len=1000, \n", + " label_aggfunc=\"sum\",\n", + " label_transform_func=np.log1p,\n", + " rc=wandb.config[\"rc\"],\n", + " max_seq_shift=wandb.config[\"max_seq_shift\"], \n", + " max_pair_shift=wandb.config[\"max_pair_shift\"],\n", + " augment_mode=\"random\", \n", + " seed=0, \n", + " genome='hg38'\n", + " )\n", + " val_dataset = BigWigSeqDataset(\n", + " intervals = val, \n", + " bw_files=[\"ENCFF093VXI.bigWig\"],\n", + " label_len=1000, label_aggfunc=\"sum\", \n", + " label_transform_func=np.log1p, genome='hg38'\n", + " )\n", + "\n", + " model = LightningModel(model_params=model_params, train_params=train_params)\n", + " model.train_on_dataset(train_dataset, val_dataset)" + ] + }, + { + "cell_type": "markdown", + "id": "24a8c194-d699-4146-a52d-96124beacc69", + "metadata": {}, + "source": [ + "## Run sweep" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ac4d52b7-b73b-4302-b747-78c52ba06712", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Create sweep with ID: udnpa77q\n", + "Sweep URL: https://wandb.ai/grelu/GM12878_dnase/sweeps/udnpa77q\n" + ] + } + ], + "source": [ + "sweep_id = wandb.sweep(sweep=sweep_params, project=project_name) " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "30688396-857a-4374-8c5f-7318261dae40", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Agent Starting Run: dcq2sg8p with config:\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tmax_pair_shift: 0\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \tmax_seq_shift: 0\n", + "\u001b[34m\u001b[1mwandb\u001b[0m: \trc: False\n" + ] + }, + { + "data": { + "text/html": [ + "Tracking run with wandb version 0.19.7" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Run data is saved locally in ./wandb/run-20250312_202916-dcq2sg8p" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "Syncing run stellar-sweep-1 to Weights & Biases (docs)
Sweep page: https://wandb.ai/grelu/GM12878_dnase/sweeps/udnpa77q" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View project at https://wandb.ai/grelu/GM12878_dnase" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View sweep at https://wandb.ai/grelu/GM12878_dnase/sweeps/udnpa77q" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + " View run at https://wandb.ai/grelu/GM12878_dnase/runs/dcq2sg8p" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: 1 of 1 files downloaded. \n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Selecting training samples\n", + "Keeping 390473 intervals\n", + "\n", + "\n", + "Selecting validation samples\n", + "Keeping 21987 intervals\n", + "\n", + "\n", + "Selecting test samples\n", + "Keeping 22595 intervals\n", + "Final sizes: train: (390473, 3), val: (21987, 3), test: (22595, 3)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "GPU available: True (cuda), used: True\n", + "TPU available: False, using: 0 TPU cores\n", + "HPU available: False, using: 0 HPUs\n", + "/opt/conda/lib/python3.11/site-packages/pytorch_lightning/loggers/wandb.py:396: There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`.\n", + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 43/43 [00:09<00:00, 4.44it/s]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The variance of predictions or target is close to zero. This can cause instability in Pearson correlationcoefficient, leading to wrong results. Consider re-scaling the input if possible or computing using alarger dtype (currently using torch.float32).\n" + ] + }, + { + "data": { + "text/html": [ + "
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃      Validate metric             DataLoader 0        ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│         val_loss               18.51033592224121     │\n",
+       "│          val_mse              18.508310317993164     │\n",
+       "│        val_pearson           -0.05173555389046669    │\n",
+       "└───────────────────────────┴───────────────────────────┘\n",
+       "
\n" + ], + "text/plain": [ + "┏━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1m Validate metric \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1m DataLoader 0 \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36m val_loss \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 18.51033592224121 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m val_mse \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m 18.508310317993164 \u001b[0m\u001b[35m \u001b[0m│\n", + "│\u001b[36m \u001b[0m\u001b[36m val_pearson \u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35m -0.05173555389046669 \u001b[0m\u001b[35m \u001b[0m│\n", + "└─────────────────��─────────┴───────────────────────────┘\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3,4,5,6,7]\n", + "\n", + " | Name | Type | Params | Mode \n", + "----------------------------------------------------------\n", + "0 | model | DilatedConvModel | 6.3 M | train\n", + "1 | loss | MSELoss | 0 | train\n", + "2 | activation | Identity | 0 | train\n", + "3 | val_metrics | MetricCollection | 0 | train\n", + "4 | test_metrics | MetricCollection | 0 | train\n", + "5 | transform | Identity | 0 | train\n", + "----------------------------------------------------------\n", + "6.3 M Trainable params\n", + "0 Non-trainable params\n", + "6.3 M Total params\n", + "25.358 Total estimated model params size (MB)\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + " \r" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/opt/conda/lib/python3.11/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The variance of predictions or target is close to zero. This can cause instability in Pearson correlationcoefficient, leading to wrong results. Consider re-scaling the input if possible or computing using alarger dtype (currently using torch.float32).\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 0: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 763/763 [07:41<00:00, 1.65it/s, v_num=sg8p, train_loss_step=0.591]\n", + "Validation: | | 0/? [00:00