image2painting / train.py
Lasercatz
Upload 9 files
97bca33 verified
import os
from dotenv import load_dotenv
import csv
import torch
from torch.utils.data import DataLoader, Subset
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingWarmRestarts
from tqdm import tqdm
from torch.amp.grad_scaler import GradScaler
from torch.amp.autocast_mode import autocast
from pipeline import Painter
from dataset import ImageNetDataset
from eval_in_training import eval_model
from checkpoint import CheckpointManager
def train_model(
model: Painter,
optimizer: torch.optim.Optimizer,
scheduler,
batch_size: int,
accum_steps: int,
train_dataset: ImageNetDataset,
val_dataset: ImageNetDataset,
device: torch.device,
n_epochs: int,
dataset_chunk_size: int
):
model.to(device)
scaler = GradScaler()
start_epoch, start_iter = 0, 0
checkpoint_epoch, checkpoint_iter = ckpt_mgr.load(
model, scaler, optimizer, scheduler)
if checkpoint_epoch == 0 and checkpoint_iter == 0:
pass
elif checkpoint_iter == len(train_dataset)-1:
start_epoch = checkpoint_epoch + 1
start_iter = 0
else:
start_epoch = checkpoint_epoch
start_iter = checkpoint_iter + 1
print(
f"Begin training from epoch {start_epoch}, iter {start_iter}/{len(train_dataset)-1}")
end_epoch = start_epoch + n_epochs
try:
for epoch in range(start_epoch, end_epoch):
index = start_iter
while index < len(train_dataset):
indices = list(range(index, min(
index + dataset_chunk_size, len(train_dataset))))
print(f"Training indices: {indices[0]} - {indices[-1]}")
partial_train_dataset = Subset(train_dataset, indices)
train_dataloader = DataLoader(
partial_train_dataset,
batch_size=batch_size,
shuffle=True, # only shuffle the training portion
num_workers=min(4, batch_size),
)
val_dataloader = DataLoader(
val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=min(4, batch_size),
)
model.train()
print(f"Learning rate: {scheduler.get_last_lr()}")
optimizer.zero_grad()
train_bar = tqdm(
train_dataloader, desc=f"Epoch {epoch}/{end_epoch} [Train]", ncols=0)
reset_loss_metric = {
'train': {'total': 0.0, 'mse': 0.0},
'val': {'total': 0.0, 'mse': 0.0},
}
loss_metric = reset_loss_metric
shard_start = indices[0]
shard_size = len(indices)
shard_end_exclusive = shard_start + shard_size
total_train_samples = 0
for batch_i, imgs in enumerate(train_bar, start=0):
batch_n = imgs.size(0)
batch_start = shard_start + batch_i * batch_size
batch_end_exclusive = batch_start + batch_n
imgs = imgs.to(device, non_blocking=True)
with autocast(device_type=str(device)):
out = model(target_img=imgs, train=True)
mse_loss = out['mse_loss']
total_loss = mse_loss
loss_metric['train']['total'] += total_loss.item() * \
batch_n
loss_metric['train']['mse'] += mse_loss.item()*batch_n
total_train_samples += batch_n
loss_to_backward = total_loss / accum_steps
scaler.scale(loss_to_backward).backward()
is_accum_step = ((batch_i + 1) % accum_steps == 0)
is_last_batch_in_shard = (
batch_end_exclusive >= shard_end_exclusive)
if is_accum_step or is_last_batch_in_shard:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(
model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad()
scheduler.step()
train_bar.set_postfix({
'loss': f"{total_loss.item():.4f}",
'mse': f"{mse_loss.item():.4f}",
})
if batch_i == 0 or batch_i % 10000 == 0:
model.eval()
eval_model(model, val_dataloader, epoch=epoch,
step=batch_start, output_dir=output_dir)
torch.cuda.empty_cache()
model.train()
if batch_i % 500 == 0:
torch.cuda.empty_cache()
last_sample_idx = shard_start + total_train_samples - 1
ckpt_mgr.save(model, scaler, optimizer,
scheduler, epoch, last_sample_idx)
avg_train_metric = {k: v / total_train_samples for k,
v in loss_metric['train'].items()}
print(avg_train_metric)
model.eval()
total_val_samples = 0
with torch.no_grad(), autocast(device_type=str(device)):
val_bar = tqdm(
val_dataloader, desc=f"Epoch {epoch}/{end_epoch} [Val]", ncols=0)
for imgs in val_bar:
batch_n = imgs.size(0)
imgs = imgs.to(device, non_blocking=True)
out = model(imgs)
mse_loss = out['mse_loss']
total_loss = mse_loss
total_loss = mse_loss
loss_metric['val']['total'] += total_loss.item() * \
batch_n
loss_metric['val']['mse'] += mse_loss.item()*batch_n
total_val_samples += batch_n
avg_val_metric = {k: v / total_val_samples for k,
v in loss_metric['val'].items()}
write_header = not os.path.exists(train_log_path)
with open(train_log_path, mode="a", newline="") as csvfile:
writer = csv.DictWriter(csvfile, fieldnames=[
"epoch", "iter",
"train_total_loss", "train_mse_loss",
"val_total_loss", "val_mse_loss"
])
if write_header:
writer.writeheader()
writer.writerow({
"epoch": epoch,
"iter": indices[-1],
"train_total_loss": avg_train_metric["total"],
"train_mse_loss": avg_train_metric["mse"],
"val_total_loss": avg_val_metric["total"],
"val_mse_loss": avg_val_metric["mse"],
})
except Exception:
checkpoint_dir = os.path.dirname(
os.path.abspath(__file__))+"/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
torch.save({"model": model.state_dict()},
os.path.join(checkpoint_dir, "ERROR_SAVE_CHECKPOINT.pth"))
raise
if __name__ == '__main__':
load_dotenv() # take environment variables from .env
dataset_dir = os.getenv("IMAGENET_DIR")
print(f"IMAGENET_DIR: {dataset_dir}")
if dataset_dir is None:
raise ValueError("Please set IMAGENET_DIR in the .env file.")
train_dataset_dir = dataset_dir+'/ILSVRC/Data/CLS-LOC/train/'
val_dataset_dir = dataset_dir+'/ILSVRC/Data/CLS-LOC/val/'
working_dir = os.path.dirname(os.path.abspath(__file__))
print(f"Working dir: {working_dir}")
output_dir = working_dir+'/test_outputs'
train_log_path = working_dir+'/train_log.csv'
ckpt_mgr = CheckpointManager()
model = Painter()
train_dataset = ImageNetDataset(
image_dir=train_dataset_dir, resize_to_size=model.vit_input_img_size)
val_dataset = ImageNetDataset(
image_dir=val_dataset_dir, resize_to_size=model.vit_input_img_size)
optimizer = torch.optim.AdamW([
{'params': model.feature_extractor.vit.parameters(), 'lr': 1e-5},
{'params': model.stroke_transformer.parameters(), 'lr': 1e-4},
], weight_decay=1e-2, amsgrad=True)
warmup_iters = 500000
warmup_scheduler = LinearLR(
optimizer,
start_factor=0.5,
total_iters=warmup_iters
)
cosine_scheduler = CosineAnnealingWarmRestarts(
optimizer,
T_0=500000,
T_mult=2,
eta_min=1e-5
)
scheduler = SequentialLR(
optimizer,
schedulers=[warmup_scheduler, cosine_scheduler],
milestones=[warmup_iters]
)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_model(model, optimizer, scheduler, batch_size=2, accum_steps=16, train_dataset=train_dataset,
val_dataset=val_dataset, device=device, n_epochs=10, dataset_chunk_size=450000)