Spaces:
Sleeping
Sleeping
| import argparse | |
| import math | |
| import os | |
| from typing import Tuple | |
| import torch | |
| from torch import Tensor | |
| from torch.optim import AdamW | |
| from torch.utils.tensorboard import SummaryWriter | |
| from tqdm import tqdm | |
| from transformers import get_cosine_schedule_with_warmup | |
| from .config import PathsConfig, TrainingConfig, ensure_dir, get_device, set_seed | |
| from .dataset import create_dataloader, create_tokenizer | |
| from .model import ImageCaptioningModel | |
| def parse_args() -> argparse.Namespace: | |
| """ | |
| Parse command-line arguments for training. | |
| """ | |
| parser = argparse.ArgumentParser(description="Train EfficientNetB0 + GPT-2 image captioning model.") | |
| parser.add_argument("--data_root", type=str, default="/Users/ryan/Downloads/visuallyimpair", help="Root path to dataset.") | |
| parser.add_argument("--epochs", type=int, default=10, help="Number of training epochs.") | |
| parser.add_argument("--batch_size", type=int, default=16, help="Batch size.") | |
| parser.add_argument("--lr", type=float, default=5e-5, help="Learning rate.") | |
| parser.add_argument("--warmup_steps", type=int, default=500, help="Number of warmup steps.") | |
| parser.add_argument("--max_length", type=int, default=50, help="Maximum caption length.") | |
| parser.add_argument("--grad_accum_steps", type=int, default=1, help="Gradient accumulation steps.") | |
| parser.add_argument("--output_dir", type=str, default="checkpoints", help="Directory to save checkpoints.") | |
| parser.add_argument("--log_dir", type=str, default="runs", help="Directory for TensorBoard logs.") | |
| parser.add_argument("--patience", type=int, default=10, help="Early stopping patience based on validation loss.") | |
| parser.add_argument("--seed", type=int, default=42, help="Random seed.") | |
| return parser.parse_args() | |
| def create_training_config_from_args(args: argparse.Namespace) -> TrainingConfig: | |
| """ | |
| Create a TrainingConfig instance using command-line arguments. | |
| """ | |
| cfg = TrainingConfig() | |
| cfg.learning_rate = args.lr | |
| cfg.batch_size = args.batch_size | |
| cfg.num_epochs = args.epochs | |
| cfg.warmup_steps = args.warmup_steps | |
| cfg.max_caption_length = args.max_length | |
| cfg.gradient_accumulation_steps = max(1, args.grad_accum_steps) | |
| cfg.output_dir = args.output_dir | |
| cfg.log_dir = args.log_dir | |
| cfg.patience = args.patience | |
| cfg.seed = args.seed | |
| return cfg | |
| def validate_dataloader( | |
| train_loader, | |
| device: torch.device, | |
| ) -> Tuple[Tensor, Tensor, Tensor, Tensor]: | |
| """ | |
| Fetch a single batch from the DataLoader to validate dataset loading. | |
| Returns | |
| ------- | |
| Tuple of (images, input_ids, attention_mask, labels). | |
| """ | |
| try: | |
| batch = next(iter(train_loader)) | |
| except StopIteration as exc: | |
| raise RuntimeError("Training DataLoader is empty. Check your dataset configuration.") from exc | |
| images = batch["image"].to(device) | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| print(f"[DATA] images batch shape: {images.shape}") | |
| print(f"[DATA] input_ids batch shape: {input_ids.shape}") | |
| print(f"[DATA] attention_mask batch shape: {attention_mask.shape}") | |
| print(f"[DATA] labels batch shape: {labels.shape}") | |
| return images, input_ids, attention_mask, labels | |
| def train_one_epoch( | |
| model: ImageCaptioningModel, | |
| train_loader, | |
| optimizer: AdamW, | |
| scheduler, | |
| device: torch.device, | |
| cfg: TrainingConfig, | |
| epoch: int, | |
| scaler: torch.cuda.amp.GradScaler, | |
| writer: SummaryWriter, | |
| ) -> float: | |
| """ | |
| Train the model for a single epoch. | |
| """ | |
| model.train() | |
| running_loss = 0.0 | |
| num_steps = 0 | |
| grad_accum_steps = cfg.gradient_accumulation_steps | |
| progress = tqdm(train_loader, desc=f"Epoch {epoch} [train]", unit="batch") | |
| for step, batch in enumerate(progress): | |
| images = batch["image"].to(device) | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| with torch.cuda.amp.autocast(enabled=(device.type == "cuda" and cfg.mixed_precision)): | |
| outputs = model( | |
| images=images, | |
| captions=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| ) | |
| loss = outputs.loss | |
| if loss is None: | |
| raise RuntimeError("Model did not return a loss during training.") | |
| loss = loss / grad_accum_steps | |
| scaler.scale(loss).backward() | |
| if (step + 1) % grad_accum_steps == 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), cfg.max_grad_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| optimizer.zero_grad(set_to_none=True) | |
| scheduler.step() | |
| running_loss += loss.item() * grad_accum_steps | |
| num_steps += 1 | |
| avg_loss = running_loss / num_steps | |
| progress.set_postfix({"loss": f"{avg_loss:.4f}"}) | |
| epoch_loss = running_loss / max(1, num_steps) | |
| writer.add_scalar("Loss/train", epoch_loss, epoch) | |
| return epoch_loss | |
| def evaluate( | |
| model: ImageCaptioningModel, | |
| val_loader, | |
| device: torch.device, | |
| cfg: TrainingConfig, | |
| epoch: int, | |
| writer: SummaryWriter, | |
| ) -> float: | |
| """ | |
| Evaluate the model on a validation split and return the average loss. | |
| """ | |
| model.eval() | |
| running_loss = 0.0 | |
| num_steps = 0 | |
| with torch.no_grad(): | |
| progress = tqdm(val_loader, desc=f"Epoch {epoch} [val]", unit="batch") | |
| for batch in progress: | |
| images = batch["image"].to(device) | |
| input_ids = batch["input_ids"].to(device) | |
| attention_mask = batch["attention_mask"].to(device) | |
| labels = batch["labels"].to(device) | |
| outputs = model( | |
| images=images, | |
| captions=input_ids, | |
| attention_mask=attention_mask, | |
| labels=labels, | |
| ) | |
| loss = outputs.loss | |
| if loss is None: | |
| raise RuntimeError("Model did not return a loss during validation.") | |
| running_loss += loss.item() | |
| num_steps += 1 | |
| avg_loss = running_loss / num_steps | |
| progress.set_postfix({"val_loss": f"{avg_loss:.4f}"}) | |
| val_loss = running_loss / max(1, num_steps) | |
| writer.add_scalar("Loss/val", val_loss, epoch) | |
| return val_loss | |
| def main() -> None: | |
| args = parse_args() | |
| # Configuration and setup | |
| paths_cfg = PathsConfig(data_root=args.data_root) | |
| training_cfg = create_training_config_from_args(args) | |
| ensure_dir(training_cfg.output_dir) | |
| ensure_dir(training_cfg.log_dir) | |
| set_seed(training_cfg.seed) | |
| device = get_device() | |
| # Data | |
| tokenizer = create_tokenizer() | |
| train_loader, tokenizer = create_dataloader( | |
| paths_cfg=paths_cfg, | |
| training_cfg=training_cfg, | |
| split="train", | |
| tokenizer=tokenizer, | |
| shuffle=True, | |
| ) | |
| val_loader, _ = create_dataloader( | |
| paths_cfg=paths_cfg, | |
| training_cfg=training_cfg, | |
| split="val", | |
| tokenizer=tokenizer, | |
| shuffle=False, | |
| ) | |
| # Validate dataset loading | |
| validate_dataloader(train_loader, device) | |
| # Model | |
| model = ImageCaptioningModel(training_cfg=training_cfg) | |
| optimizer = AdamW(model.parameters(), lr=training_cfg.learning_rate) | |
| total_training_steps = math.ceil( | |
| len(train_loader) / max(1, training_cfg.gradient_accumulation_steps) | |
| ) * training_cfg.num_epochs | |
| scheduler = get_cosine_schedule_with_warmup( | |
| optimizer, | |
| num_warmup_steps=training_cfg.warmup_steps, | |
| num_training_steps=total_training_steps, | |
| ) | |
| scaler = torch.cuda.amp.GradScaler(enabled=(device.type == "cuda" and training_cfg.mixed_precision)) | |
| writer = SummaryWriter(log_dir=training_cfg.log_dir) | |
| best_val_loss = float("inf") | |
| epochs_without_improvement = 0 | |
| try: | |
| for epoch in range(1, training_cfg.num_epochs + 1): | |
| train_loss = train_one_epoch( | |
| model=model, | |
| train_loader=train_loader, | |
| optimizer=optimizer, | |
| scheduler=scheduler, | |
| device=device, | |
| cfg=training_cfg, | |
| epoch=epoch, | |
| scaler=scaler, | |
| writer=writer, | |
| ) | |
| val_loss = evaluate( | |
| model=model, | |
| val_loader=val_loader, | |
| device=device, | |
| cfg=training_cfg, | |
| epoch=epoch, | |
| writer=writer, | |
| ) | |
| print(f"[EPOCH {epoch}] train_loss={train_loss:.4f} val_loss={val_loss:.4f}") | |
| # Checkpointing | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| epochs_without_improvement = 0 | |
| best_path = os.path.join(training_cfg.output_dir, "best_model.pt") | |
| torch.save(model.state_dict(), best_path) | |
| print(f"[CHECKPOINT] Saved new best model to {best_path}") | |
| else: | |
| epochs_without_improvement += 1 | |
| print( | |
| f"[EARLY STOP] No improvement for {epochs_without_improvement} " | |
| f"epoch(s) (patience={training_cfg.patience})." | |
| ) | |
| if epochs_without_improvement >= training_cfg.patience: | |
| print("Early stopping triggered.") | |
| break | |
| except Exception as exc: # noqa: BLE001 | |
| print(f"[ERROR] Training failed with error: {exc}") | |
| raise | |
| finally: | |
| writer.close() | |
| if __name__ == "__main__": | |
| main() | |