| import csv |
| import h5py |
| import torch |
| import torch.nn as nn |
| import random |
| import numpy as np |
| import os |
| import shutil |
| import pandas as pd |
| from torchvision import transforms |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler, Subset, random_split |
| import torch.optim as optim |
| import time |
| from tqdm import tqdm |
| from torch.optim import lr_scheduler |
| from transformers import ViTFeatureExtractor, AutoImageProcessor, ViTMAEConfig, ViTMAEModel, ViTMAEForPreTraining |
| from torchvision.datasets import ImageFolder |
| import lightning.pytorch as pl |
| from lightning.pytorch import Trainer |
| from lightning.pytorch.callbacks import ModelCheckpoint, LearningRateMonitor, RichProgressBar |
| from lightning.pytorch.loggers import TensorBoardLogger |
| from lightning.pytorch.callbacks import RichProgressBar |
| from lightning.pytorch.callbacks import TQDMProgressBar |
| from lightning.pytorch.utilities import rank_zero_only |
|
|
| DEVICE_NUM = torch.cuda.device_count() |
| os.environ['CUDA_VISIBLE_DEVICES'] = ','.join([str(i) for i in range(DEVICE_NUM)]) |
|
|
| SEED = 42 |
| DATA_DIR = "../../0.data/pretrain_nucleus_image_all_16M.hdf5" |
| BATCH_SIZE = 400 *2 |
| NUM_EPOCHS = 70 |
| LEARNINGRATE = 0.0001 |
| PROJECT_NAME = 'Nuspire_Pretraining_V5' |
|
|
| transform = transforms.Compose([ |
| transforms.Grayscale(), |
| transforms.RandomResizedCrop((112, 112), scale=(0.5625, 1.0), ratio=(0.75, 1.33)), |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.RandomVerticalFlip(p=0.5), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.21869252622127533], std=[0.1809280514717102]) |
| ]) |
|
|
| configuration = ViTMAEConfig( |
| hidden_size=768, |
| num_hidden_layers=12, |
| num_attention_heads=12, |
| intermediate_size=3072, |
| hidden_act="gelu", |
| hidden_dropout_prob=0.0, |
| attention_probs_dropout_prob=0.0, |
| initializer_range=0.02, |
| layer_norm_eps=1e-12, |
| image_size=112, |
| patch_size=8, |
| num_channels=1, |
| qkv_bias=True, |
| decoder_num_attention_heads=16, |
| decoder_hidden_size=512, |
| decoder_num_hidden_layers=8, |
| decoder_intermediate_size=1024, |
| mask_ratio=0.75, |
| norm_pix_loss=False |
| ) |
|
|
| class HDF5Dataset(Dataset): |
| def __init__(self, hdf5_path, transform=None): |
| self.hdf5_path = hdf5_path |
| self.transform = transform |
| self.hdf5_file = h5py.File(hdf5_path, 'r', rdcc_nbytes=10*1024**3, rdcc_w0=0.0, rdcc_nslots=10007) |
| self.images = self.hdf5_file['images'] |
|
|
| def __len__(self): |
| return len(self.images) |
|
|
| def __getitem__(self, idx): |
| img = self.images[idx] |
| |
| if self.transform: |
| img = Image.fromarray(img) |
| img = self.transform(img) |
| |
| return img |
|
|
| def __del__(self): |
| self.hdf5_file.close() |
|
|
| class NucleusDataModule(pl.LightningDataModule): |
| def __init__(self, dataset, batch_size): |
| super().__init__() |
| self.dataset = dataset |
| self.batch_size = batch_size |
|
|
| def setup(self, stage=None): |
| |
| train_size = int(0.8 * len(self.dataset)) |
| test_size = len(self.dataset) - train_size |
| self.train_dataset, self.test_dataset = random_split(self.dataset, [train_size, test_size]) |
|
|
| def train_dataloader(self): |
| return DataLoader(self.train_dataset, batch_size=self.batch_size, num_workers=16, pin_memory=True, prefetch_factor=5) |
|
|
| def val_dataloader(self): |
| return DataLoader(self.test_dataset, batch_size=self.batch_size * 3, num_workers=16, pin_memory=True, prefetch_factor=5) |
|
|
| class ViTMAEPreTraining(pl.LightningModule): |
| def __init__(self, configuration): |
| super().__init__() |
| self.model = ViTMAEForPreTraining(configuration) |
| self.save_hyperparameters() |
|
|
| def forward(self, x): |
| return self.model(x) |
|
|
| def training_step(self, batch, batch_idx): |
| x = batch |
| x = x.to(self.device) |
| outputs = self.model(x) |
| loss = outputs.loss |
| self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |
| return loss |
|
|
| |
| def validation_step(self, batch, batch_idx): |
| x = batch |
| x = x.to(self.device) |
| outputs = self.model(x) |
| loss = outputs.loss |
| self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True, logger=True, sync_dist=True) |
| return loss |
| |
| def configure_optimizers(self): |
| optimizer = torch.optim.AdamW(self.model.parameters(), lr=LEARNINGRATE) |
| warmup_epochs = 10 |
| warmup_factor = lambda epoch: epoch / warmup_epochs if epoch < warmup_epochs else 1 |
| scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(optimizer, warmup_factor) |
| scheduler_regular = torch.optim.lr_scheduler.StepLR(optimizer, 20, gamma=0.5) |
| scheduler = { |
| 'scheduler': torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=[scheduler_warmup, scheduler_regular], milestones=[warmup_epochs]), |
| 'interval': 'epoch', |
| 'frequency': 1 |
| } |
| return [optimizer], [scheduler] |
|
|
| class EpochLoggingCallback(pl.Callback): |
| def __init__(self): |
| super().__init__() |
| |
| @rank_zero_only |
| def on_validation_epoch_end(self, trainer, pl_module): |
| train_loss = trainer.callback_metrics.get('train_loss') |
| val_loss = trainer.callback_metrics.get('val_loss') |
| if train_loss is not None and val_loss is not None: |
| trainer.logger.experiment.add_scalars( |
| "Epoch/Loss", |
| {'Train Loss': train_loss, 'Validation Loss': val_loss}, |
| trainer.current_epoch |
| ) |
|
|
| class SaveEpochModelCallback(pl.Callback): |
| def __init__(self): |
| super().__init__() |
| |
| @rank_zero_only |
| def on_validation_epoch_end(self, trainer, pl_module): |
| path = trainer.checkpoint_callback.dirpath |
| epoch = trainer.current_epoch |
| pl_module.model.save_pretrained(f'{path}/epoch{epoch}') |
|
|
| dataset = HDF5Dataset(hdf5_path=DATA_DIR, transform=transform) |
|
|
| data_module = NucleusDataModule(dataset, BATCH_SIZE) |
|
|
| epoch_logging_callback = EpochLoggingCallback() |
|
|
| save_epoch_model_callback = SaveEpochModelCallback() |
|
|
| progress_bar = RichProgressBar() |
|
|
| logger = TensorBoardLogger(save_dir=f'./{PROJECT_NAME}_outputs', name="tensorboard") |
|
|
| best_model_callback = ModelCheckpoint( |
| dirpath=f'./{PROJECT_NAME}_outputs/model', |
| filename='{epoch:02d}-{val_loss:.2f}', |
| save_top_k=3, |
| mode='min', |
| monitor='val_loss' |
| ) |
|
|
| lr_monitor = LearningRateMonitor(logging_interval='epoch') |
|
|
| trainer = Trainer( |
| max_epochs=NUM_EPOCHS, |
| devices=DEVICE_NUM, |
| accelerator='gpu', |
| strategy='ddp', |
| logger=logger, |
| callbacks=[lr_monitor, |
| progress_bar, |
| epoch_logging_callback, |
| save_epoch_model_callback, |
| best_model_callback] |
| ) |
|
|
| |
| pl.seed_everything(SEED, workers=True) |
|
|
| model = ViTMAEPreTraining(configuration,) |
| trainer.fit(model, data_module) |
|
|
|
|
|
|