coco-demo / src /train.py
evanec's picture
Upload 12 files
1809762 verified
import argparse
import yaml
import torch
from torch.optim import AdamW
from tqdm import tqdm
from transformers import T5TokenizerFast
from models.vision_t5 import VisionT5
from models.encoder_projection_t5 import ImageProjection
import models.encoders as encoders
from data.loaders import get_coco_dataloaders
from src.inference import generate_caption
from src.utils import save_experiment, filter_kwargs, build_model
from torch.optim.lr_scheduler import CosineAnnealingLR
import math
def build_cosine_warmup_scheduler(optimizer, num_warmup_steps, num_training_steps):
def lr_lambda(step):
if step < num_warmup_steps:
return float(step) / float(max(1, num_warmup_steps))
progress = float(step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
return 0.5 * (1 + math.cos(math.pi * progress)) # cosine decay
return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)
def train_one_epoch(model, dataloader, optimizer, device, scaler, scheduler):
model.train()
running_loss = 0.0
for batch in tqdm(dataloader, desc="Training"):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
# teacher forcing labels
labels = input_ids.clone()
labels[labels == model.t5.config.pad_token_id] = -100 # HF provided value to ignore in labels for loss calc.
optimizer.zero_grad()
# Using AMP to save memory
with torch.cuda.amp.autocast():
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
loss = outputs.loss
scaler.scale(loss).backward()
# Gradient clipping
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
scaler.step(optimizer)
scaler.update()
scheduler.step()
running_loss += loss.item()
return running_loss / len(dataloader)
# Validation
@torch.no_grad()
def validate(model, tokenizer, dataloader, device, preview=False):
model.eval()
running_loss = 0.0
sample_img = None
sample_gt = None
for batch in tqdm(dataloader, desc="Validation"):
pixel_values = batch["pixel_values"].to(device)
input_ids = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
# Teacher-forcing labels
labels = input_ids.clone()
labels[labels == tokenizer.pad_token_id] = -100
outputs = model(
pixel_values=pixel_values,
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
)
running_loss += outputs.loss.item()
# Store sample for preview
if preview and sample_img is None:
sample_img = pixel_values[0].detach().cpu()
# decode GT caption (first non-pad tokens)
gt_ids = input_ids[0][input_ids[0] != tokenizer.pad_token_id]
sample_gt = tokenizer.decode(gt_ids, skip_special_tokens=True)
# preview
if preview and sample_img is not None:
print("\n--- Validation Preview ---")
pred = generate_caption(model, tokenizer, sample_img.unsqueeze(0), device=device)
print("Prediction:", pred)
print("Ground Truth:", sample_gt)
print("--------------------------\n")
return running_loss / len(dataloader)
def main(config):
device = "cuda" if torch.cuda.is_available() else "cpu"
# Model + tokenizer
model, tokenizer = build_model(config)
model.to(device)
# Data
batch_size = config["training"]["batch_size"]
image_size = config["model"].get("image_size", 224)
train_loader, val_loader, _ = get_coco_dataloaders(batch_size=batch_size, data_dir=config["paths"]["data_dir"], image_size=image_size)
optimizer = AdamW(model.parameters(), lr=config["training"]["lr"])
scaler = torch.cuda.amp.GradScaler() # For mixed precision
num_training_steps = len(train_loader) * config["training"]["epochs"]
num_warmup_steps = int(0.05 * num_training_steps)
scheduler = build_cosine_warmup_scheduler(
optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps
)
best_val = float("inf")
best_epoch = -1
# Train loop
for epoch in range(1, config["training"]["epochs"] + 1):
print(f"\nEpoch {epoch}/{config['training']['epochs']}")
train_loss = train_one_epoch(model, train_loader, optimizer, device, scaler, scheduler)
print("Train Loss:", train_loss)
val_loss = validate(model, tokenizer, val_loader, device, preview=config["training"]["preview_val"])
print("Val Loss:", val_loss)
if val_loss < best_val:
best_val = val_loss
best_epoch = epoch
save_experiment(
model=model,
tokenizer=tokenizer,
config=config,
save_dir=config["paths"]["output_dir"],
notes=f"BEST checkpoint epoch={epoch}, val_loss={val_loss:.4f}"
)
print(f"[CHECKPOINT] Saved new BEST model at epoch {epoch}")
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, required=True)
args = parser.parse_args()
with open(args.config, "r") as f:
config = yaml.safe_load(f)
main(config)