{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "p-fvcK3rZHoK", "outputId": "452c5e35-056d-4bdb-d921-a634d510638b" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Mounted at /content/drive\n" ] } ], "source": [ "from google.colab import drive\n", "drive.mount('/content/drive',force_remount=True)\n", "\n", "root = \"/content/drive/MyDrive/SPRSound/SPRSound-main\"\n", "# Set device\n", "train_mode=True\n", "test_mode=False\n", "split_data=False" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "nhv5GGqEZtqY", "outputId": "a6b22e3c-f3c3-480a-d07d-836a8b30fc0d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Cloning into 'hear'...\n", "remote: Enumerating objects: 216, done.\u001b[K\n", "remote: Counting objects: 100% (87/87), done.\u001b[K\n", "remote: Compressing objects: 100% (40/40), done.\u001b[K\n", "remote: Total 216 (delta 72), reused 47 (delta 47), pack-reused 129 (from 1)\u001b[K\n", "Receiving objects: 100% (216/216), 62.06 MiB | 19.19 MiB/s, done.\n", "Resolving deltas: 100% (128/128), done.\n" ] } ], "source": [ "import os\n", "if not os.path.exists('/content/hear'):\n", " !git clone https://github.com/Google-Health/hear" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "iX_h8XZMZNhg" }, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import Dataset,DataLoader,WeightedRandomSampler\n", "from transformers import AutoModel\n", "import math\n", "import os\n", "import sys\n", "import json\n", "import pandas as pd\n", "import numpy as np\n", "from tqdm import tqdm\n", "from pathlib import Path\n", "import torchaudio\n", "import importlib\n", "from typing import Optional, Dict, List\n", "from collections import Counter\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import (\n", " accuracy_score,\n", " f1_score,\n", " precision_score,\n", " recall_score,\n", " confusion_matrix,\n", " classification_report,\n", " roc_auc_score\n", ")\n", "from sklearn.preprocessing import label_binarize\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "from torch.nn.utils.rnn import pack_padded_sequence\n", "import regex as re\n", "import random" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "q4H9-8vsoZwF" }, "outputs": [], "source": [ "###### train test split - one time run ######\n", "\n", "def collect_all_samples(root_dir: str, consolidate_labels: bool = True) -> pd.DataFrame:\n", " \"\"\"\n", " Collect all samples from all BioCAS years (2022-2025) into a single DataFrame.\n", "\n", " Args:\n", " root_dir: Path to SPRSound dataset root\n", " consolidate_labels: Whether to consolidate rare labels into main classes\n", "\n", " Returns:\n", " DataFrame with columns: wav_path, json_path, filename, year, original_split, event_types, label, original_label\n", " \"\"\"\n", " root_dir = Path(root_dir)\n", " all_samples = []\n", "\n", " # Define all available datasets\n", " datasets = [\n", " ('2022', 'train', None),\n", " ('2022', 'test', 'inter'),\n", " ('2022', 'test', 'intra'),\n", " ('2023', 'test', None),\n", " ('2024', 'test', None),\n", " ('2025', 'test', None),\n", " ]\n", "\n", " print(\"=\"*80)\n", " print(\"COLLECTING ALL SAMPLES FROM ALL YEARS\")\n", " print(\"=\"*80)\n", "\n", " for year, split, test_type in datasets:\n", " year_dir = root_dir / f\"BioCAS{year}\"\n", "\n", " # Set paths\n", " if split == 'train':\n", " wav_dir = year_dir / f\"train{year}_wav\"\n", " json_dir = year_dir / f\"train{year}_json\"\n", " split_name = f\"{year}_train\"\n", " else: # test\n", " wav_dir = year_dir / f\"test{year}_wav\"\n", "\n", " if year == '2022' and test_type:\n", " if test_type == 'inter':\n", " json_dir = year_dir / f\"test{year}_json\" / \"inter_test_json\"\n", " split_name = f\"{year}_test_inter\"\n", " else: # intra\n", " json_dir = year_dir / f\"test{year}_json\" / \"intra_test_json\"\n", " split_name = f\"{year}_test_intra\"\n", " else:\n", " json_dir = year_dir / f\"test{year}_json\"\n", " split_name = f\"{year}_test\"\n", "\n", " # Check if directories exist\n", " if not wav_dir.exists() or not json_dir.exists():\n", " print(f\"Skipping {split_name}: directories not found\")\n", " continue\n", "\n", " # Load all JSON files\n", " json_files = sorted(json_dir.glob(\"*.json\"))\n", " print(f\"\\nProcessing {split_name}: {len(json_files)} files\")\n", "\n", " for json_path in json_files:\n", " with open(json_path, 'r') as f:\n", " annotation = json.load(f)\n", "\n", " # Get corresponding WAV filename\n", " wav_filename = json_path.stem + '.wav'\n", " wav_path = wav_dir / wav_filename\n", "\n", " if not wav_path.exists():\n", " print(f\"Warning: WAV file not found: {wav_path}\")\n", " continue\n", "\n", " # Extract event types\n", " events = annotation.get('event_annotation', [])\n", " event_types = [event.get('type', '') for event in events]\n", "\n", " # Parse original label (detailed)\n", " original_label = _parse_label_detailed(event_types)\n", "\n", " # Parse consolidated label\n", " if consolidate_labels:\n", " label = _consolidate_label(original_label)\n", " else:\n", " label = original_label\n", "\n", " sample = {\n", " 'wav_path': str(wav_path),\n", " 'json_path': str(json_path),\n", " 'filename': wav_filename,\n", " 'year': year,\n", " 'original_split': split_name,\n", " 'event_types': '|'.join(event_types), # Store as pipe-separated string\n", " 'original_label': original_label,\n", " 'label': label\n", " }\n", "\n", " all_samples.append(sample)\n", "\n", " print(f\" Collected {len(json_files)} samples from {split_name}\")\n", "\n", " # Create DataFrame\n", " df = pd.DataFrame(all_samples)\n", "\n", " print(f\"\\n{'='*80}\")\n", " print(f\"TOTAL SAMPLES COLLECTED: {len(df)}\")\n", " print(f\"{'='*80}\")\n", "\n", " # Print statistics\n", " print(\"\\nSamples per year:\")\n", " print(df['year'].value_counts().sort_index())\n", "\n", " print(\"\\nSamples per original split:\")\n", " print(df['original_split'].value_counts())\n", "\n", " if consolidate_labels:\n", " print(\"\\nOriginal label distribution (before consolidation):\")\n", " print(df['original_label'].value_counts())\n", "\n", " print(\"\\nConsolidated label distribution:\")\n", " print(df['label'].value_counts())\n", "\n", " # Show mapping\n", " print(\"\\nLabel consolidation mapping:\")\n", " mapping = df.groupby('original_label')['label'].first().to_dict()\n", " for orig, consol in sorted(mapping.items()):\n", " if orig != consol:\n", " count = (df['original_label'] == orig).sum()\n", " print(f\" {orig} -> {consol} ({count} samples)\")\n", " else:\n", " print(\"\\nLabel distribution:\")\n", " print(df['label'].value_counts())\n", "\n", " return df\n", "\n", "\n", "def _parse_label_detailed(event_types: List[str]) -> str:\n", " \"\"\"Parse detailed label from event types (preserves all combinations)\"\"\"\n", " if not event_types:\n", " return 'normal'\n", "\n", " # Get unique event types\n", " unique_events = list(set(event_types))\n", "\n", " # Remove 'Normal' from the list\n", " non_normal_events = [e for e in unique_events if e.lower() != 'normal']\n", "\n", " # If no abnormal events, label as normal\n", " if not non_normal_events:\n", " return 'normal'\n", "\n", " # Sort and join non-normal events\n", " sorted_events = sorted([e.lower() for e in non_normal_events])\n", "\n", " return '+'.join(sorted_events)\n", "\n", "\n", "def _consolidate_label(original_label: str) -> str:\n", "\n", " label_lower = original_label.lower()\n", "\n", " # Normal case\n", " if label_lower == 'normal':\n", " return 'normal'\n", " return \"abnormal\"\n", "\n", "def extract_patient_id(filename: str) -> str:\n", " \"\"\"\n", " Extract patient ID from filename.\n", "\n", " Filename pattern: PATIENTID_X.X_X_pX_XXXX.wav\n", " Example: 41055397_3.0_0_p3_10805.wav -> patient_id = 41055397\n", "\n", " Args:\n", " filename: Audio filename\n", "\n", " Returns:\n", " Patient ID string, or None if pattern doesn't match\n", " \"\"\"\n", " match = re.match(r'(\\d+)_', filename)\n", " if match:\n", " return match.group(1)\n", " return None\n", "\n", "\n", "def create_patient_level_splits(\n", " df: pd.DataFrame,\n", " train_ratio: float = 0.7,\n", " val_ratio: float = 0.15,\n", " test_ratio: float = 0.15,\n", " random_state: int = 42,\n", " min_samples_per_class: int = 10\n", ") -> tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:\n", " \"\"\"\n", " Create patient-level stratified splits to prevent data leakage.\n", "\n", " CRITICAL: This ensures NO patient appears in multiple splits.\n", " Each patient's recordings all go into the same split (train/val/test).\n", "\n", " Args:\n", " df: DataFrame with all samples\n", " train_ratio: Proportion for training (default: 0.7)\n", " val_ratio: Proportion for validation (default: 0.15)\n", " test_ratio: Proportion for test (default: 0.15)\n", " random_state: Random seed for reproducibility\n", " min_samples_per_class: Minimum samples for stratification\n", "\n", " Returns:\n", " train_df, val_df, test_df with NO patient overlap\n", " \"\"\"\n", " assert abs(train_ratio + val_ratio + test_ratio - 1.0) < 1e-6, \"Ratios must sum to 1.0\"\n", "\n", " print(f\"\\n{'='*80}\")\n", " print(\"CREATING PATIENT-LEVEL STRATIFIED SPLITS (NO LEAKAGE)\")\n", " print(f\"{'='*80}\")\n", " print(f\"Train ratio: {train_ratio:.2f}\")\n", " print(f\"Val ratio: {val_ratio:.2f}\")\n", " print(f\"Test ratio: {test_ratio:.2f}\")\n", " print(f\"Random state: {random_state}\")\n", "\n", " # Extract patient IDs\n", " df = df.copy()\n", " df['patient_id'] = df['filename'].apply(extract_patient_id)\n", "\n", " # Remove samples without patient ID\n", " samples_without_id = df['patient_id'].isna().sum()\n", " if samples_without_id > 0:\n", " print(f\"\\n⚠️ Warning: {samples_without_id} samples without patient ID will be excluded\")\n", " df = df[df['patient_id'].notna()].copy()\n", "\n", " print(f\"\\nDataset statistics:\")\n", " print(f\" Total samples: {len(df)}\")\n", " print(f\" Unique patients: {df['patient_id'].nunique()}\")\n", "\n", " # Group by patient and determine each patient's characteristics\n", " patient_data = []\n", "\n", " for patient_id in df['patient_id'].unique():\n", " patient_samples = df[df['patient_id'] == patient_id]\n", "\n", " # Get the majority label for this patient (for stratification)\n", " label_counts = patient_samples['label'].value_counts()\n", " majority_label = label_counts.index[0]\n", "\n", " # Also track if patient has multiple labels\n", " has_multiple_labels = len(patient_samples['label'].unique()) > 1\n", "\n", " patient_data.append({\n", " 'patient_id': patient_id,\n", " 'majority_label': majority_label,\n", " 'num_samples': len(patient_samples),\n", " 'unique_labels': patient_samples['label'].nunique(),\n", " 'all_labels': list(patient_samples['label'].unique())\n", " })\n", "\n", " patient_df = pd.DataFrame(patient_data)\n", "\n", " print(f\"\\nPatient-level statistics:\")\n", " print(f\" Total patients: {len(patient_df)}\")\n", " print(f\" Patients with multiple labels: {(patient_df['unique_labels'] > 1).sum()}\")\n", " print(f\" Average samples per patient: {patient_df['num_samples'].mean():.2f}\")\n", " print(f\" Median samples per patient: {patient_df['num_samples'].median():.0f}\")\n", "\n", " print(f\"\\nMajority label distribution across patients:\")\n", " print(patient_df['majority_label'].value_counts())\n", "\n", " # Handle rare classes by grouping for stratification\n", " label_counts = patient_df['majority_label'].value_counts()\n", " rare_labels = label_counts[label_counts < min_samples_per_class].index.tolist()\n", "\n", " def group_labels_for_stratification(label):\n", " \"\"\"Group rare labels for better stratification\"\"\"\n", " if label in rare_labels:\n", " return 'rare'\n", " return label\n", "\n", " patient_df['stratify_label'] = patient_df['majority_label'].apply(\n", " group_labels_for_stratification\n", " )\n", "\n", " if rare_labels:\n", " print(f\"\\n⚠️ Rare labels grouped for stratification:\")\n", " for label in rare_labels:\n", " count = (patient_df['majority_label'] == label).sum()\n", " print(f\" {label}: {count} patients\")\n", "\n", " # PATIENT-LEVEL SPLIT (not sample-level!)\n", " # Step 1: Split patients into train and temp (val+test)\n", " train_patients, temp_patients = train_test_split(\n", " patient_df['patient_id'].values,\n", " test_size=(val_ratio + test_ratio),\n", " random_state=random_state,\n", " stratify=patient_df['stratify_label'].values\n", " )\n", "\n", " # Step 2: Split temp patients into val and test\n", " temp_patient_df = patient_df[patient_df['patient_id'].isin(temp_patients)]\n", "\n", " # Adjust val ratio for the temp split\n", " val_ratio_adjusted = val_ratio / (val_ratio + test_ratio)\n", "\n", " val_patients, test_patients = train_test_split(\n", " temp_patients,\n", " test_size=(1 - val_ratio_adjusted),\n", " random_state=random_state,\n", " stratify=temp_patient_df['stratify_label'].values\n", " )\n", "\n", " # Convert to sets for fast lookup\n", " train_patient_set = set(train_patients)\n", " val_patient_set = set(val_patients)\n", " test_patient_set = set(test_patients)\n", "\n", " # Assign samples to splits based on patient ID\n", " def assign_split(patient_id):\n", " if patient_id in train_patient_set:\n", " return 'train'\n", " elif patient_id in val_patient_set:\n", " return 'val'\n", " elif patient_id in test_patient_set:\n", " return 'test'\n", " else:\n", " return 'unknown'\n", "\n", " df['split'] = df['patient_id'].apply(assign_split)\n", "\n", " # Create split DataFrames\n", " train_df = df[df['split'] == 'train'].copy()\n", " val_df = df[df['split'] == 'val'].copy()\n", " test_df = df[df['split'] == 'test'].copy()\n", "\n", " # CRITICAL VERIFICATION: Check for patient leakage\n", " print(f\"\\n{'='*80}\")\n", " print(\"LEAKAGE VERIFICATION\")\n", " print(f\"{'='*80}\")\n", "\n", " train_pts = set(train_df['patient_id'].unique())\n", " val_pts = set(val_df['patient_id'].unique())\n", " test_pts = set(test_df['patient_id'].unique())\n", "\n", " overlap_train_val = train_pts & val_pts\n", " overlap_train_test = train_pts & test_pts\n", " overlap_val_test = val_pts & test_pts\n", "\n", " print(f\"\\nPatient distribution:\")\n", " print(f\" Train: {len(train_pts)} patients ({len(train_pts)/len(patient_df)*100:.1f}%)\")\n", " print(f\" Val: {len(val_pts)} patients ({len(val_pts)/len(patient_df)*100:.1f}%)\")\n", " print(f\" Test: {len(test_pts)} patients ({len(test_pts)/len(patient_df)*100:.1f}%)\")\n", "\n", " print(f\"\\nLeakage check:\")\n", " if len(overlap_train_val) == 0:\n", " print(f\" ✓ Train-Val overlap: 0 patients (GOOD)\")\n", " else:\n", " print(f\" ✗ Train-Val overlap: {len(overlap_train_val)} patients (DATA LEAKAGE!)\")\n", "\n", " if len(overlap_train_test) == 0:\n", " print(f\" ✓ Train-Test overlap: 0 patients (GOOD)\")\n", " else:\n", " print(f\" ✗ Train-Test overlap: {len(overlap_train_test)} patients (DATA LEAKAGE!)\")\n", "\n", " if len(overlap_val_test) == 0:\n", " print(f\" ✓ Val-Test overlap: 0 patients (GOOD)\")\n", " else:\n", " print(f\" ✗ Val-Test overlap: {len(overlap_val_test)} patients (DATA LEAKAGE!)\")\n", "\n", " # Print split statistics\n", " print(f\"\\n{'='*80}\")\n", " print(\"SPLIT STATISTICS\")\n", " print(f\"{'='*80}\")\n", " print(f\"\\nSample distribution:\")\n", " print(f\" Train: {len(train_df)} samples ({len(train_df)/len(df)*100:.1f}%)\")\n", " print(f\" Val: {len(val_df)} samples ({len(val_df)/len(df)*100:.1f}%)\")\n", " print(f\" Test: {len(test_df)} samples ({len(test_df)/len(df)*100:.1f}%)\")\n", "\n", " # Label distribution per split\n", " print(f\"\\nLabel distribution per split:\")\n", " print(\"-\" * 80)\n", "\n", " all_labels = sorted(df['label'].unique())\n", " split_stats = []\n", "\n", " for label in all_labels:\n", " train_count = (train_df['label'] == label).sum()\n", " val_count = (val_df['label'] == label).sum()\n", " test_count = (test_df['label'] == label).sum()\n", " total_count = train_count + val_count + test_count\n", "\n", " if total_count > 0:\n", " split_stats.append({\n", " 'label': label,\n", " 'train': train_count,\n", " 'val': val_count,\n", " 'test': test_count,\n", " 'total': total_count\n", " })\n", "\n", " split_stats_df = pd.DataFrame(split_stats)\n", " print(split_stats_df.to_string(index=False))\n", "\n", " return train_df, val_df, test_df\n", "\n", "\n", "def save_splits_to_csv(\n", " train_df: pd.DataFrame,\n", " val_df: pd.DataFrame,\n", " test_df: pd.DataFrame,\n", " output_dir: str\n", "):\n", " \"\"\"\n", " Save train/val/test splits to CSV files.\n", "\n", " Args:\n", " train_df, val_df, test_df: DataFrames for each split\n", " output_dir: Directory to save CSV files\n", " \"\"\"\n", " output_dir = Path(output_dir)\n", " output_dir.mkdir(parents=True, exist_ok=True)\n", "\n", " train_path = output_dir / 'train.csv'\n", " val_path = output_dir / 'val.csv'\n", " test_path = output_dir / 'test.csv'\n", "\n", " train_df.to_csv(train_path, index=False)\n", " val_df.to_csv(val_path, index=False)\n", " test_df.to_csv(test_path, index=False)\n", "\n", " print(f\"\\n{'='*80}\")\n", " print(\"SAVED SPLITS TO CSV\")\n", " print(f\"{'='*80}\")\n", " print(f\"Train: {train_path} ({len(train_df)} samples)\")\n", " print(f\"Val: {val_path} ({len(val_df)} samples)\")\n", " print(f\"Test: {test_path} ({len(test_df)} samples)\")\n", "\n", " # Also save combined file with split column\n", " combined_df = pd.concat([train_df, val_df, test_df], ignore_index=True)\n", " combined_path = output_dir / 'all_splits.csv'\n", " combined_df.to_csv(combined_path, index=False)\n", " print(f\"Combined: {combined_path} ({len(combined_df)} samples)\")\n", "\n", " # Save metadata\n", " metadata = {\n", " 'total_samples': len(combined_df),\n", " 'train_samples': len(train_df),\n", " 'val_samples': len(val_df),\n", " 'test_samples': len(test_df),\n", " 'train_patients': int(train_df['patient_id'].nunique()),\n", " 'val_patients': int(val_df['patient_id'].nunique()),\n", " 'test_patients': int(test_df['patient_id'].nunique()),\n", " 'num_classes': len(combined_df['label'].unique()),\n", " 'classes': sorted(combined_df['label'].unique()),\n", " 'class_distribution': combined_df['label'].value_counts().to_dict(),\n", " 'note': 'Patient-level split: NO patient appears in multiple splits'\n", " }\n", "\n", " metadata_path = output_dir / 'metadata.json'\n", " with open(metadata_path, 'w') as f:\n", " json.dump(metadata, f, indent=2)\n", " print(f\"Metadata: {metadata_path}\")\n", "\n", "\n", "def create_and_save_splits(\n", " root_dir: str,\n", " output_dir: str,\n", " train_ratio: float = 0.7,\n", " val_ratio: float = 0.15,\n", " test_ratio: float = 0.15,\n", " random_state: int = 42,\n", " consolidate_labels: bool = True,\n", " min_samples_per_class: int = 10\n", "):\n", " \"\"\"\n", " Main function to collect all data and create patient-level splits.\n", "\n", " IMPORTANT: This creates patient-level splits to prevent data leakage.\n", " No patient will appear in multiple splits (train/val/test).\n", "\n", " Args:\n", " root_dir: Path to SPRSound dataset root\n", " output_dir: Directory to save CSV splits\n", " train_ratio: Proportion for training (default: 0.7)\n", " val_ratio: Proportion for validation (default: 0.15)\n", " test_ratio: Proportion for test (default: 0.15)\n", " random_state: Random seed for reproducibility\n", " consolidate_labels: Whether to consolidate rare labels\n", " min_samples_per_class: Minimum samples for stratification\n", " \"\"\"\n", " # Collect all samples\n", " all_df = collect_all_samples(root_dir, consolidate_labels=consolidate_labels)\n", "\n", " # Create PATIENT-LEVEL stratified splits (FIXED!)\n", " train_df, val_df, test_df = create_patient_level_splits(\n", " all_df,\n", " train_ratio=train_ratio,\n", " val_ratio=val_ratio,\n", " test_ratio=test_ratio,\n", " random_state=random_state,\n", " min_samples_per_class=min_samples_per_class\n", " )\n", "\n", " # Save to CSV\n", " save_splits_to_csv(train_df, val_df, test_df, output_dir)\n", "\n", " return train_df, val_df, test_df\n", "\n", "if __name__ == \"__main__\":\n", " if split_data:\n", " create_and_save_splits(root,root)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "9eYXVJ_blQDs" }, "outputs": [], "source": [ "\n", "# -----------------------\n", "# utils\n", "# -----------------------\n", "def _rand_uniform(a, b):\n", " return a + (b - a) * random.random()\n", "\n", "\n", "def _rms(x: torch.Tensor):\n", " return torch.sqrt(torch.mean(x * x) + 1e-8)\n", "\n", "\n", "# -----------------------\n", "# FAST room effect (no convolution)\n", "# -----------------------\n", "def _apply_fast_echo(wav: torch.Tensor, sr: int):\n", " \"\"\"\n", " Very cheap reverb-like effect using a few delayed taps.\n", " O(N) time, no conv1d.\n", " \"\"\"\n", " y = wav.clone()\n", " n = y.numel()\n", "\n", " for _ in range(random.randint(2, 4)):\n", " delay = int(_rand_uniform(0.01, 0.08) * sr) # 10–80 ms\n", " if delay <= 0 or delay >= n:\n", " continue\n", " gain = _rand_uniform(0.05, 0.25)\n", " y[delay:] += gain * wav[:-delay]\n", "\n", " # mild damping\n", " y = 0.7 * y + 0.3 * torch.tanh(2.0 * y)\n", " return y\n", "\n", "\n", "# -----------------------\n", "# FAST colored noise (no filters)\n", "# -----------------------\n", "def _colored_noise_fast(noise: torch.Tensor):\n", " \"\"\"\n", " Brown-ish noise via cumulative sum + optional high-pass differencing.\n", " Pure tensor ops.\n", " \"\"\"\n", " n = torch.cumsum(noise, dim=0)\n", " n = n / (n.std().clamp_min(1e-6))\n", "\n", " if random.random() < 0.5:\n", " n = torch.cat([n[:1], n[1:] - 0.98 * n[:-1]], dim=0)\n", "\n", " return n\n", "\n", "\n", "def _add_noise_snr_fast(wav: torch.Tensor, snr_db_range=(3, 25)):\n", " snr_db = _rand_uniform(*snr_db_range)\n", " noise = torch.randn_like(wav)\n", " noise = _colored_noise_fast(noise)\n", "\n", " sig_rms = _rms(wav)\n", " noise_rms = _rms(noise)\n", " target_noise_rms = sig_rms / (10 ** (snr_db / 20))\n", " noise = noise * (target_noise_rms / (noise_rms + 1e-8))\n", "\n", " return wav + noise\n", "\n", "\n", "# -----------------------\n", "# FAST phone band-limiting EQ (2 biquads max)\n", "# -----------------------\n", "def _phone_bandlimit_fast(wav: torch.Tensor, sr: int):\n", " hp = random.choice([120, 150, 200])\n", " lp = random.choice([4000, 6000, 8000])\n", "\n", " y = torchaudio.functional.highpass_biquad(wav, sr, hp)\n", " y = torchaudio.functional.lowpass_biquad(y, sr, lp)\n", " return y\n", "\n", "\n", "# -----------------------\n", "# AGC-like soft compression\n", "# -----------------------\n", "def _soft_agc(wav: torch.Tensor):\n", " gain_db = _rand_uniform(-6, 12)\n", " gain = 10 ** (gain_db / 20)\n", " y = wav * gain\n", "\n", " drive = _rand_uniform(1.5, 3.5)\n", " y = torch.tanh(drive * y)\n", "\n", " return y\n", "\n", "\n", "# -----------------------\n", "# Misc cheap ops\n", "# -----------------------\n", "def _random_gain(wav: torch.Tensor, db_range=(-18, 8)):\n", " g = 10 ** (_rand_uniform(*db_range) / 20)\n", " return wav * g\n", "\n", "\n", "def _random_time_shift(wav: torch.Tensor, sr: int, max_s=0.08):\n", " max_shift = int(max_s * sr)\n", " if max_shift <= 0:\n", " return wav\n", " shift = random.randint(-max_shift, max_shift)\n", " return torch.roll(wav, shifts=shift)\n", "\n", "\n", "def _random_clipping(wav: torch.Tensor, p=0.15):\n", " if random.random() > p:\n", " return wav\n", " clip = _rand_uniform(0.3, 0.9)\n", " return torch.clamp(wav, -clip, clip)\n", "\n", "\n", "# -----------------------\n", "# MAIN AUGMENT CLASS\n", "# -----------------------\n", "class PhoneLikeAugment:\n", " \"\"\"\n", " Fast phone-mic simulation for stethoscope → phone domain shift.\n", " Designed for DataLoader safety + speed.\n", " \"\"\"\n", "\n", " def __init__(self, sr=16000, p=0.5):\n", " self.sr = sr\n", " self.p = p\n", "\n", " def __call__(self, wav: torch.Tensor):\n", " if random.random() > self.p:\n", " return wav\n", "\n", " y = wav\n", "\n", " # small timing jitter\n", " if random.random() < 0.4:\n", " y = _random_time_shift(y, self.sr)\n", "\n", " # distance / level\n", " y = _random_gain(y)\n", "\n", " # phone band-limiting\n", " if random.random() < 0.9:\n", " y = _phone_bandlimit_fast(y, self.sr)\n", "\n", " # cheap room effect\n", " if random.random() < 0.15:\n", " y = _apply_fast_echo(y, self.sr)\n", "\n", " # noise\n", " if random.random() < 0.8:\n", " y = _add_noise_snr_fast(y)\n", "\n", " # AGC\n", " if random.random() < 0.7:\n", " y = _soft_agc(y)\n", "\n", " # rare clipping\n", " y = _random_clipping(y)\n", "\n", " # final normalization\n", " peak = y.abs().max().clamp_min(1e-6)\n", " y = y / peak * _rand_uniform(0.3, 1.0)\n", "\n", " return y" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7teh4Wk1Z75R" }, "outputs": [], "source": [ "##########################DataSet###########################\n", "class SPRSoundDataset(Dataset):\n", " \"\"\"\n", " Dataset class for SPRSound that loads from pre-split CSV files.\n", "\n", " HeAR expects 16kHz audio in 2-second chunks (32,000 samples).\n", " For recordings longer than 2s, we chunk into non-overlapping 2s windows,\n", " run preprocess_audio on each chunk, and stack them so the model can\n", " aggregate (mean-pool) the per-chunk embeddings.\n", "\n", " For recordings shorter than 2s, preprocess_audio zero-pads internally.\n", "\n", " Args:\n", " csv_path: Path to CSV file (train.csv, val.csv, or test.csv)\n", " target_sr: Target sample rate — MUST be 16000 to match HeAR\n", " max_duration: Maximum audio duration in seconds to keep (default: 10)\n", " apply_hear_preprocess: Whether to apply HEAR preprocessing (default: True)\n", " \"\"\"\n", "\n", " # HeAR's fixed contract: 2 seconds at 16 kHz\n", " HEAR_SR = 16000\n", " HEAR_CHUNK_SAMPLES = 32000 # 2s * 16kHz\n", "\n", " def __init__(\n", " self,\n", " csv_path: str,\n", " target_sr: int = 16000,\n", " max_duration: float = 10.0,\n", " apply_hear_preprocess: bool = True,\n", " class_to_idx=None,\n", " is_train=False\n", " ):\n", " self.csv_path = Path(csv_path)\n", " self.target_sr = target_sr\n", " self.max_duration = max_duration\n", " self.max_samples = int(target_sr * max_duration)\n", " self.apply_hear_preprocess = apply_hear_preprocess\n", "\n", " self.augment = PhoneLikeAugment(sr=self.target_sr, p=0.5) if is_train else None\n", " self._resamplers = {}\n", "\n", " self.class_to_idx = dict(class_to_idx)\n", " self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}\n", "\n", " assert self.target_sr == self.HEAR_SR, (\n", " f\"target_sr must be {self.HEAR_SR} to match HeAR preprocessing. \"\n", " f\"Got {self.target_sr}. Resampling from native SR to 16kHz is \"\n", " f\"handled automatically in _load_audio().\"\n", " )\n", "\n", " # Import HEAR preprocessing if needed\n", " if apply_hear_preprocess:\n", " audio_utils = importlib.import_module(\n", " \"hear.python.data_processing.audio_utils\"\n", " )\n", " self.preprocess_audio = audio_utils.preprocess_audio\n", "\n", " # Load CSV\n", " self.df = pd.read_csv(csv_path)\n", " self.df[\"label\"] = self.df[\"label\"].apply(\n", " lambda x: \"normal\" if x == \"normal\" else \"abnormal\"\n", " )\n", " \"\"\"\n", " # Build class mapping from all unique labels\n", " all_labels = sorted(self.df['label'].unique())\n", " self.class_to_idx = {label: idx for idx, label in enumerate(all_labels)}\n", " self.idx_to_class = {v: k for k, v in self.class_to_idx.items()}\n", " \"\"\"\n", " split_name = self.csv_path.stem\n", " print(f\"Loaded {len(self.df)} samples from {csv_path}\")\n", " print(f\"Split: {split_name}\")\n", " print(f\"Target SR: {self.target_sr} Hz (HeAR native)\")\n", " print(f\"Number of classes: {len(self.class_to_idx)}\")\n", " self._print_statistics()\n", "\n", " def _print_statistics(self):\n", " \"\"\"Print dataset statistics\"\"\"\n", " label_counts = self.df['label'].value_counts()\n", "\n", " print(f\"\\nClass distribution:\")\n", " for label in sorted(label_counts.index):\n", " count = label_counts[label]\n", " percentage = count / len(self.df) * 100\n", " class_idx = self.class_to_idx.get(label, '?')\n", " print(f\" [{class_idx:2d}] {label:20s}: {count:5d} ({percentage:5.2f}%)\")\n", "\n", " def _load_audio(self, audio_path: str):\n", " \"\"\"\n", " Load audio, convert to mono, resample to 16 kHz, and truncate\n", " to max_duration. Does NOT pad to a fixed length — chunking\n", " handles variable lengths.\n", " \"\"\"\n", " waveform, sr = torchaudio.load(audio_path)\n", "\n", " # Convert to mono if stereo\n", " if waveform.shape[0] > 1:\n", " waveform = torch.mean(waveform, dim=0, keepdim=True)\n", "\n", " # Resample to 16 kHz (from 8 kHz or whatever native SR)\n", " if sr != self.target_sr:\n", " if sr not in self._resamplers:\n", " self._resamplers[sr] = torchaudio.transforms.Resample(sr, self.target_sr)\n", " waveform = self._resamplers[sr](waveform)\n", "\n", "\n", " # Remove channel dimension → [samples]\n", " waveform = waveform.squeeze(0)\n", "\n", " # Truncate to max_duration (but don't pad — chunking handles short clips)\n", " if waveform.shape[0] > self.max_samples:\n", " waveform = waveform[:self.max_samples]\n", "\n", " return waveform\n", "\n", " def _chunk_and_preprocess(self, waveform: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Split waveform into non-overlapping 2-second chunks, run HeAR's\n", " preprocess_audio on each, and stack the resulting spectrograms.\n", "\n", " For audio shorter than 2s, preprocess_audio handles zero-padding.\n", " The last chunk is zero-padded if it doesn't fill a full 2s window.\n", "\n", " Returns:\n", " Tensor of shape [num_chunks, 1, 192, 128]\n", " \"\"\"\n", " chunk_size = self.HEAR_CHUNK_SAMPLES\n", " total_samples = waveform.shape[0]\n", "\n", " # Calculate number of chunks (ceiling division — last chunk gets padded)\n", " num_chunks = max(1, math.ceil(total_samples / chunk_size))\n", "\n", " chunks = []\n", " for i in range(num_chunks):\n", " start = i * chunk_size\n", " end = min(start + chunk_size, total_samples)\n", " chunk = waveform[start:end]\n", "\n", " # Zero-pad last chunk if needed (preprocess_audio also pads, but\n", " # let's be explicit and give it exactly 32000 samples)\n", " if chunk.shape[0] < chunk_size:\n", "\n", " chunk = torch.nn.functional.pad(\n", " chunk, (0, chunk_size - chunk.shape[0])\n", " )\n", "\n", "\n", " chunks.append(chunk)\n", "\n", " # Stack chunks into a batch: [num_chunks, 32000]\n", "\n", " chunk_batch = torch.stack(chunks, dim=0)\n", "\n", " # Run HeAR preprocessing on the whole batch at once\n", " # preprocess_audio expects [batch, 32000] → returns [batch, 1, 192, 128]\n", "\n", " spectrograms = self.preprocess_audio(chunk_batch)\n", " spectrograms = spectrograms.clamp(max=3.0)\n", "\n", " return spectrograms\n", "\n", " def __len__(self):\n", " return len(self.df)\n", "\n", " def __getitem__(self, idx):\n", " row = self.df.iloc[idx]\n", "\n", " waveform = self._load_audio(row['wav_path'])\n", "\n", " if self.augment is not None:\n", " waveform = self.augment(waveform)\n", "\n", " if self.apply_hear_preprocess:\n", " # [T, 1, 192, 128]\n", " chunk_spectrograms = self._chunk_and_preprocess(waveform)\n", "\n", " features = chunk_spectrograms\n", " length = chunk_spectrograms.shape[0]\n", " else:\n", "\n", " features = waveform.unsqueeze(0) # [1, samples]\n", " length = 1\n", "\n", " label = self.class_to_idx[row['label']]\n", " event_types = row['event_types'].split('|') if pd.notna(row['event_types']) and row['event_types'] else []\n", "\n", " return {\n", " 'features': features, # [T, 1, 192, 128]\n", " 'length': length, # int\n", " 'label': label,\n", " 'filename': row['filename'],\n", " 'event_types': event_types,\n", " 'year': row['year'],\n", " 'original_split': row['original_split'],\n", " 'original_label': row.get('original_label', row['label'])\n", " }\n", "\n", "\n", " def get_class_weights(self, method='inverse'):\n", " \"\"\"Calculate class weights for handling imbalanced data\"\"\"\n", " import torch\n", "\n", " label_counts = Counter(self.df['label'])\n", "\n", " total = len(self.df)\n", " num_classes = len(self.class_to_idx)\n", "\n", " if method == 'inverse':\n", " weights = {}\n", " for label, count in label_counts.items():\n", " weights[self.class_to_idx[label]] = total / (num_classes * count)\n", "\n", " elif method == 'effective':\n", " beta = 0.9999\n", " weights = {}\n", " for label, count in label_counts.items():\n", " effective_num = 1.0 - (beta ** count)\n", " weights[self.class_to_idx[label]] = (1.0 - beta) / effective_num\n", "\n", " else:\n", " raise ValueError(f\"Unknown method: {method}\")\n", "\n", " # Convert to tensor\n", " weight_tensor = torch.zeros(num_classes)\n", " for idx, weight in weights.items():\n", " weight_tensor[idx] = weight\n", "\n", " return weight_tensor\n", "def collate_respiratory_batch(batch):\n", " # batch[i]['features'] is [T, 1, 192, 128]\n", " lengths = torch.tensor([item['length'] for item in batch], dtype=torch.long)\n", " max_len = int(lengths.max().item())\n", "\n", " # infer feature shape\n", " _, C, H, W = batch[0]['features'].shape\n", "\n", " padded = torch.zeros(len(batch), max_len, C, H, W, dtype=batch[0]['features'].dtype)\n", " for i, item in enumerate(batch):\n", " T = item['features'].shape[0]\n", " padded[i, :T] = item['features']\n", "\n", " labels = torch.tensor([item['label'] for item in batch], dtype=torch.long)\n", " filenames = [item['filename'] for item in batch]\n", " event_types = [item['event_types'] for item in batch]\n", " years = [item['year'] for item in batch]\n", " original_splits = [item['original_split'] for item in batch]\n", " original_labels = [item['original_label'] for item in batch]\n", "\n", " return {\n", " 'features': padded, # [B, T_max, 1, 192, 128]\n", " 'lengths': lengths, # [B]\n", " 'label': labels,\n", " 'filename': filenames,\n", " 'event_types': event_types,\n", " 'year': years,\n", " 'original_split': original_splits,\n", " 'original_label': original_labels\n", " }\n", "def visualize_spectrograms_by_class(dataset, class_names, num_samples_per_class=2, save_path=None):\n", " \"\"\"\n", " Visualize spectrogram samples from each class.\n", "\n", " Args:\n", " dataset: The dataset object (SPRSoundDataset)\n", " class_names: List of class names\n", " num_samples_per_class: Number of samples to visualize per class (default: 2)\n", " save_path: Optional path to save the visualization\n", " \"\"\"\n", " # Collect indices for each class\n", " class_indices = {i: [] for i in range(len(class_names))}\n", "\n", " print(\"Collecting samples from each class...\")\n", " for idx in range(len(dataset)):\n", " label = dataset.df.iloc[idx]['label']\n", " label=dataset.class_to_idx[label]\n", " if len(class_indices[label]) < num_samples_per_class:\n", " class_indices[label].append(idx)\n", "\n", " # Stop if we have enough samples for all classes\n", " if all(len(indices) >= num_samples_per_class for indices in class_indices.values()):\n", " break\n", "\n", " # Check if we have enough samples for each class\n", " for class_idx, indices in class_indices.items():\n", " if len(indices) < num_samples_per_class:\n", " print(f\"Warning: Only found {len(indices)} samples for class '{class_names[class_idx]}'\")\n", "\n", " # Create the visualization\n", " num_classes = len(class_names)\n", " fig, axes = plt.subplots(num_classes, num_samples_per_class,\n", " figsize=(5*num_samples_per_class, 4*num_classes))\n", "\n", " # Handle case where we only have 1 sample per class\n", " if num_samples_per_class == 1:\n", " axes = axes.reshape(-1, 1)\n", "\n", " print(\"\\nGenerating spectrograms...\")\n", " for class_idx in range(num_classes):\n", " indices = class_indices[class_idx]\n", "\n", " for sample_idx, data_idx in enumerate(indices[:num_samples_per_class]):\n", " # Get the data sample\n", " sample = dataset[data_idx]\n", " spectrogram = sample['features'] # Shape: [time, freq] or [freq, time]\n", " filename = sample['filename']\n", "\n", " # Convert to numpy if tensor\n", " if torch.is_tensor(spectrogram):\n", " spec_np = spectrogram.cpu().numpy()\n", " else:\n", " spec_np = spectrogram\n", " # Handle 4D tensor: [T, 1, 192, 128]\n", " if len(spec_np.shape) == 4:\n", " spec_np = spec_np[0, 0, :, :] # Shape: [192, 128] (freq, time)\n", " elif len(spec_np.shape) == 3:\n", " spec_np = spec_np[0, :, :] # Shape: [192, 128]\n", " # Plot the spectrogram\n", " ax = axes[class_idx, sample_idx]\n", "\n", " # Transpose if needed for correct orientation (frequency on y-axis)\n", " if spec_np.shape[0] > spec_np.shape[1]:\n", " spec_np = spec_np.T\n", "\n", " im = ax.imshow(spec_np, aspect='auto', origin='lower', cmap='viridis')\n", "\n", " # Set title with class name and filename\n", " if sample_idx == 0:\n", " ax.set_ylabel(f\"{class_names[class_idx]}\\n({filename})\",\n", " fontsize=10, fontweight='bold')\n", " else:\n", " ax.set_ylabel(filename, fontsize=8)\n", "\n", " ax.set_xlabel('Time')\n", " if sample_idx == 0:\n", " ax.set_ylabel(f\"{class_names[class_idx]}\\nFrequency\", fontsize=10)\n", " else:\n", " ax.set_ylabel('Frequency', fontsize=8)\n", "\n", " ax.set_title(f\"Sample {sample_idx+1}\", fontsize=9)\n", "\n", " # Add colorbar\n", " plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)\n", "\n", " plt.suptitle('Spectrogram Samples by Class', fontsize=14, fontweight='bold', y=0.995)\n", " plt.tight_layout()\n", "\n", " if save_path:\n", " os.makedirs(os.path.dirname(save_path), exist_ok=True)\n", " plt.savefig(save_path, dpi=150, bbox_inches='tight')\n", " print(f\"\\nVisualization saved to: {save_path}\")\n", "\n", " plt.show()\n", "\n", " return fig\n", "\n", "def quick_visualize_spectrograms(trainer):\n", " \"\"\"\n", " Quick function to visualize spectrograms from the trainer object.\n", "\n", " Args:\n", " trainer: RespiratoryTrainer object (already initialized)\n", " \"\"\"\n", " import matplotlib.pyplot as plt\n", " import os\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(\"VISUALIZING SPECTROGRAMS\")\n", " print(\"=\"*80)\n", "\n", " # Create output directory\n", " output_dir = os.path.join(trainer.configs[\"logdir\"], \"figures\")\n", " os.makedirs(output_dir, exist_ok=True)\n", "\n", " # Visualize training spectrograms\n", " print(\"\\nVisualizing TRAINING set...\")\n", " train_save_path = os.path.join(output_dir, \"spectrograms_train.png\")\n", " visualize_spectrograms_by_class(\n", " dataset=trainer.train_dataset,\n", " class_names=trainer.class_names,\n", " num_samples_per_class=2,\n", " save_path=train_save_path\n", " )\n", "\n", " # Visualize validation spectrograms\n", " print(\"\\nVisualizing VALIDATION set...\")\n", " val_save_path = os.path.join(output_dir, \"spectrograms_val.png\")\n", " visualize_spectrograms_by_class(\n", " dataset=trainer.val_dataset,\n", " class_names=trainer.class_names,\n", " num_samples_per_class=2,\n", " save_path=val_save_path\n", " )\n", "\n", " print(\"\\n\" + \"=\"*80)\n", " print(\"VISUALIZATION COMPLETE!\")\n", " print(\"=\"*80)\n", " print(f\"Training spectrograms: {train_save_path}\")\n", " print(f\"Validation spectrograms: {val_save_path}\")\n", "\n", "\n", "def create_dataloaders_from_csv(\n", " csv_dir: str,\n", " batch_size: int = 16,\n", " num_workers: int = 2\n", "):\n", " \"\"\"\n", " Create dataloaders from CSV files.\n", "\n", " Args:\n", " csv_dir: Directory containing train.csv, val.csv, test.csv\n", " batch_size: Batch size\n", " num_workers: Number of data loading workers\n", " \"\"\"\n", "\n", " csv_dir = Path(csv_dir)\n", "\n", " CANONICAL_CLASSES = [\"normal\", \"abnormal\"]\n", " CLASS_TO_IDX = {c: i for i, c in enumerate(CANONICAL_CLASSES)}\n", " IDX_TO_CLASS = {i: c for c, i in CLASS_TO_IDX.items()}\n", "\n", "\n", " # Create datasets\n", " train_dataset = SPRSoundDataset(\n", " csv_path=csv_dir / 'train.csv',\n", " target_sr=16000 ,\n", " max_duration=10.0,\n", " apply_hear_preprocess=True,\n", " class_to_idx=CLASS_TO_IDX,\n", " is_train=True\n", " )\n", "\n", " val_dataset = SPRSoundDataset(\n", " csv_path=csv_dir / 'val.csv',\n", " target_sr=16000 ,\n", " max_duration=10.0,\n", " apply_hear_preprocess=True,\n", " class_to_idx=CLASS_TO_IDX\n", " )\n", "\n", " test_dataset = SPRSoundDataset(\n", " csv_path=csv_dir / 'test.csv',\n", " target_sr=16000 ,\n", " max_duration=10.0,\n", " apply_hear_preprocess=True,\n", " class_to_idx=CLASS_TO_IDX\n", " )\n", " print(\"\\n=== CLASS MAP CHECK ===\")\n", " print(\"Train:\", train_dataset.class_to_idx)\n", " print(\"Val: \", val_dataset.class_to_idx)\n", " print(\"Test: \", test_dataset.class_to_idx)\n", "\n", " # Strict equality checks\n", " print(\"train == val ?\", train_dataset.class_to_idx == val_dataset.class_to_idx)\n", " print(\"train == test?\", train_dataset.class_to_idx == test_dataset.class_to_idx)\n", "\n", " class_weights = train_dataset.get_class_weights()\n", " label_indices = [train_dataset.class_to_idx[label] for label in train_dataset.df['label']]\n", " sample_weights = [class_weights[idx].item() for idx in label_indices]\n", " sampler = WeightedRandomSampler(\n", " weights=sample_weights,\n", " num_samples=len(sample_weights),\n", " replacement=True\n", " )\n", " # Create dataloaders\n", " train_loader = DataLoader(\n", " train_dataset,\n", " batch_size=batch_size,\n", " sampler=sampler,\n", " num_workers=num_workers,\n", " pin_memory=True,\n", " persistent_workers=True if num_workers > 0 else False,\n", " collate_fn=collate_respiratory_batch\n", " )\n", "\n", " val_loader = DataLoader(\n", " val_dataset,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " num_workers=num_workers,\n", " pin_memory=True,\n", " persistent_workers=True if num_workers > 0 else False,\n", " collate_fn=collate_respiratory_batch\n", " )\n", "\n", " test_loader = DataLoader(\n", " test_dataset,\n", " batch_size=batch_size,\n", " shuffle=False,\n", " num_workers=num_workers,\n", " pin_memory=True,\n", " persistent_workers=True if num_workers > 0 else False,\n", " collate_fn=collate_respiratory_batch\n", " )\n", "\n", " print(f\"\\n{'='*80}\")\n", " print(\"DATALOADERS CREATED\")\n", " print(f\"{'='*80}\")\n", " print(f\"Train batches: {len(train_loader)}\")\n", " print(f\"Val batches: {len(val_loader)}\")\n", " print(f\"Test batches: {len(test_loader)}\")\n", "\n", " return train_loader, val_loader, test_loader, class_weights, train_dataset\n", "\n", "# train_loader, val_loader, test_loader, class_weights, train_dataset=create_dataloaders_from_csv(root)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MTXCmo5vquAU" }, "outputs": [], "source": [ "####### Model #######\n", "class GatedAttentionPool(nn.Module):\n", " def __init__(self, dim: int, attn_hidden: int = 128, dropout: float = 0.2):\n", " super().__init__()\n", " self.V = nn.Sequential(\n", " nn.LayerNorm(dim),\n", " nn.Dropout(dropout),\n", " nn.Linear(dim, attn_hidden)\n", " )\n", " self.U = nn.Sequential(\n", " nn.LayerNorm(dim),\n", " nn.Dropout(dropout),\n", " nn.Linear(dim, attn_hidden)\n", " )\n", " self.w = nn.Linear(attn_hidden, 1)\n", "\n", " def forward(self, x, lengths):\n", " # x: [B,T,D]\n", " Vx = self.V(x) # [B,T,H]\n", " Ux = self.U(x) # [B,T,H]\n", " scores = self.w(Vx * Ux).squeeze(-1) # [B,T]\n", "\n", " # Create mask\n", " idxs = torch.arange(x.size(1), device=x.device).unsqueeze(0)\n", " mask = idxs >= lengths.unsqueeze(1)\n", " scores = scores.masked_fill(mask, float(\"-inf\"))\n", "\n", " # Attention weights\n", " attn = torch.softmax(scores, dim=1)\n", "\n", " # Weighted pooling\n", " pooled = torch.sum(x * attn.unsqueeze(-1), dim=1)\n", "\n", " return pooled, attn\n", "\n", "class LoRALinear(nn.Module):\n", " \"\"\"\n", " Wraps an existing nn.Linear layer with LoRA:\n", " y = xW^T + b + scale * x(BA)^T\n", " \"\"\"\n", " def __init__(self, base_linear: nn.Linear, r: int = 8, alpha: int = 16, dropout: float = 0.0):\n", " super().__init__()\n", " assert isinstance(base_linear, nn.Linear)\n", " self.base = base_linear\n", " self.in_features = base_linear.in_features\n", " self.out_features = base_linear.out_features\n", "\n", " self.r = r\n", " self.alpha = alpha\n", " self.scale = alpha / r if r > 0 else 1.0\n", " self.lora_dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()\n", "\n", " # Freeze base weights\n", " self.base.weight.requires_grad = False\n", " if self.base.bias is not None:\n", " self.base.bias.requires_grad = False\n", "\n", " # LoRA params\n", " if r > 0:\n", " self.lora_A = nn.Parameter(torch.zeros(r, self.in_features))\n", " self.lora_B = nn.Parameter(torch.zeros(self.out_features, r))\n", "\n", " # Init: A small random, B zeros => start as no-op\n", " nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))\n", " nn.init.zeros_(self.lora_B)\n", " else:\n", " self.register_parameter(\"lora_A\", None)\n", " self.register_parameter(\"lora_B\", None)\n", "\n", " def forward(self, x):\n", " y = self.base(x)\n", " if self.r > 0:\n", " x_d = self.lora_dropout(x)\n", " # (x @ A^T) @ B^T\n", " lora = (x_d @ self.lora_A.t()) @ self.lora_B.t()\n", " y = y + self.scale * lora\n", " return y\n", "\n", " @staticmethod\n", " def apply_lora_to_hear_vit(\n", " hear_model: nn.Module,\n", " r: int = 8,\n", " alpha: int = 16,\n", " dropout: float = 0.05,\n", " last_n_blocks: int = 2,\n", " target_modules: list = None # ['query', 'key', 'value'] or ['query', 'value']\n", " ):\n", " \"\"\"\n", " Applies LoRA to HeAR model (Hugging Face ViT architecture).\n", "\n", " HeAR Structure:\n", " - hear_model.encoder.layer[i] = transformer blocks\n", " - hear_model.encoder.layer[i].attention.attention.query/key/value = Linear layers\n", "\n", " Args:\n", " hear_model: HeAR model from transformers\n", " r: LoRA rank\n", " alpha: LoRA alpha (scaling)\n", " dropout: LoRA dropout\n", " last_n_blocks: Apply LoRA to last N blocks\n", " target_modules: Which attention projections to apply LoRA to\n", " ['query', 'key', 'value'] or subset like ['query', 'value']\n", " \"\"\"\n", " if target_modules is None:\n", " target_modules = ['query','key', 'value'] # Common choice, skip 'key'\n", "\n", " # Freeze everything first\n", " for p in hear_model.parameters():\n", " p.requires_grad = False\n", "\n", " # Access transformer blocks\n", " # HuggingFace ViT: model.encoder.layer is ModuleList of ViTLayer\n", " if not hasattr(hear_model, 'encoder'):\n", " raise ValueError(\"Expected HeAR model with .encoder attribute\")\n", "\n", " layers = hear_model.encoder.layer\n", " start = max(0, len(layers) - last_n_blocks)\n", "\n", " patched = 0\n", " for i in range(start, len(layers)):\n", " layer = layers[i]\n", "\n", " # HuggingFace ViT attention structure:\n", " # layer.attention.attention.query/key/value\n", " if hasattr(layer, 'attention') and hasattr(layer.attention, 'attention'):\n", " attn = layer.attention.attention\n", "\n", " # Apply LoRA to specified modules\n", " if 'query' in target_modules and isinstance(attn.query, nn.Linear):\n", " attn.query = LoRALinear(attn.query, r=r, alpha=alpha, dropout=dropout)\n", " patched += 1\n", "\n", " if 'key' in target_modules and isinstance(attn.key, nn.Linear):\n", " attn.key = LoRALinear(attn.key, r=r, alpha=alpha, dropout=dropout)\n", " patched += 1\n", "\n", " if 'value' in target_modules and isinstance(attn.value, nn.Linear):\n", " attn.value = LoRALinear(attn.value, r=r, alpha=alpha, dropout=dropout)\n", " patched += 1\n", "\n", " # Make LayerNorm trainable in those blocks\n", " for name, m in layer.named_modules():\n", " if isinstance(m, nn.LayerNorm):\n", " for p in m.parameters():\n", " p.requires_grad = True\n", "\n", " # Make final LayerNorm trainable\n", " if hasattr(hear_model, 'layernorm') and isinstance(hear_model.layernorm, nn.LayerNorm):\n", " for p in hear_model.layernorm.parameters():\n", " p.requires_grad = True\n", "\n", " print(f\"✓ LoRA applied to {patched} attention projections in last {last_n_blocks} blocks\")\n", " print(f\" Target modules: {target_modules}\")\n", " print(f\" LoRA rank: {r}, alpha: {alpha}, dropout: {dropout}\")\n", "\n", " return hear_model\n", "\n", "class AdaptiveRespiratoryModel(nn.Module):\n", " def __init__(\n", " self,\n", " num_classes: int = 2,\n", " dropout: float = 0.4,\n", " use_lora: bool = True,\n", " lora_r: int = 8,\n", " lora_alpha: int = 16,\n", " lora_dropout: float = 0.05,\n", " lora_last_n_blocks: int = 2,\n", " rnn_hidden: int = 512,\n", " rnn_layers: int = 2,\n", " ):\n", " super().__init__()\n", "\n", " # Load HeAR backbone\n", " self.hear = AutoModel.from_pretrained(\n", " \"google/hear-pytorch\",\n", " trust_remote_code=True\n", " )\n", "\n", " # Apply LoRA if requested\n", " if use_lora:\n", " print(\"Applying LoRA to HeAR backbone...\")\n", " self.hear = LoRALinear.apply_lora_to_hear_vit(\n", " self.hear,\n", " r=lora_r,\n", " alpha=lora_alpha,\n", " dropout=lora_dropout,\n", " last_n_blocks=lora_last_n_blocks,\n", " target_modules=['query', 'value']\n", " )\n", " else:\n", " # Freeze entire backbone\n", " for p in self.hear.parameters():\n", " p.requires_grad = False\n", "\n", " self.feature_dim = 512\n", "\n", " self.gate = GatedAttentionPool(\n", " dim=self.feature_dim, # Bidirectional doubles the output\n", " attn_hidden=512,\n", " dropout=dropout\n", " )\n", " classifier_input_dim = self.feature_dim\n", " self.classifier = nn.Sequential(\n", " nn.Linear(classifier_input_dim, classifier_input_dim // 2),\n", " nn.GELU(),\n", " nn.Dropout(dropout),\n", " nn.Linear(classifier_input_dim // 2, num_classes)\n", " )\n", "\n", " self._print_trainable_params()\n", "\n", " def _print_trainable_params(self):\n", " \"\"\"Print trainable parameter statistics\"\"\"\n", " trainable = sum(p.numel() for p in self.parameters() if p.requires_grad)\n", " total = sum(p.numel() for p in self.parameters())\n", "\n", " print(f\"\\n{'='*60}\")\n", " print(f\"Model Parameter Summary:\")\n", " print(f\" Total parameters: {total:,}\")\n", " print(f\" Trainable parameters: {trainable:,}\")\n", " print(f\" Frozen parameters: {total - trainable:,}\")\n", " print(f\" Trainable %: {100 * trainable / total:.2f}%\")\n", " print(f\"{'='*60}\\n\")\n", "\n", " def forward(self, spectrogram_seq, lengths):\n", " \"\"\" Args:\n", " spectrogram_seq: [B, T, 1, 192, 128]\n", " lengths: [B] true lengths (num chunks per file)\n", " \"\"\"\n", " B, T, C, H, W = spectrogram_seq.shape\n", "\n", " # Flatten to run HeAR in one pass\n", " x = spectrogram_seq.view(B * T, C, H, W)\n", " outputs = self.hear(x, return_dict=True)\n", " emb = outputs.pooler_output # [B*T, 512]\n", "\n", " # Reshape back to sequence: [B, T, 512]\n", " emb_seq = emb.view(B, T, -1)\n", "\n", " # Apply gated attention pooling\n", " pooled_emb, attn = self.gate(emb_seq, lengths) # [B, rnn_hidden*2], [B, T]\n", "\n", " # Classification\n", " logits = self.classifier(pooled_emb)\n", "\n", " return logits, attn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "NZeFi5xED5pl" }, "outputs": [], "source": [ "######## Loss ########\n", "class FocalLoss(nn.Module):\n", " \"\"\"\n", " Multi-class Focal Loss.\n", "\n", " Args:\n", " gamma (float): focusing parameter (>= 0). Typical values: 1.0–2.0\n", " alpha (Tensor or None): class weights, shape [num_classes]\n", " reduction (str): 'mean', 'sum', or 'none'\n", " eps (float): numerical stability\n", " \"\"\"\n", " def __init__(\n", " self,\n", " gamma: float = 2.0,\n", " alpha: torch.Tensor | None = None,\n", " reduction: str = \"mean\",\n", " eps: float = 1e-8,\n", " ):\n", " super().__init__()\n", " self.gamma = gamma\n", " self.reduction = reduction\n", " self.eps = eps\n", "\n", " if alpha is not None:\n", " self.register_buffer(\"alpha\", alpha)\n", " else:\n", " self.alpha = None\n", "\n", " def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:\n", " \"\"\"\n", " Args:\n", " logits: [B, C] raw model outputs\n", " targets: [B] class indices\n", " \"\"\"\n", " # Log-softmax for numerical stability\n", " log_probs = F.log_softmax(logits, dim=1)\n", " probs = log_probs.exp()\n", "\n", " # Select the probabilities of the true classes\n", " targets = targets.view(-1, 1)\n", " log_pt = log_probs.gather(1, targets).squeeze(1)\n", " pt = probs.gather(1, targets).squeeze(1)\n", "\n", " # Focal term\n", " focal_term = (1.0 - pt).clamp(min=0.0) ** self.gamma\n", "\n", " loss = -focal_term * log_pt\n", "\n", " # Apply class weights if provided\n", " if self.alpha is not None:\n", " alpha_t = self.alpha.gather(0, targets.squeeze(1))\n", " loss = alpha_t * loss\n", "\n", " if self.reduction == \"mean\":\n", " return loss.mean()\n", " elif self.reduction == \"sum\":\n", " return loss.sum()\n", " else:\n", " return loss\n", "import numpy as np\n", "\n", "def find_best_thresholds_one_vs_rest(probs: np.ndarray,\n", " y_true: np.ndarray,\n", " num_classes: int,\n", " grid_size: int = 101,\n", " min_thr: float = 0.05,\n", " max_thr: float = 0.95,Pmin =0.3):\n", " \"\"\"\n", " probs: [N, C] softmax probabilities\n", " y_true: [N] int labels in 0..C-1\n", " returns thresholds: [C] float thresholds maximizing F1 per class (one-vs-rest).\n", " \"\"\"\n", " thresholds = np.full(num_classes, 0.5, dtype=np.float32)\n", " grid = np.linspace(min_thr, max_thr, grid_size)\n", "\n", " for c in range(num_classes):\n", " y_c = (y_true == c).astype(np.int32) # one-vs-rest ground truth\n", " p_c = probs[:, c]\n", "\n", " best_f1 = -1.0\n", " best_t = 0.5\n", "\n", " for t in grid:\n", " pred_c = (p_c >= t).astype(np.int32)\n", "\n", " tp = np.sum((pred_c == 1) & (y_c == 1))\n", " fp = np.sum((pred_c == 1) & (y_c == 0))\n", " fn = np.sum((pred_c == 0) & (y_c == 1))\n", "\n", " precision = tp / (tp + fp + 1e-12)\n", " recall = tp / (tp + fn + 1e-12)\n", " f1 = (2 * precision * recall) / (precision + recall + 1e-12)\n", "\n", " if (precision >= Pmin and (f1 > best_f1 or (abs(f1-best_f1) < 1e-6 and t > best_t))):\n", " best_f1 = f1\n", " best_t = float(t)\n", "\n", "\n", " thresholds[c] = best_t\n", "\n", " return thresholds\n", "\n", "\n", "def predict_with_thresholds(probs: np.ndarray,\n", " thresholds: np.ndarray,\n", " abstain: bool = False,\n", " abstain_label: int | None = None):\n", " \"\"\"\n", " Multi-class decision rule using per-class thresholds.\n", "\n", " Strategy:\n", " - Compute margin = probs[:, c] - thresholds[c]\n", " - If abstain=False:\n", " pick argmax margin (always returns a class)\n", " - If abstain=True:\n", " if all margins < 0: return abstain_label (must be provided)\n", " else pick argmax margin among margins\n", " \"\"\"\n", " margins = probs - thresholds.reshape(1, -1) # [N, C]\n", " best_c = np.argmax(margins, axis=1)\n", " if not abstain:\n", " return best_c\n", "\n", " if abstain_label is None:\n", " raise ValueError(\"abstain_label must be set when abstain=True\")\n", "\n", " best_margin = np.max(margins, axis=1)\n", " preds = best_c.copy()\n", " preds[best_margin < 0.0] = abstain_label\n", " return preds\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "background_save": true, "base_uri": "https://localhost:8080/", "height": 1000, "referenced_widgets": [ "d6b8c80f70e24e33b6c2104d82484004", "c2fae8d5cbe54bbcbb93d5676f3b4db2", "8a9f0e5b71474967a27332869f7a1755", "b9f2f16c668b48efa2d019254d709dc3", "4cfb4d1f8ea54afa86ff3ce32a4889db", "db97babd15c84f3f97822363e857de50", "1370324eb86f47e785b4a1e04896e6b1", "16ec7b6cf9884aa2a1a77f5dfcb65d7c", "817e5330a42748878c92829bbde7a9cd", "2fd3c35342ff4fd0b72fbed909bb6cda", "8a806bdafb6046b18ce57d80ad413040", "a75cd80cecb54ec089a8de2fc59070a6", "b9cd2cccec9140319ff16bdcc40a884b", "aef1edfd1bf94319879409a909597dcd", "bb31de74d6774fdb9bb15b24917e9b4d", "e248e7c88b624c369ae18dee0e7d4c3d", "2cf91b6aad244705a9721b8704290d4b", "e1dc8e90ba904e439af9eaa3f27aff8f", "4d3b0629e80a4449870a784b5b62786b", "8f6b4022643d4a2b85cf0ff7ca1d39ae", "a195f061dde54d0b84bf4bc7af03663f", "3c95f65fafad4071ad9a16fdfd2ddab7", "e9a5eaf20a8a49e495e90fad044b26fb", "3e76226dc5f244d09c6ae03b074a505f", "05c3aa1ba9b84307bb8c6d4139fd4265", "e9895fad9318457a933fe6ce5e36cb5b", "024be6229952428ab2864b3eb55dc117", "47ead54cc29e4bb79e520462e69116d3", "2f2fa263a30e4ed1b80c06d41b13a29b", "066987f87c5b40758d7175504e83566d", "a61d354d50a3444388629ce15de0d1c2", "c55f046386b7485188c5a23110e93630", "48a0e4b5fa85417a92f38b44b162849c" ] }, "id": "M2X9bqUbw0lJ", "outputId": "5cf6dea9-eeee-4017-b4e3-9e4a00a40488" }, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "d6b8c80f70e24e33b6c2104d82484004", "version_major": 2, "version_minor": 0 }, "text/plain": [ "config.json: 0%| | 0.00/533 [00:00 self.best_val_f1\n", " if is_best:\n", " self.best_val_f1 = val_metrics['f1_weighted']\n", " self.best_thresholds = thresholds\n", "\n", "\n", " self.save_checkpoint(epoch, val_metrics['f1_weighted'],self.best_thresholds , is_best=is_best)\n", "\n", " # Save history\n", " history_path = os.path.join(self.configs[\"logdir\"], \"history.json\")\n", " with open(history_path, \"w\") as f:\n", " json.dump(self.history, f, indent=2)\n", "\n", " # Plot metrics every N epochs\n", " if (epoch + 1) % 5 == 0:\n", " self.plot_training_metrics(epoch + 1)\n", " self.plot_confusion_matrix(\n", " val_metrics['labels'],\n", " val_metrics['preds'],\n", " epoch + 1\n", " )\n", "\n", " # Early stopping check\n", " if self.early_stopping(val_metrics[\"f1_weighted\"]):\n", " print(f\"\\n{'='*80}\")\n", " print(f\"Early stopping triggered at epoch {epoch + 1}\")\n", " print(f\"Best validation F1: {self.best_val_f1:.4f}\")\n", " print(f\"{'='*80}\\n\")\n", " break\n", "\n", " print(f\"\\n{'='*80}\")\n", " print(\"TRAINING COMPLETED\")\n", " print(f\"Best validation F1: {self.best_val_f1:.4f}\")\n", " print(f\"{'='*80}\\n\")\n", "\n", " def calculate_auc_scores(self, labels, probs, average='macro'):\n", " \"\"\"\n", " Calculate AUC scores for multi-class classification.\n", " Handles missing classes gracefully.\n", "\n", " Args:\n", " labels: True labels (1D array)\n", " probs: Predicted probabilities (2D array: samples x classes)\n", " average: 'macro', 'micro', or 'weighted'\n", "\n", " Returns:\n", " AUC score(s)\n", " \"\"\"\n", " try:\n", " # For binary case\n", " if len(self.class_names) == 2:\n", " auc_score = roc_auc_score(labels, probs[:, 1])\n", " return auc_score\n", "\n", " # Check if we have at least 2 classes\n", " present_classes = np.unique(labels)\n", " if len(present_classes) < 2:\n", " print(f\"Warning: Only {len(present_classes)} class present. Need at least 2.\")\n", " return 0.0\n", "\n", " # Binarize labels for one-vs-rest\n", " labels_bin = label_binarize(labels, classes=range(len(self.class_names)))\n", "\n", " # Calculate per-class AUC for present classes only\n", " auc_scores = []\n", " class_counts = []\n", "\n", " for i in present_classes:\n", " try:\n", " # Check if this class has both positive and negative samples\n", " if labels_bin[:, i].sum() > 0 and labels_bin[:, i].sum() < len(labels_bin):\n", " class_auc = roc_auc_score(labels_bin[:, i], probs[:, i])\n", " auc_scores.append(class_auc)\n", " class_counts.append((labels == i).sum())\n", " except Exception as e:\n", " print(f\"Warning: Could not calculate AUC for class {i} ({self.class_names[i]}): {e}\")\n", " continue\n", "\n", " if not auc_scores:\n", " print(\"Warning: Could not calculate AUC for any class\")\n", " return 0.0\n", "\n", " # Apply averaging\n", " if average == 'macro':\n", " # Unweighted mean of per-class AUCs\n", " return np.mean(auc_scores)\n", " elif average == 'weighted':\n", " # Weighted by class frequency\n", " weights = np.array(class_counts) / sum(class_counts)\n", " return np.average(auc_scores, weights=weights)\n", " elif average == 'micro':\n", " # Micro-averaging: aggregate all classes then calculate AUC\n", " # Flatten the binary labels and probabilities for all present classes\n", " all_labels = []\n", " all_probs = []\n", " for i in present_classes:\n", " all_labels.extend(labels_bin[:, i])\n", " all_probs.extend(probs[:, i])\n", " return roc_auc_score(all_labels, all_probs)\n", " else:\n", " return np.mean(auc_scores)\n", "\n", " except Exception as e:\n", " print(f\"Warning: Could not calculate AUC - {e}\")\n", " import traceback\n", " traceback.print_exc()\n", " return 0.0\n", "\n", " def calculate_per_class_auc(self, labels, probs):\n", " \"\"\"\n", " Calculate AUC for each class separately.\n", "\n", " Returns:\n", " Dictionary mapping class names to AUC scores\n", " \"\"\"\n", " per_class_auc = {}\n", "\n", " # Binarize labels for one-vs-rest\n", " labels_bin = label_binarize(labels, classes=range(len(self.class_names)))\n", "\n", " for i, class_name in enumerate(self.class_names):\n", " try:\n", " if labels_bin.shape[1] == 1: # Binary case\n", " auc_score = roc_auc_score(labels, probs[:, 1])\n", " else:\n", " auc_score = roc_auc_score(labels_bin[:, i], probs[:, i])\n", " per_class_auc[class_name] = auc_score\n", " except Exception as e:\n", " print(f\"Warning: Could not calculate AUC for {class_name} - {e}\")\n", " per_class_auc[class_name] = 0.0\n", "\n", " return per_class_auc\n", "\n", " def plot_training_metrics(self, epoch):\n", " \"\"\"Plot training metrics\"\"\"\n", " save_dir = os.path.join(self.configs[\"logdir\"], \"figures\", f\"epoch_{epoch}\")\n", " os.makedirs(save_dir, exist_ok=True)\n", "\n", " epochs = list(range(1, len(self.history[\"train_loss\"]) + 1))\n", "\n", " # Plot 1: Loss curves\n", " plt.figure(figsize=(10, 6))\n", " plt.plot(epochs, self.history[\"train_loss\"], label=\"Train Loss\", linewidth=2.5, color='blue')\n", " plt.plot(epochs, self.history[\"val_loss\"], label=\"Val Loss\", linewidth=2.5, color='orange')\n", " plt.xlabel(\"Epoch\", fontsize=12)\n", " plt.ylabel(\"Loss\", fontsize=12)\n", " plt.title(\"Training & Validation Loss\", fontsize=14, fontweight='bold')\n", " plt.legend(fontsize=11)\n", " plt.grid(True, linestyle='--', alpha=0.6)\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(save_dir, \"loss_curves.png\"), dpi=150)\n", " plt.close()\n", "\n", " # Plot 2: Accuracy curves\n", " plt.figure(figsize=(10, 6))\n", " plt.plot(epochs, self.history[\"train_acc\"], label=\"Train Accuracy\", linewidth=2.5, color='blue')\n", " plt.plot(epochs, self.history[\"val_acc\"], label=\"Val Accuracy\", linewidth=2.5, color='orange')\n", " plt.xlabel(\"Epoch\", fontsize=12)\n", " plt.ylabel(\"Accuracy\", fontsize=12)\n", " plt.title(\"Training & Validation Accuracy\", fontsize=14, fontweight='bold')\n", " plt.legend(fontsize=11)\n", " plt.grid(True, linestyle='--', alpha=0.6)\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(save_dir, \"accuracy_curves.png\"), dpi=150)\n", " plt.close()\n", "\n", " # Plot 3: F1 Score curves\n", " plt.figure(figsize=(10, 6))\n", " plt.plot(epochs, self.history[\"train_f1\"], label=\"Train F1\", linewidth=2.5, color='blue')\n", " plt.plot(epochs, self.history[\"val_f1\"], label=\"Val F1\", linewidth=2.5, color='orange')\n", " plt.xlabel(\"Epoch\", fontsize=12)\n", " plt.ylabel(\"F1 Score\", fontsize=12)\n", " plt.title(\"Training & Validation F1 Score\", fontsize=14, fontweight='bold')\n", " plt.legend(fontsize=11)\n", " plt.grid(True, linestyle='--', alpha=0.6)\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(save_dir, \"f1_curves.png\"), dpi=150)\n", " plt.close()\n", "\n", " # Plot 4: Learning rate\n", " plt.figure(figsize=(10, 6))\n", " plt.plot(epochs, self.history[\"lr\"], linewidth=2.5, color='green')\n", " plt.xlabel(\"Epoch\", fontsize=12)\n", " plt.ylabel(\"Learning Rate\", fontsize=12)\n", " plt.title(\"Learning Rate Schedule\", fontsize=14, fontweight='bold')\n", " plt.yscale('log')\n", " plt.grid(True, linestyle='--', alpha=0.6)\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(save_dir, \"learning_rate.png\"), dpi=150)\n", " plt.close()\n", "\n", " def plot_confusion_matrix(self, labels, preds, epoch):\n", " \"\"\"Plot confusion matrix\"\"\"\n", " save_dir = os.path.join(self.configs[\"logdir\"], \"figures\", f\"epoch_{epoch}\")\n", " os.makedirs(save_dir, exist_ok=True)\n", "\n", " # Compute confusion matrix\n", " cm = confusion_matrix(labels, preds)\n", "\n", " # Normalize\n", " cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", "\n", " # Plot\n", " plt.figure(figsize=(12, 10))\n", " sns.heatmap(\n", " cm_normalized,\n", " annot=True,\n", " fmt='.2f',\n", " cmap='Blues',\n", " xticklabels=self.class_names,\n", " yticklabels=self.class_names,\n", " cbar_kws={'label': 'Normalized Count'}\n", " )\n", " plt.xlabel(\"Predicted\", fontsize=12)\n", " plt.ylabel(\"True\", fontsize=12)\n", " plt.title(\"Confusion Matrix (Normalized)\", fontsize=14, fontweight='bold')\n", " plt.xticks(rotation=45, ha='right')\n", " plt.yticks(rotation=0)\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(save_dir, \"confusion_matrix.png\"), dpi=150)\n", " plt.close()\n", "\n", " # Save classification report\n", " report = classification_report(\n", " labels,\n", " preds,\n", " target_names=self.class_names,\n", " digits=4\n", " )\n", " report_path = os.path.join(save_dir, \"classification_report.txt\")\n", " with open(report_path, 'w') as f:\n", " f.write(report)\n", "\n", " def test(self, model_path: Optional[str] = None):\n", " \"\"\"Test the model\"\"\"\n", " if model_path is None:\n", " model_path = self.configs[\"resume\"]\n", "\n", " print(f\"\\n{'='*80}\")\n", " print(\"TESTING MODEL\")\n", " print(f\"{'='*80}\\n\")\n", " print(f\"Loading model from: {model_path}\")\n", "\n", " checkpoint = torch.load(model_path, map_location=self.configs[\"device\"],weights_only =False)\n", " self.model.load_state_dict(checkpoint[\"model\"], strict=False)\n", " self.model.eval()\n", "\n", " all_preds = []\n", " all_labels = []\n", " all_probs = []\n", " all_filenames = []\n", "\n", " with torch.no_grad():\n", " for data in tqdm(self.testloader, desc=\"Testing\", file=self.testtee):\n", " features = data['features'].to(self.configs[\"device\"], non_blocking=True)\n", " lengths = data['lengths'].to(self.configs[\"device\"], non_blocking=True) # [B]\n", " labels = data['label'].to(self.configs[\"device\"], non_blocking=True)\n", "\n", " if self.scaler is not None:\n", " with torch.amp.autocast('cuda'):\n", " logits,_ = self.model(features,lengths)\n", " else:\n", " logits,_ = self.model(features,lengths)\n", "\n", " probs = torch.softmax(logits.float(), dim=1)\n", " #preds = torch.argmax(probs, dim=1)\n", " preds = predict_with_thresholds(\n", " probs.cpu(),\n", " np.array(self.best_thresholds, dtype=np.float32)\n", " )\n", "\n", "\n", " all_preds.extend(preds.detach().cpu().numpy())\n", " all_labels.extend(labels.detach().cpu().numpy())\n", " all_probs.extend(probs.detach().cpu().numpy())\n", " all_filenames.extend(data['filename'])\n", "\n", " # Calculate metrics\n", " accuracy = accuracy_score(all_labels, all_preds)\n", " f1_weighted = f1_score(all_labels, all_preds, average='weighted')\n", " f1_macro = f1_score(all_labels, all_preds, average='macro')\n", " precision = precision_score(all_labels, all_preds, average='weighted', zero_division=0)\n", " recall = recall_score(all_labels, all_preds, average='weighted', zero_division=0)\n", " all_probs = np.array(all_probs)\n", " auc_macro = self.calculate_auc_scores(all_labels, all_probs, average='macro')\n", " auc_micro = self.calculate_auc_scores(all_labels, all_probs, average='micro')\n", " per_class_auc = self.calculate_per_class_auc(all_labels, all_probs)\n", "\n", " print(f\"\\n{'='*80}\")\n", " print(\"TEST RESULTS\")\n", " print(f\"{'='*80}\")\n", " print(f\"Accuracy: {accuracy:.4f}\")\n", " print(f\"F1 (weighted): {f1_weighted:.4f}\")\n", " print(f\"F1 (macro): {f1_macro:.4f}\")\n", " print(f\"Precision: {precision:.4f}\")\n", " print(f\"Recall: {recall:.4f}\")\n", " print(f\"AUC (macro): {auc_macro:.4f}\")\n", " print(f\"AUC (micro): {auc_micro:.4f}\")\n", " print(f\"Per-class AUC: {per_class_auc}\")\n", " print(f\"{'='*80}\\n\")\n", "\n", " # Detailed classification report\n", " report = classification_report(\n", " all_labels,\n", " all_preds,\n", " target_names=self.class_names,\n", " digits=4\n", " )\n", " print(\"\\nClassification Report:\")\n", " print(report)\n", "\n", " # Save results\n", " test_results_dir = os.path.join(self.configs[\"logdir\"], \"test_results\")\n", " os.makedirs(test_results_dir, exist_ok=True)\n", "\n", " # Save classification report\n", " with open(os.path.join(test_results_dir, \"classification_report.txt\"), 'w') as f:\n", " f.write(report)\n", "\n", " # Save confusion matrix\n", " cm = confusion_matrix(all_labels, all_preds)\n", " cm_normalized = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]\n", "\n", " plt.figure(figsize=(12, 10))\n", " sns.heatmap(\n", " cm_normalized,\n", " annot=True,\n", " fmt='.2f',\n", " cmap='Blues',\n", " xticklabels=self.class_names,\n", " yticklabels=self.class_names\n", " )\n", " plt.xlabel(\"Predicted\")\n", " plt.ylabel(\"True\")\n", " plt.title(\"Test Set Confusion Matrix (Normalized)\")\n", " plt.tight_layout()\n", " plt.savefig(os.path.join(test_results_dir, \"confusion_matrix.png\"), dpi=150)\n", " plt.close()\n", "\n", " # Save predictions\n", " results_df = pd.DataFrame({\n", " 'filename': all_filenames,\n", " 'true_label': [self.class_names[l] for l in all_labels],\n", " 'pred_label': [self.class_names[p] for p in all_preds],\n", " 'correct': [l == p for l, p in zip(all_labels, all_preds)]\n", " })\n", "\n", " # Add probabilities for each class\n", " for i, class_name in enumerate(self.class_names):\n", " results_df[f'prob_{class_name}'] = [probs[i] for probs in all_probs]\n", "\n", " results_df.to_csv(os.path.join(test_results_dir, \"predictions.csv\"), index=False)\n", "\n", " # Save metrics\n", " metrics = {\n", " 'accuracy': float(accuracy),\n", " 'f1_weighted': float(f1_weighted),\n", " 'f1_macro': float(f1_macro),\n", " 'precision': float(precision),\n", " 'recall': float(recall)\n", " }\n", "\n", " with open(os.path.join(test_results_dir, \"metrics.json\"), 'w') as f:\n", " json.dump(metrics, f, indent=2)\n", "\n", " return metrics\n", "\n", "\n", "class EarlyStopping:\n", " \"\"\"Early stopping to prevent overfitting\"\"\"\n", " def __init__(self, patience=5, min_delta=0.001):\n", " self.patience = patience\n", " self.min_delta = min_delta\n", " self.counter = 0\n", " self.best_loss = None\n", " self.early_stop = False\n", "\n", " def __call__(self, val_loss):\n", " if self.best_loss is None:\n", " self.best_loss = val_loss\n", " elif val_loss > self.best_loss - self.min_delta:\n", " self.counter += 1\n", " print(f\"EarlyStopping counter: {self.counter}/{self.patience}\")\n", " if self.counter >= self.patience:\n", " self.early_stop = True\n", " return True\n", " else:\n", " self.best_loss = val_loss\n", " self.counter = 0\n", " return False\n", "\n", "\n", "class TeeFile:\n", " \"\"\"File-like object that writes to multiple streams\"\"\"\n", " def __init__(self, *file_objects_or_paths):\n", " self.files = []\n", " self.opened_files = []\n", "\n", " for item in file_objects_or_paths:\n", " if isinstance(item, str):\n", " f = open(item, 'a', buffering=1)\n", " self.files.append(f)\n", " self.opened_files.append(f)\n", " else:\n", " self.files.append(item)\n", "\n", " def write(self, data):\n", " for f in self.files:\n", " try:\n", " f.write(data)\n", " f.flush()\n", " except Exception as e:\n", " print(f\"Warning: Could not write to {f}: {e}\", file=sys.stderr)\n", "\n", " def flush(self):\n", " for f in self.files:\n", " try:\n", " f.flush()\n", " except:\n", " pass\n", "\n", " def isatty(self):\n", " return any(getattr(f, \"isatty\", lambda: False)() for f in self.files)\n", "\n", " def close(self):\n", " for f in self.opened_files:\n", " try:\n", " f.close()\n", " except:\n", " pass\n", " self.opened_files.clear()\n", "\n", " def __del__(self):\n", " self.close()\n", "\n", " def __enter__(self):\n", " return self\n", "\n", " def __exit__(self, exc_type, exc_val, exc_tb):\n", " self.close()\n", " return False\n", "\n", "\n", "# Example usage\n", "if __name__ == '__main__':\n", " # Configuration\n", " root = \"/content/drive/MyDrive/SPRSound/SPRSound-main\"\n", "\n", " config = {\n", " # Paths\n", " \"csv_dir\": root,\n", " \"logdir\": os.path.join(root, \"logs\"),\n", " \"resume\": os.path.join(root, \"checkpoints\", \"best_model.pth\"),\n", " \"dirsToMake\": [\n", " os.path.join(root, \"checkpoints\"),\n", " os.path.join(root, \"logs\"),\n", " os.path.join(root, \"logs\", \"figures\")\n", " ],\n", "\n", " # Model\n", " \"num_classes\": 2, # normal, abnormal\n", " \"dropout\": 0.4,\n", " \"use_lora\": True,\n", " \"lora_r\": 16,\n", " \"lora_alpha\": 16,\n", " \"lora_dropout\": 0.3,\n", " \"lora_last_n_blocks\": 6,\n", "\n", " # Training\n", " \"lr\": 5e-5,\n", " \"weight_decay\": 0.2,\n", " \"warmup\": 10,\n", " \"num_epochs\": 100,\n", " \"batch_size\": 96,\n", " \"accumulation\": 1,\n", " \"use_amp\": False,\n", " \"num_workers\": 2,\n", "\n", " # Early stopping\n", " \"early_stopping_patience\": 20,\n", " \"early_stopping_min_delta\": 0.001,\n", "\n", "\n", " # Device\n", " \"device\": torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\"),\n", " }\n", "\n", " # Create trainer\n", " trainer = RespiratoryTrainer(config)\n", " if train_mode:\n", " # Train\n", " trainer.train()\n", " elif test_mode:\n", " # Test\n", " trainer.test()" ] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "A100", "machine_shape": "hm", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "024be6229952428ab2864b3eb55dc117": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "05c3aa1ba9b84307bb8c6d4139fd4265": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_066987f87c5b40758d7175504e83566d", "max": 392, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_a61d354d50a3444388629ce15de0d1c2", "value": 392 } }, "066987f87c5b40758d7175504e83566d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "1370324eb86f47e785b4a1e04896e6b1": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "16ec7b6cf9884aa2a1a77f5dfcb65d7c": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2cf91b6aad244705a9721b8704290d4b": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "2f2fa263a30e4ed1b80c06d41b13a29b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "2fd3c35342ff4fd0b72fbed909bb6cda": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "3c95f65fafad4071ad9a16fdfd2ddab7": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "3e76226dc5f244d09c6ae03b074a505f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_47ead54cc29e4bb79e520462e69116d3", "placeholder": "​", "style": "IPY_MODEL_2f2fa263a30e4ed1b80c06d41b13a29b", "value": "Loading weights: 100%" } }, "47ead54cc29e4bb79e520462e69116d3": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "48a0e4b5fa85417a92f38b44b162849c": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "4cfb4d1f8ea54afa86ff3ce32a4889db": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "4d3b0629e80a4449870a784b5b62786b": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "817e5330a42748878c92829bbde7a9cd": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "8a806bdafb6046b18ce57d80ad413040": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "8a9f0e5b71474967a27332869f7a1755": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_16ec7b6cf9884aa2a1a77f5dfcb65d7c", "max": 533, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_817e5330a42748878c92829bbde7a9cd", "value": 533 } }, "8f6b4022643d4a2b85cf0ff7ca1d39ae": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "a195f061dde54d0b84bf4bc7af03663f": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "a61d354d50a3444388629ce15de0d1c2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "ProgressStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "ProgressStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "bar_color": null, "description_width": "" } }, "a75cd80cecb54ec089a8de2fc59070a6": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_b9cd2cccec9140319ff16bdcc40a884b", "IPY_MODEL_aef1edfd1bf94319879409a909597dcd", "IPY_MODEL_bb31de74d6774fdb9bb15b24917e9b4d" ], "layout": "IPY_MODEL_e248e7c88b624c369ae18dee0e7d4c3d" } }, "aef1edfd1bf94319879409a909597dcd": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "FloatProgressModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "FloatProgressModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "ProgressView", "bar_style": "success", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_4d3b0629e80a4449870a784b5b62786b", "max": 1212947234, "min": 0, "orientation": "horizontal", "style": "IPY_MODEL_8f6b4022643d4a2b85cf0ff7ca1d39ae", "value": 1212947234 } }, "b9cd2cccec9140319ff16bdcc40a884b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2cf91b6aad244705a9721b8704290d4b", "placeholder": "​", "style": "IPY_MODEL_e1dc8e90ba904e439af9eaa3f27aff8f", "value": "pytorch_model.bin: 100%" } }, "b9f2f16c668b48efa2d019254d709dc3": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_2fd3c35342ff4fd0b72fbed909bb6cda", "placeholder": "​", "style": "IPY_MODEL_8a806bdafb6046b18ce57d80ad413040", "value": " 533/533 [00:00<00:00, 65.4kB/s]" } }, "bb31de74d6774fdb9bb15b24917e9b4d": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_a195f061dde54d0b84bf4bc7af03663f", "placeholder": "​", "style": "IPY_MODEL_3c95f65fafad4071ad9a16fdfd2ddab7", "value": " 1.21G/1.21G [00:03<00:00, 709MB/s]" } }, "c2fae8d5cbe54bbcbb93d5676f3b4db2": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_db97babd15c84f3f97822363e857de50", "placeholder": "​", "style": "IPY_MODEL_1370324eb86f47e785b4a1e04896e6b1", "value": "config.json: 100%" } }, "c55f046386b7485188c5a23110e93630": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "d6b8c80f70e24e33b6c2104d82484004": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_c2fae8d5cbe54bbcbb93d5676f3b4db2", "IPY_MODEL_8a9f0e5b71474967a27332869f7a1755", "IPY_MODEL_b9f2f16c668b48efa2d019254d709dc3" ], "layout": "IPY_MODEL_4cfb4d1f8ea54afa86ff3ce32a4889db" } }, "db97babd15c84f3f97822363e857de50": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e1dc8e90ba904e439af9eaa3f27aff8f": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "DescriptionStyleModel", "state": { "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "DescriptionStyleModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "StyleView", "description_width": "" } }, "e248e7c88b624c369ae18dee0e7d4c3d": { "model_module": "@jupyter-widgets/base", "model_module_version": "1.2.0", "model_name": "LayoutModel", "state": { "_model_module": "@jupyter-widgets/base", "_model_module_version": "1.2.0", "_model_name": "LayoutModel", "_view_count": null, "_view_module": "@jupyter-widgets/base", "_view_module_version": "1.2.0", "_view_name": "LayoutView", "align_content": null, "align_items": null, "align_self": null, "border": null, "bottom": null, "display": null, "flex": null, "flex_flow": null, "grid_area": null, "grid_auto_columns": null, "grid_auto_flow": null, "grid_auto_rows": null, "grid_column": null, "grid_gap": null, "grid_row": null, "grid_template_areas": null, "grid_template_columns": null, "grid_template_rows": null, "height": null, "justify_content": null, "justify_items": null, "left": null, "margin": null, "max_height": null, "max_width": null, "min_height": null, "min_width": null, "object_fit": null, "object_position": null, "order": null, "overflow": null, "overflow_x": null, "overflow_y": null, "padding": null, "right": null, "top": null, "visibility": null, "width": null } }, "e9895fad9318457a933fe6ce5e36cb5b": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HTMLModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HTMLModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HTMLView", "description": "", "description_tooltip": null, "layout": "IPY_MODEL_c55f046386b7485188c5a23110e93630", "placeholder": "​", "style": "IPY_MODEL_48a0e4b5fa85417a92f38b44b162849c", "value": " 392/392 [00:00<00:00, 1510.87it/s, Materializing param=pooler.dense.weight]" } }, "e9a5eaf20a8a49e495e90fad044b26fb": { "model_module": "@jupyter-widgets/controls", "model_module_version": "1.5.0", "model_name": "HBoxModel", "state": { "_dom_classes": [], "_model_module": "@jupyter-widgets/controls", "_model_module_version": "1.5.0", "_model_name": "HBoxModel", "_view_count": null, "_view_module": "@jupyter-widgets/controls", "_view_module_version": "1.5.0", "_view_name": "HBoxView", "box_style": "", "children": [ "IPY_MODEL_3e76226dc5f244d09c6ae03b074a505f", "IPY_MODEL_05c3aa1ba9b84307bb8c6d4139fd4265", "IPY_MODEL_e9895fad9318457a933fe6ce5e36cb5b" ], "layout": "IPY_MODEL_024be6229952428ab2864b3eb55dc117" } } } } }, "nbformat": 4, "nbformat_minor": 0 }