Spaces:
Sleeping
Sleeping
| """ | |
| CaptionIQ β Training Script | |
| Data generator + training loop for VGG16 and VGG19 captioning models. | |
| """ | |
| import os | |
| import sys | |
| import argparse | |
| import math | |
| import numpy as np | |
| import tensorflow as tf | |
| import matplotlib | |
| matplotlib.use("Agg") | |
| import matplotlib.pyplot as plt | |
| from tensorflow.keras.models import load_model | |
| from tensorflow.keras.utils import to_categorical | |
| from tensorflow.keras.preprocessing.sequence import pad_sequences | |
| from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.config import ( | |
| EPOCHS, BATCH_SIZE, MAX_LENGTH, FEATURE_DIM, FEATURE_LOCATIONS, | |
| VGG16_FEATURES_FILE, VGG19_FEATURES_FILE, | |
| VGG16_MODEL_FILE, VGG19_MODEL_FILE, | |
| VGG16_LOSS_PLOT, VGG19_LOSS_PLOT, | |
| CAPTIONS_FILE, TOKENIZER_FILE, | |
| TRAIN_IMAGES_FILE, VAL_IMAGES_FILE, | |
| MODELS_DIR, OUTPUTS_DIR, | |
| EARLY_STOP_PATIENCE, LR_PATIENCE, LR_FACTOR, | |
| ) | |
| from src.utils import load_captions, load_image_list, load_features, load_tokenizer, get_vocab_size | |
| from src.model import build_model, BahdanauAttention | |
| def data_generator(captions: dict, features: dict, tokenizer, | |
| max_length: int, vocab_size: int, batch_size: int): | |
| """ | |
| Generator that yields batches of ((image_features, partial_sequence), next_word). | |
| For each image-caption pair, creates multiple training samples: | |
| Input: image features + first k words β Target: (k+1)th word | |
| Yields: | |
| ((img_features_batch, seq_batch), target_batch) | |
| """ | |
| image_ids = list(captions.keys()) | |
| while True: | |
| np.random.shuffle(image_ids) | |
| img_batch, seq_batch, target_batch = [], [], [] | |
| for img_id in image_ids: | |
| if img_id not in features: | |
| continue | |
| img_feature = features[img_id] | |
| for caption in captions[img_id]: | |
| # Encode the caption as a sequence of integers | |
| seq = tokenizer.texts_to_sequences([caption])[0] | |
| # Create input-output pairs for each subsequence | |
| for i in range(1, len(seq)): | |
| # Input: partial sequence up to position i | |
| in_seq = pad_sequences([seq[:i]], maxlen=max_length, padding="post")[0] | |
| # Output: next word (one-hot) | |
| out_word = to_categorical([seq[i]], num_classes=vocab_size)[0] | |
| img_batch.append(img_feature) | |
| seq_batch.append(in_seq) | |
| target_batch.append(out_word) | |
| if len(img_batch) >= batch_size: | |
| yield ( | |
| (np.array(img_batch, dtype=np.float32), | |
| np.array(seq_batch, dtype=np.int32)), | |
| np.array(target_batch, dtype=np.float32) | |
| ) | |
| img_batch, seq_batch, target_batch = [], [], [] | |
| # Yield remaining samples | |
| if img_batch: | |
| yield ( | |
| (np.array(img_batch, dtype=np.float32), | |
| np.array(seq_batch, dtype=np.int32)), | |
| np.array(target_batch, dtype=np.float32) | |
| ) | |
| def make_dataset(captions, features, tokenizer, max_length, vocab_size, batch_size): | |
| """Wrap the Python generator in a tf.data.Dataset with proper output_signature.""" | |
| dataset = tf.data.Dataset.from_generator( | |
| lambda: data_generator(captions, features, tokenizer, | |
| max_length, vocab_size, batch_size), | |
| output_signature=( | |
| ( | |
| tf.TensorSpec(shape=(None, FEATURE_LOCATIONS, FEATURE_DIM), dtype=tf.float32), | |
| tf.TensorSpec(shape=(None, max_length), dtype=tf.int32), | |
| ), | |
| tf.TensorSpec(shape=(None, vocab_size), dtype=tf.float32), | |
| ) | |
| ) | |
| return dataset.prefetch(tf.data.AUTOTUNE) | |
| def count_steps(captions: dict, features: dict, tokenizer, batch_size: int) -> int: | |
| """Count total training steps per epoch.""" | |
| total = 0 | |
| for img_id in captions: | |
| if img_id not in features: | |
| continue | |
| for caption in captions[img_id]: | |
| seq = tokenizer.texts_to_sequences([caption])[0] | |
| total += len(seq) - 1 # number of subsequences per caption | |
| # Use ceil so every sample is seen each epoch (including final partial batch). | |
| return max(1, math.ceil(total / batch_size)) | |
| def compute_max_length(captions: dict) -> int: | |
| """Compute maximum caption length from the data.""" | |
| max_len = 0 | |
| for caps in captions.values(): | |
| for cap in caps: | |
| length = len(cap.split()) | |
| if length > max_len: | |
| max_len = length | |
| return max_len | |
| def plot_loss(history: dict, filepath: str, title: str): | |
| """Plot and save training loss curve.""" | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(history["loss"], label="Training Loss", color="#4a90d9", linewidth=2) | |
| if "val_loss" in history: | |
| plt.plot(history["val_loss"], label="Validation Loss", color="#e74c3c", linewidth=2) | |
| plt.title(title, fontsize=14, fontweight="bold") | |
| plt.xlabel("Epoch", fontsize=12) | |
| plt.ylabel("Loss", fontsize=12) | |
| plt.legend(fontsize=11) | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(filepath, dpi=150) | |
| plt.close() | |
| print(f"Loss plot saved to: {filepath}") | |
| def train(backbone: str = "vgg16", epochs: int = None, resume: bool = True): | |
| """ | |
| Train the captioning model with specified backbone features. | |
| Args: | |
| backbone: "vgg16" or "vgg19" | |
| epochs: Override default epoch count | |
| """ | |
| epochs = epochs or EPOCHS | |
| # ββ Select files based on backbone ββ | |
| if backbone == "vgg16": | |
| features_file = VGG16_FEATURES_FILE | |
| model_file = VGG16_MODEL_FILE | |
| loss_plot = VGG16_LOSS_PLOT | |
| else: | |
| features_file = VGG19_FEATURES_FILE | |
| model_file = VGG19_MODEL_FILE | |
| loss_plot = VGG19_LOSS_PLOT | |
| print("=" * 60) | |
| print(f" CaptionIQ β Training with {backbone.upper()}") | |
| print("=" * 60) | |
| # ββ Load data ββ | |
| print("\nLoading data...") | |
| all_captions = load_captions(CAPTIONS_FILE) | |
| tokenizer = load_tokenizer(TOKENIZER_FILE) | |
| features = load_features(features_file) | |
| train_images = load_image_list(TRAIN_IMAGES_FILE) | |
| val_images = load_image_list(VAL_IMAGES_FILE) | |
| vocab_size = get_vocab_size(tokenizer) | |
| print(f" Vocabulary size: {vocab_size}") | |
| # Filter captions for train/val splits | |
| train_captions = {k: v for k, v in all_captions.items() if k in train_images} | |
| val_captions = {k: v for k, v in all_captions.items() if k in val_images} | |
| print(f" Train images: {len(train_captions)}") | |
| print(f" Val images: {len(val_captions)}") | |
| # Compute max length from training data | |
| max_length = compute_max_length(train_captions) | |
| print(f" Max caption length: {max_length}") | |
| # ββ Build or resume model ββ | |
| if resume and os.path.exists(model_file): | |
| print(f"\nLoading existing checkpoint to resume training: {model_file}") | |
| try: | |
| model = load_model( | |
| model_file, | |
| custom_objects={"BahdanauAttention": BahdanauAttention} | |
| ) | |
| except Exception as exc: | |
| print(f" Warning: could not load existing model ({exc}). Rebuilding from scratch.") | |
| model = build_model(vocab_size, max_length) | |
| else: | |
| print("\nBuilding model...") | |
| model = build_model(vocab_size, max_length) | |
| model.summary() | |
| # ββ Create tf.data.Dataset ββ | |
| train_dataset = make_dataset( | |
| train_captions, features, tokenizer, max_length, vocab_size, BATCH_SIZE | |
| ) | |
| val_dataset = make_dataset( | |
| val_captions, features, tokenizer, max_length, vocab_size, BATCH_SIZE | |
| ) | |
| train_steps = count_steps(train_captions, features, tokenizer, BATCH_SIZE) | |
| val_steps = count_steps(val_captions, features, tokenizer, BATCH_SIZE) | |
| print(f" Train steps/epoch: {train_steps}") | |
| print(f" Val steps/epoch: {val_steps}") | |
| # ββ Callbacks ββ | |
| os.makedirs(MODELS_DIR, exist_ok=True) | |
| os.makedirs(OUTPUTS_DIR, exist_ok=True) | |
| callbacks = [ | |
| ModelCheckpoint( | |
| model_file, monitor="val_loss", save_best_only=True, verbose=1 | |
| ), | |
| ReduceLROnPlateau( | |
| monitor="val_loss", factor=LR_FACTOR, patience=LR_PATIENCE, | |
| min_lr=1e-6, verbose=1 | |
| ), | |
| EarlyStopping( | |
| monitor="val_loss", patience=EARLY_STOP_PATIENCE, | |
| restore_best_weights=True, verbose=1 | |
| ), | |
| ] | |
| # ββ Train ββ | |
| print(f"\nTraining for {epochs} epochs...") | |
| history = model.fit( | |
| train_dataset, | |
| steps_per_epoch=train_steps, | |
| epochs=epochs, | |
| validation_data=val_dataset, | |
| validation_steps=val_steps, | |
| callbacks=callbacks, | |
| verbose=1, | |
| ) | |
| # ββ Save final model ββ | |
| model.save(model_file) | |
| print(f"\nModel saved to: {model_file}") | |
| # ββ Plot loss ββ | |
| plot_loss( | |
| history.history, loss_plot, | |
| f"CaptionIQ Training Loss β {backbone.upper()}" | |
| ) | |
| return model, history | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train CaptionIQ model") | |
| parser.add_argument( | |
| "--backbone", type=str, default="vgg19", | |
| choices=["vgg16", "vgg19", "both"], | |
| help="Which backbone features to use (default: vgg19)" | |
| ) | |
| parser.add_argument( | |
| "--epochs", type=int, default=None, | |
| help=f"Number of epochs (default: {EPOCHS})" | |
| ) | |
| parser.add_argument( | |
| "--resume", action=argparse.BooleanOptionalAction, default=True, | |
| help="Resume from existing model checkpoint if available (default: enabled)" | |
| ) | |
| args = parser.parse_args() | |
| backbones = ["vgg16", "vgg19"] if args.backbone == "both" else [args.backbone] | |
| for backbone in backbones: | |
| train(backbone, args.epochs, args.resume) | |
| print("\nβ Training complete!") | |
| if __name__ == "__main__": | |
| main() | |