OneMore1 commited on
Commit
538668e
·
verified ·
1 Parent(s): 9fee185

Upload 12 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pipeline.png filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SLIM-BRAIN: A DATA- AND TRAINING-EFFICIENT FOUNDATION MODEL FOR FMRI DATA ANALYSIS
2
+
3
+ <div align="center">
4
+
5
+ [![arXiv](https://img.shields.io/badge/arXiv-2512.21881-b31b1b.svg?style=flat-square)](https://www.arxiv.org/abs/2512.21881)
6
+ [![GitHub](https://img.shields.io/badge/GitHub-Repository-181717?style=flat-square&logo=github)](https://github.com/OneMore1/SLIM-Brain2026)
7
+ [![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-blue)](https://huggingface.co/OneMore1/Slim-Brain)
8
+
9
+ </div>
10
+
11
+ This repository contains the official implementation of SLIM-Brain. SLIM-Brain is a two-stage, selective-compute pipeline for voxel-level fMRI representation learning. A lightweight global branch ranks informative temporal windows; a high-capacity 4D Hiera–JEPA encoder processes only those windows, focusing compute on brain voxels and drastically reducing memory.
12
+
13
+
14
+ <p align="center">
15
+ <img src="pipeline.png" width="800" alt="framework">
16
+ </p>
17
+
18
+ ---
19
+
20
+ ## Installation
21
+
22
+ Setting up the environment requires Python 3.13 and CUDA-compatible PyTorch for GPU acceleration:
23
+
24
+ ```bash
25
+ conda create -n hiera-jepa python=3.13.5
26
+ conda activate hiera-jepa
27
+
28
+ # Install dependencies
29
+ pip install -r requirements.txt
30
+ ```
31
+
32
+ ## Project Structure
33
+
34
+ The codebase is organized into modular components for easy navigation and extension:
35
+
36
+ ```
37
+ hiera-jepa/
38
+ ├── configs/ # YAML configuration files for training and model parameters
39
+ ├── checkpoints/ # Saved model weights and training checkpoints
40
+ ├── hiera/ # Hierarchical Vision Transformer backbone implementation
41
+ ├── scripts/ # Bash....
42
+ ├── finetune.py # Downstream task training and feature extraction script
43
+ └── requirements.txt # Python package dependencies
44
+ ```
45
+
46
+ ## Downstream evaluation
47
+
48
+ 1. Ensure your pre-train data structure as follow:
49
+
50
+ ```
51
+ data_root/
52
+ ├── ABIDE_train/
53
+ ├── ABIDE_val/
54
+ ├── HCP_val/
55
+ └── HCP_train/
56
+ ├── 0010001/ # Subject ID
57
+ └── 0010002/
58
+ ├── 0010002_run-1_0000-0199_1.npz # Data chunk 1
59
+ ├── 0010002_run-1_0000-0199_2.npz # Data chunk 2
60
+ ```
61
+
62
+ 2. Loading downstream datasets as following data structure:
63
+
64
+ ```yaml
65
+ task:
66
+ csv: "/path/to/data_csv"
67
+
68
+ data:
69
+ data_root: /path/to/data_root
70
+ datasets: ["HCP"]
71
+ mode: "directory"
72
+ ```
73
+
74
+ 3. Start downstream training:
75
+
76
+ ```bash
77
+ # running downstream training
78
+ sh scripts/finetune.sh
79
+ ```
80
+
81
+ #### Model Checkpoints
82
+
83
+ Our pre-trained model weights can be found in the checkpoints directory: `./checkpoints/best_model.pth`
84
+
85
+
86
+
87
+
checkpoints/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+
configs/finetune_config.yaml ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment:
2
+ name: "finetune_classification"
3
+ output_dir: "./output/hiera_finetune"
4
+ seed: 44
5
+ resume: null # Path to checkpoint to resume from
6
+ pretrained_checkpoint: "/vePFS-0x0d/home/yewh/Hiera_MAE/checkpoint/checkpoint_epoch_39.pth"
7
+
8
+ # --- Task Settings ---
9
+ task:
10
+ task_type: "regression" # "classification" or "regression"
11
+ num_classes: 1 # Number of classes for classification (e.g., 2 for binary classification)
12
+ mean: 33.9289
13
+ std: 21.5580
14
+ csv: "/vePFS-0x0d/home/yewh/data_csv/RBC-NKI.csv" # CSV with columns: Subject, DX_GROUP (for classification) or age (for regression)
15
+
16
+ # --- Data Settings ---
17
+ data:
18
+ data_root: "/vePFS-0x0d/fmri-data/WAR_NPYZ"
19
+ datasets: ["HCP"]
20
+ train_split_suffixes: ["train_40"]
21
+ val_split_suffixes: ["val_40"]
22
+ test_split_suffixes: ["test_40"]
23
+ input_seq_len: 40 # Temporal length to crop (T dimension)
24
+
25
+ # Data dimensions (D, H, W, T) -> will be permuted to (T, D, H, W)
26
+ spatial_dims: [96, 96, 96] # D, H, W
27
+
28
+ # DataLoader settings
29
+ batch_size: 2 # Per GPU batch size (can be larger than pretraining)
30
+ num_workers: 4
31
+ pin_memory: true
32
+ prefetch_factor: 2
33
+
34
+ # --- Model Settings ---
35
+ model:
36
+ # Input configuration
37
+ input_size: [40, 96, 96, 96] # [T, D, H, W]
38
+ in_chans: 1
39
+
40
+ # Patch embedding configuration
41
+ patch_kernel: [1, 4, 4, 4] # [T, D, H, W]
42
+ patch_stride: [1, 4, 4, 4] # [T, D, H, W]
43
+ patch_padding: [0, 0, 0, 0] # [T, D, H, W]
44
+
45
+ # Hiera architecture
46
+ embed_dim: 64
47
+ num_heads: 1
48
+ stages: [2, 3, 16, 3]
49
+ q_pool: 2
50
+ q_stride: [2, 2, 2, 2] # Stride for q_pool [T, D, H, W]
51
+ mask_unit_size: [8, 8, 8, 8] # Mask unit size [T, D, H, W]
52
+ mlp_ratio: 4.0
53
+ mask_unit_attn: [true, true, False, False]
54
+
55
+ # --- Training Settings ---
56
+ training:
57
+ # Optimization
58
+ optimizer: "adamw"
59
+ learning_rate: 1.0e-4 # Lower learning rate for fine-tuning
60
+ head_lr: 5.0e-4 # Higher learning rate for classification head
61
+ layer_decay: 0.75
62
+ weight_decay: 0.05
63
+ betas: [0.9, 0.99]
64
+
65
+ # Learning rate schedule
66
+ lr_scheduler: "cosine"
67
+ warmup_epochs: 2 # Warmup epochs
68
+
69
+ # Weight freezing
70
+ freeze_encoder: true # Set to true to freeze the entire encoder and only train the head
71
+ min_lr: 1.0e-6 # Minimum learning rate at the end of schedule
72
+
73
+ # Training duration
74
+ epochs: 200
75
+
76
+ # Gradient settings
77
+ clip_grad: 1.0 # Gradient clipping value, null to disable
78
+ accum_iter: 8 # Gradient accumulation steps (usually 1 for fine-tuning)
79
+
80
+ # Mixed precision
81
+ use_amp: true # Use automatic mixed precision
82
+
83
+ # --- Distributed Training Settings ---
84
+ distributed:
85
+ backend: "nccl"
86
+ init_method: "env://"
87
+ world_size: -1 # Will be set automatically
88
+ rank: -1 # Will be set automatically
89
+ dist_url: "env://"
90
+
91
+ # --- Logging Settings ---
92
+ logging:
93
+ print_freq: 40 # Print frequency (iterations)
94
+ log_freq: 40 # Log frequency (iterations)
95
+ save_freq: 5 # Checkpoint save frequency (epochs)
96
+
97
+ # Weights & Biases
98
+ use_wandb: false
99
+ wandb_project: "hiera_fmri_finetune"
100
+ wandb_entity: null # Your wandb username/team
101
+
102
+ # --- Validation Settings ---
103
+ validation:
104
+ val_freq: 1 # Validation frequency (epochs)
105
+ save_best: true # Save best model based on validation metric
106
+
data/downstream_dataset.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import re
4
+ import numpy as np
5
+ import pandas as pd
6
+ import torch
7
+ from torch.utils.data import Dataset
8
+ from typing import List, Tuple, Union, Literal
9
+ import torch.nn.functional as F
10
+ from .pretrain_dataset import fMRIDataset
11
+ import io
12
+ import nibabel as nib
13
+
14
+ class fMRITaskDataset(fMRIDataset):
15
+
16
+ def __init__(
17
+ self,
18
+ data_root: str,
19
+ datasets: List[str],
20
+ split_suffixes: List[str],
21
+ crop_length: int,
22
+ label_csv_path: str,
23
+ task_type: Literal['classification', 'regression'] = 'classification',
24
+ downstream=True,
25
+ ):
26
+ super().__init__(data_root, datasets, split_suffixes, crop_length, downstream)
27
+
28
+ self.task_type = task_type
29
+ self.labels_map = self._load_and_process_labels(label_csv_path)
30
+
31
+ initial_file_count = len(self.file_paths)
32
+ self.file_paths = [
33
+ path for path in self.file_paths
34
+ if self._extract_subject_id(path) in self.labels_map
35
+ ]
36
+
37
+ if len(self.file_paths) < initial_file_count:
38
+ print(f"Warning: Dropped {initial_file_count - len(self.file_paths)} files due to missing labels in CSV.")
39
+
40
+ print(f"Task Dataset ready for {self.task_type}. Usable files: {len(self.file_paths)}")
41
+
42
+
43
+ def _extract_subject_id(self, file_path: str) -> str:
44
+
45
+
46
+ # folder_name = os.path.basename(os.path.dirname(file_path))
47
+ # match = re.search(r'(\d{7})', folder_name)
48
+
49
+ match = re.search(r'(\d{6})', os.path.basename(file_path))
50
+
51
+ if match:
52
+ subject_id_with_zeros = match.group(1)
53
+ subject_id = subject_id_with_zeros.lstrip('0')
54
+
55
+ return subject_id
56
+
57
+ return ""
58
+
59
+ def _load_and_process_labels(self, csv_path: str) -> dict:
60
+
61
+ if not os.path.exists(csv_path):
62
+ raise FileNotFoundError(f"Label CSV file not found at: {csv_path}")
63
+
64
+ print(f"Loading labels from {csv_path}...")
65
+ df = pd.read_csv(csv_path)
66
+
67
+ df['Subject'] = df['Subject'].astype(str)
68
+ df.dropna(subset=['Subject'], inplace=True)
69
+
70
+ labels_map = {}
71
+
72
+ if self.task_type == 'classification':
73
+ label_col = None
74
+ if 'Gender' in df.columns:
75
+ label_col = 'Gender'
76
+ elif 'gender' in df.columns:
77
+ label_col = 'gender'
78
+ elif 'age_group' in df.columns:
79
+ label_col = 'age_group'
80
+
81
+ if label_col is None:
82
+ raise ValueError("CSV must contain 'sex', 'gender' or 'age_group' column for classification.")
83
+
84
+ print(f"Using column '{label_col}' as label.")
85
+
86
+ # unique_vals = df[label_col].unique()
87
+
88
+ sex_mapping = {'F': 0, 'M': 1, 'f': 0, 'm': 1}
89
+
90
+ if df[label_col].dtype == object and df[label_col].astype(str).iloc[0].upper() in ['F', 'M']:
91
+ print(f"Encoding {label_col} (F/M) to Integers (0/1)...")
92
+ df = df[df[label_col].isin(sex_mapping.keys())]
93
+ df[label_col] = df[label_col].map(sex_mapping)
94
+ else:
95
+ df[label_col] = pd.to_numeric(df[label_col], errors='coerce').astype(int)
96
+
97
+ for _, row in df.iterrows():
98
+ subject_id = row['Subject']
99
+ labels_map[subject_id] = torch.tensor(row[label_col], dtype=torch.long)
100
+
101
+ elif self.task_type == 'regression':
102
+ label_col = 'age'
103
+ if label_col not in df.columns:
104
+ raise ValueError(f"Regression task requires '{label_col}' column.")
105
+ df[label_col] = pd.to_numeric(df[label_col], errors='coerce')
106
+ df.dropna(subset=[label_col], inplace=True)
107
+
108
+ for _, row in df.iterrows():
109
+ subject_id = row['Subject']
110
+ labels_map[subject_id] = torch.tensor(row[label_col], dtype=torch.float32).view(1)
111
+
112
+ else:
113
+ raise ValueError(f"Unsupported task_type: {self.task_type}")
114
+
115
+ print(f"Successfully loaded {len(labels_map)} subjects' labels.")
116
+ return labels_map
117
+
118
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
119
+
120
+ retries = 0
121
+ max_retries = 100
122
+ while retries < max_retries:
123
+ try:
124
+ data_tensor = super().__getitem__(idx)
125
+
126
+ if data_tensor is None:
127
+ raise ValueError(f"Failed to load data at index {idx} (super returned None)")
128
+
129
+ file_path = self.file_paths[idx]
130
+
131
+ subject_id = self._extract_subject_id(file_path)
132
+
133
+ data_tensor = data_tensor.unsqueeze(0)
134
+
135
+ if subject_id in self.labels_map:
136
+ label_tensor = self.labels_map[subject_id]
137
+
138
+ return data_tensor, label_tensor
139
+ else:
140
+ raise KeyError(f"Label not found for subject ID: {subject_id}")
141
+
142
+ except Exception as e:
143
+ # print(f"Warning: Error loading index {idx}: {e}. Retrying...")
144
+
145
+ idx = np.random.randint(0, len(self))
146
+ retries += 1
147
+
148
+ raise RuntimeError(f"Failed to load any valid data after {max_retries} retries.")
149
+
150
+ return data_tensor, label_tensor
data/pretrain_dataset.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import numpy as np
4
+ from typing import Any, Callable, Dict, Optional, Set, Tuple
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import random
8
+
9
+ class fMRIDataset(Dataset):
10
+ def __init__(self,
11
+ data_root, datasets, split_suffixes, crop_length=40, downstream=False):
12
+
13
+ self.file_paths = []
14
+ self.crop_length = crop_length
15
+ self.downstream = downstream
16
+ for dataset_name in datasets:
17
+ for suffix in split_suffixes:
18
+ folder_name = f"{dataset_name}_{suffix}"
19
+ folder_path = os.path.join(data_root, folder_name)
20
+ if not os.path.exists(folder_path):
21
+ print(f"Warning: Folder not found: {folder_path}")
22
+ continue
23
+
24
+ for root, dirs, files in os.walk(folder_path):
25
+ npz_files = glob.glob(os.path.join(root, "*.npz"))
26
+ if len(npz_files) > 1:
27
+ # sample_size = max(1, int(len(npz_files) * 0.5))
28
+ # npz_files = random.sample(npz_files, sample_size)
29
+ npz_files = sorted(npz_files)[:1]
30
+ self.file_paths.extend(npz_files)
31
+
32
+ print(f"Dataset loaded. Total files found: {len(self.file_paths)}")
33
+
34
+ def __len__(self):
35
+ return len(self.file_paths)
36
+
37
+ def __getitem__(self, idx):
38
+
39
+ file_path = self.file_paths[idx]
40
+ try:
41
+ with np.load(file_path) as data_file:
42
+ key = list(data_file.keys())[0]
43
+ fmri_data = data_file[key]
44
+ fmri_data = fmri_data.astype(np.float32)
45
+ except Exception as e:
46
+ print(f"Error loading file {file_path}: {e}")
47
+ return None
48
+
49
+ total_time_frames = fmri_data.shape[-1]
50
+ if total_time_frames > self.crop_length:
51
+ start_idx = np.random.randint(0, total_time_frames - self.crop_length + 1)
52
+ end_idx = start_idx + self.crop_length
53
+ cropped_data = fmri_data[..., start_idx:end_idx]
54
+ else:
55
+ cropped_data = fmri_data[..., :self.crop_length]
56
+
57
+ data_tensor = torch.from_numpy(cropped_data)
58
+
59
+ data_tensor = data_tensor.permute(3, 0, 1, 2)
60
+
61
+ return data_tensor
finetune.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import argparse
4
+ import yaml
5
+ import datetime
6
+ import numpy as np
7
+ from pathlib import Path
8
+ from sklearn.metrics import f1_score
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.distributed as dist
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ from torch.utils.data import DataLoader, DistributedSampler
15
+ from torch.cuda.amp import GradScaler, autocast
16
+
17
+ sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'hiera'))
18
+
19
+ from hiera.hiera_mae import HieraClassifier
20
+ from data.downstream_dataset import fMRITaskDataset, fMRITaskDataset1, EmoFMRIDataset, HCPtaskDataset
21
+ from data.adni_dataset import ADNIDataset
22
+
23
+ from utils.utils import MetricLogger, load_config, log_to_file, count_parameters, save_checkpoint, load_checkpoint, LabelScaler
24
+ from utils.optim import create_optimizer, create_lr_scheduler
25
+ from utils.ddp import setup_distributed, set_seed, cleanup_distributed
26
+
27
+
28
+ def create_model(config):
29
+ """Create Hiera Classifier model from config"""
30
+ task_config = config['task']
31
+ exp_config = config['experiment']
32
+
33
+ model_config = config['model']
34
+ pretrained_checkpoint_path = exp_config.get('pretrained_checkpoint', None)
35
+
36
+ if pretrained_checkpoint_path:
37
+ pretrain_config_path = Path(pretrained_checkpoint_path).parent.parent / 'config.yaml'
38
+ if os.path.exists(pretrain_config_path):
39
+ print(f"Loading model architecture from pretrained config: {pretrain_config_path}")
40
+ pretrain_config = load_config(pretrain_config_path)
41
+ model_config = pretrain_config['model']
42
+ else:
43
+ print(f"Warning: Pretrained config not found at {pretrain_config_path}. Using finetune config for model architecture.")
44
+
45
+ model = HieraClassifier(
46
+ num_classes=task_config['num_classes'],
47
+ task_type=task_config['task_type'],
48
+ input_size=tuple(model_config['input_size']),
49
+ in_chans=model_config['in_chans'],
50
+ patch_kernel=tuple(model_config['patch_kernel']),
51
+ patch_stride=tuple(model_config['patch_stride']),
52
+ patch_padding=tuple(model_config['patch_padding']),
53
+ q_stride=tuple(model_config['q_stride']),
54
+ mask_unit_size=tuple(model_config['mask_unit_size']),
55
+ embed_dim=model_config['embed_dim'],
56
+ num_heads=model_config['num_heads'],
57
+ stages=tuple(model_config['stages']),
58
+ q_pool=model_config['q_pool'],
59
+ mlp_ratio=model_config['mlp_ratio'],
60
+ )
61
+
62
+ # Load pretrained weights if specified
63
+ if pretrained_checkpoint_path:
64
+ if os.path.exists(pretrained_checkpoint_path):
65
+ model.load_pretrained_mae(pretrained_checkpoint_path)
66
+ else:
67
+ print(f"Warning: Pretrained checkpoint not found at {pretrained_checkpoint_path}. Model is randomly initialized.")
68
+ else:
69
+ print("Warning: No pretrained checkpoint specified. Model is randomly initialized.")
70
+
71
+ return model
72
+
73
+
74
+
75
+ def create_dataloaders(config, is_distributed, rank, world_size):
76
+ """Create train, validation, and test dataloaders"""
77
+ data_config = config['data']
78
+ task_config = config['task']
79
+
80
+ train_dataset = fMRITaskDataset(
81
+ data_root=data_config['data_root'],
82
+ datasets=data_config['datasets'],
83
+ split_suffixes=data_config['train_split_suffixes'],
84
+ crop_length=data_config['input_seq_len'],
85
+ label_csv_path=task_config['csv'],
86
+ task_type=task_config['task_type']
87
+ )
88
+
89
+ val_dataset = fMRITaskDataset(
90
+ data_root=data_config['data_root'],
91
+ datasets=data_config['datasets'],
92
+ split_suffixes=data_config['val_split_suffixes'],
93
+ crop_length=data_config['input_seq_len'],
94
+ label_csv_path=task_config['csv'],
95
+ task_type=task_config['task_type']
96
+ )
97
+
98
+
99
+ test_dataset = fMRITaskDataset(
100
+ data_root=data_config['data_root'],
101
+ datasets=data_config['datasets'],
102
+ split_suffixes=data_config.get('test_split_suffixes', ['test']),
103
+ crop_length=data_config['input_seq_len'],
104
+ label_csv_path=task_config['csv'],
105
+ task_type=task_config['task_type']
106
+ )
107
+
108
+
109
+
110
+ # Create samplers
111
+ if is_distributed:
112
+ train_sampler = DistributedSampler(
113
+ train_dataset,
114
+ num_replicas=world_size,
115
+ rank=rank,
116
+ shuffle=True,
117
+ seed=config['experiment']['seed']
118
+ )
119
+ val_sampler = DistributedSampler(val_dataset, num_replicas=world_size, rank=rank, shuffle=False)
120
+ test_sampler = DistributedSampler(test_dataset, num_replicas=world_size, rank=rank, shuffle=False)
121
+ else:
122
+ train_sampler = None
123
+ val_sampler = None
124
+ test_sampler = None
125
+
126
+ # Create dataloaders
127
+ train_loader = DataLoader(
128
+ train_dataset,
129
+ batch_size=data_config['batch_size'],
130
+ sampler=train_sampler,
131
+ shuffle=(train_sampler is None),
132
+ num_workers=data_config['num_workers'],
133
+ pin_memory=data_config['pin_memory'],
134
+ prefetch_factor=data_config.get('prefetch_factor', 2),
135
+ drop_last=True
136
+ )
137
+
138
+ val_loader = DataLoader(
139
+ val_dataset,
140
+ batch_size=data_config['batch_size'],
141
+ sampler=val_sampler,
142
+ shuffle=False,
143
+ num_workers=data_config['num_workers'],
144
+ pin_memory=data_config['pin_memory'],
145
+ prefetch_factor=data_config.get('prefetch_factor', 2),
146
+ drop_last=False
147
+ )
148
+
149
+ test_loader = DataLoader(
150
+ test_dataset,
151
+ batch_size=data_config['batch_size'],
152
+ sampler=test_sampler,
153
+ shuffle=False,
154
+ num_workers=data_config['num_workers'],
155
+ pin_memory=data_config['pin_memory'],
156
+ prefetch_factor=data_config.get('prefetch_factor', 2),
157
+ drop_last=False
158
+ )
159
+
160
+ return train_loader, val_loader, test_loader, train_sampler
161
+
162
+
163
+ def train_one_epoch(model, train_loader, criterion, optimizer, scheduler, scaler, epoch, config,
164
+ rank, world_size, label_scaler=None,log_file=None):
165
+ """Train for one epoch"""
166
+ model.train()
167
+
168
+ metric_logger = MetricLogger(delimiter=" ")
169
+ header = f'Epoch: [{epoch}]'
170
+
171
+ train_config = config['training']
172
+ log_config = config['logging']
173
+ task_config = config['task']
174
+
175
+ accum_iter = train_config['accum_iter']
176
+ use_amp = train_config['use_amp']
177
+ clip_grad = train_config.get('clip_grad', None)
178
+
179
+ optimizer.zero_grad()
180
+
181
+ for data_iter_step, (samples, labels) in enumerate(metric_logger.log_every(train_loader, log_config['print_freq'], header)):
182
+ # Adjust learning rate per iteration
183
+ if data_iter_step % accum_iter == 0:
184
+ scheduler.step()
185
+
186
+ # Move data to GPU
187
+ samples = samples.cuda(rank, non_blocking=True)
188
+ labels = labels.cuda(rank, non_blocking=True)
189
+
190
+
191
+ # Forward pass with mixed precision
192
+ with autocast(enabled=use_amp):
193
+ outputs = model(samples)
194
+
195
+ # Calculate loss based on task type
196
+ if task_config['task_type'] == 'classification':
197
+ if labels.dim() > 1:
198
+ labels = labels.squeeze()
199
+
200
+ loss = criterion(outputs, labels)
201
+ # Calculate accuracy
202
+ _, predicted = outputs.max(1)
203
+ correct = predicted.eq(labels).sum().item()
204
+ accuracy = correct / labels.size(0)
205
+ else: # regression
206
+ if label_scaler is not None:
207
+ target_for_loss = label_scaler.transform(labels)
208
+ else:
209
+ target_for_loss = labels
210
+ loss = criterion(outputs.squeeze(), target_for_loss.squeeze())
211
+ accuracy = 0.0 # Not applicable for regression
212
+
213
+ loss = loss / accum_iter
214
+
215
+ # Backward pass
216
+ if use_amp:
217
+ scaler.scale(loss).backward()
218
+
219
+ if (data_iter_step + 1) % accum_iter == 0:
220
+ if clip_grad is not None:
221
+ scaler.unscale_(optimizer)
222
+ nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
223
+ scaler.step(optimizer)
224
+ scaler.update()
225
+ optimizer.zero_grad()
226
+ else:
227
+ loss.backward()
228
+
229
+ if (data_iter_step + 1) % accum_iter == 0:
230
+ if clip_grad is not None:
231
+ nn.utils.clip_grad_norm_(model.parameters(), clip_grad)
232
+ optimizer.step()
233
+ optimizer.zero_grad()
234
+
235
+ # Synchronize loss across GPUs
236
+ loss_value = loss.item() * accum_iter
237
+ if not np.isfinite(loss_value):
238
+ print(f"Loss is {loss_value}, stopping training")
239
+ sys.exit(1)
240
+
241
+ metric_logger.update(loss=loss_value)
242
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
243
+ if task_config['task_type'] == 'classification':
244
+ metric_logger.update(acc=accuracy)
245
+
246
+ # Gather stats from all processes
247
+ metric_logger.synchronize_between_processes()
248
+ print(f"Averaged stats: {metric_logger}")
249
+
250
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
251
+
252
+
253
+ @torch.no_grad()
254
+ def evaluate(model, data_loader, criterion, config, rank, epoch=None, label_scaler=None, mode='val'):
255
+
256
+ model.eval()
257
+ metric_logger = MetricLogger(delimiter=" ")
258
+ header = f'{mode.capitalize()} Epoch: [{epoch}]' if epoch is not None else f'{mode.capitalize()}:'
259
+
260
+ task_type = config['task']['task_type']
261
+
262
+ all_preds, all_targets = [], []
263
+
264
+ for samples, labels in metric_logger.log_every(data_loader, 50, header):
265
+ samples = samples.cuda(rank, non_blocking=True)
266
+ labels = labels.cuda(rank, non_blocking=True)
267
+
268
+ outputs = model(samples)
269
+
270
+ if task_type == 'classification':
271
+ labels = labels.squeeze().long() if labels.dim() > 1 else labels.long()
272
+ loss = criterion(outputs, labels)
273
+
274
+ preds = outputs.argmax(1)
275
+ acc = (preds == labels).float().mean().item()
276
+ metric_logger.update(loss=loss.item(), acc=acc)
277
+
278
+ all_preds.append(preds.cpu())
279
+ all_targets.append(labels.cpu())
280
+
281
+ else:
282
+ if label_scaler is not None:
283
+ target_norm = label_scaler.transform(labels)
284
+ loss = criterion(outputs.view(-1), target_norm.view(-1))
285
+
286
+ metric_logger.update(loss=loss.item())
287
+ all_preds.append(outputs.detach().cpu().view(-1))
288
+ all_targets.append(target_norm.detach().cpu().view(-1))
289
+
290
+ if len(all_preds) > 0:
291
+ all_preds = torch.cat(all_preds)
292
+ all_targets = torch.cat(all_targets)
293
+
294
+ if task_type == 'classification':
295
+ f1 = f1_score(all_targets.numpy(), all_preds.numpy(), average='weighted')
296
+ metric_logger.update(f1=f1)
297
+ else:
298
+ mse = torch.mean((all_preds - all_targets) ** 2).item()
299
+ mae = torch.mean(torch.abs(all_preds - all_targets)).item()
300
+
301
+ ss_res = torch.sum((all_targets - all_preds) ** 2)
302
+ ss_tot = torch.sum((all_targets - all_targets.mean()) ** 2)
303
+ r2 = (1 - ss_res / (ss_tot + 1e-8)).item()
304
+
305
+ vx = all_preds - all_preds.mean()
306
+ vy = all_targets - all_targets.mean()
307
+ corr = (torch.sum(vx * vy) / (torch.sqrt(torch.sum(vx**2)) * torch.sqrt(torch.sum(vy**2)) + 1e-8)).item()
308
+
309
+ metric_logger.update(mse=mse, mae=mae, r2=r2, corr=corr)
310
+
311
+ metric_logger.synchronize_between_processes()
312
+
313
+ if rank == 0:
314
+ print(f"[{mode.upper()}] Global stats: {metric_logger}")
315
+
316
+ return {k: meter.global_avg for k, meter in metric_logger.meters.items()}
317
+
318
+
319
+ def main():
320
+ """Main fine-tuning function"""
321
+ # Parse arguments
322
+ parser = argparse.ArgumentParser(description='Hiera MAE 4D fMRI Downstream Fine-tuning')
323
+ parser.add_argument('--config', type=str, default='configs/finetune_config.yaml',
324
+ help='Path to config file')
325
+ parser.add_argument('--resume', type=str, default=None,
326
+ help='Path to checkpoint to resume from')
327
+ parser.add_argument('--output_dir', type=str, default=None,
328
+ help='Output directory (overrides config)')
329
+ args = parser.parse_args()
330
+
331
+ # Load config
332
+ config = load_config(args.config)
333
+
334
+ # Override config with command line arguments
335
+ if args.resume is not None:
336
+ config['experiment']['resume'] = args.resume
337
+ if args.output_dir is not None:
338
+ config['experiment']['output_dir'] = args.output_dir
339
+
340
+ # Setup distributed training
341
+ is_distributed, rank, world_size, gpu = setup_distributed()
342
+
343
+ # Set random seed
344
+ set_seed(config['experiment']['seed'], rank)
345
+
346
+ # Create output directories
347
+ if rank == 0:
348
+ output_dir = Path(config['experiment']['output_dir'])
349
+ checkpoint_dir = output_dir / 'checkpoints'
350
+ log_dir = output_dir / 'logs'
351
+
352
+ output_dir.mkdir(parents=True, exist_ok=True)
353
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
354
+ log_dir.mkdir(parents=True, exist_ok=True)
355
+
356
+ # Save config
357
+ with open(output_dir / 'config.yaml', 'w') as f:
358
+ yaml.dump(config, f, default_flow_style=False)
359
+
360
+ # Setup text log file
361
+ log_file = output_dir / 'training_log.txt'
362
+ with open(log_file, 'w') as f:
363
+ f.write(f"Fine-tuning started at {datetime.datetime.now()}\n")
364
+ f.write("="*80 + "\n")
365
+ f.write(f"Config: {args.config}\n")
366
+ f.write(f"Output directory: {config['experiment']['output_dir']}\n")
367
+ f.write(f"Task type: {config['task']['task_type']}\n")
368
+ f.write("="*80 + "\n\n")
369
+ else:
370
+ checkpoint_dir = None
371
+ log_file = None
372
+
373
+ if is_distributed:
374
+ dist.barrier()
375
+
376
+ model = create_model(config)
377
+ model = model.cuda(gpu)
378
+
379
+ if rank == 0:
380
+ print("\nAnalyzing model architecture...")
381
+ count_parameters(model, verbose=True)
382
+
383
+ if is_distributed:
384
+ model = DDP(model, device_ids=[gpu], find_unused_parameters=True)
385
+
386
+ model_without_ddp = model.module if is_distributed else model
387
+
388
+ if rank == 0:
389
+ print("Creating dataloaders...")
390
+ train_loader, val_loader, test_loader, train_sampler = create_dataloaders(
391
+ config, is_distributed, rank, world_size
392
+ )
393
+
394
+ label_scaler = None
395
+ if config['task']['task_type'] == 'regression':
396
+ if rank == 0:
397
+ mean_val = config['task']['mean']
398
+ scale_val = config['task']['std']
399
+ print(f"StandardScaler fit complete. Mean: {mean_val:.4f}, Std: {scale_val:.4f}")
400
+
401
+ norm_mean = torch.tensor(mean_val, device=gpu, dtype=torch.float32)
402
+ norm_std = torch.tensor(scale_val, device=gpu, dtype=torch.float32)
403
+
404
+ if is_distributed:
405
+ dist.broadcast(norm_mean, src=0)
406
+ dist.broadcast(norm_std, src=0)
407
+
408
+ label_scaler = LabelScaler(norm_mean, norm_std)
409
+
410
+ if rank == 0:
411
+ print(f"Training samples: {len(train_loader.dataset)}")
412
+ print(f"Validation samples: {len(val_loader.dataset)}")
413
+ print(f"Test samples: {len(test_loader.dataset)}")
414
+ print(f"Batches per epoch: {len(train_loader)}")
415
+
416
+ # Create loss criterion
417
+ task_config = config['task']
418
+ if task_config['task_type'] == 'classification':
419
+ criterion = nn.CrossEntropyLoss(label_smoothing=0.0)
420
+ else: # regression
421
+ criterion = nn.MSELoss()
422
+
423
+ # Optionally freeze the encoder
424
+ if config['training'].get('freeze_encoder', False):
425
+ if rank == 0:
426
+ print("Freezing encoder weights. Only the head will be trained.")
427
+ for name, param in model_without_ddp.named_parameters():
428
+ if 'head' not in name:
429
+ param.requires_grad = False
430
+
431
+ # Log which parameters are trainable
432
+ if rank == 0:
433
+ print("Trainable parameters:")
434
+ for name, param in model_without_ddp.named_parameters():
435
+ if param.requires_grad:
436
+ print(name)
437
+
438
+ # Create optimizer and scheduler
439
+ optimizer = create_optimizer(model_without_ddp, config)
440
+ scheduler = create_lr_scheduler(optimizer, config, len(train_loader))
441
+
442
+ # Create gradient scaler for mixed precision
443
+ scaler = GradScaler() if config['training']['use_amp'] else None
444
+
445
+ # Load checkpoint if resuming
446
+ start_epoch = 0
447
+ best_metric = 0.0 # For classification: accuracy
448
+ best_loss = float('inf') # For regression: loss
449
+
450
+ if config['experiment'].get('resume', None) is not None:
451
+ start_epoch, best_metric, best_loss = load_checkpoint(
452
+ config['experiment']['resume'],
453
+ model_without_ddp,
454
+ optimizer,
455
+ scheduler,
456
+ scaler
457
+ )
458
+ print(f"Resumed from epoch {start_epoch}. Best metric: {best_metric:.4f}, Best loss: {best_loss:.4f}")
459
+ else:
460
+ # Initialize best_metric for new run based on task
461
+ if config['task']['task_type'] == 'classification':
462
+ best_metric = 0.0 # Accuracy starts at 0
463
+ else: # regression
464
+ best_metric = float('inf')
465
+
466
+ # Training loop
467
+ if rank == 0:
468
+ print("Starting fine-tuning...")
469
+ print(f"Training from epoch {start_epoch} to {config['training']['epochs']}")
470
+
471
+ for epoch in range(start_epoch, config['training']['epochs']):
472
+ if is_distributed and train_sampler is not None:
473
+ train_sampler.set_epoch(epoch)
474
+
475
+ # Train for one epoch
476
+ train_stats = train_one_epoch(
477
+ model, train_loader, criterion, optimizer, scheduler, scaler,
478
+ epoch, config, rank, world_size, label_scaler, log_file
479
+ )
480
+
481
+ # Log training stats
482
+ if rank == 0:
483
+ log_msg = f"Epoch {epoch} Training - "
484
+ log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in train_stats.items()])
485
+ print(log_msg)
486
+ log_to_file(log_file, log_msg)
487
+
488
+ # Validate
489
+ if epoch % config['validation']['val_freq'] == 0 or epoch == config['training']['epochs'] - 1:
490
+ print(f"DEBUG: label_scaler type is {type(label_scaler)}, value is {label_scaler}")
491
+ val_stats = evaluate(
492
+ model, val_loader, criterion, config, rank, epoch, label_scaler, 'val'
493
+ )
494
+ test_stats = evaluate(model, test_loader, criterion, config, rank, epoch, label_scaler, 'test' )
495
+
496
+ # Log validation stats
497
+ if rank == 0:
498
+ log_msg = f"Epoch {epoch} Validation - "
499
+ log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in val_stats.items()])
500
+ print(log_msg)
501
+ log_to_file(log_file, log_msg)
502
+
503
+ log_msg = f"Epoch {epoch} Test - "
504
+ log_msg += " | ".join([f"{k}: {v:.4f}" for k, v in test_stats.items()])
505
+ print(log_msg)
506
+ log_to_file(log_file, log_msg)
507
+
508
+ # Determine best model based on task type
509
+ if rank == 0:
510
+ if task_config['task_type'] == 'classification':
511
+ # For classification, higher accuracy is better
512
+ current_metric = val_stats.get('acc', 0.0)
513
+ is_best = current_metric > best_metric
514
+ if is_best:
515
+ best_metric = current_metric
516
+ best_loss = val_stats['loss']
517
+ else:
518
+ # For regression, lower loss is better
519
+ is_best = val_stats['loss'] < best_loss
520
+ if is_best:
521
+ best_loss = val_stats['loss']
522
+ best_metric = -best_loss # Store negative loss as metric
523
+
524
+ checkpoint_state = {
525
+ 'epoch': epoch + 1,
526
+ 'model_state_dict': model_without_ddp.state_dict(),
527
+ 'optimizer_state_dict': optimizer.state_dict(),
528
+ 'scheduler_state_dict': scheduler.state_dict(),
529
+ 'best_metric': best_metric,
530
+ 'best_loss': best_loss,
531
+ 'config': config,
532
+ 'train_stats': train_stats,
533
+ 'val_stats': val_stats,
534
+ }
535
+
536
+ if scaler is not None:
537
+ checkpoint_state['scaler_state_dict'] = scaler.state_dict()
538
+
539
+ save_checkpoint(
540
+ checkpoint_state,
541
+ is_best,
542
+ checkpoint_dir,
543
+ filename=f'checkpoint_epoch_{epoch}.pth'
544
+ )
545
+
546
+ checkpoint_msg = f"Checkpoint saved at epoch {epoch}"
547
+ print(checkpoint_msg)
548
+ log_to_file(log_file, checkpoint_msg)
549
+
550
+ if is_best:
551
+ if task_config['task_type'] == 'classification':
552
+ best_msg = f"New best validation accuracy: {best_metric:.4f}"
553
+ else:
554
+ best_msg = f"New best validation loss: {best_loss:.4f}"
555
+ print(best_msg)
556
+ log_to_file(log_file, best_msg)
557
+
558
+ # Save periodic checkpoint
559
+ if rank == 0 and (epoch + 1) % config['logging']['save_freq'] == 0:
560
+ checkpoint_state = {
561
+ 'epoch': epoch + 1,
562
+ 'model_state_dict': model_without_ddp.state_dict(),
563
+ 'optimizer_state_dict': optimizer.state_dict(),
564
+ 'scheduler_state_dict': scheduler.state_dict(),
565
+ 'best_metric': best_metric,
566
+ 'best_loss': best_loss,
567
+ 'config': config,
568
+ }
569
+
570
+ if scaler is not None:
571
+ checkpoint_state['scaler_state_dict'] = scaler.state_dict()
572
+
573
+ save_checkpoint(
574
+ checkpoint_state,
575
+ False,
576
+ checkpoint_dir,
577
+ filename=f'checkpoint_epoch_{epoch}.pth'
578
+ )
579
+
580
+
581
+ # Cleanup
582
+ cleanup_distributed()
583
+
584
+
585
+ if __name__ == '__main__':
586
+ main()
pipeline.png ADDED

Git LFS Details

  • SHA256: ee03c254d2b6c55dcf7639f9090e6b95d69489d7e5bfba9aa9ee686cc42c794e
  • Pointer size: 131 Bytes
  • Size of remote file: 632 kB
requirements.txt ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ attrs==25.3.0
2
+ certifi==2025.8.3
3
+ charset-normalizer==3.4.3
4
+ click==8.2.1
5
+ cloudpickle==3.1.1
6
+ contourpy==1.3.3
7
+ cycler==0.12.1
8
+ einops==0.8.1
9
+ et_xmlfile==2.0.0
10
+ filelock==3.18.0
11
+ fonttools==4.59.0
12
+ fsspec==2025.7.0
13
+ future==1.0.0
14
+ h5py==3.14.0
15
+ hf-xet==1.1.8
16
+ huggingface-hub==0.34.4
17
+ hyperopt==0.2.7
18
+ idna==3.10
19
+ Jinja2==3.1.6
20
+ joblib==1.5.1
21
+ jsonschema==4.25.0
22
+ jsonschema-specifications==2025.4.1
23
+ kiwisolver==1.4.8
24
+ lightning-utilities==0.15.0
25
+ MarkupSafe==3.0.2
26
+ matplotlib==3.10.3
27
+ mpmath==1.3.0
28
+ msgpack==1.1.1
29
+ networkx==3.5
30
+ nibabel==5.3.2
31
+ numpy==2.3.2
32
+ nvidia-cublas-cu12==12.6.4.1
33
+ nvidia-cuda-cupti-cu12==12.6.80
34
+ nvidia-cuda-nvrtc-cu12==12.6.77
35
+ nvidia-cuda-runtime-cu12==12.6.77
36
+ nvidia-cudnn-cu12==9.5.1.17
37
+ nvidia-cufft-cu12==11.3.0.4
38
+ nvidia-cufile-cu12==1.11.1.6
39
+ nvidia-curand-cu12==10.3.7.77
40
+ nvidia-cusolver-cu12==11.7.1.2
41
+ nvidia-cusparse-cu12==12.5.4.2
42
+ nvidia-cusparselt-cu12==0.6.3
43
+ nvidia-nccl-cu12==2.26.2
44
+ nvidia-nvjitlink-cu12==12.6.85
45
+ nvidia-nvtx-cu12==12.6.77
46
+ openpyxl==3.1.5
47
+ packaging==25.0
48
+ pandas==2.3.1
49
+ pillow==11.3.0
50
+ protobuf==6.32.0
51
+ psutil==7.0.0
52
+ py4j==0.10.9.9
53
+ pyaml==25.7.0
54
+ pyarrow==21.0.0
55
+ pyparsing==3.2.3
56
+ python-dateutil==2.9.0.post0
57
+ pytz==2025.2
58
+ PyYAML @ file:///croot/pyyaml_1731006091482/work
59
+ pyzstd==0.17.0
60
+ ray==2.48.0
61
+ referencing==0.36.2
62
+ requests==2.32.4
63
+ rpds-py==0.27.0
64
+ safetensors==0.6.2
65
+ scikit-learn==1.7.1
66
+ scikit-optimize==0.10.2
67
+ scipy==1.16.1
68
+ seaborn==0.13.2
69
+ setuptools==78.1.1
70
+ six==1.17.0
71
+ sympy==1.14.0
72
+ threadpoolctl==3.6.0
73
+ timm==1.0.19
74
+ torch==2.7.1
75
+ torchaudio==2.7.1
76
+ torchmetrics==1.8.0
77
+ torchsummary==1.5.1
78
+ torchvision==0.22.1
79
+ tqdm==4.67.1
80
+ triton==3.3.1
81
+ typing_extensions==4.14.1
82
+ tzdata==2025.2
83
+ urllib3==2.5.0
84
+ wheel==0.45.1
scripts/finetune.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set environment variables
4
+ export CUDA_VISIBLE_DEVICES=3
5
+ export OMP_NUM_THREADS=1
6
+ export MKL_NUM_THREADS=1
7
+
8
+ # Configuration
9
+ CONFIG_FILE="/vePFS-0x0d/home/yewh/Hiera_MAE/configs/finetune_config.yaml"
10
+ NUM_GPUS=1 # Fixed: Changed from 0 to 2 (number of available GPUs)
11
+ MASTER_PORT=29503
12
+
13
+ # Optional: Output directory
14
+ OUTPUT_DIR="/vePFS-0x0d/home/yewh/Hiera_MAE/output/downstream/nki/age-lp3"
15
+
16
+ # Optional: Resume from checkpoint
17
+ # RESUME_CHECKPOINT="output/hiera_finetune/checkpoints/checkpoint_epoch_10.pth"
18
+
19
+ echo "Starting DDP fine-tuning with $NUM_GPUS GPUs..."
20
+ echo "Config: $CONFIG_FILE"
21
+ echo "Output directory: $OUTPUT_DIR"
22
+
23
+ # Launch training with torchrun (recommended for PyTorch >= 1.10)
24
+ if [ -z "$RESUME_CHECKPOINT" ]; then
25
+ # Start from scratch (or from pretrained MAE)
26
+ torchrun \
27
+ --standalone \
28
+ --nnodes=1 \
29
+ --nproc_per_node=$NUM_GPUS \
30
+ --master_port=$MASTER_PORT \
31
+ /vePFS-0x0d/home/yewh/Hiera_MAE/finetune.py \
32
+ --config $CONFIG_FILE \
33
+ --output_dir $OUTPUT_DIR
34
+ else
35
+ # Resume from checkpoint
36
+ torchrun \
37
+ --standalone \
38
+ --nnodes=1 \
39
+ --nproc_per_node=$NUM_GPUS \
40
+ --master_port=$MASTER_PORT \
41
+ /vePFS-0x0d/home/yewh/Hiera_MAE/finetune.py \
42
+ --config $CONFIG_FILE \
43
+ --output_dir $OUTPUT_DIR \
44
+ --resume $RESUME_CHECKPOINT
45
+ fi
46
+
47
+ echo "Fine-tuning completed!"
utils/ddp.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import datetime
4
+ import numpy as np
5
+ import torch.distributed as dist
6
+
7
+ def setup_distributed():
8
+ """Initialize distributed training"""
9
+ if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
10
+ rank = int(os.environ["RANK"])
11
+ world_size = int(os.environ['WORLD_SIZE'])
12
+ gpu = int(os.environ['LOCAL_RANK'])
13
+ elif 'SLURM_PROCID' in os.environ:
14
+ rank = int(os.environ['SLURM_PROCID'])
15
+ gpu = rank % torch.cuda.device_count()
16
+ world_size = int(os.environ['SLURM_NTASKS'])
17
+ else:
18
+ print('Not using distributed mode')
19
+ return False, 0, 1, 0
20
+
21
+ torch.cuda.set_device(gpu)
22
+ dist.init_process_group(
23
+ backend='nccl',
24
+ init_method='env://',
25
+ world_size=world_size,
26
+ rank=rank,
27
+ timeout=datetime.timedelta(minutes=30)
28
+ )
29
+ dist.barrier()
30
+ return True, rank, world_size, gpu
31
+
32
+
33
+ def cleanup_distributed():
34
+ """Cleanup distributed training"""
35
+ if dist.is_initialized():
36
+ dist.destroy_process_group()
37
+
38
+
39
+ def set_seed(seed, rank=0):
40
+ """Set random seed for reproducibility"""
41
+ seed = seed + rank
42
+ torch.manual_seed(seed)
43
+ np.random.seed(seed)
44
+ if torch.cuda.is_available():
45
+ torch.cuda.manual_seed(seed)
46
+ torch.cuda.manual_seed_all(seed)
utils/optim.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def create_optimizer(model, config):
5
+ train_config = config['training']
6
+ base_lr = train_config['learning_rate']
7
+ weight_decay = train_config['weight_decay']
8
+
9
+ layer_decay = train_config.get('layer_decay', 0.8)
10
+
11
+ # 获取所有的 blocks 数量用于计算深度
12
+ # 假设 model 是 HieraClassifier,其 encoder blocks 在 self.blocks 中
13
+ num_layers = len(model.blocks) + 1 # +1 处理 patch_embed
14
+
15
+ parameter_groups = []
16
+
17
+ # 1. 专门处理 Head (分类头通常使用最大的 base_lr)
18
+ head_lr = train_config.get('head_lr', base_lr)
19
+ parameter_groups.append({
20
+ "params": [p for n, p in model.named_parameters() if "head" in n],
21
+ "lr": head_lr,
22
+ "weight_decay": weight_decay
23
+ })
24
+
25
+ # 2. 处理 Encoder Blocks (按层衰减)
26
+ for i, block in enumerate(model.blocks):
27
+ # 深度越深(靠近 head),学习率越高
28
+ # 最后一层 i = num_layers-2,缩放接近 1.0
29
+ # 第一层 i = 0,缩放为 layer_decay^(num_layers)
30
+ scale = layer_decay ** (num_layers - i - 1)
31
+
32
+ parameter_groups.append({
33
+ "params": block.parameters(),
34
+ "lr": base_lr * scale,
35
+ "weight_decay": weight_decay
36
+ })
37
+
38
+ # 3. 处理 Patch Embed 和其他初始层 (最低的学习率)
39
+ earliest_params = []
40
+ for n, p in model.named_parameters():
41
+ if "patch_embed" in n or "encoder_norm" in n:
42
+ earliest_params.append(p)
43
+
44
+ if earliest_params:
45
+ parameter_groups.append({
46
+ "params": earliest_params,
47
+ "lr": base_lr * (layer_decay ** num_layers),
48
+ "weight_decay": weight_decay
49
+ })
50
+
51
+ if train_config['optimizer'].lower() == 'adamw':
52
+ optimizer = torch.optim.AdamW(
53
+ parameter_groups,
54
+ betas=tuple(train_config['betas']),
55
+ weight_decay=train_config['weight_decay']
56
+ )
57
+ elif train_config['optimizer'].lower() == 'sgd':
58
+ optimizer = torch.optim.SGD(
59
+ parameter_groups,
60
+ momentum=train_config.get('momentum', 0.9),
61
+ weight_decay=train_config['weight_decay']
62
+ )
63
+ else:
64
+ raise ValueError(f"Unsupported optimizer: {train_config['optimizer']}")
65
+
66
+ return optimizer
67
+
68
+
69
+ def create_lr_scheduler(optimizer, config, steps_per_epoch):
70
+ """Create learning rate scheduler"""
71
+ train_config = config['training']
72
+ total_steps = train_config['epochs'] * steps_per_epoch
73
+ warmup_steps = train_config['warmup_epochs'] * steps_per_epoch
74
+
75
+ if train_config['lr_scheduler'].lower() == 'cosine':
76
+ def lr_lambda(current_step):
77
+ if current_step < warmup_steps:
78
+ # Linear warmup
79
+ return float(current_step) / float(max(1, warmup_steps))
80
+ else:
81
+ # Cosine annealing
82
+ progress = float(current_step - warmup_steps) / float(max(1, total_steps - warmup_steps))
83
+ return max(train_config['min_lr'] / train_config['learning_rate'],
84
+ 0.5 * (1.0 + np.cos(np.pi * progress)))
85
+
86
+ scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
87
+ else:
88
+ raise ValueError(f"Unsupported scheduler: {train_config['lr_scheduler']}")
89
+
90
+ return scheduler
utils/utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import datetime
3
+ import time
4
+ import torch.distributed as dist
5
+ import yaml
6
+ import os
7
+
8
+ class MetricLogger:
9
+ """Metric logger for training"""
10
+ def __init__(self, delimiter="\t"):
11
+ self.meters = {}
12
+ self.delimiter = delimiter
13
+
14
+ def update(self, **kwargs):
15
+ for k, v in kwargs.items():
16
+ if isinstance(v, torch.Tensor):
17
+ v = v.item()
18
+ if k not in self.meters:
19
+ self.meters[k] = SmoothedValue()
20
+ self.meters[k].update(v)
21
+
22
+ def __str__(self):
23
+ loss_str = []
24
+ for name, meter in self.meters.items():
25
+ loss_str.append(f"{name}: {meter}")
26
+ return self.delimiter.join(loss_str)
27
+
28
+ def synchronize_between_processes(self):
29
+ for meter in self.meters.values():
30
+ meter.synchronize_between_processes()
31
+
32
+ def log_every(self, iterable, print_freq, header=None):
33
+ i = 0
34
+ if not header:
35
+ header = ''
36
+ start_time = time.time()
37
+ end = time.time()
38
+ iter_time = SmoothedValue(fmt='{avg:.4f}')
39
+ data_time = SmoothedValue(fmt='{avg:.4f}')
40
+ space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
41
+ log_msg = [
42
+ header,
43
+ '[{0' + space_fmt + '}/{1}]',
44
+ 'eta: {eta}',
45
+ '{meters}',
46
+ 'time: {time}',
47
+ 'data: {data}'
48
+ ]
49
+ log_msg = self.delimiter.join(log_msg)
50
+ for obj in iterable:
51
+ data_time.update(time.time() - end)
52
+ yield obj
53
+ iter_time.update(time.time() - end)
54
+ if i % print_freq == 0 or i == len(iterable) - 1:
55
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
56
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
57
+ if torch.cuda.is_available() and dist.get_rank() == 0:
58
+ print(log_msg.format(
59
+ i, len(iterable), eta=eta_string,
60
+ meters=str(self),
61
+ time=str(iter_time), data=str(data_time)))
62
+ i += 1
63
+ end = time.time()
64
+ total_time = time.time() - start_time
65
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
66
+ print(f'{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)')
67
+
68
+
69
+ class SmoothedValue:
70
+ """Track a series of values and provide access to smoothed values"""
71
+ def __init__(self, window_size=20, fmt=None):
72
+ if fmt is None:
73
+ fmt = "{median:.4f} ({global_avg:.4f})"
74
+ self.deque = []
75
+ self.total = 0.0
76
+ self.count = 0
77
+ self.fmt = fmt
78
+ self.window_size = window_size
79
+
80
+ def update(self, value, n=1):
81
+ self.deque.append(value)
82
+ if len(self.deque) > self.window_size:
83
+ self.deque.pop(0)
84
+ self.count += n
85
+ self.total += value * n
86
+
87
+ def synchronize_between_processes(self):
88
+ """Synchronize across all processes"""
89
+ if not dist.is_available() or not dist.is_initialized():
90
+ return
91
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
92
+ dist.barrier()
93
+ dist.all_reduce(t)
94
+ t = t.tolist()
95
+ self.count = int(t[0])
96
+ self.total = t[1]
97
+
98
+ @property
99
+ def median(self):
100
+ d = sorted(self.deque)
101
+ n = len(d)
102
+ if n == 0:
103
+ return 0
104
+ if n % 2 == 0:
105
+ return (d[n // 2 - 1] + d[n // 2]) / 2
106
+ return d[n // 2]
107
+
108
+ @property
109
+ def avg(self):
110
+ if len(self.deque) == 0:
111
+ return 0
112
+ return sum(self.deque) / len(self.deque)
113
+
114
+ @property
115
+ def global_avg(self):
116
+ if self.count == 0:
117
+ return 0
118
+ return self.total / self.count
119
+
120
+ def __str__(self):
121
+ return self.fmt.format(
122
+ median=self.median,
123
+ avg=self.avg,
124
+ global_avg=self.global_avg,
125
+ max=max(self.deque) if len(self.deque) > 0 else 0,
126
+ value=self.deque[-1] if len(self.deque) > 0 else 0
127
+ )
128
+
129
+
130
+
131
+ def load_config(config_path):
132
+ """Load configuration from YAML file"""
133
+ with open(config_path, 'r') as f:
134
+ config = yaml.safe_load(f)
135
+ return config
136
+
137
+
138
+ def log_to_file(log_file, message):
139
+ """Write message to log file"""
140
+ if log_file is not None:
141
+ with open(log_file, 'a') as f:
142
+ timestamp = datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')
143
+ f.write(f"[{timestamp}] {message}\n")
144
+ f.flush()
145
+
146
+
147
+ def count_parameters(model, verbose=True):
148
+ """Count model parameters"""
149
+ def count_params(module):
150
+ return sum(p.numel() for p in module.parameters() if p.requires_grad)
151
+
152
+ def format_number(num):
153
+ if num >= 1e9:
154
+ return f"{num/1e9:.2f}B"
155
+ elif num >= 1e6:
156
+ return f"{num/1e6:.2f}M"
157
+ elif num >= 1e3:
158
+ return f"{num/1e3:.2f}K"
159
+ else:
160
+ return str(num)
161
+
162
+ # If DDP model, get original model
163
+ if hasattr(model, 'module'):
164
+ model = model.module
165
+
166
+ total_params = count_params(model)
167
+
168
+ if verbose:
169
+ print("\n" + "="*80)
170
+ print("Model Parameter Statistics")
171
+ print("="*80)
172
+
173
+ # Count encoder parameters
174
+ encoder_params = 0
175
+ for name in ['patch_embed', 'blocks', 'encoder_norm']:
176
+ if hasattr(model, name):
177
+ module = getattr(model, name)
178
+ params = count_params(module)
179
+ encoder_params += params
180
+ print(f"{name:.<35} {params:>15,} ({format_number(params):>8})")
181
+
182
+ # Count head parameters
183
+ if hasattr(model, 'head'):
184
+ head_params = count_params(model.head)
185
+ print(f"{'Classification/Regression Head':.<35} {head_params:>15,} ({format_number(head_params):>8})")
186
+
187
+ print("\n" + "="*80)
188
+ print(f"{'Encoder Parameters':.<35} {encoder_params:>15,} ({format_number(encoder_params):>8})")
189
+ print(f"{'TOTAL TRAINABLE PARAMETERS':.<35} {total_params:>15,} ({format_number(total_params):>8})")
190
+ print("="*80 + "\n")
191
+
192
+ return total_params
193
+
194
+
195
+
196
+ def save_checkpoint(state, is_best, checkpoint_dir, filename='checkpoint.pth'):
197
+ """Save checkpoint"""
198
+ checkpoint_path = os.path.join(checkpoint_dir, filename)
199
+ torch.save(state, checkpoint_path)
200
+ if is_best:
201
+ best_path = os.path.join(checkpoint_dir, 'checkpoint_best.pth')
202
+ torch.save(state, best_path)
203
+
204
+
205
+ def load_checkpoint(checkpoint_path, model, optimizer, scheduler, scaler=None):
206
+ """Load checkpoint"""
207
+ if not os.path.isfile(checkpoint_path):
208
+ print(f"No checkpoint found at '{checkpoint_path}'")
209
+ return 0, 0.0, 0.0
210
+
211
+ print(f"Loading checkpoint '{checkpoint_path}'")
212
+ checkpoint = torch.load(checkpoint_path, map_location='cpu')
213
+
214
+ start_epoch = checkpoint['epoch']
215
+ best_metric = checkpoint.get('best_metric', 0.0)
216
+ best_loss = checkpoint.get('best_loss', float('inf'))
217
+
218
+ model.load_state_dict(checkpoint['model_state_dict'])
219
+ optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
220
+ scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
221
+
222
+ if scaler is not None and 'scaler_state_dict' in checkpoint:
223
+ scaler.load_state_dict(checkpoint['scaler_state_dict'])
224
+
225
+ print(f"Loaded checkpoint from epoch {start_epoch}")
226
+ return start_epoch, best_metric, best_loss
227
+
228
+
229
+
230
+ class LabelScaler:
231
+ def __init__(self, mean, std):
232
+ self.mean = mean
233
+ self.std = std
234
+
235
+ def transform(self, labels):
236
+ """标准化: (y - mean) / std"""
237
+ return (labels - self.mean) / self.std