diff --git "a/notebooks/03_fine_tuning.ipynb" "b/notebooks/03_fine_tuning.ipynb" --- "a/notebooks/03_fine_tuning.ipynb" +++ "b/notebooks/03_fine_tuning.ipynb" @@ -1,1433 +1,1371 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Simple PyTorch Tracks Fine-Tuning Pipeline\n", - "\n", - "This notebook implements a simple PyTorch-based deep learning pipeline for tracks prediction fine-tuning.\n", - "\n", - "## Overview\n", - "- Loads a HuggingFace model (NTv3) as backbone\n", - "- Adds a prediction head for bigwig tracks\n", - "- Fine-tunes on tracks prediction with a simple training loop\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# Install useful dependencies\n", - "# !pip install pyBigWig\n", - "# !pip install pyfaidx\n", - "# !pip install torchmetrics" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "metadata": {}, - "outputs": [], - "source": [ - "# 0. Imports\n", - "import random\n", - "import functools\n", - "from typing import List, Dict, Optional, Callable\n", - "import os\n", - "import subprocess\n", - "\n", - "import torch\n", - "import torch.nn as nn\n", - "import torch.nn.functional as F\n", - "from torch.utils.data import Dataset, DataLoader\n", - "from torch.optim import AdamW\n", - "from torch.optim.lr_scheduler import LambdaLR\n", - "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n", - "import numpy as np\n", - "import pyBigWig\n", - "from pyfaidx import Fasta\n", - "from torchmetrics import PearsonCorrCoef" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 1. Configuration setup" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Using device: cpu\n" - ] - } - ], - "source": [ - "config = {\n", - " # Model\n", - " \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\", # NTv3 model\n", - " \n", - " # Data\n", - " \"data_cache_dir\": \"./data\",\n", - " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\",\n", - " \"bigwig_url_list\": [\"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\"],\n", - " \"sequence_length\": 1_024,\n", - " \"keep_target_center_fraction\": 0.375,\n", - " \n", - " # Training\n", - " \"batch_size\": 2,\n", - " \"learning_rate\": 1e-5,\n", - " \"schedule\": True,\n", - " \"num_tokens_warmup\": 10000,\n", - " \"end_learning_rate\": 5e-5,\n", - " \"weight_decay\": 0.01,\n", - " \n", - " \"num_tokens_training\": 131_072, # Total training tokens budget\n", - " \"num_tokens_per_update\": 4_096, # Target tokens per optimizer update (batch_size * seq_len * grad_accum)\n", - " \"num_tokens_per_log\": 8_192, # Tokens between training logs\n", - " \"num_tokens_per_validation\": 16_384, # Tokens between validations\n", - " \n", - " # Validation\n", - " \"num_validation_samples\": 10,\n", - " \n", - " # Loss\n", - " \"bigwig_loss_weight\": 1.0,\n", - " \"bigwig_scalar_loss_function\": \"poisson-multinomial\",\n", - " \"bigwig_shape_loss_coefficient\": 5.0,\n", - " \n", - " # General\n", - " \"seed\": 42,\n", - " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\",\n", - " \"num_workers\": 0, # Number of worker processes for DataLoader\n", - "}\n", - "\n", - "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n", - "\n", - "# Extract filenames from URLs\n", - "def extract_filename_from_url(url: str) -> str:\n", - " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n", - " # Remove query parameters if present\n", - " url_clean = url.split('?')[0]\n", - " # Get the last part of the URL path\n", - " return url_clean.split('/')[-1]\n", - "\n", - "# Create paths for downloaded files\n", - "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n", - "bigwig_path_list = [\n", - " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n", - " for url in config[\"bigwig_url_list\"]\n", - "]\n", - "\n", - "# Create bigwig_file_ids from filenames (without extension)\n", - "config[\"bigwig_file_ids\"] = [\n", - " os.path.splitext(extract_filename_from_url(url))[0]\n", - " for url in config[\"bigwig_url_list\"]\n", - "]\n", - "\n", - "# Set random seed\n", - "torch.manual_seed(config[\"seed\"])\n", - "np.random.seed(config[\"seed\"])\n", - "\n", - "# Set device\n", - "device = torch.device(config[\"device\"])\n", - "print(f\"Using device: {device}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 2. Data download" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "--2025-12-10 14:47:06-- https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\n", - "Resolving hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)... 128.114.119.163\n", - "Connecting to hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)|128.114.119.163|:443... connected.\n", - "HTTP request sent, awaiting response... 200 OK\n", - "Length: 983659424 (938M) [application/x-gzip]\n", - "Saving to: './data/hg38.fa.gz'\n", - "\n", - "hg38.fa.gz 100%[===================>] 938.09M 10.4MB/s in 1m 43s \n", - "\n", - "2025-12-10 14:48:50 (9.09 MB/s) - './data/hg38.fa.gz' saved [983659424/983659424]\n", - "\n" - ] - } - ], - "source": [ - "# Download fasta file\n", - "!wget -c {config[\"fasta_url\"]} -P {config[\"data_cache_dir\"]}/ && gunzip -f {config[\"data_cache_dir\"]}/{config[\"fasta_url\"].split(os.path.sep)[-1]}" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading ENCFF884LDL.bigWig...\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "--2025-12-10 14:54:41-- https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\n", - "Resolving www.encodeproject.org (www.encodeproject.org)... 34.211.244.144\n", - "Connecting to www.encodeproject.org (www.encodeproject.org)|34.211.244.144|:443... connected.\n", - "HTTP request sent, awaiting response... 307 Temporary Redirect\n", - "Location: https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482 [following]\n", - "--2025-12-10 14:54:42-- https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482\n", - "Resolving encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)... 52.92.248.169, 52.92.211.49, 3.5.80.18, ...\n", - "Connecting to encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)|52.92.248.169|:443... connected.\n", - "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n", - "\n", - " The file is already fully retrieved; nothing to do.\n", - "\n" - ] - } - ], - "source": [ - "# Download bigwig files\n", - "for bigwig_url in config[\"bigwig_url_list\"]:\n", - " filename = extract_filename_from_url(bigwig_url)\n", - " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n", - " print(f\"Downloading {filename}...\")\n", - " subprocess.run([\"wget\", \"-c\", bigwig_url, \"-O\", filepath], check=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": {}, - "outputs": [], - "source": [ - "chrom_splits = {\n", - " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n", - " \"val\": ['chr22'],\n", - " \"test\": ['chr21']\n", - "}" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 3. Model and tokenizer setup" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "metadata": {}, - "outputs": [], - "source": [ - "class LinearHead(nn.Module):\n", - " \"\"\"A linear head that predicts one scalar value per track.\"\"\"\n", - " def __init__(self, embed_dim: int, num_labels: int):\n", - " super().__init__()\n", - " self.layer_norm = nn.LayerNorm(embed_dim)\n", - " self.head = nn.Linear(embed_dim, num_labels)\n", - " \n", - " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", - " x = self.layer_norm(x)\n", - " x = self.head(x)\n", - " x = F.softplus(x) # Ensure positive values\n", - " return x\n", - "\n", - "\n", - "class HFModelWithHead(nn.Module):\n", - " \"\"\"Simple model wrapper: HF backbone + bigwig head.\"\"\"\n", - " \n", - " def __init__(\n", - " self,\n", - " model_name: str,\n", - " bigwig_track_names: List[str],\n", - " keep_target_center_fraction: float = 0.375,\n", - " ):\n", - " super().__init__()\n", - " \n", - " # Load config and model\n", - " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n", - " self.backbone = AutoModelForMaskedLM.from_pretrained(\n", - " model_name, \n", - " trust_remote_code=True,\n", - " config=self.config\n", - " )\n", - " \n", - " self.keep_target_center_fraction = keep_target_center_fraction\n", - "\n", - " if hasattr(self.config, \"embed_dim\"):\n", - " embed_dim = self.config.embed_dim\n", - " else:\n", - " raise ValueError(f\"Could not determine embed_dim for {model_name}\")\n", - " \n", - " # Bigwig head (NTv3 outputs at single-nucleotide resolution)\n", - " self.bigwig_head = LinearHead(embed_dim, len(bigwig_track_names))\n", - " self.model_name = model_name\n", - " \n", - " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n", - " # Forward through backbone\n", - " outputs = self.backbone(input_ids=tokens)\n", - " embedding = outputs.hidden_states[-1] # Last hidden state\n", - " \n", - " # Crop to center fraction\n", - " if self.keep_target_center_fraction < 1.0:\n", - " seq_len = embedding.shape[1]\n", - " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n", - " target_length = seq_len - 2 * target_offset\n", - " embedding = embedding[:, target_offset:target_offset + target_length, :]\n", - " \n", - " # Predict bigwig tracks\n", - " bigwig_logits = self.bigwig_head(embedding)\n", - " \n", - " return {\"bigwig_tracks_logits\": bigwig_logits}" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Model loaded: InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\n", - "Number of bigwig tracks: 1\n", - "Model parameters: 7,693,244\n" - ] - } - ], - "source": [ - "# Load tokenizer\n", - "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n", - "if tokenizer.pad_token is None:\n", - " if tokenizer.eos_token is not None:\n", - " tokenizer.pad_token = tokenizer.eos_token\n", - " else:\n", - " tokenizer.add_special_tokens({\"pad_token\": \"[PAD]\"})\n", - "\n", - "# Create model\n", - "model = HFModelWithHead(\n", - " model_name=config[\"model_name\"],\n", - " bigwig_track_names=config[\"bigwig_file_ids\"],\n", - " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n", - ")\n", - "model = model.to(device)\n", - "model.train()\n", - "\n", - "print(f\"Model loaded: {config['model_name']}\")\n", - "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n", - "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 4. Data loading" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "metadata": {}, - "outputs": [], - "source": [ - "class GenomeBigWigDataset(Dataset):\n", - " \"\"\"\n", - " Random genomic windows from a reference genome + bigWig signal.\n", - "\n", - " Each sample:\n", - " - picks a chromosome from `chroms`,\n", - " - picks a random window of length `window_size`,\n", - " - returns (sequence, signal, chrom, start, end).\n", - "\n", - " Args\n", - " ----\n", - " fasta_path : str\n", - " Path to the reference genome FASTA (e.g. hg38.fna).\n", - " bigwig_path : str\n", - " Path to the bigWig file (e.g. ENCFF884LDL.bigWig).\n", - " chroms : List[str]\n", - " Chromosome names as they appear in the bigWig (e.g. [\"chr1\", \"chr2\", ...]).\n", - " window_size : int\n", - " Length of each random window (in bp).\n", - " num_samples : int\n", - " Number of samples the dataset will provide (len(dataset)).\n", - " chrom_mapping : Optional[Dict[str, str]]\n", - " Optional mapping from bigWig chrom name -> FASTA chrom name.\n", - " If None, assumes the same names in both.\n", - " Example for hg38 RefSeq FASTA:\n", - " {\n", - " \"chr1\": \"NC_000001.11\",\n", - " \"chr2\": \"NC_000002.12\",\n", - " ...\n", - " }\n", - " \"\"\"\n", - "\n", - " def __init__(\n", - " self,\n", - " fasta_path: str,\n", - " bigwig_path_list: list[str],\n", - " chroms: List[str],\n", - " sequence_length: int,\n", - " num_samples: int,\n", - " tokenizer: AutoTokenizer,\n", - " keep_target_center_fraction: float = 1.0,\n", - " num_tracks: int = 1,\n", - " ):\n", - " super().__init__()\n", - "\n", - " self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n", - " self.bw_list = [\n", - " pyBigWig.open(bigwig_path)\n", - " for bigwig_path in bigwig_path_list\n", - " ]\n", - " self.sequence_length = sequence_length\n", - " self.num_samples = num_samples\n", - " self.tokenizer = tokenizer\n", - " self.keep_target_center_fraction = keep_target_center_fraction\n", - " self.num_tracks = num_tracks\n", - " self.chroms = chroms\n", - "\n", - " # Intersect lengths between FASTA and bigWig for safety\n", - " bw_chrom_lengths = self.bw_list[0].chroms() # dict: chrom -> length\n", - "\n", - " self.valid_chroms = []\n", - " self.chrom_lengths = {}\n", - "\n", - " for c in chroms:\n", - " if c not in bw_chrom_lengths or c not in self.fasta:\n", - " continue\n", - "\n", - " fa_len = len(self.fasta[c])\n", - " bw_len = bw_chrom_lengths[c]\n", - " L = min(fa_len, bw_len)\n", - "\n", - " if L > self.sequence_length:\n", - " self.valid_chroms.append(c)\n", - " self.chrom_lengths[c] = L\n", - "\n", - " if not self.valid_chroms:\n", - " raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n", - "\n", - " def __len__(self):\n", - " return self.num_samples\n", - "\n", - " def __getitem__(self, idx):\n", - " # Ignore idx, sample randomly\n", - " chrom = random.choice(self.valid_chroms)\n", - " chrom_len = self.chrom_lengths[chrom]\n", - "\n", - " max_start = chrom_len - self.sequence_length\n", - " start = random.randint(0, max_start)\n", - " end = start + self.sequence_length\n", - "\n", - " # Sequence\n", - " seq = self.fasta[chrom][start:end] # string slice\n", - " tokens = self.tokenizer(\n", - " seq,\n", - " return_tensors=\"pt\", # Returns a dict of PyTorch tensors\n", - " )[\"input_ids\"][0]\n", - " # The 'input_ids' field contains the tokenized sequence.\n", - " # For a single input string, its shape is typically (1, len(seq))\n", - "\n", - " # Signal from bigWig tracks (numpy array) -> torch tensor\n", - " bigwig_targets = np.array([\n", - " self.bw_list[i].values(chrom, start, end, numpy=True)\n", - " for i in range(len(self.bw_list))\n", - " ]) # shape (num_tracks, seq_len)\n", - " # Transpose to (seq_len, num_tracks)\n", - " bigwig_targets = bigwig_targets.T\n", - " # pyBigWig returns NaN where no data; turn NaN into 0\n", - " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n", - " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n", - " \n", - " # Crop targets to center fraction\n", - " if self.keep_target_center_fraction < 1.0:\n", - " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n", - " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n", - " target_length = seq_len - 2 * target_offset\n", - " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n", - "\n", - " sample = {\n", - " \"tokens\": tokens,\n", - " \"bigwig_targets\": bigwig_targets,\n", - " \"chrom\": chrom,\n", - " \"start\": start,\n", - " \"end\": end,\n", - " }\n", - " return sample" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train samples: 100\n", - "Val samples: 10\n", - "Test samples: 10\n" - ] - } - ], - "source": [ - "create_dataset_fn = functools.partial(\n", - " GenomeBigWigDataset,\n", - " fasta_path=fasta_path,\n", - " bigwig_path_list=bigwig_path_list,\n", - " sequence_length=config[\"sequence_length\"],\n", - " tokenizer=tokenizer,\n", - " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n", - " num_tracks=len(config[\"bigwig_file_ids\"]),\n", - ")\n", - "\n", - "train_dataset = create_dataset_fn(\n", - " chroms=chrom_splits[\"train\"],\n", - " num_samples=100,\n", - ")\n", - "\n", - "val_dataset = create_dataset_fn(\n", - " chroms=chrom_splits[\"val\"],\n", - " num_samples=config[\"num_validation_samples\"],\n", - ")\n", - "\n", - "test_dataset = create_dataset_fn(\n", - " chroms=chrom_splits[\"test\"],\n", - " num_samples=config[\"num_validation_samples\"],\n", - ")\n", - "\n", - "# Create dataloaders\n", - "train_loader = DataLoader(\n", - " train_dataset,\n", - " batch_size=config[\"batch_size\"],\n", - " shuffle=True,\n", - " num_workers=config[\"num_workers\"],\n", - ")\n", - "\n", - "val_loader = DataLoader(\n", - " val_dataset,\n", - " batch_size=config[\"batch_size\"],\n", - " shuffle=False,\n", - " num_workers=config[\"num_workers\"],\n", - ")\n", - "\n", - "test_loader = DataLoader(\n", - " test_dataset,\n", - " batch_size=config[\"batch_size\"],\n", - " shuffle=False,\n", - " num_workers=config[\"num_workers\"],\n", - ")\n", - "\n", - "print(f\"Train samples: {len(train_dataset)}\")\n", - "print(f\"Val samples: {len(val_dataset)}\")\n", - "print(f\"Test samples: {len(test_dataset)}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 5. Optimizer and Learning Rate Scheduler" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "metadata": {}, - "outputs": [], - "source": [ - "# Learning rate scheduler utils\n", - "def _modified_square_decay(\n", - " current_step: int,\n", - " lr_at_step_0: float,\n", - " lr_peak_after_warmup: float,\n", - " num_warmup_steps: int,\n", - " num_training_steps: int,\n", - ") -> float:\n", - " \"\"\"\n", - " Learning rate schedule with linear warmup and square root decay.\n", - " Simplified version of the pipeline's scheduler.\n", - " \"\"\"\n", - " if current_step < num_warmup_steps:\n", - " # Linear warmup\n", - " return lr_at_step_0 + (lr_peak_after_warmup - lr_at_step_0) * (current_step / num_warmup_steps)\n", - " else:\n", - " # Square root decay\n", - " progress = (current_step - num_warmup_steps) / (num_training_steps - num_warmup_steps)\n", - " decay_factor = (1.0 - progress) ** 0.5\n", - " return lr_peak_after_warmup * decay_factor" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Gradient accumulation steps: 2\n", - "Effective batch size: 4\n", - "Effective tokens per update: 4096\n", - "\n", - "Training constants:\n", - " Total training steps: 32\n", - " Log training metrics every: 2 steps\n", - " Run validation every: 4 steps\n", - " Warmup steps: 3\n", - "\n", - "Optimizer setup:\n", - " Initial LR: 1e-05\n", - " Peak LR: 5e-05\n" - ] - } - ], - "source": [ - "# Calculate gradient accumulation steps and effective batch size\n", - "num_devices = 1 # Single device for now\n", - "sequence_length = config[\"sequence_length\"]\n", - "batch_size = config[\"batch_size\"]\n", - "\n", - "# Calculate gradient accumulation steps\n", - "num_accumulation_gradient = max(1, int(config[\"num_tokens_per_update\"] // (batch_size * num_devices * sequence_length)))\n", - "\n", - "# Calculate effective batch size and tokens per update\n", - "effective_batch_size = batch_size * num_devices * num_accumulation_gradient\n", - "effective_num_tokens_per_update = effective_batch_size * sequence_length\n", - "\n", - "print(f\"Gradient accumulation steps: {num_accumulation_gradient}\")\n", - "print(f\"Effective batch size: {effective_batch_size}\")\n", - "print(f\"Effective tokens per update: {effective_num_tokens_per_update}\")\n", - "\n", - "# Compute logging constants (based on deepspeed pipeline: compute_logging_constants)\n", - "num_train_samples = len(train_dataset)\n", - "num_tokens_per_update = effective_num_tokens_per_update # Same as effective_num_tokens_per_update\n", - "\n", - "# Total training steps based on token budget\n", - "num_steps_training = config[\"num_tokens_training\"] // num_tokens_per_update\n", - "\n", - "# Steps for logging and validation\n", - "log_train_step = int(np.ceil(config[\"num_tokens_per_log\"] / num_tokens_per_update))\n", - "log_validation_step = int(np.ceil(config[\"num_tokens_per_validation\"] / num_tokens_per_update))\n", - "\n", - "# Warmup steps\n", - "num_warmup_steps = max(1, int(np.ceil(config[\"num_tokens_warmup\"] / effective_num_tokens_per_update)))\n", - "\n", - "print(f\"\\nTraining constants:\")\n", - "print(f\" Total training steps: {num_steps_training}\")\n", - "print(f\" Log training metrics every: {log_train_step} steps\")\n", - "print(f\" Run validation every: {log_validation_step} steps\")\n", - "print(f\" Warmup steps: {num_warmup_steps}\")\n", - "\n", - "# Setup optimizer\n", - "optimizer = AdamW(\n", - " model.parameters(),\n", - " lr=config[\"end_learning_rate\"] if config[\"schedule\"] else config[\"learning_rate\"],\n", - " weight_decay=config[\"weight_decay\"],\n", - ")\n", - "\n", - "# Setup scheduler\n", - "if config[\"schedule\"]:\n", - " lr_scheduler_fn = lambda step: _modified_square_decay(\n", - " current_step=step,\n", - " lr_at_step_0=config[\"learning_rate\"],\n", - " lr_peak_after_warmup=config[\"end_learning_rate\"],\n", - " num_warmup_steps=num_warmup_steps,\n", - " num_training_steps=num_steps_training,\n", - " )\n", - " scheduler = LambdaLR(optimizer, lr_lambda=lr_scheduler_fn)\n", - "else:\n", - " scheduler = None\n", - "\n", - "print(f\"\\nOptimizer setup:\")\n", - "print(f\" Initial LR: {config['learning_rate']}\")\n", - "print(f\" Peak LR: {config['end_learning_rate']}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 6. Metrics setup (using TorchMetrics)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [], - "source": [ - "class TracksMetrics:\n", - " \"\"\"Simple metrics tracker for tracks prediction with both scaled and raw metrics.\"\"\"\n", - " \n", - " def __init__(self, track_names: List[str]):\n", - " self.track_names = track_names\n", - " self.num_tracks = len(track_names)\n", - " # Scaled metrics: comparing scaled targets with scaled predictions\n", - " self.pearson_metrics_scaled = [\n", - " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n", - " ]\n", - " # Raw metrics: comparing raw targets with unscaled predictions\n", - " self.pearson_metrics_raw = [\n", - " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n", - " ]\n", - " self.losses = []\n", - " \n", - " def reset(self):\n", - " for metric in self.pearson_metrics_scaled:\n", - " metric.reset()\n", - " for metric in self.pearson_metrics_raw:\n", - " metric.reset()\n", - " self.losses = []\n", - " \n", - " def update(\n", - " self, \n", - " predictions_scaled: torch.Tensor, \n", - " targets_scaled: torch.Tensor,\n", - " predictions_raw: torch.Tensor,\n", - " targets_raw: torch.Tensor,\n", - " loss: float\n", - " ):\n", - " \"\"\"\n", - " Update both scaled and raw metrics.\n", - " Args:\n", - " predictions_scaled: (batch, seq_len, num_tracks) - scaled predictions\n", - " targets_scaled: (batch, seq_len, num_tracks) - scaled targets\n", - " predictions_raw: (batch, seq_len, num_tracks) - raw/unscaled predictions\n", - " targets_raw: (batch, seq_len, num_tracks) - raw targets\n", - " loss: scalar loss value\n", - " \"\"\"\n", - " # Flatten batch and sequence dimensions\n", - " pred_scaled_flat = predictions_scaled.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", - " target_scaled_flat = targets_scaled.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", - " pred_raw_flat = predictions_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", - " target_raw_flat = targets_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", - " \n", - " # Update scaled metrics\n", - " for i, metric in enumerate(self.pearson_metrics_scaled):\n", - " metric.update(pred_scaled_flat[:, i], target_scaled_flat[:, i])\n", - " \n", - " # Update raw metrics\n", - " for i, metric in enumerate(self.pearson_metrics_raw):\n", - " metric.update(pred_raw_flat[:, i], target_raw_flat[:, i])\n", - " \n", - " self.losses.append(loss)\n", - " \n", - " def compute(self) -> Dict[str, float]:\n", - " \"\"\"Compute and return all metrics (both scaled and raw).\"\"\"\n", - " metrics_dict = {}\n", - " \n", - " # Scaled metrics: per-track Pearson correlations\n", - " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_scaled)):\n", - " corr = metric.compute().item()\n", - " metrics_dict[f\"metrics_scaled/{track_name}/pearson\"] = corr\n", - " \n", - " # Scaled metrics: mean Pearson correlation\n", - " correlations_scaled = [metric.compute().item() for metric in self.pearson_metrics_scaled]\n", - " metrics_dict[\"metrics_scaled/mean/pearson\"] = np.nanmean(correlations_scaled)\n", - " \n", - " # Raw metrics: per-track Pearson correlations\n", - " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_raw)):\n", - " corr = metric.compute().item()\n", - " metrics_dict[f\"metrics_raw/{track_name}/pearson\"] = corr\n", - " \n", - " # Raw metrics: mean Pearson correlation\n", - " correlations_raw = [metric.compute().item() for metric in self.pearson_metrics_raw]\n", - " metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n", - " \n", - " # Mean loss\n", - " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n", - " \n", - " return metrics_dict" - ] - }, + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Simple PyTorch Tracks Fine-Tuning Pipeline\n", + "\n", + "This notebook implements a simple PyTorch-based deep learning pipeline for tracks prediction fine-tuning.\n", + "\n", + "## Overview\n", + "- Loads a HuggingFace model (NTv3) as backbone\n", + "- Adds a prediction head for bigwig tracks\n", + "- Fine-tunes on tracks prediction with a simple training loop\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install useful dependencies\n", + "# !pip install pyBigWig\n", + "# !pip install pyfaidx\n", + "# !pip install torchmetrics" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "# 0. Imports\n", + "import random\n", + "import functools\n", + "from typing import List, Dict, Optional, Callable\n", + "import os\n", + "import subprocess\n", + "\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import Dataset, DataLoader\n", + "from torch.optim import AdamW\n", + "from transformers import AutoConfig, AutoModelForMaskedLM, AutoTokenizer\n", + "import numpy as np\n", + "import pyBigWig\n", + "from pyfaidx import Fasta\n", + "from torchmetrics import PearsonCorrCoef" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. Configuration setup" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 22, - "metadata": {}, - "outputs": [], - "source": [ - "train_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n", - "val_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n", - "test_metrics = TracksMetrics(config[\"bigwig_file_ids\"])" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Using device: cpu\n" + ] + } + ], + "source": [ + "config = {\n", + " # Model\n", + " \"model_name\": \"InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\", # HuggingFace model name/identifier\n", + " \n", + " # Data\n", + " \"data_cache_dir\": \"./data\", # Directory where downloaded data files (FASTA, bigWig) will be stored\n", + " \"fasta_url\": \"https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\", # URL to download reference genome FASTA file\n", + " \"bigwig_url_list\": [\"https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\"], # List of URLs for bigWig track files to download\n", + " \"sequence_length\": 1_024, # Length of input sequences in base pairs (bp)\n", + " \"keep_target_center_fraction\": 0.375, # Fraction of center sequence to keep for target prediction (crops edges to focus on center)\n", + " \n", + " # Training\n", + " \"batch_size\": 2, # Number of samples per batch\n", + " \"learning_rate\": 1e-5, # Constant learning rate for optimizer\n", + " \"weight_decay\": 0.01, # L2 regularization coefficient for optimizer\n", + " \n", + " \"num_tokens_training\": 131_072, # Total training tokens budget (determines total training steps)\n", + " \"num_tokens_per_update\": 4_096, # Target tokens per optimizer update (batch_size * seq_len * grad_accum)\n", + " \"num_tokens_per_log\": 8_192, # Tokens between training logs (how often to print metrics)\n", + " \"num_tokens_per_validation\": 16_384, # Tokens between validation runs (how often to evaluate on validation set)\n", + " \n", + " # Validation\n", + " \"num_validation_samples\": 10, # Number of samples to use for validation set\n", + " \n", + " # Loss\n", + " \"bigwig_loss_weight\": 1.0, # Weight multiplier for bigwig prediction loss\n", + " \"bigwig_scalar_loss_function\": \"poisson-multinomial\", # Loss function type for bigwig tracks\n", + " \"bigwig_shape_loss_coefficient\": 5.0, # Coefficient balancing shape loss vs scale loss in poisson-multinomial loss\n", + " \n", + " # General\n", + " \"seed\": 42, # Random seed for reproducibility\n", + " \"device\": \"cuda\" if torch.cuda.is_available() else \"cpu\", # Device to run training on (\"cuda\" or \"cpu\")\n", + " \"num_workers\": 0, # Number of worker processes for DataLoader (0 = single-threaded)\n", + "}\n", + "\n", + "os.makedirs(config[\"data_cache_dir\"], exist_ok=True)\n", + "\n", + "# Extract filenames from URLs\n", + "def extract_filename_from_url(url: str) -> str:\n", + " \"\"\"Extract filename from URL, handling query parameters.\"\"\"\n", + " # Remove query parameters if present\n", + " url_clean = url.split('?')[0]\n", + " # Get the last part of the URL path\n", + " return url_clean.split('/')[-1]\n", + "\n", + "# Create paths for downloaded files\n", + "fasta_path = os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(config[\"fasta_url\"]).replace('.gz', ''))\n", + "bigwig_path_list = [\n", + " os.path.join(config[\"data_cache_dir\"], extract_filename_from_url(url))\n", + " for url in config[\"bigwig_url_list\"]\n", + "]\n", + "\n", + "# Create bigwig_file_ids from filenames (without extension)\n", + "config[\"bigwig_file_ids\"] = [\n", + " os.path.splitext(extract_filename_from_url(url))[0]\n", + " for url in config[\"bigwig_url_list\"]\n", + "]\n", + "\n", + "# Set random seed\n", + "torch.manual_seed(config[\"seed\"])\n", + "np.random.seed(config[\"seed\"])\n", + "\n", + "# Set device\n", + "device = torch.device(config[\"device\"])\n", + "print(f\"Using device: {device}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. Data download" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 7. Scaling functions setup (copied from pipeline)" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "--2025-12-10 14:47:06-- https://hgdownload.gi.ucsc.edu/goldenPath/hg38/bigZips/hg38.fa.gz\n", + "Resolving hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)... 128.114.119.163\n", + "Connecting to hgdownload.gi.ucsc.edu (hgdownload.gi.ucsc.edu)|128.114.119.163|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 983659424 (938M) [application/x-gzip]\n", + "Saving to: './data/hg38.fa.gz'\n", + "\n", + "hg38.fa.gz 100%[===================>] 938.09M 10.4MB/s in 1m 43s \n", + "\n", + "2025-12-10 14:48:50 (9.09 MB/s) - './data/hg38.fa.gz' saved [983659424/983659424]\n", + "\n" + ] + } + ], + "source": [ + "# Download fasta file\n", + "!wget -c {config[\"fasta_url\"]} -P {config[\"data_cache_dir\"]}/ && gunzip -f {config[\"data_cache_dir\"]}/{config[\"fasta_url\"].split(os.path.sep)[-1]}" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Scaling functions created\n" - ] - } - ], - "source": [ - "def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n", - " \"\"\"\n", - " Get track means for normalization.\n", - " For now, return dummy values. In real pipeline, this loads from metadata.\n", - " \"\"\"\n", - " # Dummy values - in real pipeline, this would load from actual metadata\n", - " return np.ones(len(bigwig_file_ids), dtype=np.float32) * 1.0\n", - "\n", - "\n", - "def get_rna_seq_track_ids(bigwig_file_ids: List[str]) -> List[int]:\n", - " \"\"\"\n", - " Get RNA-seq track indices.\n", - " For now, return empty list. In real pipeline, this identifies RNA-seq tracks.\n", - " \"\"\"\n", - " # Dummy - in real pipeline, this would identify RNA-seq tracks\n", - " return []\n", - "\n", - "\n", - "def create_targets_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n", - " \"\"\"\n", - " Build a scaling function based on track means and RNA-seq squashing.\n", - " Copied from the supervised tracks pipeline.\n", - " \"\"\"\n", - " # Load track means\n", - " track_means_np = get_track_means(bigwig_file_ids)\n", - " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n", - " \n", - " # Get which tracks use squashing\n", - " rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n", - " apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n", - " if len(rna_ids) > 0:\n", - " apply_squashing[rna_ids] = True\n", - " \n", - " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"\n", - " x: torch.Tensor, shape (batch, seq_len, num_tracks)\n", - " \"\"\"\n", - " device = x.device\n", - " \n", - " # Move constants to correct device\n", - " means = track_means.to(device)\n", - " squash_mask = apply_squashing.to(device)\n", - " \n", - " # Normalize\n", - " scaled = x / means\n", - " \n", - " # Power squashing where needed\n", - " squashed = torch.where(\n", - " squash_mask.view(1, 1, -1),\n", - " scaled.pow(0.75),\n", - " scaled,\n", - " )\n", - " \n", - " # Smooth clipping: if > 10, apply formula\n", - " clipped = torch.where(\n", - " squashed > 10.0,\n", - " 2.0 * torch.sqrt(squashed * 10.0) - 10.0,\n", - " squashed,\n", - " )\n", - " \n", - " return clipped\n", - " \n", - " return transform_fn\n", - "\n", - "\n", - "def create_predictions_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n", - " \"\"\"\n", - " Inverse scaling function to apply on predictions before computing metrics.\n", - " Copied from the supervised tracks pipeline.\n", - " \"\"\"\n", - " # Load means\n", - " track_means_np = get_track_means(bigwig_file_ids)\n", - " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n", - " \n", - " # RNA-seq mask\n", - " rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n", - " apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n", - " if len(rna_ids) > 0:\n", - " apply_squashing[rna_ids] = True\n", - " \n", - " def inverse_transform_fn(x: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"\n", - " x: torch.Tensor, shape (batch, seq_len, num_tracks)\n", - " \"\"\"\n", - " device = x.device\n", - " means = track_means.to(device)\n", - " squash_mask = apply_squashing.to(device)\n", - " \n", - " # Undo clipping\n", - " unclipped = torch.where(\n", - " x > 10.0,\n", - " (x + 10.0).pow(2) / (4 * 10.0),\n", - " x,\n", - " )\n", - " \n", - " # Undo squashing\n", - " unsquashed = torch.where(\n", - " squash_mask.view(1, 1, -1),\n", - " unclipped.pow(1.0 / 0.75),\n", - " unclipped,\n", - " )\n", - " \n", - " # Undo normalization\n", - " return unsquashed * means\n", - " \n", - " return inverse_transform_fn\n", - "\n", - "\n", - "# Create scaling functions\n", - "scale_targets_fn = create_targets_scaling_fn(config[\"bigwig_file_ids\"])\n", - "scale_predictions_fn = create_predictions_scaling_fn(config[\"bigwig_file_ids\"])\n", - "\n", - "print(\"Scaling functions created\")" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading ENCFF884LDL.bigWig...\n" + ] }, { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 8. Loss functions" - ] - }, + "name": "stderr", + "output_type": "stream", + "text": [ + "--2025-12-10 14:54:41-- https://www.encodeproject.org/files/ENCFF884LDL/@@download/ENCFF884LDL.bigWig\n", + "Resolving www.encodeproject.org (www.encodeproject.org)... 34.211.244.144\n", + "Connecting to www.encodeproject.org (www.encodeproject.org)|34.211.244.144|:443... connected.\n", + "HTTP request sent, awaiting response... 307 Temporary Redirect\n", + "Location: https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482 [following]\n", + "--2025-12-10 14:54:42-- https://encode-public.s3.amazonaws.com/2020/09/19/425880b6-b323-4ee2-95ce-56bdd088d126/ENCFF884LDL.bigWig?response-content-disposition=attachment%3B%20filename%3DENCFF884LDL.bigWig&AWSAccessKeyId=ASIATGZNGCNXU6SGJVOL&Signature=4o0Pp2RvJtnZc9z7HOuCU1k9wwI%3D&x-amz-security-token=IQoJb3JpZ2luX2VjEA0aCXVzLXdlc3QtMiJGMEQCIEdyOOxtHk6rJT06xIjzZR3nVyqbPB1twIFxCDtIQfNXAiAph1lc69CfHzPPglodVnVh9QCjlsXHFyUEU3K0%2Bx%2F%2Bziq8BQjW%2F%2F%2F%2F%2F%2F%2F%2F%2F%2F8BEAAaDDIyMDc0ODcxNDg2MyIMYwkeEaXuk%2BE48EDAKpAFkm4uzCSB40oRz3YT4m%2FZfBSH7XIuSCuzS7nrL5tXb9Q2rfPQSD4PHOyTR0LOOfcr98%2FyF8cJw4NE%2Fwsw8BRs4xPFEEyN6yGqwHmAyxBuwdca4GLSMGRDaSPoleMJw1FcSv96ofbZFYTTSol4b6%2FZj4jJjCa887%2F6S5x9kNIjTAtgX%2Fr3Ci4wi4FXGKTijTU%2FnbuuLZ3Cz2UobD6p732apsayl7avmUdWbUvROl3sHFOWOGCKsmDv0mavyEu2EsHxniBPfECy00BNvf%2Bj2FDaz1BImMIDavVBSwcWk8uCPjbsccsgiuKAfwr3dOXQ7R6y4NwmuFluBqn1GOXw1K13T4LrF%2BrhmqdOWeIVKB%2Bo9vnfQm1Dws6EoyS%2BG0bWDnyuUnLtWGf4cZPA6kjcM14fspFxoMnLjHBfdpYKZ3VmikbgwE8mDaiHODH1WQ36lUPigKbbIeHqOnHTIEw5h6F8D0MfIdVBSV2HCXweIlxCr6%2FV8hy2RzDouzT%2FIH%2FIobhHjGPM%2FlmkLAcfEzS2fioCJwkqQ3F%2BC77alAhtDQ4Oy5OIxRnRHVLpO%2BMA9Ml0SrEegCGPIzLucuCtbj2UTEOnBRQXyMolyySopJZb4p4BpJ6MiitLyCt1C66lvJpX5oMri%2BVD7FcTgdPYxcqM%2FMLD%2B4XqTYh5wdK7EYe3CpsVjpviZSVbn7yVHAb8WqdmFO%2BXRGhjQdN6rMrwGPiMCmQq12tTQftfmEwPGN1CVHG%2BbL1KUpEF4BRE61xDwEu7ZXyycPqTJMKHVn%2BXZ%2BxFsaxpUsp25U6JIVVPiNgt1OyhfjU6oqzwzeXH7KMRIcqz2d%2B3p%2BIbjRvoHcLc8AzgY4RvgWMGlb5gIpv15HQTDvdiLLwwjd3lyQY6sgE9t%2Bhi2Jv1DPgJN0YUGblcTV3Ey95h%2BBIXo6zWGwqhyZhkH%2ByxJKXouv2S1mKS3BM0dp2maJGDp69Mze8UkGjFYvdzxHT1zrCZ4dMRRkRObY3%2F4ZP33ogelhzchd7S76et35vYwYHd9DYycWZnJ%2FIcfpSZURGMJu3gLM3YhIscykGwQKqB21Tmyjufi0AaYyLk4w2OKc31kgjFvs6lNaHhqTuFButuHEiBUMzieixOI%2BX6&Expires=1765504482\n", + "Resolving encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)... 52.92.248.169, 52.92.211.49, 3.5.80.18, ...\n", + "Connecting to encode-public.s3.amazonaws.com (encode-public.s3.amazonaws.com)|52.92.248.169|:443... connected.\n", + "HTTP request sent, awaiting response... 416 Requested Range Not Satisfiable\n", + "\n", + " The file is already fully retrieved; nothing to do.\n", + "\n" + ] + } + ], + "source": [ + "# Download bigwig files\n", + "for bigwig_url in config[\"bigwig_url_list\"]:\n", + " filename = extract_filename_from_url(bigwig_url)\n", + " filepath = os.path.join(config[\"data_cache_dir\"], filename)\n", + " print(f\"Downloading {filename}...\")\n", + " subprocess.run([\"wget\", \"-c\", bigwig_url, \"-O\", filepath], check=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [], + "source": [ + "chrom_splits = {\n", + " \"train\": [f\"chr{i}\" for i in range(1, 21)] + ['chrX', 'chrY'],\n", + " \"val\": ['chr22'],\n", + " \"test\": ['chr21']\n", + "}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 3. Model and tokenizer setup" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [], + "source": [ + "class LinearHead(nn.Module):\n", + " \"\"\"A linear head that predicts one scalar value per track.\"\"\"\n", + " def __init__(self, embed_dim: int, num_labels: int):\n", + " super().__init__()\n", + " self.layer_norm = nn.LayerNorm(embed_dim)\n", + " self.head = nn.Linear(embed_dim, num_labels)\n", + " \n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + " x = self.layer_norm(x)\n", + " x = self.head(x)\n", + " x = F.softplus(x) # Ensure positive values\n", + " return x\n", + "\n", + "\n", + "class HFModelWithHead(nn.Module):\n", + " \"\"\"Simple model wrapper: HF backbone + bigwig head.\"\"\"\n", + " \n", + " def __init__(\n", + " self,\n", + " model_name: str,\n", + " bigwig_track_names: List[str],\n", + " keep_target_center_fraction: float = 0.375,\n", + " ):\n", + " super().__init__()\n", + " \n", + " # Load config and model\n", + " self.config = AutoConfig.from_pretrained(model_name, trust_remote_code=True)\n", + " self.backbone = AutoModelForMaskedLM.from_pretrained(\n", + " model_name, \n", + " trust_remote_code=True,\n", + " config=self.config\n", + " )\n", + " \n", + " self.keep_target_center_fraction = keep_target_center_fraction\n", + "\n", + " if hasattr(self.config, \"embed_dim\"):\n", + " embed_dim = self.config.embed_dim\n", + " else:\n", + " raise ValueError(f\"Could not determine embed_dim for {model_name}\")\n", + " \n", + " # Bigwig head (NTv3 outputs at single-nucleotide resolution)\n", + " self.bigwig_head = LinearHead(embed_dim, len(bigwig_track_names))\n", + " self.model_name = model_name\n", + " \n", + " def forward(self, tokens: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:\n", + " # Forward through backbone\n", + " outputs = self.backbone(input_ids=tokens)\n", + " embedding = outputs.hidden_states[-1] # Last hidden state\n", + " \n", + " # Crop to center fraction\n", + " if self.keep_target_center_fraction < 1.0:\n", + " seq_len = embedding.shape[1]\n", + " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n", + " target_length = seq_len - 2 * target_offset\n", + " embedding = embedding[:, target_offset:target_offset + target_length, :]\n", + " \n", + " # Predict bigwig tracks\n", + " bigwig_logits = self.bigwig_head(embedding)\n", + " \n", + " return {\"bigwig_tracks_logits\": bigwig_logits}" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 24, - "metadata": {}, - "outputs": [], - "source": [ - "def poisson_loss(ytrue: torch.Tensor, ypred: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:\n", - " \"\"\"Poisson loss per element: ypred - ytrue * log(ypred).\"\"\"\n", - " return ypred - ytrue * torch.log(ypred + epsilon)\n", - "\n", - "\n", - "def safe_for_grad_log_torch(x: torch.Tensor) -> torch.Tensor:\n", - " \"\"\"Guarantees that the log is defined for all x > 0 in a differentiable way.\"\"\"\n", - " return torch.log(torch.where(x > 0.0, x, torch.ones_like(x)))\n", - "\n", - "\n", - "def poisson_multinomial_loss(\n", - " logits: torch.Tensor,\n", - " targets: torch.Tensor,\n", - " mask: torch.Tensor | None = None,\n", - " shape_loss_coefficient: float = 5.0,\n", - " epsilon: float = 1e-7,\n", - ") -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n", - " \"\"\"\n", - " Regression loss for bigwig tracks (MSE, Poisson, or Poisson-Multinomial).\n", - " \"\"\"\n", - " scale_loss, shape_loss = None, None\n", - " \n", - " if mask is None:\n", - " mask = torch.ones_like(targets, dtype=torch.float32, device=targets.device)\n", - " else:\n", - " mask = mask.float()\n", - " \n", - " mask_sum = mask.sum() + epsilon\n", - " masked_logits = logits * mask\n", - " masked_targets = targets * mask\n", - "\n", - " # Scale loss\n", - " mask_sum_per_track_per_seq = mask.sum(dim=1) # (batch, num_tracks)\n", - " mask_per_sequence = mask_sum_per_track_per_seq > 0.0 # (batch, num_tracks)\n", - " \n", - " sum_pred = masked_logits.sum(dim=1) # (batch, num_tracks)\n", - " sum_true = masked_targets.sum(dim=1) # (batch, num_tracks)\n", - " \n", - " scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon)\n", - " scale_loss = scale_loss / (mask_sum_per_track_per_seq + epsilon)\n", - " \n", - " if mask_per_sequence.any():\n", - " scale_loss_filtered = scale_loss[mask_per_sequence]\n", - " scale_loss = scale_loss_filtered.mean()\n", - " else:\n", - " scale_loss = torch.tensor(0.0, device=targets.device, dtype=targets.dtype)\n", - " \n", - " # Shape loss\n", - " predicted_counts = masked_logits + (epsilon * mask)\n", - " masked_targets_with_epsilon = masked_targets + (epsilon * mask)\n", - " \n", - " denom = predicted_counts.sum(dim=1, keepdim=True) + epsilon\n", - " p_pred = predicted_counts / denom\n", - " \n", - " pl_pred = safe_for_grad_log_torch(p_pred)\n", - " shape_loss = -(masked_targets_with_epsilon * pl_pred).sum() / mask_sum\n", - " \n", - " # Combine\n", - " loss = shape_loss + scale_loss / shape_loss_coefficient\n", - "\n", - " return loss, scale_loss, shape_loss\n" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Model loaded: InstaDeepAI/ntv3_8M_7downsample_pretrained_le_1mb\n", + "Number of bigwig tracks: 1\n", + "Model parameters: 7,693,244\n" + ] + } + ], + "source": [ + "# Load tokenizer\n", + "tokenizer = AutoTokenizer.from_pretrained(config[\"model_name\"], trust_remote_code=True)\n", + "\n", + "# Create model\n", + "model = HFModelWithHead(\n", + " model_name=config[\"model_name\"],\n", + " bigwig_track_names=config[\"bigwig_file_ids\"],\n", + " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n", + ")\n", + "model = model.to(device)\n", + "model.train()\n", + "\n", + "print(f\"Model loaded: {config['model_name']}\")\n", + "print(f\"Number of bigwig tracks: {len(config['bigwig_file_ids'])}\")\n", + "print(f\"Model parameters: {sum(p.numel() for p in model.parameters()):,}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 4. Data loading" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "class GenomeBigWigDataset(Dataset):\n", + " \"\"\"\n", + " Random genomic windows from a reference genome + bigWig signal.\n", + "\n", + " Each sample:\n", + " - picks a chromosome from `chroms`,\n", + " - picks a random window of length `window_size`,\n", + " - returns (sequence, signal, chrom, start, end).\n", + "\n", + " Args\n", + " ----\n", + " fasta_path : str\n", + " Path to the reference genome FASTA (e.g. hg38.fna).\n", + " bigwig_path : str\n", + " Path to the bigWig file (e.g. ENCFF884LDL.bigWig).\n", + " chroms : List[str]\n", + " Chromosome names as they appear in the bigWig (e.g. [\"chr1\", \"chr2\", ...]).\n", + " window_size : int\n", + " Length of each random window (in bp).\n", + " num_samples : int\n", + " Number of samples the dataset will provide (len(dataset)).\n", + " chrom_mapping : Optional[Dict[str, str]]\n", + " Optional mapping from bigWig chrom name -> FASTA chrom name.\n", + " If None, assumes the same names in both.\n", + " Example for hg38 RefSeq FASTA:\n", + " {\n", + " \"chr1\": \"NC_000001.11\",\n", + " \"chr2\": \"NC_000002.12\",\n", + " ...\n", + " }\n", + " \"\"\"\n", + "\n", + " def __init__(\n", + " self,\n", + " fasta_path: str,\n", + " bigwig_path_list: list[str],\n", + " chroms: List[str],\n", + " sequence_length: int,\n", + " num_samples: int,\n", + " tokenizer: AutoTokenizer,\n", + " keep_target_center_fraction: float = 1.0,\n", + " num_tracks: int = 1,\n", + " ):\n", + " super().__init__()\n", + "\n", + " self.fasta = Fasta(fasta_path, as_raw=True, sequence_always_upper=True)\n", + " self.bw_list = [\n", + " pyBigWig.open(bigwig_path)\n", + " for bigwig_path in bigwig_path_list\n", + " ]\n", + " self.sequence_length = sequence_length\n", + " self.num_samples = num_samples\n", + " self.tokenizer = tokenizer\n", + " self.keep_target_center_fraction = keep_target_center_fraction\n", + " self.num_tracks = num_tracks\n", + " self.chroms = chroms\n", + "\n", + " # Intersect lengths between FASTA and bigWig for safety\n", + " bw_chrom_lengths = self.bw_list[0].chroms() # dict: chrom -> length\n", + "\n", + " self.valid_chroms = []\n", + " self.chrom_lengths = {}\n", + "\n", + " for c in chroms:\n", + " if c not in bw_chrom_lengths or c not in self.fasta:\n", + " continue\n", + "\n", + " fa_len = len(self.fasta[c])\n", + " bw_len = bw_chrom_lengths[c]\n", + " L = min(fa_len, bw_len)\n", + "\n", + " if L > self.sequence_length:\n", + " self.valid_chroms.append(c)\n", + " self.chrom_lengths[c] = L\n", + "\n", + " if not self.valid_chroms:\n", + " raise ValueError(\"No valid chromosomes after intersecting FASTA and bigWig.\")\n", + "\n", + " def __len__(self):\n", + " return self.num_samples\n", + "\n", + " def __getitem__(self, idx):\n", + " # Ignore idx, sample randomly\n", + " chrom = random.choice(self.valid_chroms)\n", + " chrom_len = self.chrom_lengths[chrom]\n", + "\n", + " max_start = chrom_len - self.sequence_length\n", + " start = random.randint(0, max_start)\n", + " end = start + self.sequence_length\n", + "\n", + " # Sequence\n", + " seq = self.fasta[chrom][start:end] # string slice\n", + " tokens = self.tokenizer(\n", + " seq,\n", + " return_tensors=\"pt\", # Returns a dict of PyTorch tensors\n", + " )[\"input_ids\"][0]\n", + " # The 'input_ids' field contains the tokenized sequence.\n", + " # For a single input string, its shape is typically (1, len(seq))\n", + "\n", + " # Signal from bigWig tracks (numpy array) -> torch tensor\n", + " bigwig_targets = np.array([\n", + " self.bw_list[i].values(chrom, start, end, numpy=True)\n", + " for i in range(len(self.bw_list))\n", + " ]) # shape (num_tracks, seq_len)\n", + " # Transpose to (seq_len, num_tracks)\n", + " bigwig_targets = bigwig_targets.T\n", + " # pyBigWig returns NaN where no data; turn NaN into 0\n", + " bigwig_targets = torch.tensor(bigwig_targets, dtype=torch.float32)\n", + " bigwig_targets = torch.nan_to_num(bigwig_targets, nan=0.0)\n", + " \n", + " # Crop targets to center fraction\n", + " if self.keep_target_center_fraction < 1.0:\n", + " seq_len = bigwig_targets.shape[0] # First dimension is sequence length\n", + " target_offset = int(seq_len * (1 - self.keep_target_center_fraction) // 2)\n", + " target_length = seq_len - 2 * target_offset\n", + " bigwig_targets = bigwig_targets[target_offset:target_offset + target_length, :]\n", + "\n", + " sample = {\n", + " \"tokens\": tokens,\n", + " \"bigwig_targets\": bigwig_targets,\n", + " \"chrom\": chrom,\n", + " \"start\": start,\n", + " \"end\": end,\n", + " }\n", + " return sample" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 9. Training loop" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Train samples: 100\n", + "Val samples: 10\n", + "Test samples: 10\n" + ] + } + ], + "source": [ + "create_dataset_fn = functools.partial(\n", + " GenomeBigWigDataset,\n", + " fasta_path=fasta_path,\n", + " bigwig_path_list=bigwig_path_list,\n", + " sequence_length=config[\"sequence_length\"],\n", + " tokenizer=tokenizer,\n", + " keep_target_center_fraction=config[\"keep_target_center_fraction\"],\n", + " num_tracks=len(config[\"bigwig_file_ids\"]),\n", + ")\n", + "\n", + "train_dataset = create_dataset_fn(\n", + " chroms=chrom_splits[\"train\"],\n", + " num_samples=100,\n", + ")\n", + "\n", + "val_dataset = create_dataset_fn(\n", + " chroms=chrom_splits[\"val\"],\n", + " num_samples=config[\"num_validation_samples\"],\n", + ")\n", + "\n", + "test_dataset = create_dataset_fn(\n", + " chroms=chrom_splits[\"test\"],\n", + " num_samples=config[\"num_validation_samples\"],\n", + ")\n", + "\n", + "# Create dataloaders\n", + "train_loader = DataLoader(\n", + " train_dataset,\n", + " batch_size=config[\"batch_size\"],\n", + " shuffle=True,\n", + " num_workers=config[\"num_workers\"],\n", + ")\n", + "\n", + "val_loader = DataLoader(\n", + " val_dataset,\n", + " batch_size=config[\"batch_size\"],\n", + " shuffle=False,\n", + " num_workers=config[\"num_workers\"],\n", + ")\n", + "\n", + "test_loader = DataLoader(\n", + " test_dataset,\n", + " batch_size=config[\"batch_size\"],\n", + " shuffle=False,\n", + " num_workers=config[\"num_workers\"],\n", + ")\n", + "\n", + "print(f\"Train samples: {len(train_dataset)}\")\n", + "print(f\"Val samples: {len(val_dataset)}\")\n", + "print(f\"Test samples: {len(test_dataset)}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 5. Optimizer setup\n" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 25, - "metadata": {}, - "outputs": [], - "source": [ - "def train_step(\n", - " model: nn.Module,\n", - " batch: Dict[str, torch.Tensor],\n", - " optimizer: torch.optim.Optimizer,\n", - " scale_targets_fn: Callable,\n", - " config: Dict,\n", - " num_accumulation_steps: int = 1,\n", - ") -> float:\n", - " \"\"\"Single training step with gradient accumulation support.\"\"\"\n", - " tokens = batch[\"tokens\"].to(device)\n", - " bigwig_targets = batch[\"bigwig_targets\"].to(device) # Shape: (batch, seq_len_cropped, num_tracks)\n", - " \n", - " # Forward pass\n", - " outputs = model(tokens=tokens)\n", - " bigwig_logits = outputs[\"bigwig_tracks_logits\"] # Shape: (batch, cropped_seq_len, num_tracks)\n", - " \n", - " # Scale targets\n", - " scaled_targets = scale_targets_fn(bigwig_targets)\n", - " \n", - " # Compute loss\n", - " loss, _, _ = poisson_multinomial_loss(\n", - " logits=bigwig_logits,\n", - " targets=scaled_targets,\n", - " shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n", - " )\n", - " \n", - " # Scale loss by accumulation steps (for gradient accumulation)\n", - " loss = loss / num_accumulation_steps\n", - " \n", - " # Backward pass (accumulate gradients)\n", - " loss.backward()\n", - " \n", - " return loss.item() * num_accumulation_steps # Return unscaled loss for logging\n", - "\n", - "\n", - "def validation_step(\n", - " model: nn.Module,\n", - " batch: Dict[str, torch.Tensor],\n", - " scale_targets_fn: Callable,\n", - " scale_predictions_fn: Callable,\n", - " metrics: TracksMetrics,\n", - " config: Dict,\n", - ") -> float:\n", - " \"\"\"Single validation step.\"\"\"\n", - " model.eval()\n", - " \n", - " tokens = batch[\"tokens\"].to(device)\n", - " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n", - " \n", - " with torch.no_grad():\n", - " # Forward pass\n", - " outputs = model(tokens=tokens)\n", - " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n", - " \n", - " # Scale targets for loss computation\n", - " scaled_targets = scale_targets_fn(bigwig_targets)\n", - " \n", - " # Compute loss (using scaled targets)\n", - " loss, _, _ = poisson_multinomial_loss(\n", - " logits=bigwig_logits,\n", - " targets=scaled_targets,\n", - " shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n", - " )\n", - " \n", - " # Scale predictions back to original space for metrics\n", - " # (predictions are in scaled space, need to inverse transform)\n", - " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n", - " \n", - " # Update metrics (using original space targets and predictions)\n", - " metrics.update(\n", - " predictions_scaled=bigwig_logits,\n", - " targets_scaled=scaled_targets,\n", - " predictions_raw=unscaled_predictions,\n", - " targets_raw=bigwig_targets,\n", - " loss=loss.item()\n", - " )\n", - " \n", - " return loss.item()" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Gradient accumulation steps: 2\n", + "Effective batch size: 4\n", + "Effective tokens per update: 4096\n", + "\n", + "Training constants:\n", + " Total training steps: 32\n", + " Log training metrics every: 2 steps\n", + " Run validation every: 4 steps\n", + "\n", + "Optimizer setup:\n", + " Learning rate: 1e-05\n" + ] + } + ], + "source": [ + "# Calculate gradient accumulation steps and effective batch size\n", + "num_devices = 1 # Single device for now\n", + "sequence_length = config[\"sequence_length\"]\n", + "batch_size = config[\"batch_size\"]\n", + "\n", + "# Calculate gradient accumulation steps\n", + "num_accumulation_gradient = max(1, int(config[\"num_tokens_per_update\"] // (batch_size * num_devices * sequence_length)))\n", + "\n", + "# Calculate effective batch size and tokens per update\n", + "effective_batch_size = batch_size * num_devices * num_accumulation_gradient\n", + "effective_num_tokens_per_update = effective_batch_size * sequence_length\n", + "\n", + "print(f\"Gradient accumulation steps: {num_accumulation_gradient}\")\n", + "print(f\"Effective batch size: {effective_batch_size}\")\n", + "print(f\"Effective tokens per update: {effective_num_tokens_per_update}\")\n", + "\n", + "# Compute logging constants (based on deepspeed pipeline: compute_logging_constants)\n", + "num_train_samples = len(train_dataset)\n", + "num_tokens_per_update = effective_num_tokens_per_update # Same as effective_num_tokens_per_update\n", + "\n", + "# Total training steps based on token budget\n", + "num_steps_training = config[\"num_tokens_training\"] // num_tokens_per_update\n", + "\n", + "# Steps for logging and validation\n", + "log_train_step = int(np.ceil(config[\"num_tokens_per_log\"] / num_tokens_per_update))\n", + "log_validation_step = int(np.ceil(config[\"num_tokens_per_validation\"] / num_tokens_per_update))\n", + "\n", + "print(f\"\\nTraining constants:\")\n", + "print(f\" Total training steps: {num_steps_training}\")\n", + "print(f\" Log training metrics every: {log_train_step} steps\")\n", + "print(f\" Run validation every: {log_validation_step} steps\")\n", + "\n", + "# Setup optimizer\n", + "optimizer = AdamW(\n", + " model.parameters(),\n", + " lr=config[\"learning_rate\"],\n", + " weight_decay=config[\"weight_decay\"],\n", + ")\n", + "\n", + "print(f\"\\nOptimizer setup:\")\n", + "print(f\" Learning rate: {config['learning_rate']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 6. Metrics setup (using TorchMetrics)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "class TracksMetrics:\n", + " \"\"\"Simple metrics tracker for tracks prediction with both scaled and raw metrics.\"\"\"\n", + " \n", + " def __init__(self, track_names: List[str]):\n", + " self.track_names = track_names\n", + " self.num_tracks = len(track_names)\n", + " # Scaled metrics: comparing scaled targets with scaled predictions\n", + " self.pearson_metrics_scaled = [\n", + " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n", + " ]\n", + " # Raw metrics: comparing raw targets with unscaled predictions\n", + " self.pearson_metrics_raw = [\n", + " PearsonCorrCoef().to(device) for _ in range(self.num_tracks)\n", + " ]\n", + " self.losses = []\n", + " \n", + " def reset(self):\n", + " for metric in self.pearson_metrics_scaled:\n", + " metric.reset()\n", + " for metric in self.pearson_metrics_raw:\n", + " metric.reset()\n", + " self.losses = []\n", + " \n", + " def update(\n", + " self, \n", + " predictions_scaled: torch.Tensor, \n", + " targets_scaled: torch.Tensor,\n", + " predictions_raw: torch.Tensor,\n", + " targets_raw: torch.Tensor,\n", + " loss: float\n", + " ):\n", + " \"\"\"\n", + " Update both scaled and raw metrics.\n", + " Args:\n", + " predictions_scaled: (batch, seq_len, num_tracks) - scaled predictions\n", + " targets_scaled: (batch, seq_len, num_tracks) - scaled targets\n", + " predictions_raw: (batch, seq_len, num_tracks) - raw/unscaled predictions\n", + " targets_raw: (batch, seq_len, num_tracks) - raw targets\n", + " loss: scalar loss value\n", + " \"\"\"\n", + " # Flatten batch and sequence dimensions\n", + " pred_scaled_flat = predictions_scaled.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", + " target_scaled_flat = targets_scaled.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", + " pred_raw_flat = predictions_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", + " target_raw_flat = targets_raw.detach().reshape(-1, self.num_tracks) # (N, num_tracks)\n", + " \n", + " # Update scaled metrics\n", + " for i, metric in enumerate(self.pearson_metrics_scaled):\n", + " metric.update(pred_scaled_flat[:, i], target_scaled_flat[:, i])\n", + " \n", + " # Update raw metrics\n", + " for i, metric in enumerate(self.pearson_metrics_raw):\n", + " metric.update(pred_raw_flat[:, i], target_raw_flat[:, i])\n", + " \n", + " self.losses.append(loss)\n", + " \n", + " def compute(self) -> Dict[str, float]:\n", + " \"\"\"Compute and return all metrics (both scaled and raw).\"\"\"\n", + " metrics_dict = {}\n", + " \n", + " # Scaled metrics: per-track Pearson correlations\n", + " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_scaled)):\n", + " corr = metric.compute().item()\n", + " metrics_dict[f\"metrics_scaled/{track_name}/pearson\"] = corr\n", + " \n", + " # Scaled metrics: mean Pearson correlation\n", + " correlations_scaled = [metric.compute().item() for metric in self.pearson_metrics_scaled]\n", + " metrics_dict[\"metrics_scaled/mean/pearson\"] = np.nanmean(correlations_scaled)\n", + " \n", + " # Raw metrics: per-track Pearson correlations\n", + " for i, (track_name, metric) in enumerate(zip(self.track_names, self.pearson_metrics_raw)):\n", + " corr = metric.compute().item()\n", + " metrics_dict[f\"metrics_raw/{track_name}/pearson\"] = corr\n", + " \n", + " # Raw metrics: mean Pearson correlation\n", + " correlations_raw = [metric.compute().item() for metric in self.pearson_metrics_raw]\n", + " metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n", + " \n", + " # Mean loss\n", + " metrics_dict[\"loss\"] = np.mean(self.losses) if self.losses else 0.0\n", + " \n", + " return metrics_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [], + "source": [ + "train_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n", + "val_metrics = TracksMetrics(config[\"bigwig_file_ids\"])\n", + "test_metrics = TracksMetrics(config[\"bigwig_file_ids\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 7. Scaling functions setup (copied from pipeline)" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 26, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Starting training...\n", - "Training for 32 steps with 2 gradient accumulation steps\n", - "\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.\n", - "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n", - " warnings.warn(error_message)\n", - "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The variance of predictions or target is close to zero. This can cause instability in Pearson correlationcoefficient, leading to wrong results. Consider re-scaling the input if possible or computing using alarger dtype (currently using torch.float32). Setting the correlation coefficient to nan.\n", - " warnings.warn(*args, **kwargs)\n", - "/tmp/ipykernel_1758159/1960846655.py:68: RuntimeWarning: Mean of empty slice\n", - " metrics_dict[\"metrics_scaled/mean/pearson\"] = np.nanmean(correlations_scaled)\n", - "/tmp/ipykernel_1758159/1960846655.py:77: RuntimeWarning: Mean of empty slice\n", - " metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Step 1/32 | Loss: 0.8378 | Mean Pearson: nan | LR: 1.17e-09 | Tokens: 4,096\n", - "\n", - "Running validation at step 0...\n", - " Validation Loss: 0.5279\n", - " Validation Mean Pearson: -0.0192\n", - " ENCFF884LDL/pearson: -0.0192\n", - "Step 3/32 | Loss: 0.4650 | Mean Pearson: -0.0149 | LR: 2.50e-09 | Tokens: 12,288\n", - "Step 5/32 | Loss: 0.3369 | Mean Pearson: -0.1350 | LR: 2.41e-09 | Tokens: 20,480\n", - "\n", - "Running validation at step 4...\n", - " Validation Loss: 0.3878\n", - " Validation Mean Pearson: -0.1298\n", - " ENCFF884LDL/pearson: -0.1298\n", - "Step 7/32 | Loss: 0.3609 | Mean Pearson: -0.0102 | LR: 2.32e-09 | Tokens: 28,672\n", - "Step 9/32 | Loss: 0.3301 | Mean Pearson: -0.0902 | LR: 2.23e-09 | Tokens: 36,864\n", - "\n", - "Running validation at step 8...\n", - " Validation Loss: 0.4743\n", - " Validation Mean Pearson: -0.0739\n", - " ENCFF884LDL/pearson: -0.0739\n", - "Step 11/32 | Loss: 0.3905 | Mean Pearson: -0.0113 | LR: 2.13e-09 | Tokens: 45,056\n", - "Step 13/32 | Loss: 0.3181 | Mean Pearson: -0.1564 | LR: 2.02e-09 | Tokens: 53,248\n", - "\n", - "Running validation at step 12...\n", - " Validation Loss: 0.3337\n", - " Validation Mean Pearson: -0.0650\n", - " ENCFF884LDL/pearson: -0.0650\n", - "Step 15/32 | Loss: 0.3638 | Mean Pearson: 0.0295 | LR: 1.91e-09 | Tokens: 61,440\n", - "Step 17/32 | Loss: 0.4170 | Mean Pearson: -0.0442 | LR: 1.80e-09 | Tokens: 69,632\n", - "\n", - "Running validation at step 16...\n", - " Validation Loss: 0.7969\n", - " Validation Mean Pearson: -0.0304\n", - " ENCFF884LDL/pearson: -0.0304\n", - "Step 19/32 | Loss: 0.5033 | Mean Pearson: -0.0173 | LR: 1.67e-09 | Tokens: 77,824\n", - "Step 21/32 | Loss: 0.4084 | Mean Pearson: -0.0516 | LR: 1.54e-09 | Tokens: 86,016\n", - "\n", - "Running validation at step 20...\n", - " Validation Loss: 0.3475\n", - " Validation Mean Pearson: -0.3040\n", - " ENCFF884LDL/pearson: -0.3040\n", - "Step 23/32 | Loss: 0.4915 | Mean Pearson: -0.1727 | LR: 1.39e-09 | Tokens: 94,208\n", - "Step 25/32 | Loss: 0.3654 | Mean Pearson: -0.3257 | LR: 1.23e-09 | Tokens: 102,400\n", - "\n", - "Running validation at step 24...\n", - " Validation Loss: 0.4069\n", - " Validation Mean Pearson: -0.0551\n", - " ENCFF884LDL/pearson: -0.0551\n", - "Step 27/32 | Loss: 0.5344 | Mean Pearson: -0.0604 | LR: 1.04e-09 | Tokens: 110,592\n", - "Step 29/32 | Loss: 0.3671 | Mean Pearson: -0.0290 | LR: 8.04e-10 | Tokens: 118,784\n", - "\n", - "Running validation at step 28...\n", - " Validation Loss: 0.3162\n", - " Validation Mean Pearson: -0.1008\n", - " ENCFF884LDL/pearson: -0.1008\n", - "Step 31/32 | Loss: 0.5994 | Mean Pearson: -0.0107 | LR: 4.64e-10 | Tokens: 126,976\n", - "\n", - "Training completed after 32 steps!\n" - ] - } - ], - "source": [ - "# Training loop (step-based with gradient accumulation)\n", - "print(\"Starting training...\")\n", - "print(f\"Training for {num_steps_training} steps with {num_accumulation_gradient} gradient accumulation steps\\n\")\n", - "\n", - "model.train()\n", - "train_metrics.reset()\n", - "optimizer.zero_grad() # Initialize gradients\n", - "\n", - "# Create iterator for training data (will cycle if needed)\n", - "train_iter = iter(train_loader)\n", - "num_tokens_seen = 0\n", - "\n", - "# Main training loop: for loop over optimizer steps (like deepspeed pipeline)\n", - "for optimizer_step_idx in range(num_steps_training):\n", - " # Gradient accumulation loop\n", - " accumulated_loss = 0.0\n", - " for acc_idx in range(num_accumulation_gradient):\n", - " try:\n", - " batch = next(train_iter)\n", - " except StopIteration:\n", - " # Restart iterator if we run out of data\n", - " train_iter = iter(train_loader)\n", - " batch = next(train_iter)\n", - " \n", - " # Forward pass and accumulate gradients\n", - " loss = train_step(\n", - " model, batch, optimizer, scale_targets_fn, config, \n", - " num_accumulation_steps=num_accumulation_gradient\n", - " )\n", - " accumulated_loss += loss\n", - " \n", - " # Update optimizer (after accumulation)\n", - " optimizer.step()\n", - " optimizer.zero_grad()\n", - " \n", - " # Update scheduler\n", - " if scheduler is not None:\n", - " scheduler.step()\n", - " \n", - " # Update tokens seen\n", - " num_tokens_seen += effective_num_tokens_per_update\n", - " \n", - " # Update metrics (on last batch of accumulation)\n", - " tokens = batch[\"tokens\"].to(device)\n", - " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n", - " with torch.no_grad():\n", - " outputs = model(tokens=tokens)\n", - " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n", - " \n", - " # Scale targets for scaled metrics\n", - " scaled_targets = scale_targets_fn(bigwig_targets)\n", - " \n", - " # Unscale predictions for raw metrics\n", - " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n", - " \n", - " avg_loss = accumulated_loss / num_accumulation_gradient\n", - " train_metrics.update(\n", - " predictions_scaled=bigwig_logits,\n", - " targets_scaled=scaled_targets,\n", - " predictions_raw=unscaled_predictions,\n", - " targets_raw=bigwig_targets,\n", - " loss=avg_loss\n", - " )\n", - " \n", - " # Logging\n", - " if optimizer_step_idx % log_train_step == 0:\n", - " train_metrics_dict = train_metrics.compute()\n", - " current_lr = scheduler.get_last_lr()[0] if scheduler else config[\"learning_rate\"]\n", - " print(f\"Step {optimizer_step_idx + 1}/{num_steps_training} | \"\n", - " f\"Loss: {avg_loss:.4f} | \"\n", - " f\"Mean Pearson: {train_metrics_dict['metrics_scaled/mean/pearson']:.4f} | \"\n", - " f\"LR: {current_lr:.2e} | \"\n", - " f\"Tokens: {num_tokens_seen:,}\")\n", - " train_metrics.reset()\n", - " \n", - " # Validation\n", - " if optimizer_step_idx % log_validation_step == 0:\n", - " print(f\"\\nRunning validation at step {optimizer_step_idx}...\")\n", - " val_metrics.reset()\n", - " model.eval()\n", - " \n", - " val_losses = []\n", - " for val_batch in val_loader:\n", - " val_loss = validation_step(\n", - " model, val_batch, scale_targets_fn, scale_predictions_fn, val_metrics, config\n", - " )\n", - " val_losses.append(val_loss)\n", - " \n", - " # Print validation metrics\n", - " val_metrics_dict = val_metrics.compute()\n", - " print(f\" Validation Loss: {np.mean(val_losses):.4f}\")\n", - " print(f\" Validation Mean Pearson: {val_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n", - " for track_name in config[\"bigwig_file_ids\"]:\n", - " print(f\" {track_name}/pearson: {val_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n", - " \n", - " model.train() # Back to training mode\n", - "\n", - "print(f\"\\nTraining completed after {num_steps_training} steps!\")\n" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Scaling functions created\n" + ] + } + ], + "source": [ + "def get_track_means(bigwig_file_ids: List[str]) -> np.ndarray:\n", + " \"\"\"\n", + " Get track means for normalization.\n", + " For now, return dummy values. In real pipeline, this loads from metadata.\n", + " \"\"\"\n", + " # Dummy values - in real pipeline, this would load from actual metadata\n", + " return np.ones(len(bigwig_file_ids), dtype=np.float32) * 1.0\n", + "\n", + "\n", + "def get_rna_seq_track_ids(bigwig_file_ids: List[str]) -> List[int]:\n", + " \"\"\"\n", + " Get RNA-seq track indices.\n", + " For now, return empty list. In real pipeline, this identifies RNA-seq tracks.\n", + " \"\"\"\n", + " # Dummy - in real pipeline, this would identify RNA-seq tracks\n", + " return []\n", + "\n", + "\n", + "def create_targets_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n", + " \"\"\"\n", + " Build a scaling function based on track means and RNA-seq squashing.\n", + " Copied from the supervised tracks pipeline.\n", + " \"\"\"\n", + " # Load track means\n", + " track_means_np = get_track_means(bigwig_file_ids)\n", + " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n", + " \n", + " # Get which tracks use squashing\n", + " rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n", + " apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n", + " if len(rna_ids) > 0:\n", + " apply_squashing[rna_ids] = True\n", + " \n", + " def transform_fn(x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " x: torch.Tensor, shape (batch, seq_len, num_tracks)\n", + " \"\"\"\n", + " device = x.device\n", + " \n", + " # Move constants to correct device\n", + " means = track_means.to(device)\n", + " squash_mask = apply_squashing.to(device)\n", + " \n", + " # Normalize\n", + " scaled = x / means\n", + " \n", + " # Power squashing where needed\n", + " squashed = torch.where(\n", + " squash_mask.view(1, 1, -1),\n", + " scaled.pow(0.75),\n", + " scaled,\n", + " )\n", + " \n", + " # Smooth clipping: if > 10, apply formula\n", + " clipped = torch.where(\n", + " squashed > 10.0,\n", + " 2.0 * torch.sqrt(squashed * 10.0) - 10.0,\n", + " squashed,\n", + " )\n", + " \n", + " return clipped\n", + " \n", + " return transform_fn\n", + "\n", + "\n", + "def create_predictions_scaling_fn(bigwig_file_ids: List[str]) -> Callable[[torch.Tensor], torch.Tensor]:\n", + " \"\"\"\n", + " Inverse scaling function to apply on predictions before computing metrics.\n", + " Copied from the supervised tracks pipeline.\n", + " \"\"\"\n", + " # Load means\n", + " track_means_np = get_track_means(bigwig_file_ids)\n", + " track_means = torch.tensor(track_means_np, dtype=torch.float32)\n", + " \n", + " # RNA-seq mask\n", + " rna_ids = get_rna_seq_track_ids(bigwig_file_ids)\n", + " apply_squashing = torch.zeros((len(bigwig_file_ids),), dtype=torch.bool)\n", + " if len(rna_ids) > 0:\n", + " apply_squashing[rna_ids] = True\n", + " \n", + " def inverse_transform_fn(x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"\n", + " x: torch.Tensor, shape (batch, seq_len, num_tracks)\n", + " \"\"\"\n", + " device = x.device\n", + " means = track_means.to(device)\n", + " squash_mask = apply_squashing.to(device)\n", + " \n", + " # Undo clipping\n", + " unclipped = torch.where(\n", + " x > 10.0,\n", + " (x + 10.0).pow(2) / (4 * 10.0),\n", + " x,\n", + " )\n", + " \n", + " # Undo squashing\n", + " unsquashed = torch.where(\n", + " squash_mask.view(1, 1, -1),\n", + " unclipped.pow(1.0 / 0.75),\n", + " unclipped,\n", + " )\n", + " \n", + " # Undo normalization\n", + " return unsquashed * means\n", + " \n", + " return inverse_transform_fn\n", + "\n", + "\n", + "# Create scaling functions\n", + "scale_targets_fn = create_targets_scaling_fn(config[\"bigwig_file_ids\"])\n", + "scale_predictions_fn = create_predictions_scaling_fn(config[\"bigwig_file_ids\"])\n", + "\n", + "print(\"Scaling functions created\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 8. Loss functions" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": {}, + "outputs": [], + "source": [ + "def poisson_loss(ytrue: torch.Tensor, ypred: torch.Tensor, epsilon: float = 1e-7) -> torch.Tensor:\n", + " \"\"\"Poisson loss per element: ypred - ytrue * log(ypred).\"\"\"\n", + " return ypred - ytrue * torch.log(ypred + epsilon)\n", + "\n", + "\n", + "def safe_for_grad_log_torch(x: torch.Tensor) -> torch.Tensor:\n", + " \"\"\"Guarantees that the log is defined for all x > 0 in a differentiable way.\"\"\"\n", + " return torch.log(torch.where(x > 0.0, x, torch.ones_like(x)))\n", + "\n", + "\n", + "def poisson_multinomial_loss(\n", + " logits: torch.Tensor,\n", + " targets: torch.Tensor,\n", + " mask: torch.Tensor | None = None,\n", + " shape_loss_coefficient: float = 5.0,\n", + " epsilon: float = 1e-7,\n", + ") -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:\n", + " \"\"\"\n", + " Regression loss for bigwig tracks (MSE, Poisson, or Poisson-Multinomial).\n", + " \"\"\"\n", + " scale_loss, shape_loss = None, None\n", + " \n", + " if mask is None:\n", + " mask = torch.ones_like(targets, dtype=torch.float32, device=targets.device)\n", + " else:\n", + " mask = mask.float()\n", + " \n", + " mask_sum = mask.sum() + epsilon\n", + " masked_logits = logits * mask\n", + " masked_targets = targets * mask\n", + "\n", + " # Scale loss\n", + " mask_sum_per_track_per_seq = mask.sum(dim=1) # (batch, num_tracks)\n", + " mask_per_sequence = mask_sum_per_track_per_seq > 0.0 # (batch, num_tracks)\n", + " \n", + " sum_pred = masked_logits.sum(dim=1) # (batch, num_tracks)\n", + " sum_true = masked_targets.sum(dim=1) # (batch, num_tracks)\n", + " \n", + " scale_loss = poisson_loss(sum_true, sum_pred, epsilon=epsilon)\n", + " scale_loss = scale_loss / (mask_sum_per_track_per_seq + epsilon)\n", + " \n", + " if mask_per_sequence.any():\n", + " scale_loss_filtered = scale_loss[mask_per_sequence]\n", + " scale_loss = scale_loss_filtered.mean()\n", + " else:\n", + " scale_loss = torch.tensor(0.0, device=targets.device, dtype=targets.dtype)\n", + " \n", + " # Shape loss\n", + " predicted_counts = masked_logits + (epsilon * mask)\n", + " masked_targets_with_epsilon = masked_targets + (epsilon * mask)\n", + " \n", + " denom = predicted_counts.sum(dim=1, keepdim=True) + epsilon\n", + " p_pred = predicted_counts / denom\n", + " \n", + " pl_pred = safe_for_grad_log_torch(p_pred)\n", + " shape_loss = -(masked_targets_with_epsilon * pl_pred).sum() / mask_sum\n", + " \n", + " # Combine\n", + " loss = shape_loss + scale_loss / shape_loss_coefficient\n", + "\n", + " return loss, scale_loss, shape_loss\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 9. Training loop" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": {}, + "outputs": [], + "source": [ + "def train_step(\n", + " model: nn.Module,\n", + " batch: Dict[str, torch.Tensor],\n", + " optimizer: torch.optim.Optimizer,\n", + " scale_targets_fn: Callable,\n", + " config: Dict,\n", + " num_accumulation_steps: int = 1,\n", + ") -> float:\n", + " \"\"\"Single training step with gradient accumulation support.\"\"\"\n", + " tokens = batch[\"tokens\"].to(device)\n", + " bigwig_targets = batch[\"bigwig_targets\"].to(device) # Shape: (batch, seq_len_cropped, num_tracks)\n", + " \n", + " # Forward pass\n", + " outputs = model(tokens=tokens)\n", + " bigwig_logits = outputs[\"bigwig_tracks_logits\"] # Shape: (batch, cropped_seq_len, num_tracks)\n", + " \n", + " # Scale targets\n", + " scaled_targets = scale_targets_fn(bigwig_targets)\n", + " \n", + " # Compute loss\n", + " loss, _, _ = poisson_multinomial_loss(\n", + " logits=bigwig_logits,\n", + " targets=scaled_targets,\n", + " shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n", + " )\n", + " \n", + " # Scale loss by accumulation steps (for gradient accumulation)\n", + " loss = loss / num_accumulation_steps\n", + " \n", + " # Backward pass (accumulate gradients)\n", + " loss.backward()\n", + " \n", + " return loss.item() * num_accumulation_steps # Return unscaled loss for logging\n", + "\n", + "\n", + "def validation_step(\n", + " model: nn.Module,\n", + " batch: Dict[str, torch.Tensor],\n", + " scale_targets_fn: Callable,\n", + " scale_predictions_fn: Callable,\n", + " metrics: TracksMetrics,\n", + " config: Dict,\n", + ") -> float:\n", + " \"\"\"Single validation step.\"\"\"\n", + " model.eval()\n", + " \n", + " tokens = batch[\"tokens\"].to(device)\n", + " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n", + " \n", + " with torch.no_grad():\n", + " # Forward pass\n", + " outputs = model(tokens=tokens)\n", + " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n", + " \n", + " # Scale targets for loss computation\n", + " scaled_targets = scale_targets_fn(bigwig_targets)\n", + " \n", + " # Compute loss (using scaled targets)\n", + " loss, _, _ = poisson_multinomial_loss(\n", + " logits=bigwig_logits,\n", + " targets=scaled_targets,\n", + " shape_loss_coefficient=config[\"bigwig_shape_loss_coefficient\"],\n", + " )\n", + " \n", + " # Scale predictions back to original space for metrics\n", + " # (predictions are in scaled space, need to inverse transform)\n", + " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n", + " \n", + " # Update metrics (using original space targets and predictions)\n", + " metrics.update(\n", + " predictions_scaled=bigwig_logits,\n", + " targets_scaled=scaled_targets,\n", + " predictions_raw=unscaled_predictions,\n", + " targets_raw=bigwig_targets,\n", + " loss=loss.item()\n", + " )\n", + " \n", + " return loss.item()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# 10. Test evaluation" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Starting training...\n", + "Training for 32 steps with 2 gradient accumulation steps\n", + "\n" + ] }, { - "cell_type": "code", - "execution_count": 27, - "metadata": {}, - "outputs": [], - "source": [ - "def test_step(\n", - " model: nn.Module,\n", - " batch: Dict[str, torch.Tensor],\n", - " scale_targets_fn: Callable,\n", - " scale_predictions_fn: Callable,\n", - " metrics: TracksMetrics,\n", - ") -> None:\n", - " \"\"\"\n", - " Pure evaluation step for test set (no loss computation).\n", - " Based on tracks_evaluation_step_torch from deepspeed pipeline.\n", - " \"\"\"\n", - " tokens = batch[\"tokens\"].to(device)\n", - " bigwig_targets = batch[\"bigwig_targets\"].to(device) # Shape: (batch, seq_len_cropped, num_tracks)\n", - " \n", - " with torch.no_grad():\n", - " # Forward pass\n", - " outputs = model(tokens=tokens)\n", - " bigwig_logits = outputs[\"bigwig_tracks_logits\"] # Shape: (batch, cropped_seq_len, num_tracks)\n", - " \n", - " # Scale targets for scaled metrics\n", - " scaled_targets = scale_targets_fn(bigwig_targets)\n", - " \n", - " # Unscale predictions for raw metrics\n", - " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n", - " \n", - " # Update metrics with both scaled and raw values\n", - " # Pass 0.0 as loss since we don't compute loss in test evaluation\n", - " metrics.update(\n", - " predictions_scaled=bigwig_logits,\n", - " targets_scaled=scaled_targets,\n", - " predictions_raw=unscaled_predictions,\n", - " targets_raw=bigwig_targets,\n", - " loss=0.0\n", - " )" - ] + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torch/amp/autocast_mode.py:287: UserWarning: In CPU autocast, but the target dtype is not supported. Disabling autocast.\n", + "CPU Autocast only supports dtype of torch.bfloat16, torch.float16 currently.\n", + " warnings.warn(error_message)\n", + "/home/y-bornachot/venvs/ntv3-env/lib/python3.12/site-packages/torchmetrics/utilities/prints.py:43: UserWarning: The variance of predictions or target is close to zero. This can cause instability in Pearson correlationcoefficient, leading to wrong results. Consider re-scaling the input if possible or computing using alarger dtype (currently using torch.float32). Setting the correlation coefficient to nan.\n", + " warnings.warn(*args, **kwargs)\n", + "/tmp/ipykernel_1758159/1960846655.py:68: RuntimeWarning: Mean of empty slice\n", + " metrics_dict[\"metrics_scaled/mean/pearson\"] = np.nanmean(correlations_scaled)\n", + "/tmp/ipykernel_1758159/1960846655.py:77: RuntimeWarning: Mean of empty slice\n", + " metrics_dict[\"metrics_raw/mean/pearson\"] = np.nanmean(correlations_raw)\n" + ] }, { - "cell_type": "code", - "execution_count": 28, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "==================================================\n", - "Test Set Evaluation\n", - "==================================================\n", - "Running test evaluation with 5 steps (10 samples)\n", - "\n", - "==================================================\n", - "Test Set Results\n", - "==================================================\n", - "\n", - "Scaled Metrics (scaled predictions vs scaled targets):\n", - " Mean Pearson (scaled): -0.0020\n", - " ENCFF884LDL/pearson: -0.0020\n", - "\n", - "Raw Metrics (raw predictions vs raw targets):\n", - " Mean Pearson (raw): -0.0020\n", - " ENCFF884LDL/pearson: -0.0020\n", - "==================================================\n" - ] - } - ], - "source": [ - "print(\"\\n\" + \"=\"*50)\n", - "print(\"Test Set Evaluation\")\n", - "print(\"=\"*50)\n", - "\n", - "# Calculate number of test steps (based on deepspeed pipeline)\n", - "num_test_samples = len(test_dataset)\n", - "num_test_steps = num_test_samples // config[\"batch_size\"]\n", - "\n", - "print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n", - "\n", - "# Set model to eval mode\n", - "model.eval()\n", - "\n", - "# Create iterator for test data\n", - "test_iter = iter(test_loader)\n", - "\n", - "# Run test evaluation (based on deepspeed pipeline: for loop over test steps)\n", - "for _ in range(num_test_steps):\n", - " try:\n", - " test_batch = next(test_iter)\n", - " except StopIteration:\n", - " break\n", - " \n", - " # Perform test evaluation (pure evaluation, no loss computation)\n", - " test_step(\n", - " model, test_batch, scale_targets_fn, scale_predictions_fn, test_metrics\n", - " )\n", - "\n", - "# Compute final test metrics\n", - "test_metrics_dict = test_metrics.compute()\n", - "\n", - "print(\"\\n\" + \"=\"*50)\n", - "print(\"Test Set Results\")\n", - "print(\"=\"*50)\n", - "print(f\"\\nScaled Metrics (scaled predictions vs scaled targets):\")\n", - "print(f\" Mean Pearson (scaled): {test_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n", - "for track_name in config[\"bigwig_file_ids\"]:\n", - " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n", - "\n", - "print(f\"\\nRaw Metrics (raw predictions vs raw targets):\")\n", - "print(f\" Mean Pearson (raw): {test_metrics_dict['metrics_raw/mean/pearson']:.4f}\")\n", - "for track_name in config[\"bigwig_file_ids\"]:\n", - " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_raw/{track_name}/pearson']:.4f}\")\n", - "print(\"=\"*50)" - ] + "name": "stdout", + "output_type": "stream", + "text": [ + "Step 1/32 | Loss: 0.8378 | Mean Pearson: nan | LR: 1.17e-09 | Tokens: 4,096\n", + "\n", + "Running validation at step 0...\n", + " Validation Loss: 0.5279\n", + " Validation Mean Pearson: -0.0192\n", + " ENCFF884LDL/pearson: -0.0192\n", + "Step 3/32 | Loss: 0.4650 | Mean Pearson: -0.0149 | LR: 2.50e-09 | Tokens: 12,288\n", + "Step 5/32 | Loss: 0.3369 | Mean Pearson: -0.1350 | LR: 2.41e-09 | Tokens: 20,480\n", + "\n", + "Running validation at step 4...\n", + " Validation Loss: 0.3878\n", + " Validation Mean Pearson: -0.1298\n", + " ENCFF884LDL/pearson: -0.1298\n", + "Step 7/32 | Loss: 0.3609 | Mean Pearson: -0.0102 | LR: 2.32e-09 | Tokens: 28,672\n", + "Step 9/32 | Loss: 0.3301 | Mean Pearson: -0.0902 | LR: 2.23e-09 | Tokens: 36,864\n", + "\n", + "Running validation at step 8...\n", + " Validation Loss: 0.4743\n", + " Validation Mean Pearson: -0.0739\n", + " ENCFF884LDL/pearson: -0.0739\n", + "Step 11/32 | Loss: 0.3905 | Mean Pearson: -0.0113 | LR: 2.13e-09 | Tokens: 45,056\n", + "Step 13/32 | Loss: 0.3181 | Mean Pearson: -0.1564 | LR: 2.02e-09 | Tokens: 53,248\n", + "\n", + "Running validation at step 12...\n", + " Validation Loss: 0.3337\n", + " Validation Mean Pearson: -0.0650\n", + " ENCFF884LDL/pearson: -0.0650\n", + "Step 15/32 | Loss: 0.3638 | Mean Pearson: 0.0295 | LR: 1.91e-09 | Tokens: 61,440\n", + "Step 17/32 | Loss: 0.4170 | Mean Pearson: -0.0442 | LR: 1.80e-09 | Tokens: 69,632\n", + "\n", + "Running validation at step 16...\n", + " Validation Loss: 0.7969\n", + " Validation Mean Pearson: -0.0304\n", + " ENCFF884LDL/pearson: -0.0304\n", + "Step 19/32 | Loss: 0.5033 | Mean Pearson: -0.0173 | LR: 1.67e-09 | Tokens: 77,824\n", + "Step 21/32 | Loss: 0.4084 | Mean Pearson: -0.0516 | LR: 1.54e-09 | Tokens: 86,016\n", + "\n", + "Running validation at step 20...\n", + " Validation Loss: 0.3475\n", + " Validation Mean Pearson: -0.3040\n", + " ENCFF884LDL/pearson: -0.3040\n", + "Step 23/32 | Loss: 0.4915 | Mean Pearson: -0.1727 | LR: 1.39e-09 | Tokens: 94,208\n", + "Step 25/32 | Loss: 0.3654 | Mean Pearson: -0.3257 | LR: 1.23e-09 | Tokens: 102,400\n", + "\n", + "Running validation at step 24...\n", + " Validation Loss: 0.4069\n", + " Validation Mean Pearson: -0.0551\n", + " ENCFF884LDL/pearson: -0.0551\n", + "Step 27/32 | Loss: 0.5344 | Mean Pearson: -0.0604 | LR: 1.04e-09 | Tokens: 110,592\n", + "Step 29/32 | Loss: 0.3671 | Mean Pearson: -0.0290 | LR: 8.04e-10 | Tokens: 118,784\n", + "\n", + "Running validation at step 28...\n", + " Validation Loss: 0.3162\n", + " Validation Mean Pearson: -0.1008\n", + " ENCFF884LDL/pearson: -0.1008\n", + "Step 31/32 | Loss: 0.5994 | Mean Pearson: -0.0107 | LR: 4.64e-10 | Tokens: 126,976\n", + "\n", + "Training completed after 32 steps!\n" + ] } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3.12 (ntv3-env)", - "language": "python", - "name": "ntv3-env" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" + ], + "source": [ + "# Training loop (step-based with gradient accumulation)\n", + "print(\"Starting training...\")\n", + "print(f\"Training for {num_steps_training} steps with {num_accumulation_gradient} gradient accumulation steps\\n\")\n", + "\n", + "model.train()\n", + "train_metrics.reset()\n", + "optimizer.zero_grad() # Initialize gradients\n", + "\n", + "# Create iterator for training data (will cycle if needed)\n", + "train_iter = iter(train_loader)\n", + "num_tokens_seen = 0\n", + "\n", + "# Main training loop: for loop over optimizer steps (like deepspeed pipeline)\n", + "for optimizer_step_idx in range(num_steps_training):\n", + " # Gradient accumulation loop\n", + " accumulated_loss = 0.0\n", + " for acc_idx in range(num_accumulation_gradient):\n", + " try:\n", + " batch = next(train_iter)\n", + " except StopIteration:\n", + " # Restart iterator if we run out of data\n", + " train_iter = iter(train_loader)\n", + " batch = next(train_iter)\n", + " \n", + " # Forward pass and accumulate gradients\n", + " loss = train_step(\n", + " model, batch, optimizer, scale_targets_fn, config, \n", + " num_accumulation_steps=num_accumulation_gradient\n", + " )\n", + " accumulated_loss += loss\n", + " \n", + " # Update optimizer (after accumulation)\n", + " optimizer.step()\n", + " optimizer.zero_grad()\n", + " \n", + " # Update tokens seen\n", + " num_tokens_seen += effective_num_tokens_per_update\n", + " \n", + " # Update metrics (on last batch of accumulation)\n", + " tokens = batch[\"tokens\"].to(device)\n", + " bigwig_targets = batch[\"bigwig_targets\"].to(device)\n", + " with torch.no_grad():\n", + " outputs = model(tokens=tokens)\n", + " bigwig_logits = outputs[\"bigwig_tracks_logits\"]\n", + " \n", + " # Scale targets for scaled metrics\n", + " scaled_targets = scale_targets_fn(bigwig_targets)\n", + " \n", + " # Unscale predictions for raw metrics\n", + " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n", + " \n", + " avg_loss = accumulated_loss / num_accumulation_gradient\n", + " train_metrics.update(\n", + " predictions_scaled=bigwig_logits,\n", + " targets_scaled=scaled_targets,\n", + " predictions_raw=unscaled_predictions,\n", + " targets_raw=bigwig_targets,\n", + " loss=avg_loss\n", + " )\n", + " \n", + " # Logging\n", + " if optimizer_step_idx % log_train_step == 0:\n", + " train_metrics_dict = train_metrics.compute()\n", + " current_lr = config[\"learning_rate\"]\n", + " print(f\"Step {optimizer_step_idx + 1}/{num_steps_training} | \"\n", + " f\"Loss: {avg_loss:.4f} | \"\n", + " f\"Mean Pearson: {train_metrics_dict['metrics_scaled/mean/pearson']:.4f} | \"\n", + " f\"Tokens: {num_tokens_seen:,}\")\n", + " train_metrics.reset()\n", + " \n", + " # Validation\n", + " if optimizer_step_idx % log_validation_step == 0:\n", + " print(f\"\\nRunning validation at step {optimizer_step_idx}...\")\n", + " val_metrics.reset()\n", + " model.eval()\n", + " \n", + " val_losses = []\n", + " for val_batch in val_loader:\n", + " val_loss = validation_step(\n", + " model, val_batch, scale_targets_fn, scale_predictions_fn, val_metrics, config\n", + " )\n", + " val_losses.append(val_loss)\n", + " \n", + " # Print validation metrics\n", + " val_metrics_dict = val_metrics.compute()\n", + " print(f\" Validation Loss: {np.mean(val_losses):.4f}\")\n", + " print(f\" Validation Mean Pearson: {val_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n", + " for track_name in config[\"bigwig_file_ids\"]:\n", + " print(f\" {track_name}/pearson: {val_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n", + " \n", + " model.train() # Back to training mode\n", + "\n", + "print(f\"\\nTraining completed after {num_steps_training} steps!\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 10. Test evaluation" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "def test_step(\n", + " model: nn.Module,\n", + " batch: Dict[str, torch.Tensor],\n", + " scale_targets_fn: Callable,\n", + " scale_predictions_fn: Callable,\n", + " metrics: TracksMetrics,\n", + ") -> None:\n", + " \"\"\"\n", + " Pure evaluation step for test set (no loss computation).\n", + " Based on tracks_evaluation_step_torch from deepspeed pipeline.\n", + " \"\"\"\n", + " tokens = batch[\"tokens\"].to(device)\n", + " bigwig_targets = batch[\"bigwig_targets\"].to(device) # Shape: (batch, seq_len_cropped, num_tracks)\n", + " \n", + " with torch.no_grad():\n", + " # Forward pass\n", + " outputs = model(tokens=tokens)\n", + " bigwig_logits = outputs[\"bigwig_tracks_logits\"] # Shape: (batch, cropped_seq_len, num_tracks)\n", + " \n", + " # Scale targets for scaled metrics\n", + " scaled_targets = scale_targets_fn(bigwig_targets)\n", + " \n", + " # Unscale predictions for raw metrics\n", + " unscaled_predictions = scale_predictions_fn(bigwig_logits)\n", + " \n", + " # Update metrics with both scaled and raw values\n", + " # Pass 0.0 as loss since we don't compute loss in test evaluation\n", + " metrics.update(\n", + " predictions_scaled=bigwig_logits,\n", + " targets_scaled=scaled_targets,\n", + " predictions_raw=unscaled_predictions,\n", + " targets_raw=bigwig_targets,\n", + " loss=0.0\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "==================================================\n", + "Test Set Evaluation\n", + "==================================================\n", + "Running test evaluation with 5 steps (10 samples)\n", + "\n", + "==================================================\n", + "Test Set Results\n", + "==================================================\n", + "\n", + "Scaled Metrics (scaled predictions vs scaled targets):\n", + " Mean Pearson (scaled): -0.0020\n", + " ENCFF884LDL/pearson: -0.0020\n", + "\n", + "Raw Metrics (raw predictions vs raw targets):\n", + " Mean Pearson (raw): -0.0020\n", + " ENCFF884LDL/pearson: -0.0020\n", + "==================================================\n" + ] } + ], + "source": [ + "print(\"\\n\" + \"=\"*50)\n", + "print(\"Test Set Evaluation\")\n", + "print(\"=\"*50)\n", + "\n", + "# Calculate number of test steps (based on deepspeed pipeline)\n", + "num_test_samples = len(test_dataset)\n", + "num_test_steps = num_test_samples // config[\"batch_size\"]\n", + "\n", + "print(f\"Running test evaluation with {num_test_steps} steps ({num_test_samples} samples)\")\n", + "\n", + "# Set model to eval mode\n", + "model.eval()\n", + "\n", + "# Create iterator for test data\n", + "test_iter = iter(test_loader)\n", + "\n", + "# Run test evaluation (based on deepspeed pipeline: for loop over test steps)\n", + "for _ in range(num_test_steps):\n", + " try:\n", + " test_batch = next(test_iter)\n", + " except StopIteration:\n", + " break\n", + " \n", + " # Perform test evaluation (pure evaluation, no loss computation)\n", + " test_step(\n", + " model, test_batch, scale_targets_fn, scale_predictions_fn, test_metrics\n", + " )\n", + "\n", + "# Compute final test metrics\n", + "test_metrics_dict = test_metrics.compute()\n", + "\n", + "print(\"\\n\" + \"=\"*50)\n", + "print(\"Test Set Results\")\n", + "print(\"=\"*50)\n", + "print(f\"\\nScaled Metrics (scaled predictions vs scaled targets):\")\n", + "print(f\" Mean Pearson (scaled): {test_metrics_dict['metrics_scaled/mean/pearson']:.4f}\")\n", + "for track_name in config[\"bigwig_file_ids\"]:\n", + " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_scaled/{track_name}/pearson']:.4f}\")\n", + "\n", + "print(f\"\\nRaw Metrics (raw predictions vs raw targets):\")\n", + "print(f\" Mean Pearson (raw): {test_metrics_dict['metrics_raw/mean/pearson']:.4f}\")\n", + "for track_name in config[\"bigwig_file_ids\"]:\n", + " print(f\" {track_name}/pearson: {test_metrics_dict[f'metrics_raw/{track_name}/pearson']:.4f}\")\n", + "print(\"=\"*50)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3.12 (ntv3-env)", + "language": "python", + "name": "ntv3-env" }, - "nbformat": 4, - "nbformat_minor": 2 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 }