import os from dotenv import load_dotenv import csv import torch from torch.utils.data import DataLoader, Subset from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingWarmRestarts from tqdm import tqdm from torch.amp.grad_scaler import GradScaler from torch.amp.autocast_mode import autocast from pipeline import Painter from dataset import ImageNetDataset from eval_in_training import eval_model from checkpoint import CheckpointManager def train_model( model: Painter, optimizer: torch.optim.Optimizer, scheduler, batch_size: int, accum_steps: int, train_dataset: ImageNetDataset, val_dataset: ImageNetDataset, device: torch.device, n_epochs: int, dataset_chunk_size: int ): model.to(device) scaler = GradScaler() start_epoch, start_iter = 0, 0 checkpoint_epoch, checkpoint_iter = ckpt_mgr.load( model, scaler, optimizer, scheduler) if checkpoint_epoch == 0 and checkpoint_iter == 0: pass elif checkpoint_iter == len(train_dataset)-1: start_epoch = checkpoint_epoch + 1 start_iter = 0 else: start_epoch = checkpoint_epoch start_iter = checkpoint_iter + 1 print( f"Begin training from epoch {start_epoch}, iter {start_iter}/{len(train_dataset)-1}") end_epoch = start_epoch + n_epochs try: for epoch in range(start_epoch, end_epoch): index = start_iter while index < len(train_dataset): indices = list(range(index, min( index + dataset_chunk_size, len(train_dataset)))) print(f"Training indices: {indices[0]} - {indices[-1]}") partial_train_dataset = Subset(train_dataset, indices) train_dataloader = DataLoader( partial_train_dataset, batch_size=batch_size, shuffle=True, # only shuffle the training portion num_workers=min(4, batch_size), ) val_dataloader = DataLoader( val_dataset, batch_size=batch_size, shuffle=False, num_workers=min(4, batch_size), ) model.train() print(f"Learning rate: {scheduler.get_last_lr()}") optimizer.zero_grad() train_bar = tqdm( train_dataloader, desc=f"Epoch {epoch}/{end_epoch} [Train]", ncols=0) reset_loss_metric = { 'train': {'total': 0.0, 'mse': 0.0}, 'val': {'total': 0.0, 'mse': 0.0}, } loss_metric = reset_loss_metric shard_start = indices[0] shard_size = len(indices) shard_end_exclusive = shard_start + shard_size total_train_samples = 0 for batch_i, imgs in enumerate(train_bar, start=0): batch_n = imgs.size(0) batch_start = shard_start + batch_i * batch_size batch_end_exclusive = batch_start + batch_n imgs = imgs.to(device, non_blocking=True) with autocast(device_type=str(device)): out = model(target_img=imgs, train=True) mse_loss = out['mse_loss'] total_loss = mse_loss loss_metric['train']['total'] += total_loss.item() * \ batch_n loss_metric['train']['mse'] += mse_loss.item()*batch_n total_train_samples += batch_n loss_to_backward = total_loss / accum_steps scaler.scale(loss_to_backward).backward() is_accum_step = ((batch_i + 1) % accum_steps == 0) is_last_batch_in_shard = ( batch_end_exclusive >= shard_end_exclusive) if is_accum_step or is_last_batch_in_shard: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_( model.parameters(), max_norm=1.0) scaler.step(optimizer) scaler.update() optimizer.zero_grad() scheduler.step() train_bar.set_postfix({ 'loss': f"{total_loss.item():.4f}", 'mse': f"{mse_loss.item():.4f}", }) if batch_i == 0 or batch_i % 10000 == 0: model.eval() eval_model(model, val_dataloader, epoch=epoch, step=batch_start, output_dir=output_dir) torch.cuda.empty_cache() model.train() if batch_i % 500 == 0: torch.cuda.empty_cache() last_sample_idx = shard_start + total_train_samples - 1 ckpt_mgr.save(model, scaler, optimizer, scheduler, epoch, last_sample_idx) avg_train_metric = {k: v / total_train_samples for k, v in loss_metric['train'].items()} print(avg_train_metric) model.eval() total_val_samples = 0 with torch.no_grad(), autocast(device_type=str(device)): val_bar = tqdm( val_dataloader, desc=f"Epoch {epoch}/{end_epoch} [Val]", ncols=0) for imgs in val_bar: batch_n = imgs.size(0) imgs = imgs.to(device, non_blocking=True) out = model(imgs) mse_loss = out['mse_loss'] total_loss = mse_loss total_loss = mse_loss loss_metric['val']['total'] += total_loss.item() * \ batch_n loss_metric['val']['mse'] += mse_loss.item()*batch_n total_val_samples += batch_n avg_val_metric = {k: v / total_val_samples for k, v in loss_metric['val'].items()} write_header = not os.path.exists(train_log_path) with open(train_log_path, mode="a", newline="") as csvfile: writer = csv.DictWriter(csvfile, fieldnames=[ "epoch", "iter", "train_total_loss", "train_mse_loss", "val_total_loss", "val_mse_loss" ]) if write_header: writer.writeheader() writer.writerow({ "epoch": epoch, "iter": indices[-1], "train_total_loss": avg_train_metric["total"], "train_mse_loss": avg_train_metric["mse"], "val_total_loss": avg_val_metric["total"], "val_mse_loss": avg_val_metric["mse"], }) except Exception: checkpoint_dir = os.path.dirname( os.path.abspath(__file__))+"/checkpoints" os.makedirs(checkpoint_dir, exist_ok=True) torch.save({"model": model.state_dict()}, os.path.join(checkpoint_dir, "ERROR_SAVE_CHECKPOINT.pth")) raise if __name__ == '__main__': load_dotenv() # take environment variables from .env dataset_dir = os.getenv("IMAGENET_DIR") print(f"IMAGENET_DIR: {dataset_dir}") if dataset_dir is None: raise ValueError("Please set IMAGENET_DIR in the .env file.") train_dataset_dir = dataset_dir+'/ILSVRC/Data/CLS-LOC/train/' val_dataset_dir = dataset_dir+'/ILSVRC/Data/CLS-LOC/val/' working_dir = os.path.dirname(os.path.abspath(__file__)) print(f"Working dir: {working_dir}") output_dir = working_dir+'/test_outputs' train_log_path = working_dir+'/train_log.csv' ckpt_mgr = CheckpointManager() model = Painter() train_dataset = ImageNetDataset( image_dir=train_dataset_dir, resize_to_size=model.vit_input_img_size) val_dataset = ImageNetDataset( image_dir=val_dataset_dir, resize_to_size=model.vit_input_img_size) optimizer = torch.optim.AdamW([ {'params': model.feature_extractor.vit.parameters(), 'lr': 1e-5}, {'params': model.stroke_transformer.parameters(), 'lr': 1e-4}, ], weight_decay=1e-2, amsgrad=True) warmup_iters = 500000 warmup_scheduler = LinearLR( optimizer, start_factor=0.5, total_iters=warmup_iters ) cosine_scheduler = CosineAnnealingWarmRestarts( optimizer, T_0=500000, T_mult=2, eta_min=1e-5 ) scheduler = SequentialLR( optimizer, schedulers=[warmup_scheduler, cosine_scheduler], milestones=[warmup_iters] ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") train_model(model, optimizer, scheduler, batch_size=2, accum_steps=16, train_dataset=train_dataset, val_dataset=val_dataset, device=device, n_epochs=10, dataset_chunk_size=450000)