|
|
import argparse |
|
|
import yaml |
|
|
import torch |
|
|
from torch.optim import AdamW |
|
|
from tqdm import tqdm |
|
|
from transformers import T5TokenizerFast |
|
|
|
|
|
|
|
|
from models.vision_t5 import VisionT5 |
|
|
from models.encoder_projection_t5 import ImageProjection |
|
|
import models.encoders as encoders |
|
|
|
|
|
from data.loaders import get_coco_dataloaders |
|
|
from src.inference import generate_caption |
|
|
from src.utils import save_experiment, filter_kwargs, build_model |
|
|
from torch.optim.lr_scheduler import CosineAnnealingLR |
|
|
|
|
|
import math |
|
|
|
|
|
|
|
|
|
|
|
def build_cosine_warmup_scheduler(optimizer, num_warmup_steps, num_training_steps): |
|
|
def lr_lambda(step): |
|
|
if step < num_warmup_steps: |
|
|
return float(step) / float(max(1, num_warmup_steps)) |
|
|
progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps)) |
|
|
return 0.5 * (1 + math.cos(math.pi * progress)) |
|
|
|
|
|
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda) |
|
|
|
|
|
|
|
|
|
|
|
def train_one_epoch(model, dataloader, optimizer, device, scaler, scheduler): |
|
|
model.train() |
|
|
running_loss = 0.0 |
|
|
|
|
|
for batch in tqdm(dataloader, desc="Training"): |
|
|
pixel_values = batch["pixel_values"].to(device) |
|
|
input_ids = batch["input_ids"].to(device) |
|
|
attention_mask = batch["attention_mask"].to(device) |
|
|
|
|
|
|
|
|
labels = input_ids.clone() |
|
|
labels[labels == model.t5.config.pad_token_id] = -100 |
|
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(): |
|
|
outputs = model( |
|
|
pixel_values=pixel_values, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
) |
|
|
loss = outputs.loss |
|
|
|
|
|
scaler.scale(loss).backward() |
|
|
|
|
|
|
|
|
scaler.unscale_(optimizer) |
|
|
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) |
|
|
|
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
scheduler.step() |
|
|
running_loss += loss.item() |
|
|
|
|
|
return running_loss / len(dataloader) |
|
|
|
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def validate(model, tokenizer, dataloader, device, preview=False): |
|
|
model.eval() |
|
|
running_loss = 0.0 |
|
|
|
|
|
sample_img = None |
|
|
sample_gt = None |
|
|
|
|
|
for batch in tqdm(dataloader, desc="Validation"): |
|
|
pixel_values = batch["pixel_values"].to(device) |
|
|
input_ids = batch["input_ids"].to(device) |
|
|
attention_mask = batch["attention_mask"].to(device) |
|
|
|
|
|
|
|
|
labels = input_ids.clone() |
|
|
labels[labels == tokenizer.pad_token_id] = -100 |
|
|
|
|
|
outputs = model( |
|
|
pixel_values=pixel_values, |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
labels=labels, |
|
|
) |
|
|
|
|
|
running_loss += outputs.loss.item() |
|
|
|
|
|
|
|
|
if preview and sample_img is None: |
|
|
sample_img = pixel_values[0].detach().cpu() |
|
|
|
|
|
gt_ids = input_ids[0][input_ids[0] != tokenizer.pad_token_id] |
|
|
sample_gt = tokenizer.decode(gt_ids, skip_special_tokens=True) |
|
|
|
|
|
|
|
|
if preview and sample_img is not None: |
|
|
print("\n--- Validation Preview ---") |
|
|
pred = generate_caption(model, tokenizer, sample_img.unsqueeze(0), device=device) |
|
|
print("Prediction:", pred) |
|
|
print("Ground Truth:", sample_gt) |
|
|
print("--------------------------\n") |
|
|
|
|
|
return running_loss / len(dataloader) |
|
|
|
|
|
|
|
|
|
|
|
def main(config): |
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
|
|
|
model, tokenizer = build_model(config) |
|
|
model.to(device) |
|
|
|
|
|
|
|
|
batch_size = config["training"]["batch_size"] |
|
|
image_size = config["model"].get("image_size", 224) |
|
|
train_loader, val_loader, _ = get_coco_dataloaders(batch_size=batch_size, data_dir=config["paths"]["data_dir"], image_size=image_size) |
|
|
|
|
|
optimizer = AdamW(model.parameters(), lr=config["training"]["lr"]) |
|
|
scaler = torch.cuda.amp.GradScaler() |
|
|
|
|
|
num_training_steps = len(train_loader) * config["training"]["epochs"] |
|
|
num_warmup_steps = int(0.05 * num_training_steps) |
|
|
scheduler = build_cosine_warmup_scheduler( |
|
|
optimizer, |
|
|
num_warmup_steps=num_warmup_steps, |
|
|
num_training_steps=num_training_steps |
|
|
) |
|
|
best_val = float("inf") |
|
|
best_epoch = -1 |
|
|
|
|
|
|
|
|
for epoch in range(1, config["training"]["epochs"] + 1): |
|
|
print(f"\nEpoch {epoch}/{config['training']['epochs']}") |
|
|
|
|
|
train_loss = train_one_epoch(model, train_loader, optimizer, device, scaler, scheduler) |
|
|
print("Train Loss:", train_loss) |
|
|
|
|
|
val_loss = validate(model, tokenizer, val_loader, device, preview=config["training"]["preview_val"]) |
|
|
print("Val Loss:", val_loss) |
|
|
|
|
|
if val_loss < best_val: |
|
|
best_val = val_loss |
|
|
best_epoch = epoch |
|
|
|
|
|
save_experiment( |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
config=config, |
|
|
save_dir=config["paths"]["output_dir"], |
|
|
notes=f"BEST checkpoint epoch={epoch}, val_loss={val_loss:.4f}" |
|
|
) |
|
|
print(f"[CHECKPOINT] Saved new BEST model at epoch {epoch}") |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument("--config", type=str, required=True) |
|
|
args = parser.parse_args() |
|
|
|
|
|
with open(args.config, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
main(config) |
|
|
|