{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "E: Could not open lock file /var/lib/dpkg/lock-frontend - open (13: Permission denied)\n", "E: Unable to acquire the dpkg frontend lock (/var/lib/dpkg/lock-frontend), are you root?\n", "Requirement already satisfied: tqdm in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (4.67.1)\n" ] } ], "source": [ "!pip install -q transformers datasets accelerate bitsandbytes pysam pandas pyarrow fastparquet\n", "!apt-get install -q samtools\n", "!pip install tqdm\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Hello\n" ] } ], "source": [ "print(\"Hello\")" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting evaluate\n", " Downloading evaluate-0.4.6-py3-none-any.whl.metadata (9.5 kB)\n", "Requirement already satisfied: datasets>=2.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (4.4.2)\n", "Requirement already satisfied: numpy>=1.17 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (2.2.6)\n", "Requirement already satisfied: dill in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (0.4.0)\n", "Requirement already satisfied: pandas in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (2.3.3)\n", "Requirement already satisfied: requests>=2.19.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (2.32.5)\n", "Requirement already satisfied: tqdm>=4.62.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (4.67.1)\n", "Requirement already satisfied: xxhash in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (3.6.0)\n", "Requirement already satisfied: multiprocess in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (0.70.18)\n", "Requirement already satisfied: fsspec>=2021.05.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from fsspec[http]>=2021.05.0->evaluate) (2025.10.0)\n", "Requirement already satisfied: huggingface-hub>=0.7.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (0.36.0)\n", "Requirement already satisfied: packaging in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from evaluate) (25.0)\n", "Requirement already satisfied: filelock in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (3.20.3)\n", "Requirement already satisfied: pyarrow>=21.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (22.0.0)\n", "Requirement already satisfied: httpx<1.0.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (0.28.1)\n", "Requirement already satisfied: pyyaml>=5.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from datasets>=2.0.0->evaluate) (6.0.3)\n", "Requirement already satisfied: aiohttp!=4.0.0a0,!=4.0.0a1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from fsspec[http]>=2021.05.0->evaluate) (3.13.3)\n", "Requirement already satisfied: anyio in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (4.12.1)\n", "Requirement already satisfied: certifi in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (2026.1.4)\n", "Requirement already satisfied: httpcore==1.* in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (1.0.9)\n", "Requirement already satisfied: idna in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from httpx<1.0.0->datasets>=2.0.0->evaluate) (3.11)\n", "Requirement already satisfied: h11>=0.16 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from httpcore==1.*->httpx<1.0.0->datasets>=2.0.0->evaluate) (0.16.0)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.7.0->evaluate) (4.15.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.3 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from huggingface-hub>=0.7.0->evaluate) (1.2.0)\n", "Requirement already satisfied: aiohappyeyeballs>=2.5.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (2.6.1)\n", "Requirement already satisfied: aiosignal>=1.4.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.4.0)\n", "Requirement already satisfied: async-timeout<6.0,>=4.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (5.0.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (25.4.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.8.0)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (6.7.0)\n", "Requirement already satisfied: propcache>=0.2.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (0.4.1)\n", "Requirement already satisfied: yarl<2.0,>=1.17.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from aiohttp!=4.0.0a0,!=4.0.0a1->fsspec[http]>=2021.05.0->evaluate) (1.22.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests>=2.19.0->evaluate) (3.4.4)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from requests>=2.19.0->evaluate) (2.6.3)\n", "Requirement already satisfied: exceptiongroup>=1.0.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from anyio->httpx<1.0.0->datasets>=2.0.0->evaluate) (1.3.1)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pandas->evaluate) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pandas->evaluate) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from pandas->evaluate) (2025.3)\n", "Requirement already satisfied: six>=1.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from python-dateutil>=2.8.2->pandas->evaluate) (1.17.0)\n", "Downloading evaluate-0.4.6-py3-none-any.whl (84 kB)\n", "Installing collected packages: evaluate\n", "Successfully installed evaluate-0.4.6\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Collecting scikit-learn\n", " Downloading scikit_learn-1.7.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (11 kB)\n", "Requirement already satisfied: numpy>=1.22.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from scikit-learn) (2.2.6)\n", "Collecting scipy>=1.8.0 (from scikit-learn)\n", " Downloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n", "Collecting joblib>=1.2.0 (from scikit-learn)\n", " Downloading joblib-1.5.3-py3-none-any.whl.metadata (5.5 kB)\n", "Collecting threadpoolctl>=3.1.0 (from scikit-learn)\n", " Downloading threadpoolctl-3.6.0-py3-none-any.whl.metadata (13 kB)\n", "Downloading scikit_learn-1.7.2-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (9.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m9.7/9.7 MB\u001b[0m \u001b[31m186.6 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", "\u001b[?25hDownloading joblib-1.5.3-py3-none-any.whl (309 kB)\n", "Downloading scipy-1.15.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m37.7/37.7 MB\u001b[0m \u001b[31m207.9 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", "\u001b[?25hDownloading threadpoolctl-3.6.0-py3-none-any.whl (18 kB)\n", "Installing collected packages: threadpoolctl, scipy, joblib, scikit-learn\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4/4\u001b[0m [scikit-learn][0m [scikit-learn]\n", "\u001b[1A\u001b[2KSuccessfully installed joblib-1.5.3 scikit-learn-1.7.2 scipy-1.15.3 threadpoolctl-3.6.0\n", "Collecting matplotlib\n", " Downloading matplotlib-3.10.8-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (52 kB)\n", "Collecting contourpy>=1.0.1 (from matplotlib)\n", " Downloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (5.5 kB)\n", "Collecting cycler>=0.10 (from matplotlib)\n", " Downloading cycler-0.12.1-py3-none-any.whl.metadata (3.8 kB)\n", "Collecting fonttools>=4.22.0 (from matplotlib)\n", " Downloading fonttools-4.61.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl.metadata (114 kB)\n", "Collecting kiwisolver>=1.3.1 (from matplotlib)\n", " Downloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (6.3 kB)\n", "Requirement already satisfied: numpy>=1.23 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib) (2.2.6)\n", "Requirement already satisfied: packaging>=20.0 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib) (25.0)\n", "Collecting pillow>=8 (from matplotlib)\n", " Downloading pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl.metadata (8.8 kB)\n", "Collecting pyparsing>=3 (from matplotlib)\n", " Downloading pyparsing-3.3.1-py3-none-any.whl.metadata (5.6 kB)\n", "Requirement already satisfied: python-dateutil>=2.7 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from matplotlib) (2.9.0.post0)\n", "Requirement already satisfied: six>=1.5 in /system/conda/miniconda3/envs/cloudspace/lib/python3.10/site-packages (from python-dateutil>=2.7->matplotlib) (1.17.0)\n", "Downloading matplotlib-3.10.8-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (8.7 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m8.7/8.7 MB\u001b[0m \u001b[31m151.4 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", "\u001b[?25hDownloading contourpy-1.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (325 kB)\n", "Downloading cycler-0.12.1-py3-none-any.whl (8.3 kB)\n", "Downloading fonttools-4.61.1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (4.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m4.9/4.9 MB\u001b[0m \u001b[31m169.2 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", "\u001b[?25hDownloading kiwisolver-1.4.9-cp310-cp310-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.6 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.6/1.6 MB\u001b[0m \u001b[31m129.4 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pillow-12.1.0-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (7.0 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.0/7.0 MB\u001b[0m \u001b[31m153.0 MB/s\u001b[0m \u001b[33m0:00:00\u001b[0m\n", "\u001b[?25hDownloading pyparsing-3.3.1-py3-none-any.whl (121 kB)\n", "Installing collected packages: pyparsing, pillow, kiwisolver, fonttools, cycler, contourpy, matplotlib\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7/7\u001b[0m [matplotlib]7\u001b[0m [matplotlib]\n", "\u001b[1A\u001b[2KSuccessfully installed contourpy-1.3.2 cycler-0.12.1 fonttools-4.61.1 kiwisolver-1.4.9 matplotlib-3.10.8 pillow-12.1.0 pyparsing-3.3.1\n" ] } ], "source": [ "!pip install evaluate\n", "!pip install scikit-learn\n", "!pip install matplotlib\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/zeus/miniconda3/envs/cloudspace/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "NUM_PROCS: 8 MAX_LEN: 1000\n", "Loaded: 299999 train | 14073 test\n", "Feature means/std computed.\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Setting TOKENIZERS_PARALLELISM=false for forked processes.\n", "Tokenizing train (num_proc=8): 100%|██████████| 299999/299999 [01:30<00:00, 3321.33 examples/s]\n", "Setting TOKENIZERS_PARALLELISM=false for forked processes.\n", "Packing/normalizing features train (num_proc=8): 100%|██████████| 299999/299999 [00:30<00:00, 9687.11 examples/s] \n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[train] sample keys: ['labels', 'input_ids', 'attention_mask', 'features']\n", "[train] shapes/dtypes: input_ids=torch.Size([1000])/torch.int64, features=9/torch.float32, labels=torch.int64\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Setting TOKENIZERS_PARALLELISM=false for forked processes.\n", "Tokenizing test (num_proc=8): 100%|██████████| 14073/14073 [00:05<00:00, 2406.31 examples/s]\n", "Setting TOKENIZERS_PARALLELISM=false for forked processes.\n", "Packing/normalizing features test (num_proc=8): 100%|██████████| 14073/14073 [00:02<00:00, 4779.46 examples/s]" ] }, { "name": "stdout", "output_type": "stream", "text": [ "[test] sample keys: ['labels', 'input_ids', 'attention_mask', 'features']\n", "[test] shapes/dtypes: input_ids=torch.Size([1000])/torch.int64, features=9/torch.float32, labels=torch.int64\n", "✅ Cell 3 complete. Train size: 299999 Test size: 14073\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "\n" ] } ], "source": [ "# ==== CELL 3: Re-tokenize + normalize extra features + build torch-format HF Datasets ====\n", "import os\n", "import numpy as np\n", "import pandas as pd\n", "import torch\n", "from datasets import Dataset\n", "from transformers import AutoTokenizer\n", "\n", "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"true\"\n", "NUM_PROCS = min(8, max(1, (os.cpu_count() or 8) - 2))\n", "MAX_LEN = 1000\n", "MODEL_NAME = \"InstaDeepAI/nucleotide-transformer-500m-human-ref\"\n", "\n", "# --------------- feature list (same order everywhere) ---------------\n", "feature_cols = [\n", " \"gnomad_af\",\n", " \"GERP++_RS_rankscore\",\n", " \"GERP_91_mammals_rankscore\",\n", " \"phyloP100way_vertebrate_rankscore\",\n", " \"phyloP470way_mammalian_rankscore\",\n", " \"phyloP17way_primate_rankscore\",\n", " \"phastCons100way_vertebrate_rankscore\",\n", " \"phastCons470way_mammalian_rankscore\",\n", " \"phastCons17way_primate_rankscore\",\n", "]\n", "\n", "print(\"NUM_PROCS:\", NUM_PROCS, \"MAX_LEN:\", MAX_LEN)\n", "\n", "# 1) load\n", "train_df = pd.read_parquet(\"Balanced_300k_SEQUENCES.parquet\")\n", "test_df = pd.read_parquet(\"test_enriched_SEQUENCES.parquet\")\n", "print(\"Loaded:\", len(train_df), \"train |\", len(test_df), \"test\")\n", "\n", "# basic check\n", "for c in feature_cols:\n", " if c not in train_df.columns:\n", " raise RuntimeError(f\"Missing column in train dataframe: {c}\")\n", "\n", "# 2) compute train mean/std for normalization (stable)\n", "train_feat = train_df[feature_cols].astype(float).fillna(0.0)\n", "feat_means = train_feat.mean(axis=0).astype(np.float32).values\n", "feat_stds = train_feat.std(axis=0).replace(0, 1).astype(np.float32).values\n", "\n", "print(\"Feature means/std computed.\")\n", "\n", "# 3) tokenizer\n", "tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)\n", "\n", "# 4) helper to build dataset\n", "from datasets import Dataset\n", "\n", "def build_and_tokenize(df, split_name=\"train\"):\n", " # keep only required columns\n", " sub = df[[\"raw_sequence\", \"clean_label\"] + feature_cols].copy()\n", " sub[\"labels\"] = (sub[\"clean_label\"] == \"Pathogenic\").astype(np.int64)\n", " sub = sub.drop(columns=[\"clean_label\"])\n", " ds = Dataset.from_pandas(sub, preserve_index=False)\n", "\n", " # tokenization\n", " def tok_fn(batch):\n", " enc = tokenizer(\n", " batch[\"raw_sequence\"],\n", " truncation=True,\n", " padding=\"max_length\",\n", " max_length=MAX_LEN\n", " )\n", " return enc\n", "\n", " ds = ds.map(\n", " tok_fn,\n", " batched=True,\n", " batch_size=128,\n", " num_proc=NUM_PROCS,\n", " remove_columns=[\"raw_sequence\"],\n", " desc=f\"Tokenizing {split_name}\"\n", " )\n", "\n", " # pack & normalize features\n", " def pack_norm_features(batch):\n", " # stack columns in correct order and normalize by train stats\n", " arrs = [np.array(batch[c], dtype=np.float32) for c in feature_cols]\n", " stacked = np.stack(arrs, axis=1) # (batch, n_features)\n", " # normalize\n", " stacked = (stacked - feat_means) / feat_stds\n", " return {\"features\": stacked.tolist()}\n", "\n", " ds = ds.map(\n", " pack_norm_features,\n", " batched=True,\n", " batch_size=512,\n", " num_proc=NUM_PROCS,\n", " remove_columns=feature_cols,\n", " desc=f\"Packing/normalizing features {split_name}\"\n", " )\n", "\n", " # set torch format (ensures Trainer receives torch tensors)\n", " ds.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"features\", \"labels\"])\n", "\n", " # sample sanity\n", " s = ds[0]\n", " print(f\"[{split_name}] sample keys: {list(s.keys())}\")\n", " print(f\"[{split_name}] shapes/dtypes: input_ids={s['input_ids'].shape}/{s['input_ids'].dtype}, features={len(s['features'])}/{s['features'].dtype}, labels={s['labels'].dtype}\")\n", " return ds\n", "\n", "train_fast = build_and_tokenize(train_df, \"train\")\n", "test_fast = build_and_tokenize(test_df, \"test\")\n", "\n", "print(\"✅ Cell 3 complete. Train size:\", len(train_fast), \"Test size:\", len(test_fast))\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🔍 RUNNING PRE-FLIGHT CHECKS...\n", "✅ GPU DETECTED: NVIDIA H200\n", " VRAM: 150.12 GB\n", " 🚀 GOD MODE HARDWARE CONFIRMED.\n", "✅ Train Set: 299999 samples\n", "✅ Test Set: 14073 samples\n", " Structure: OK (Torch tensors present)\n", "✅ Detected modern Transformers: Using 'eval_strategy'\n", "\n", "System Ready. Using strategy key: 'eval_strategy'\n" ] } ], "source": [ "# ===== CELL 3.5 — SYSTEM & VARIABLE VERIFICATION =====\n", "import torch\n", "import transformers\n", "from transformers import TrainingArguments\n", "import psutil\n", "\n", "print(\"🔍 RUNNING PRE-FLIGHT CHECKS...\")\n", "\n", "# 1. HARDWARE VERIFICATION (Confirm H200)\n", "if torch.cuda.is_available():\n", " gpu_name = torch.cuda.get_device_name(0)\n", " vram_gb = torch.cuda.get_device_properties(0).total_memory / 1e9\n", " print(f\"✅ GPU DETECTED: {gpu_name}\")\n", " print(f\" VRAM: {vram_gb:.2f} GB\")\n", " \n", " if \"H200\" in gpu_name or \"H100\" in gpu_name:\n", " print(\" 🚀 GOD MODE HARDWARE CONFIRMED.\")\n", " else:\n", " print(\" ⚠️ WARNING: You are not on H100/H200. Adjust batch sizes!\")\n", "else:\n", " raise RuntimeError(\"❌ NO GPU DETECTED! Training will fail.\")\n", "\n", "# 2. DATASET VERIFICATION\n", "try:\n", " print(f\"✅ Train Set: {len(train_fast)} samples\")\n", " print(f\"✅ Test Set: {len(test_fast)} samples\")\n", " \n", " # Check Column Names (Must match model expectations)\n", " sample = train_fast[0]\n", " required_keys = [\"input_ids\", \"attention_mask\", \"features\", \"labels\"]\n", " missing = [k for k in required_keys if k not in sample.keys()]\n", " \n", " if missing:\n", " raise ValueError(f\"❌ DATASET MISSING KEYS: {missing}\")\n", " print(\" Structure: OK (Torch tensors present)\")\n", " \n", "except NameError:\n", " raise RuntimeError(\"❌ Datasets 'train_fast' or 'test_fast' not found. Did Cell 3 run?\")\n", "\n", "# 3. HUGGING FACE ARGUMENT CHECK (The \"eval_strategy\" Bug Fix)\n", "# Newer transformers use 'eval_strategy', older use 'evaluation_strategy'\n", "import inspect\n", "args_sig = inspect.signature(TrainingArguments.__init__)\n", "valid_params = args_sig.parameters.keys()\n", "\n", "if \"eval_strategy\" in valid_params:\n", " EVAL_STRATEGY_KEY = \"eval_strategy\"\n", " print(\"✅ Detected modern Transformers: Using 'eval_strategy'\")\n", "else:\n", " EVAL_STRATEGY_KEY = \"evaluation_strategy\"\n", " print(\"⚠️ Detected older Transformers: Using 'evaluation_strategy'\")\n", "\n", "print(f\"\\nSystem Ready. Using strategy key: '{EVAL_STRATEGY_KEY}'\")" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Some weights of EsmModel were not initialized from the model checkpoint at InstaDeepAI/nucleotide-transformer-500m-human-ref and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']\n", "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "🔥 Encoder gradient checkpointing enabled\n", "🚀 NT-500M biological fine-tuning started (A100, fresh run)\n" ] }, { "data": { "text/html": [ "\n", "
| Step | \n", "Training Loss | \n", "Validation Loss | \n", "Roc Auc | \n", "
|---|---|---|---|
| 500 | \n", "0.527000 | \n", "0.522692 | \n", "0.821170 | \n", "
| 1000 | \n", "0.396500 | \n", "0.508216 | \n", "0.858456 | \n", "
| 1500 | \n", "0.365700 | \n", "0.587682 | \n", "0.874515 | \n", "
| 2000 | \n", "0.327400 | \n", "0.442712 | \n", "0.891527 | \n", "
| 2500 | \n", "0.300100 | \n", "0.435639 | \n", "0.907509 | \n", "
| 3000 | \n", "0.237100 | \n", "0.424697 | \n", "0.911374 | \n", "
| 3500 | \n", "0.228200 | \n", "0.414745 | \n", "0.915350 | \n", "
| 4000 | \n", "0.224300 | \n", "0.423489 | \n", "0.917674 | \n", "
| 4500 | \n", "0.217200 | \n", "0.422278 | \n", "0.918179 | \n", "
"
],
"text/plain": [
"\n",
" \n",
"
\n",
"\n",
" \n",
" \n",
" \n",
" \n",
" AF bin \n",
" N \n",
" Pathogenic % \n",
" ROC-AUC \n",
" \n",
" \n",
" 0 \n",
" Ultra-rare (<1e-6) \n",
" 8027 \n",
" 0.671359 \n",
" 0.923828 \n",
" \n",
" \n",
" 1 \n",
" Rare (1e-6–1e-4) \n",
" 4665 \n",
" 0.346624 \n",
" 0.906936 \n",
" \n",
" \n",
" 2 \n",
" Low-freq (1e-4–1e-2) \n",
" 967 \n",
" 0.054809 \n",
" 0.821890 \n",
" \n",
" \n",
" \n",
"3 \n",
" Common (>1e-2) \n",
" 414 \n",
" 0.004831 \n",
" 0.947209 \n",
"