CaptionIQ / src /train.py
pavanpraneeth's picture
Upload folder using huggingface_hub
290f366 verified
Raw
History Blame Contribute Delete
10.2 kB
"""
CaptionIQ β€” Training Script
Data generator + training loop for VGG16 and VGG19 captioning models.
"""
import os
import sys
import argparse
import math
import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config import (
EPOCHS, BATCH_SIZE, MAX_LENGTH, FEATURE_DIM, FEATURE_LOCATIONS,
VGG16_FEATURES_FILE, VGG19_FEATURES_FILE,
VGG16_MODEL_FILE, VGG19_MODEL_FILE,
VGG16_LOSS_PLOT, VGG19_LOSS_PLOT,
CAPTIONS_FILE, TOKENIZER_FILE,
TRAIN_IMAGES_FILE, VAL_IMAGES_FILE,
MODELS_DIR, OUTPUTS_DIR,
EARLY_STOP_PATIENCE, LR_PATIENCE, LR_FACTOR,
)
from src.utils import load_captions, load_image_list, load_features, load_tokenizer, get_vocab_size
from src.model import build_model, BahdanauAttention
def data_generator(captions: dict, features: dict, tokenizer,
max_length: int, vocab_size: int, batch_size: int):
"""
Generator that yields batches of ((image_features, partial_sequence), next_word).
For each image-caption pair, creates multiple training samples:
Input: image features + first k words β†’ Target: (k+1)th word
Yields:
((img_features_batch, seq_batch), target_batch)
"""
image_ids = list(captions.keys())
while True:
np.random.shuffle(image_ids)
img_batch, seq_batch, target_batch = [], [], []
for img_id in image_ids:
if img_id not in features:
continue
img_feature = features[img_id]
for caption in captions[img_id]:
# Encode the caption as a sequence of integers
seq = tokenizer.texts_to_sequences([caption])[0]
# Create input-output pairs for each subsequence
for i in range(1, len(seq)):
# Input: partial sequence up to position i
in_seq = pad_sequences([seq[:i]], maxlen=max_length, padding="post")[0]
# Output: next word (one-hot)
out_word = to_categorical([seq[i]], num_classes=vocab_size)[0]
img_batch.append(img_feature)
seq_batch.append(in_seq)
target_batch.append(out_word)
if len(img_batch) >= batch_size:
yield (
(np.array(img_batch, dtype=np.float32),
np.array(seq_batch, dtype=np.int32)),
np.array(target_batch, dtype=np.float32)
)
img_batch, seq_batch, target_batch = [], [], []
# Yield remaining samples
if img_batch:
yield (
(np.array(img_batch, dtype=np.float32),
np.array(seq_batch, dtype=np.int32)),
np.array(target_batch, dtype=np.float32)
)
def make_dataset(captions, features, tokenizer, max_length, vocab_size, batch_size):
"""Wrap the Python generator in a tf.data.Dataset with proper output_signature."""
dataset = tf.data.Dataset.from_generator(
lambda: data_generator(captions, features, tokenizer,
max_length, vocab_size, batch_size),
output_signature=(
(
tf.TensorSpec(shape=(None, FEATURE_LOCATIONS, FEATURE_DIM), dtype=tf.float32),
tf.TensorSpec(shape=(None, max_length), dtype=tf.int32),
),
tf.TensorSpec(shape=(None, vocab_size), dtype=tf.float32),
)
)
return dataset.prefetch(tf.data.AUTOTUNE)
def count_steps(captions: dict, features: dict, tokenizer, batch_size: int) -> int:
"""Count total training steps per epoch."""
total = 0
for img_id in captions:
if img_id not in features:
continue
for caption in captions[img_id]:
seq = tokenizer.texts_to_sequences([caption])[0]
total += len(seq) - 1 # number of subsequences per caption
# Use ceil so every sample is seen each epoch (including final partial batch).
return max(1, math.ceil(total / batch_size))
def compute_max_length(captions: dict) -> int:
"""Compute maximum caption length from the data."""
max_len = 0
for caps in captions.values():
for cap in caps:
length = len(cap.split())
if length > max_len:
max_len = length
return max_len
def plot_loss(history: dict, filepath: str, title: str):
"""Plot and save training loss curve."""
plt.figure(figsize=(10, 6))
plt.plot(history["loss"], label="Training Loss", color="#4a90d9", linewidth=2)
if "val_loss" in history:
plt.plot(history["val_loss"], label="Validation Loss", color="#e74c3c", linewidth=2)
plt.title(title, fontsize=14, fontweight="bold")
plt.xlabel("Epoch", fontsize=12)
plt.ylabel("Loss", fontsize=12)
plt.legend(fontsize=11)
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(filepath, dpi=150)
plt.close()
print(f"Loss plot saved to: {filepath}")
def train(backbone: str = "vgg16", epochs: int = None, resume: bool = True):
"""
Train the captioning model with specified backbone features.
Args:
backbone: "vgg16" or "vgg19"
epochs: Override default epoch count
"""
epochs = epochs or EPOCHS
# ── Select files based on backbone ──
if backbone == "vgg16":
features_file = VGG16_FEATURES_FILE
model_file = VGG16_MODEL_FILE
loss_plot = VGG16_LOSS_PLOT
else:
features_file = VGG19_FEATURES_FILE
model_file = VGG19_MODEL_FILE
loss_plot = VGG19_LOSS_PLOT
print("=" * 60)
print(f" CaptionIQ β€” Training with {backbone.upper()}")
print("=" * 60)
# ── Load data ──
print("\nLoading data...")
all_captions = load_captions(CAPTIONS_FILE)
tokenizer = load_tokenizer(TOKENIZER_FILE)
features = load_features(features_file)
train_images = load_image_list(TRAIN_IMAGES_FILE)
val_images = load_image_list(VAL_IMAGES_FILE)
vocab_size = get_vocab_size(tokenizer)
print(f" Vocabulary size: {vocab_size}")
# Filter captions for train/val splits
train_captions = {k: v for k, v in all_captions.items() if k in train_images}
val_captions = {k: v for k, v in all_captions.items() if k in val_images}
print(f" Train images: {len(train_captions)}")
print(f" Val images: {len(val_captions)}")
# Compute max length from training data
max_length = compute_max_length(train_captions)
print(f" Max caption length: {max_length}")
# ── Build or resume model ──
if resume and os.path.exists(model_file):
print(f"\nLoading existing checkpoint to resume training: {model_file}")
try:
model = load_model(
model_file,
custom_objects={"BahdanauAttention": BahdanauAttention}
)
except Exception as exc:
print(f" Warning: could not load existing model ({exc}). Rebuilding from scratch.")
model = build_model(vocab_size, max_length)
else:
print("\nBuilding model...")
model = build_model(vocab_size, max_length)
model.summary()
# ── Create tf.data.Dataset ──
train_dataset = make_dataset(
train_captions, features, tokenizer, max_length, vocab_size, BATCH_SIZE
)
val_dataset = make_dataset(
val_captions, features, tokenizer, max_length, vocab_size, BATCH_SIZE
)
train_steps = count_steps(train_captions, features, tokenizer, BATCH_SIZE)
val_steps = count_steps(val_captions, features, tokenizer, BATCH_SIZE)
print(f" Train steps/epoch: {train_steps}")
print(f" Val steps/epoch: {val_steps}")
# ── Callbacks ──
os.makedirs(MODELS_DIR, exist_ok=True)
os.makedirs(OUTPUTS_DIR, exist_ok=True)
callbacks = [
ModelCheckpoint(
model_file, monitor="val_loss", save_best_only=True, verbose=1
),
ReduceLROnPlateau(
monitor="val_loss", factor=LR_FACTOR, patience=LR_PATIENCE,
min_lr=1e-6, verbose=1
),
EarlyStopping(
monitor="val_loss", patience=EARLY_STOP_PATIENCE,
restore_best_weights=True, verbose=1
),
]
# ── Train ──
print(f"\nTraining for {epochs} epochs...")
history = model.fit(
train_dataset,
steps_per_epoch=train_steps,
epochs=epochs,
validation_data=val_dataset,
validation_steps=val_steps,
callbacks=callbacks,
verbose=1,
)
# ── Save final model ──
model.save(model_file)
print(f"\nModel saved to: {model_file}")
# ── Plot loss ──
plot_loss(
history.history, loss_plot,
f"CaptionIQ Training Loss β€” {backbone.upper()}"
)
return model, history
def main():
parser = argparse.ArgumentParser(description="Train CaptionIQ model")
parser.add_argument(
"--backbone", type=str, default="vgg19",
choices=["vgg16", "vgg19", "both"],
help="Which backbone features to use (default: vgg19)"
)
parser.add_argument(
"--epochs", type=int, default=None,
help=f"Number of epochs (default: {EPOCHS})"
)
parser.add_argument(
"--resume", action=argparse.BooleanOptionalAction, default=True,
help="Resume from existing model checkpoint if available (default: enabled)"
)
args = parser.parse_args()
backbones = ["vgg16", "vgg19"] if args.backbone == "both" else [args.backbone]
for backbone in backbones:
train(backbone, args.epochs, args.resume)
print("\nβœ“ Training complete!")
if __name__ == "__main__":
main()