Spaces:
Sleeping
Sleeping
| import os | |
| from dotenv import load_dotenv | |
| import csv | |
| import torch | |
| from torch.utils.data import DataLoader, Subset | |
| from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingWarmRestarts | |
| from tqdm import tqdm | |
| from torch.amp.grad_scaler import GradScaler | |
| from torch.amp.autocast_mode import autocast | |
| from pipeline import Painter | |
| from dataset import ImageNetDataset | |
| from eval_in_training import eval_model | |
| from checkpoint import CheckpointManager | |
| def train_model( | |
| model: Painter, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler, | |
| batch_size: int, | |
| accum_steps: int, | |
| train_dataset: ImageNetDataset, | |
| val_dataset: ImageNetDataset, | |
| device: torch.device, | |
| n_epochs: int, | |
| dataset_chunk_size: int | |
| ): | |
| model.to(device) | |
| scaler = GradScaler() | |
| start_epoch, start_iter = 0, 0 | |
| checkpoint_epoch, checkpoint_iter = ckpt_mgr.load( | |
| model, scaler, optimizer, scheduler) | |
| if checkpoint_epoch == 0 and checkpoint_iter == 0: | |
| pass | |
| elif checkpoint_iter == len(train_dataset)-1: | |
| start_epoch = checkpoint_epoch + 1 | |
| start_iter = 0 | |
| else: | |
| start_epoch = checkpoint_epoch | |
| start_iter = checkpoint_iter + 1 | |
| print( | |
| f"Begin training from epoch {start_epoch}, iter {start_iter}/{len(train_dataset)-1}") | |
| end_epoch = start_epoch + n_epochs | |
| try: | |
| for epoch in range(start_epoch, end_epoch): | |
| index = start_iter | |
| while index < len(train_dataset): | |
| indices = list(range(index, min( | |
| index + dataset_chunk_size, len(train_dataset)))) | |
| print(f"Training indices: {indices[0]} - {indices[-1]}") | |
| partial_train_dataset = Subset(train_dataset, indices) | |
| train_dataloader = DataLoader( | |
| partial_train_dataset, | |
| batch_size=batch_size, | |
| shuffle=True, # only shuffle the training portion | |
| num_workers=min(4, batch_size), | |
| ) | |
| val_dataloader = DataLoader( | |
| val_dataset, | |
| batch_size=batch_size, | |
| shuffle=False, | |
| num_workers=min(4, batch_size), | |
| ) | |
| model.train() | |
| print(f"Learning rate: {scheduler.get_last_lr()}") | |
| optimizer.zero_grad() | |
| train_bar = tqdm( | |
| train_dataloader, desc=f"Epoch {epoch}/{end_epoch} [Train]", ncols=0) | |
| reset_loss_metric = { | |
| 'train': {'total': 0.0, 'mse': 0.0}, | |
| 'val': {'total': 0.0, 'mse': 0.0}, | |
| } | |
| loss_metric = reset_loss_metric | |
| shard_start = indices[0] | |
| shard_size = len(indices) | |
| shard_end_exclusive = shard_start + shard_size | |
| total_train_samples = 0 | |
| for batch_i, imgs in enumerate(train_bar, start=0): | |
| batch_n = imgs.size(0) | |
| batch_start = shard_start + batch_i * batch_size | |
| batch_end_exclusive = batch_start + batch_n | |
| imgs = imgs.to(device, non_blocking=True) | |
| with autocast(device_type=str(device)): | |
| out = model(target_img=imgs, train=True) | |
| mse_loss = out['mse_loss'] | |
| total_loss = mse_loss | |
| loss_metric['train']['total'] += total_loss.item() * \ | |
| batch_n | |
| loss_metric['train']['mse'] += mse_loss.item()*batch_n | |
| total_train_samples += batch_n | |
| loss_to_backward = total_loss / accum_steps | |
| scaler.scale(loss_to_backward).backward() | |
| is_accum_step = ((batch_i + 1) % accum_steps == 0) | |
| is_last_batch_in_shard = ( | |
| batch_end_exclusive >= shard_end_exclusive) | |
| if is_accum_step or is_last_batch_in_shard: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_( | |
| model.parameters(), max_norm=1.0) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad() | |
| scheduler.step() | |
| train_bar.set_postfix({ | |
| 'loss': f"{total_loss.item():.4f}", | |
| 'mse': f"{mse_loss.item():.4f}", | |
| }) | |
| if batch_i == 0 or batch_i % 10000 == 0: | |
| model.eval() | |
| eval_model(model, val_dataloader, epoch=epoch, | |
| step=batch_start, output_dir=output_dir) | |
| torch.cuda.empty_cache() | |
| model.train() | |
| if batch_i % 500 == 0: | |
| torch.cuda.empty_cache() | |
| last_sample_idx = shard_start + total_train_samples - 1 | |
| ckpt_mgr.save(model, scaler, optimizer, | |
| scheduler, epoch, last_sample_idx) | |
| avg_train_metric = {k: v / total_train_samples for k, | |
| v in loss_metric['train'].items()} | |
| print(avg_train_metric) | |
| model.eval() | |
| total_val_samples = 0 | |
| with torch.no_grad(), autocast(device_type=str(device)): | |
| val_bar = tqdm( | |
| val_dataloader, desc=f"Epoch {epoch}/{end_epoch} [Val]", ncols=0) | |
| for imgs in val_bar: | |
| batch_n = imgs.size(0) | |
| imgs = imgs.to(device, non_blocking=True) | |
| out = model(imgs) | |
| mse_loss = out['mse_loss'] | |
| total_loss = mse_loss | |
| total_loss = mse_loss | |
| loss_metric['val']['total'] += total_loss.item() * \ | |
| batch_n | |
| loss_metric['val']['mse'] += mse_loss.item()*batch_n | |
| total_val_samples += batch_n | |
| avg_val_metric = {k: v / total_val_samples for k, | |
| v in loss_metric['val'].items()} | |
| write_header = not os.path.exists(train_log_path) | |
| with open(train_log_path, mode="a", newline="") as csvfile: | |
| writer = csv.DictWriter(csvfile, fieldnames=[ | |
| "epoch", "iter", | |
| "train_total_loss", "train_mse_loss", | |
| "val_total_loss", "val_mse_loss" | |
| ]) | |
| if write_header: | |
| writer.writeheader() | |
| writer.writerow({ | |
| "epoch": epoch, | |
| "iter": indices[-1], | |
| "train_total_loss": avg_train_metric["total"], | |
| "train_mse_loss": avg_train_metric["mse"], | |
| "val_total_loss": avg_val_metric["total"], | |
| "val_mse_loss": avg_val_metric["mse"], | |
| }) | |
| except Exception: | |
| checkpoint_dir = os.path.dirname( | |
| os.path.abspath(__file__))+"/checkpoints" | |
| os.makedirs(checkpoint_dir, exist_ok=True) | |
| torch.save({"model": model.state_dict()}, | |
| os.path.join(checkpoint_dir, "ERROR_SAVE_CHECKPOINT.pth")) | |
| raise | |
| if __name__ == '__main__': | |
| load_dotenv() # take environment variables from .env | |
| dataset_dir = os.getenv("IMAGENET_DIR") | |
| print(f"IMAGENET_DIR: {dataset_dir}") | |
| if dataset_dir is None: | |
| raise ValueError("Please set IMAGENET_DIR in the .env file.") | |
| train_dataset_dir = dataset_dir+'/ILSVRC/Data/CLS-LOC/train/' | |
| val_dataset_dir = dataset_dir+'/ILSVRC/Data/CLS-LOC/val/' | |
| working_dir = os.path.dirname(os.path.abspath(__file__)) | |
| print(f"Working dir: {working_dir}") | |
| output_dir = working_dir+'/test_outputs' | |
| train_log_path = working_dir+'/train_log.csv' | |
| ckpt_mgr = CheckpointManager() | |
| model = Painter() | |
| train_dataset = ImageNetDataset( | |
| image_dir=train_dataset_dir, resize_to_size=model.vit_input_img_size) | |
| val_dataset = ImageNetDataset( | |
| image_dir=val_dataset_dir, resize_to_size=model.vit_input_img_size) | |
| optimizer = torch.optim.AdamW([ | |
| {'params': model.feature_extractor.vit.parameters(), 'lr': 1e-5}, | |
| {'params': model.stroke_transformer.parameters(), 'lr': 1e-4}, | |
| ], weight_decay=1e-2, amsgrad=True) | |
| warmup_iters = 500000 | |
| warmup_scheduler = LinearLR( | |
| optimizer, | |
| start_factor=0.5, | |
| total_iters=warmup_iters | |
| ) | |
| cosine_scheduler = CosineAnnealingWarmRestarts( | |
| optimizer, | |
| T_0=500000, | |
| T_mult=2, | |
| eta_min=1e-5 | |
| ) | |
| scheduler = SequentialLR( | |
| optimizer, | |
| schedulers=[warmup_scheduler, cosine_scheduler], | |
| milestones=[warmup_iters] | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| train_model(model, optimizer, scheduler, batch_size=2, accum_steps=16, train_dataset=train_dataset, | |
| val_dataset=val_dataset, device=device, n_epochs=10, dataset_chunk_size=450000) | |