Spaces:
Build error
Build error
| import os | |
| import time | |
| import math | |
| import pickle | |
| import gc # For memory optimization | |
| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import DataLoader | |
| from torch.nn.utils.rnn import pack_padded_sequence | |
| from .model import ImageCaptioningModel # Import the model | |
| from .data_preprocessing import COCODataset, COCOVocabulary # Import data handling classes | |
| from .evaluation import calculate_bleu_scores_detailed # Import evaluation metric | |
| from .utils import get_logger, get_train_transform, get_eval_transform # Import utilities | |
| logger = get_logger(__name__) | |
| def train_epoch(model, train_loader, criterion, optimizer, device, epoch, config): | |
| """ | |
| Performs a single training epoch. | |
| Args: | |
| model (nn.Module): The image captioning model. | |
| train_loader (DataLoader): DataLoader for training data. | |
| criterion (nn.Module): Loss function. | |
| optimizer (torch.optim.Optimizer): Optimizer. | |
| device (torch.device): Device to run training on (cpu/cuda). | |
| epoch (int): Current epoch number (0-indexed). | |
| config (dict): Configuration dictionary. | |
| Returns: | |
| float: Average training loss for the epoch. | |
| """ | |
| model.train() # Set model to training mode | |
| running_loss = 0.0 | |
| start_time = time.time() | |
| total_batches = len(train_loader) | |
| # Use tqdm for a progress bar | |
| for i, (images, captions, lengths, _) in enumerate(train_loader): | |
| images = images.to(device) | |
| captions = captions.to(device) | |
| lengths = lengths.to(device) | |
| # Forward pass | |
| # scores: (batch_size, max_decode_length_in_batch, vocab_size) | |
| # caps_sorted: (batch_size, max_padded_length_from_dataset) | |
| # decode_lengths: list of actual lengths for current batch (after sorting) | |
| scores, caps_sorted, decode_lengths, _, _ = model(images, captions, lengths) | |
| # Prepare targets for loss calculation | |
| # Pack scores to remove padding and ensure correct length for loss calculation. | |
| # This matches the dynamic lengths of the sequences. | |
| scores_packed = pack_padded_sequence(scores, decode_lengths, batch_first=True).data | |
| # Slice targets to match the length of scores_packed, removing the <START> token. | |
| # The target sequence is `caption[1:]` because the model predicts the word | |
| # at `t+1` given `caption[t]`. | |
| targets = caps_sorted[:, 1:] # Remove <START> token from targets | |
| targets_packed = pack_padded_sequence(targets, decode_lengths, batch_first=True).data | |
| loss = criterion(scores_packed, targets_packed) | |
| # Backward pass and optimize | |
| optimizer.zero_grad() # Clear gradients from previous step | |
| loss.backward() # Compute gradients | |
| # Gradient clipping to prevent exploding gradients, especially common in RNNs | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.get('grad_clip', 5.0)) | |
| optimizer.step() # Update model parameters | |
| running_loss += loss.item() | |
| # Log training progress periodically | |
| if (i + 1) % config.get('log_step', 100) == 0: | |
| current_loss = loss.item() | |
| perplexity = math.exp(current_loss) if current_loss < float('inf') else float('inf') | |
| logger.info(f"Epoch [{epoch+1}/{config['num_epochs']}], Step [{i+1}/{total_batches}], " | |
| f"Loss: {current_loss:.4f}, Perplexity: {perplexity:.4f}") | |
| epoch_loss = running_loss / total_batches | |
| epoch_time = time.time() - start_time | |
| logger.info(f"Epoch {epoch+1} Training finished. Avg Loss: {epoch_loss:.4f}, Time: {epoch_time:.2f}s") | |
| return epoch_loss | |
| def validate_epoch(model, val_loader, criterion, vocabulary, device, config): | |
| """ | |
| Performs a single validation epoch. | |
| Generates captions for a subset of the validation set to calculate BLEU scores. | |
| Args: | |
| model (nn.Module): The image captioning model. | |
| val_loader (DataLoader): DataLoader for validation data. | |
| criterion (nn.Module): Loss function (used for validation loss). | |
| vocabulary (COCOVocabulary): Vocabulary object, used for converting indices to words. | |
| device (torch.device): Device to run validation on (cpu/cuda). | |
| config (dict): Configuration dictionary. | |
| Returns: | |
| tuple: (Average validation loss, list of generated captions, list of reference captions) | |
| """ | |
| model.eval() # Set model to evaluation mode | |
| val_running_loss = 0.0 | |
| val_generated_captions = [] | |
| val_reference_captions = [] | |
| with torch.no_grad(): # Disable gradient calculations for validation | |
| total_batches = len(val_loader) | |
| # Iterate through the validation loader for loss calculation and caption generation | |
| for i, (images, captions, lengths, _) in enumerate(val_loader): | |
| images = images.to(device) | |
| val_captions_for_loss = captions.to(device) | |
| val_lengths_for_loss = lengths.to(device) | |
| # Forward pass for loss calculation (similar to training) | |
| val_scores, val_caps_sorted, val_decode_lengths, _, _ = model(images, val_captions_for_loss, val_lengths_for_loss) | |
| val_scores_packed = pack_padded_sequence(val_scores, val_decode_lengths, batch_first=True).data | |
| val_targets = val_caps_sorted[:, 1:] # Remove <START> | |
| val_targets_packed = pack_padded_sequence(val_targets, val_decode_lengths, batch_first=True).data | |
| val_loss = criterion(val_scores_packed, val_targets_packed) | |
| val_running_loss += val_loss.item() | |
| # Generate captions using beam search for a subset of batches or all | |
| # The `val_inference_batches` config parameter controls how many batches to run inference on. | |
| val_inference_batches_limit = config.get('val_inference_batches') | |
| if val_inference_batches_limit is None or i < val_inference_batches_limit: | |
| # Iterate through each image in the current batch to generate captions | |
| for j in range(images.size(0)): | |
| image_tensor_single = images[j] # Get a single image tensor (C, H, W) | |
| generated_caption = model.generate_caption( | |
| image_tensor_single, vocabulary, device, | |
| beam_size=config.get('val_beam_size', 3), # Use beam search for validation | |
| max_length=config.get('max_caption_length', 20) | |
| ) | |
| # Convert reference caption indices back to string for metric calculation | |
| reference_caption_str = vocabulary.indices_to_caption(captions[j].cpu().numpy()) | |
| val_generated_captions.append(generated_caption) | |
| val_reference_captions.append(reference_caption_str) | |
| val_avg_loss = val_running_loss / total_batches | |
| perplexity = math.exp(val_avg_loss) if val_avg_loss < float('inf') else float('inf') | |
| logger.info(f"Validation Avg Loss: {val_avg_loss:.4f}, Perplexity: {perplexity:.4f}") | |
| return val_avg_loss, val_generated_captions, val_reference_captions | |
| def train_model(config): | |
| """ | |
| Main training function. Orchestrates training and validation epochs. | |
| Args: | |
| config (dict): Configuration dictionary containing all training parameters. | |
| Returns: | |
| tuple: (Trained model, optimizer, scheduler, vocabulary) | |
| """ | |
| logger.info("Starting training process...") | |
| # Set device (CUDA if available, else CPU) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Load data paths from configuration | |
| data_folder = config['data_folder'] | |
| train_image_folder = config['train_image_folder'] | |
| val_image_folder = config['val_image_folder'] | |
| train_caption_file = config['train_caption_file'] | |
| val_caption_file = config['val_caption_file'] | |
| # Check if caption files exist | |
| if not os.path.exists(train_caption_file): | |
| raise FileNotFoundError(f"Training caption file not found: {train_caption_file}") | |
| if not os.path.exists(val_caption_file): | |
| raise FileNotFoundError(f"Validation caption file not found: {val_caption_file}") | |
| # Image transformations for training and validation | |
| train_transform = get_train_transform() | |
| val_transform = get_eval_transform() # Use eval transform for validation images | |
| # ======================== VOCABULARY HANDLING ======================== | |
| # Define paths for loading/saving vocabulary | |
| # First, try to load from a pre-saved vocabulary file in the output directory | |
| VOCABULARY_FILE_PATH = os.path.join(config['output_dir'], 'vocabulary.pkl') | |
| vocabulary = None # Initialize vocabulary to None | |
| # Try to LOAD vocabulary | |
| if os.path.exists(VOCABULARY_FILE_PATH): | |
| try: | |
| with open(VOCABULARY_FILE_PATH, 'rb') as f: | |
| vocabulary = pickle.load(f) | |
| logger.info(f"Loaded vocabulary from {VOCABULARY_FILE_PATH}") | |
| except Exception as e: | |
| logger.warning(f"Could not load vocabulary from {VOCABULARY_FILE_PATH}: {e}. Will attempt to build new vocabulary.") | |
| vocabulary = None # Ensure it's None if loading fails | |
| else: | |
| logger.info(f"Vocabulary file not found at {VOCABULARY_FILE_PATH}. Will build new vocabulary.") | |
| # If vocabulary is still None (meaning it couldn't be loaded), then BUILD a new one | |
| if vocabulary is None: | |
| logger.info("Building new vocabulary from training dataset...") | |
| # Create a temporary dataset to build the vocabulary. | |
| # No image transforms are needed for vocabulary building. | |
| temp_train_dataset_for_vocab = COCODataset( | |
| image_dir=os.path.join(data_folder, train_image_folder), # Image dir is still needed for dataset init | |
| caption_file=train_caption_file, | |
| subset_size=config.get('vocab_subset_size'), # Use subset if specified for vocab building | |
| transform=None, | |
| vocabulary=None # Explicitly tell it to build a new vocabulary | |
| ) | |
| vocabulary = temp_train_dataset_for_vocab.vocabulary | |
| del temp_train_dataset_for_vocab # Free up memory | |
| gc.collect() # Force garbage collection | |
| logger.info("New vocabulary built.") | |
| # Save the newly built vocabulary | |
| try: | |
| os.makedirs(os.path.dirname(VOCABULARY_FILE_PATH), exist_ok=True) | |
| with open(VOCABULARY_FILE_PATH, 'wb') as f: | |
| pickle.dump(vocabulary, f) | |
| logger.info(f"Saved newly built vocabulary to {VOCABULARY_FILE_PATH}") | |
| except Exception as e: | |
| logger.error(f"Error saving newly built vocabulary to {VOCABULARY_FILE_PATH}: {e}") | |
| # =========================================================================== | |
| # Create datasets for training and validation using the determined vocabulary | |
| train_dataset = COCODataset( | |
| image_dir=os.path.join(data_folder, train_image_folder), | |
| caption_file=train_caption_file, | |
| vocabulary=vocabulary, # Pass the vocabulary | |
| max_caption_length=config.get('max_caption_length', 20), | |
| subset_size=config.get('train_subset_size'), | |
| transform=train_transform | |
| ) | |
| val_dataset = COCODataset( | |
| image_dir=os.path.join(data_folder, val_image_folder), | |
| caption_file=val_caption_file, | |
| vocabulary=vocabulary, # Pass the same vocabulary | |
| max_caption_length=config.get('max_caption_length', 20), | |
| subset_size=config.get('val_subset_size'), | |
| transform=val_transform | |
| ) | |
| # Create data loaders | |
| train_loader = DataLoader( | |
| train_dataset, | |
| batch_size=config.get('batch_size', 64), | |
| shuffle=True, # Shuffle training data | |
| num_workers=config.get('num_workers', 2), | |
| pin_memory=True # Pin memory for faster data transfer to GPU | |
| ) | |
| val_loader = DataLoader( | |
| val_dataset, | |
| batch_size=config.get('batch_size', 64), | |
| shuffle=False, # Do not shuffle validation data | |
| num_workers=config.get('num_workers', 2), | |
| pin_memory=True | |
| ) | |
| logger.info(f"Training dataset size: {len(train_dataset)}") | |
| logger.info(f"Validation dataset size: {len(val_dataset)}") | |
| # Initialize model | |
| model = ImageCaptioningModel( | |
| vocab_size=vocabulary.vocab_size, | |
| embed_dim=config.get('embed_dim', 256), | |
| attention_dim=config.get('attention_dim', 256), | |
| decoder_dim=config.get('decoder_dim', 256), | |
| dropout=config.get('dropout', 0.5), | |
| fine_tune_encoder=config.get('fine_tune_encoder', True), | |
| max_caption_length=config.get('max_caption_length', 20) # Pass for model's generate_caption | |
| ).to(device) # Move model to specified device | |
| # Loss function and optimizer | |
| # CrossEntropyLoss ignores the <PAD> token in target labels | |
| criterion = nn.CrossEntropyLoss(ignore_index=vocabulary.word2idx['<PAD>']).to(device) | |
| # Separate optimizer for encoder and decoder if fine_tune_encoder is True. | |
| # This allows setting different learning rates. | |
| encoder_params = list(model.encoder.parameters()) | |
| decoder_params = list(model.decoder.parameters()) | |
| optimizer = optim.Adam([ | |
| {'params': encoder_params, 'lr': config.get('encoder_learning_rate', 1e-5) if config.get('fine_tune_encoder', True) else 0.0}, | |
| {'params': decoder_params, 'lr': config.get('learning_rate', 4e-4)} | |
| ]) | |
| # Learning rate scheduler: Reduces learning rate when a metric (BLEU-4) stops improving | |
| scheduler = optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, | |
| mode='max', # Monitor validation metric (e.g., BLEU-4, which we want to maximize) | |
| factor=config.get('lr_reduce_factor', 0.5), # Factor by which the learning rate will be reduced | |
| patience=config.get('lr_patience', 5), # Number of epochs with no improvement after which learning rate will be reduced | |
| verbose=True, # Print messages when LR is updated | |
| min_lr=1e-7 # Minimum learning rate | |
| ) | |
| # ======================== RESUMPTION LOGIC ======================== | |
| start_epoch = 0 | |
| # Initialize best_val_score to a very low value for 'max' mode, so any improvement is noted | |
| best_val_score = 0.0 | |
| output_dir = config['output_dir'] | |
| models_dir = config['models_dir'] | |
| # Try to find and load the latest checkpoint to resume training | |
| latest_checkpoint_path = None | |
| # Look for best_model_bleu*.pth first, then model_epoch_*.pth | |
| saved_models = [f for f in os.listdir(models_dir) if f.startswith('best_model_bleu') and f.endswith('.pth')] | |
| if not saved_models: | |
| saved_models = [f for f in os.listdir(output_dir) if f.startswith('model_epoch_') and f.endswith('.pth')] | |
| if saved_models: | |
| if 'best_model_bleu' in saved_models[0]: | |
| # Sort by BLEU score extracted from filename for best_model_bleu naming | |
| latest_checkpoint_name = max(saved_models, key=lambda f: float(f.split('bleu')[1].replace('.pth', ''))) | |
| else: # For 'model_epoch_X.pth' or similar, sort by epoch number | |
| latest_checkpoint_name = sorted(saved_models, key=lambda x: int(x.split('_')[-1].split('.')[0]))[-1] | |
| # Determine the full path of the latest checkpoint | |
| if latest_checkpoint_name.startswith('best_model_bleu'): | |
| latest_checkpoint_path = os.path.join(models_dir, latest_checkpoint_name) | |
| else: | |
| latest_checkpoint_path = os.path.join(output_dir, latest_checkpoint_name) | |
| logger.info(f"Attempting to resume training from checkpoint: {latest_checkpoint_path}") | |
| try: | |
| # Load checkpoint without strict=False unless there are known key mismatches | |
| checkpoint = torch.load(latest_checkpoint_path, map_location=device) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
| # Load scheduler state if it exists in the checkpoint (important for correct LR adjustment) | |
| if 'scheduler_state_dict' in checkpoint: | |
| scheduler.load_state_dict(checkpoint['scheduler_state_dict']) | |
| else: | |
| logger.warning("Scheduler state not found in checkpoint. Scheduler will restart its state.") | |
| start_epoch = checkpoint['epoch'] | |
| # Safely get best_val_score, default to 0.0 if not found | |
| best_val_score = checkpoint.get('best_val_score', 0.0) | |
| logger.info(f"Resumed training from epoch {start_epoch}. Best validation score so far: {best_val_score:.4f}") | |
| except Exception as e: | |
| logger.error(f"Could not load checkpoint from {latest_checkpoint_path}: {e}. Starting training from scratch.") | |
| # Reset start_epoch and best_val_score if loading fails | |
| start_epoch = 0 | |
| best_val_score = 0.0 | |
| else: | |
| logger.info("No checkpoint found. Starting training from scratch.") | |
| # =========================================================================== | |
| # Training loop | |
| num_epochs = config.get('num_epochs', 10) | |
| for epoch in range(start_epoch, num_epochs): # Start from 'start_epoch' for resuming | |
| # Train for one epoch | |
| epoch_train_loss = train_epoch(model, train_loader, criterion, optimizer, device, epoch, config) | |
| # Validate after each training epoch | |
| val_avg_loss, val_generated_captions, val_reference_captions = validate_epoch( | |
| model, val_loader, criterion, vocabulary, device, config | |
| ) | |
| # Calculate BLEU scores on validation set for tracking and scheduler stepping | |
| if val_generated_captions and val_reference_captions: | |
| val_bleu_scores = calculate_bleu_scores_detailed(val_reference_captions, val_generated_captions) | |
| current_val_score_for_scheduler = val_bleu_scores['BLEU-4'] # Use BLEU-4 for scheduler | |
| logger.info(f"Epoch {epoch+1} Validation BLEU-4: {current_val_score_for_scheduler:.4f}") | |
| # Step the scheduler based on validation BLEU-4. | |
| # This will reduce the learning rate if BLEU-4 does not improve for 'patience' epochs. | |
| scheduler.step(current_val_score_for_scheduler) | |
| # Save the best model based on BLEU-4 score on the validation set | |
| if current_val_score_for_scheduler > best_val_score: | |
| best_val_score = current_val_score_for_scheduler | |
| # Save best model to the 'models' directory | |
| model_path = os.path.join(models_dir, f"best_model_bleu{best_val_score:.4f}.pth") | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), # IMPORTANT: Save scheduler state! | |
| 'loss': epoch_train_loss, | |
| 'vocabulary': vocabulary, | |
| 'config': config, # Save config for easy loading later | |
| 'best_val_score': best_val_score # Save the best score achieved | |
| }, model_path) | |
| logger.info(f"Saved best model checkpoint to {model_path}") | |
| else: | |
| logger.warning("No captions generated during validation for metric calculation. Scheduler stepped with 0.0.") | |
| scheduler.step(0.0) # Step with a low value if no metrics | |
| # Save checkpoint periodically (optional) | |
| # This is good practice for resuming training even if it's not the "best" model yet. | |
| if (epoch + 1) % config.get('save_interval', 5) == 0: | |
| model_path_periodic = os.path.join(output_dir, f"model_epoch_{epoch+1}.pth") | |
| torch.save({ | |
| 'epoch': epoch + 1, | |
| 'model_state_dict': model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), # IMPORTANT: Save scheduler state! | |
| 'loss': epoch_train_loss, | |
| 'vocabulary': vocabulary, | |
| 'config': config, | |
| 'best_val_score': best_val_score # Also save current best score here | |
| }, model_path_periodic) | |
| logger.info(f"Saved periodic model checkpoint to {model_path_periodic}") | |
| # ======================== MEMORY OPTIMIZATION AFTER EACH EPOCH ======================== | |
| logger.info("Performing memory optimization after epoch...") | |
| # Clear PyTorch's CUDA cache (if using GPU) | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| logger.info("CUDA cache emptied.") | |
| # Force Python's garbage collector to run to free up unreferenced objects | |
| gc.collect() | |
| logger.info("Python garbage collector run.") | |
| # ====================================================================================== | |
| logger.info("Training complete.") | |
| return model, optimizer, scheduler, vocabulary # Return trained components for potential further use | |
| if __name__ == '__main__': | |
| # When `train.py` is run directly, it will initiate the training process. | |
| from config import TRAINING_CONFIG, update_config_with_latest_model, _MODELS_DIR, _OUTPUT_DIR | |
| # Update config to ensure it looks for latest model in 'models' dir | |
| # This specifically helps if you copy pre-trained models into 'models' folder for initial load. | |
| # If starting from scratch, it will still default to 0.0000. | |
| update_config_with_latest_model(TRAINING_CONFIG) | |
| logger.info("Starting model training process...") | |
| try: | |
| trained_model, optimizer, scheduler, vocabulary = train_model(TRAINING_CONFIG) | |
| logger.info("Model Training Complete!") | |
| # Optional: You might want to save the final model explicitly if it's not the best one. | |
| # This ensures you have the model from the last epoch. | |
| final_model_path = os.path.join(_MODELS_DIR, f"final_model_epoch_{TRAINING_CONFIG['num_epochs']}.pth") | |
| torch.save({ | |
| 'epoch': TRAINING_CONFIG['num_epochs'], | |
| 'model_state_dict': trained_model.state_dict(), | |
| 'optimizer_state_dict': optimizer.state_dict(), | |
| 'scheduler_state_dict': scheduler.state_dict(), | |
| 'vocabulary': vocabulary, | |
| 'config': TRAINING_CONFIG, | |
| 'best_val_score': 0 # Placeholder, retrieve from scheduler if needed | |
| }, final_model_path) | |
| logger.info(f"Saved final model checkpoint to {final_model_path}") | |
| except FileNotFoundError as e: | |
| logger.error(f"Critical data file missing: {e}") | |
| logger.error("Please ensure the COCO dataset and annotation files are correctly placed as described in README.md.") | |
| except Exception as e: | |
| logger.critical(f"An unhandled error occurred during training: {e}", exc_info=True) | |
| # exc_info=True prints the full traceback | |