import os import time import math import pickle import gc # For memory optimization import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.nn.utils.rnn import pack_padded_sequence from .model import ImageCaptioningModel # Import the model from .data_preprocessing import COCODataset, COCOVocabulary # Import data handling classes from .evaluation import calculate_bleu_scores_detailed # Import evaluation metric from .utils import get_logger, get_train_transform, get_eval_transform # Import utilities logger = get_logger(__name__) def train_epoch(model, train_loader, criterion, optimizer, device, epoch, config): """ Performs a single training epoch. Args: model (nn.Module): The image captioning model. train_loader (DataLoader): DataLoader for training data. criterion (nn.Module): Loss function. optimizer (torch.optim.Optimizer): Optimizer. device (torch.device): Device to run training on (cpu/cuda). epoch (int): Current epoch number (0-indexed). config (dict): Configuration dictionary. Returns: float: Average training loss for the epoch. """ model.train() # Set model to training mode running_loss = 0.0 start_time = time.time() total_batches = len(train_loader) # Use tqdm for a progress bar for i, (images, captions, lengths, _) in enumerate(train_loader): images = images.to(device) captions = captions.to(device) lengths = lengths.to(device) # Forward pass # scores: (batch_size, max_decode_length_in_batch, vocab_size) # caps_sorted: (batch_size, max_padded_length_from_dataset) # decode_lengths: list of actual lengths for current batch (after sorting) scores, caps_sorted, decode_lengths, _, _ = model(images, captions, lengths) # Prepare targets for loss calculation # Pack scores to remove padding and ensure correct length for loss calculation. # This matches the dynamic lengths of the sequences. scores_packed = pack_padded_sequence(scores, decode_lengths, batch_first=True).data # Slice targets to match the length of scores_packed, removing the token. # The target sequence is `caption[1:]` because the model predicts the word # at `t+1` given `caption[t]`. targets = caps_sorted[:, 1:] # Remove token from targets targets_packed = pack_padded_sequence(targets, decode_lengths, batch_first=True).data loss = criterion(scores_packed, targets_packed) # Backward pass and optimize optimizer.zero_grad() # Clear gradients from previous step loss.backward() # Compute gradients # Gradient clipping to prevent exploding gradients, especially common in RNNs torch.nn.utils.clip_grad_norm_(model.parameters(), config.get('grad_clip', 5.0)) optimizer.step() # Update model parameters running_loss += loss.item() # Log training progress periodically if (i + 1) % config.get('log_step', 100) == 0: current_loss = loss.item() perplexity = math.exp(current_loss) if current_loss < float('inf') else float('inf') logger.info(f"Epoch [{epoch+1}/{config['num_epochs']}], Step [{i+1}/{total_batches}], " f"Loss: {current_loss:.4f}, Perplexity: {perplexity:.4f}") epoch_loss = running_loss / total_batches epoch_time = time.time() - start_time logger.info(f"Epoch {epoch+1} Training finished. Avg Loss: {epoch_loss:.4f}, Time: {epoch_time:.2f}s") return epoch_loss def validate_epoch(model, val_loader, criterion, vocabulary, device, config): """ Performs a single validation epoch. Generates captions for a subset of the validation set to calculate BLEU scores. Args: model (nn.Module): The image captioning model. val_loader (DataLoader): DataLoader for validation data. criterion (nn.Module): Loss function (used for validation loss). vocabulary (COCOVocabulary): Vocabulary object, used for converting indices to words. device (torch.device): Device to run validation on (cpu/cuda). config (dict): Configuration dictionary. Returns: tuple: (Average validation loss, list of generated captions, list of reference captions) """ model.eval() # Set model to evaluation mode val_running_loss = 0.0 val_generated_captions = [] val_reference_captions = [] with torch.no_grad(): # Disable gradient calculations for validation total_batches = len(val_loader) # Iterate through the validation loader for loss calculation and caption generation for i, (images, captions, lengths, _) in enumerate(val_loader): images = images.to(device) val_captions_for_loss = captions.to(device) val_lengths_for_loss = lengths.to(device) # Forward pass for loss calculation (similar to training) val_scores, val_caps_sorted, val_decode_lengths, _, _ = model(images, val_captions_for_loss, val_lengths_for_loss) val_scores_packed = pack_padded_sequence(val_scores, val_decode_lengths, batch_first=True).data val_targets = val_caps_sorted[:, 1:] # Remove val_targets_packed = pack_padded_sequence(val_targets, val_decode_lengths, batch_first=True).data val_loss = criterion(val_scores_packed, val_targets_packed) val_running_loss += val_loss.item() # Generate captions using beam search for a subset of batches or all # The `val_inference_batches` config parameter controls how many batches to run inference on. val_inference_batches_limit = config.get('val_inference_batches') if val_inference_batches_limit is None or i < val_inference_batches_limit: # Iterate through each image in the current batch to generate captions for j in range(images.size(0)): image_tensor_single = images[j] # Get a single image tensor (C, H, W) generated_caption = model.generate_caption( image_tensor_single, vocabulary, device, beam_size=config.get('val_beam_size', 3), # Use beam search for validation max_length=config.get('max_caption_length', 20) ) # Convert reference caption indices back to string for metric calculation reference_caption_str = vocabulary.indices_to_caption(captions[j].cpu().numpy()) val_generated_captions.append(generated_caption) val_reference_captions.append(reference_caption_str) val_avg_loss = val_running_loss / total_batches perplexity = math.exp(val_avg_loss) if val_avg_loss < float('inf') else float('inf') logger.info(f"Validation Avg Loss: {val_avg_loss:.4f}, Perplexity: {perplexity:.4f}") return val_avg_loss, val_generated_captions, val_reference_captions def train_model(config): """ Main training function. Orchestrates training and validation epochs. Args: config (dict): Configuration dictionary containing all training parameters. Returns: tuple: (Trained model, optimizer, scheduler, vocabulary) """ logger.info("Starting training process...") # Set device (CUDA if available, else CPU) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Load data paths from configuration data_folder = config['data_folder'] train_image_folder = config['train_image_folder'] val_image_folder = config['val_image_folder'] train_caption_file = config['train_caption_file'] val_caption_file = config['val_caption_file'] # Check if caption files exist if not os.path.exists(train_caption_file): raise FileNotFoundError(f"Training caption file not found: {train_caption_file}") if not os.path.exists(val_caption_file): raise FileNotFoundError(f"Validation caption file not found: {val_caption_file}") # Image transformations for training and validation train_transform = get_train_transform() val_transform = get_eval_transform() # Use eval transform for validation images # ======================== VOCABULARY HANDLING ======================== # Define paths for loading/saving vocabulary # First, try to load from a pre-saved vocabulary file in the output directory VOCABULARY_FILE_PATH = os.path.join(config['output_dir'], 'vocabulary.pkl') vocabulary = None # Initialize vocabulary to None # Try to LOAD vocabulary if os.path.exists(VOCABULARY_FILE_PATH): try: with open(VOCABULARY_FILE_PATH, 'rb') as f: vocabulary = pickle.load(f) logger.info(f"Loaded vocabulary from {VOCABULARY_FILE_PATH}") except Exception as e: logger.warning(f"Could not load vocabulary from {VOCABULARY_FILE_PATH}: {e}. Will attempt to build new vocabulary.") vocabulary = None # Ensure it's None if loading fails else: logger.info(f"Vocabulary file not found at {VOCABULARY_FILE_PATH}. Will build new vocabulary.") # If vocabulary is still None (meaning it couldn't be loaded), then BUILD a new one if vocabulary is None: logger.info("Building new vocabulary from training dataset...") # Create a temporary dataset to build the vocabulary. # No image transforms are needed for vocabulary building. temp_train_dataset_for_vocab = COCODataset( image_dir=os.path.join(data_folder, train_image_folder), # Image dir is still needed for dataset init caption_file=train_caption_file, subset_size=config.get('vocab_subset_size'), # Use subset if specified for vocab building transform=None, vocabulary=None # Explicitly tell it to build a new vocabulary ) vocabulary = temp_train_dataset_for_vocab.vocabulary del temp_train_dataset_for_vocab # Free up memory gc.collect() # Force garbage collection logger.info("New vocabulary built.") # Save the newly built vocabulary try: os.makedirs(os.path.dirname(VOCABULARY_FILE_PATH), exist_ok=True) with open(VOCABULARY_FILE_PATH, 'wb') as f: pickle.dump(vocabulary, f) logger.info(f"Saved newly built vocabulary to {VOCABULARY_FILE_PATH}") except Exception as e: logger.error(f"Error saving newly built vocabulary to {VOCABULARY_FILE_PATH}: {e}") # =========================================================================== # Create datasets for training and validation using the determined vocabulary train_dataset = COCODataset( image_dir=os.path.join(data_folder, train_image_folder), caption_file=train_caption_file, vocabulary=vocabulary, # Pass the vocabulary max_caption_length=config.get('max_caption_length', 20), subset_size=config.get('train_subset_size'), transform=train_transform ) val_dataset = COCODataset( image_dir=os.path.join(data_folder, val_image_folder), caption_file=val_caption_file, vocabulary=vocabulary, # Pass the same vocabulary max_caption_length=config.get('max_caption_length', 20), subset_size=config.get('val_subset_size'), transform=val_transform ) # Create data loaders train_loader = DataLoader( train_dataset, batch_size=config.get('batch_size', 64), shuffle=True, # Shuffle training data num_workers=config.get('num_workers', 2), pin_memory=True # Pin memory for faster data transfer to GPU ) val_loader = DataLoader( val_dataset, batch_size=config.get('batch_size', 64), shuffle=False, # Do not shuffle validation data num_workers=config.get('num_workers', 2), pin_memory=True ) logger.info(f"Training dataset size: {len(train_dataset)}") logger.info(f"Validation dataset size: {len(val_dataset)}") # Initialize model model = ImageCaptioningModel( vocab_size=vocabulary.vocab_size, embed_dim=config.get('embed_dim', 256), attention_dim=config.get('attention_dim', 256), decoder_dim=config.get('decoder_dim', 256), dropout=config.get('dropout', 0.5), fine_tune_encoder=config.get('fine_tune_encoder', True), max_caption_length=config.get('max_caption_length', 20) # Pass for model's generate_caption ).to(device) # Move model to specified device # Loss function and optimizer # CrossEntropyLoss ignores the token in target labels criterion = nn.CrossEntropyLoss(ignore_index=vocabulary.word2idx['']).to(device) # Separate optimizer for encoder and decoder if fine_tune_encoder is True. # This allows setting different learning rates. encoder_params = list(model.encoder.parameters()) decoder_params = list(model.decoder.parameters()) optimizer = optim.Adam([ {'params': encoder_params, 'lr': config.get('encoder_learning_rate', 1e-5) if config.get('fine_tune_encoder', True) else 0.0}, {'params': decoder_params, 'lr': config.get('learning_rate', 4e-4)} ]) # Learning rate scheduler: Reduces learning rate when a metric (BLEU-4) stops improving scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode='max', # Monitor validation metric (e.g., BLEU-4, which we want to maximize) factor=config.get('lr_reduce_factor', 0.5), # Factor by which the learning rate will be reduced patience=config.get('lr_patience', 5), # Number of epochs with no improvement after which learning rate will be reduced verbose=True, # Print messages when LR is updated min_lr=1e-7 # Minimum learning rate ) # ======================== RESUMPTION LOGIC ======================== start_epoch = 0 # Initialize best_val_score to a very low value for 'max' mode, so any improvement is noted best_val_score = 0.0 output_dir = config['output_dir'] models_dir = config['models_dir'] # Try to find and load the latest checkpoint to resume training latest_checkpoint_path = None # Look for best_model_bleu*.pth first, then model_epoch_*.pth saved_models = [f for f in os.listdir(models_dir) if f.startswith('best_model_bleu') and f.endswith('.pth')] if not saved_models: saved_models = [f for f in os.listdir(output_dir) if f.startswith('model_epoch_') and f.endswith('.pth')] if saved_models: if 'best_model_bleu' in saved_models[0]: # Sort by BLEU score extracted from filename for best_model_bleu naming latest_checkpoint_name = max(saved_models, key=lambda f: float(f.split('bleu')[1].replace('.pth', ''))) else: # For 'model_epoch_X.pth' or similar, sort by epoch number latest_checkpoint_name = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))[-1] # Determine the full path of the latest checkpoint if latest_checkpoint_name.startswith('best_model_bleu'): latest_checkpoint_path = os.path.join(models_dir, latest_checkpoint_name) else: latest_checkpoint_path = os.path.join(output_dir, latest_checkpoint_name) logger.info(f"Attempting to resume training from checkpoint: {latest_checkpoint_path}") try: # Load checkpoint without strict=False unless there are known key mismatches checkpoint = torch.load(latest_checkpoint_path, map_location=device) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) # Load scheduler state if it exists in the checkpoint (important for correct LR adjustment) if 'scheduler_state_dict' in checkpoint: scheduler.load_state_dict(checkpoint['scheduler_state_dict']) else: logger.warning("Scheduler state not found in checkpoint. Scheduler will restart its state.") start_epoch = checkpoint['epoch'] # Safely get best_val_score, default to 0.0 if not found best_val_score = checkpoint.get('best_val_score', 0.0) logger.info(f"Resumed training from epoch {start_epoch}. Best validation score so far: {best_val_score:.4f}") except Exception as e: logger.error(f"Could not load checkpoint from {latest_checkpoint_path}: {e}. Starting training from scratch.") # Reset start_epoch and best_val_score if loading fails start_epoch = 0 best_val_score = 0.0 else: logger.info("No checkpoint found. Starting training from scratch.") # =========================================================================== # Training loop num_epochs = config.get('num_epochs', 10) for epoch in range(start_epoch, num_epochs): # Start from 'start_epoch' for resuming # Train for one epoch epoch_train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch, config) # Validate after each training epoch val_avg_loss, val_generated_captions, val_reference_captions = validate_epoch( model, val_loader, criterion, vocabulary, device, config ) # Calculate BLEU scores on validation set for tracking and scheduler stepping if val_generated_captions and val_reference_captions: val_bleu_scores = calculate_bleu_scores_detailed(val_reference_captions, val_generated_captions) current_val_score_for_scheduler = val_bleu_scores['BLEU-4'] # Use BLEU-4 for scheduler logger.info(f"Epoch {epoch+1} Validation BLEU-4: {current_val_score_for_scheduler:.4f}") # Step the scheduler based on validation BLEU-4. # This will reduce the learning rate if BLEU-4 does not improve for 'patience' epochs. scheduler.step(current_val_score_for_scheduler) # Save the best model based on BLEU-4 score on the validation set if current_val_score_for_scheduler > best_val_score: best_val_score = current_val_score_for_scheduler # Save best model to the 'models' directory model_path = os.path.join(models_dir, f"best_model_bleu{best_val_score:.4f}.pth") torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), # IMPORTANT: Save scheduler state! 'loss': epoch_train_loss, 'vocabulary': vocabulary, 'config': config, # Save config for easy loading later 'best_val_score': best_val_score # Save the best score achieved }, model_path) logger.info(f"Saved best model checkpoint to {model_path}") else: logger.warning("No captions generated during validation for metric calculation. Scheduler stepped with 0.0.") scheduler.step(0.0) # Step with a low value if no metrics # Save checkpoint periodically (optional) # This is good practice for resuming training even if it's not the "best" model yet. if (epoch + 1) % config.get('save_interval', 5) == 0: model_path_periodic = os.path.join(output_dir, f"model_epoch_{epoch+1}.pth") torch.save({ 'epoch': epoch + 1, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), # IMPORTANT: Save scheduler state! 'loss': epoch_train_loss, 'vocabulary': vocabulary, 'config': config, 'best_val_score': best_val_score # Also save current best score here }, model_path_periodic) logger.info(f"Saved periodic model checkpoint to {model_path_periodic}") # ======================== MEMORY OPTIMIZATION AFTER EACH EPOCH ======================== logger.info("Performing memory optimization after epoch...") # Clear PyTorch's CUDA cache (if using GPU) if torch.cuda.is_available(): torch.cuda.empty_cache() logger.info("CUDA cache emptied.") # Force Python's garbage collector to run to free up unreferenced objects gc.collect() logger.info("Python garbage collector run.") # ====================================================================================== logger.info("Training complete.") return model, optimizer, scheduler, vocabulary # Return trained components for potential further use if __name__ == '__main__': # When `train.py` is run directly, it will initiate the training process. from config import TRAINING_CONFIG, update_config_with_latest_model, _MODELS_DIR, _OUTPUT_DIR # Update config to ensure it looks for latest model in 'models' dir # This specifically helps if you copy pre-trained models into 'models' folder for initial load. # If starting from scratch, it will still default to 0.0000. update_config_with_latest_model(TRAINING_CONFIG) logger.info("Starting model training process...") try: trained_model, optimizer, scheduler, vocabulary = train_model(TRAINING_CONFIG) logger.info("Model Training Complete!") # Optional: You might want to save the final model explicitly if it's not the best one. # This ensures you have the model from the last epoch. final_model_path = os.path.join(_MODELS_DIR, f"final_model_epoch_{TRAINING_CONFIG['num_epochs']}.pth") torch.save({ 'epoch': TRAINING_CONFIG['num_epochs'], 'model_state_dict': trained_model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict(), 'vocabulary': vocabulary, 'config': TRAINING_CONFIG, 'best_val_score': 0 # Placeholder, retrieve from scheduler if needed }, final_model_path) logger.info(f"Saved final model checkpoint to {final_model_path}") except FileNotFoundError as e: logger.error(f"Critical data file missing: {e}") logger.error("Please ensure the COCO dataset and annotation files are correctly placed as described in README.md.") except Exception as e: logger.critical(f"An unhandled error occurred during training: {e}", exc_info=True) # exc_info=True prints the full traceback