File size: 10,158 Bytes
290f366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
"""
CaptionIQ β€” Training Script
Data generator + training loop for VGG16 and VGG19 captioning models.
"""

import os
import sys
import argparse
import math
import numpy as np
import tensorflow as tf
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from tensorflow.keras.models import load_model
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.preprocessing.sequence import pad_sequences
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping, ReduceLROnPlateau

sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config import (
    EPOCHS, BATCH_SIZE, MAX_LENGTH, FEATURE_DIM, FEATURE_LOCATIONS,
    VGG16_FEATURES_FILE, VGG19_FEATURES_FILE,
    VGG16_MODEL_FILE, VGG19_MODEL_FILE,
    VGG16_LOSS_PLOT, VGG19_LOSS_PLOT,
    CAPTIONS_FILE, TOKENIZER_FILE,
    TRAIN_IMAGES_FILE, VAL_IMAGES_FILE,
    MODELS_DIR, OUTPUTS_DIR,
    EARLY_STOP_PATIENCE, LR_PATIENCE, LR_FACTOR,
)
from src.utils import load_captions, load_image_list, load_features, load_tokenizer, get_vocab_size
from src.model import build_model, BahdanauAttention


def data_generator(captions: dict, features: dict, tokenizer,
                   max_length: int, vocab_size: int, batch_size: int):
    """
    Generator that yields batches of ((image_features, partial_sequence), next_word).

    For each image-caption pair, creates multiple training samples:
        Input: image features + first k words β†’ Target: (k+1)th word

    Yields:
        ((img_features_batch, seq_batch), target_batch)
    """
    image_ids = list(captions.keys())

    while True:
        np.random.shuffle(image_ids)
        img_batch, seq_batch, target_batch = [], [], []

        for img_id in image_ids:
            if img_id not in features:
                continue

            img_feature = features[img_id]

            for caption in captions[img_id]:
                # Encode the caption as a sequence of integers
                seq = tokenizer.texts_to_sequences([caption])[0]

                # Create input-output pairs for each subsequence
                for i in range(1, len(seq)):
                    # Input: partial sequence up to position i
                    in_seq = pad_sequences([seq[:i]], maxlen=max_length, padding="post")[0]
                    # Output: next word (one-hot)
                    out_word = to_categorical([seq[i]], num_classes=vocab_size)[0]

                    img_batch.append(img_feature)
                    seq_batch.append(in_seq)
                    target_batch.append(out_word)

                    if len(img_batch) >= batch_size:
                        yield (
                            (np.array(img_batch, dtype=np.float32),
                             np.array(seq_batch, dtype=np.int32)),
                            np.array(target_batch, dtype=np.float32)
                        )
                        img_batch, seq_batch, target_batch = [], [], []

        # Yield remaining samples
        if img_batch:
            yield (
                (np.array(img_batch, dtype=np.float32),
                 np.array(seq_batch, dtype=np.int32)),
                np.array(target_batch, dtype=np.float32)
            )


def make_dataset(captions, features, tokenizer, max_length, vocab_size, batch_size):
    """Wrap the Python generator in a tf.data.Dataset with proper output_signature."""
    dataset = tf.data.Dataset.from_generator(
        lambda: data_generator(captions, features, tokenizer,
                               max_length, vocab_size, batch_size),
        output_signature=(
            (
                tf.TensorSpec(shape=(None, FEATURE_LOCATIONS, FEATURE_DIM), dtype=tf.float32),
                tf.TensorSpec(shape=(None, max_length), dtype=tf.int32),
            ),
            tf.TensorSpec(shape=(None, vocab_size), dtype=tf.float32),
        )
    )
    return dataset.prefetch(tf.data.AUTOTUNE)


def count_steps(captions: dict, features: dict, tokenizer, batch_size: int) -> int:
    """Count total training steps per epoch."""
    total = 0
    for img_id in captions:
        if img_id not in features:
            continue
        for caption in captions[img_id]:
            seq = tokenizer.texts_to_sequences([caption])[0]
            total += len(seq) - 1  # number of subsequences per caption
    # Use ceil so every sample is seen each epoch (including final partial batch).
    return max(1, math.ceil(total / batch_size))


def compute_max_length(captions: dict) -> int:
    """Compute maximum caption length from the data."""
    max_len = 0
    for caps in captions.values():
        for cap in caps:
            length = len(cap.split())
            if length > max_len:
                max_len = length
    return max_len


def plot_loss(history: dict, filepath: str, title: str):
    """Plot and save training loss curve."""
    plt.figure(figsize=(10, 6))
    plt.plot(history["loss"], label="Training Loss", color="#4a90d9", linewidth=2)
    if "val_loss" in history:
        plt.plot(history["val_loss"], label="Validation Loss", color="#e74c3c", linewidth=2)
    plt.title(title, fontsize=14, fontweight="bold")
    plt.xlabel("Epoch", fontsize=12)
    plt.ylabel("Loss", fontsize=12)
    plt.legend(fontsize=11)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(filepath, dpi=150)
    plt.close()
    print(f"Loss plot saved to: {filepath}")


def train(backbone: str = "vgg16", epochs: int = None, resume: bool = True):
    """
    Train the captioning model with specified backbone features.

    Args:
        backbone: "vgg16" or "vgg19"
        epochs: Override default epoch count
    """
    epochs = epochs or EPOCHS

    # ── Select files based on backbone ──
    if backbone == "vgg16":
        features_file = VGG16_FEATURES_FILE
        model_file = VGG16_MODEL_FILE
        loss_plot = VGG16_LOSS_PLOT
    else:
        features_file = VGG19_FEATURES_FILE
        model_file = VGG19_MODEL_FILE
        loss_plot = VGG19_LOSS_PLOT

    print("=" * 60)
    print(f"  CaptionIQ β€” Training with {backbone.upper()}")
    print("=" * 60)

    # ── Load data ──
    print("\nLoading data...")
    all_captions = load_captions(CAPTIONS_FILE)
    tokenizer = load_tokenizer(TOKENIZER_FILE)
    features = load_features(features_file)
    train_images = load_image_list(TRAIN_IMAGES_FILE)
    val_images = load_image_list(VAL_IMAGES_FILE)

    vocab_size = get_vocab_size(tokenizer)
    print(f"  Vocabulary size: {vocab_size}")

    # Filter captions for train/val splits
    train_captions = {k: v for k, v in all_captions.items() if k in train_images}
    val_captions = {k: v for k, v in all_captions.items() if k in val_images}
    print(f"  Train images: {len(train_captions)}")
    print(f"  Val images:   {len(val_captions)}")

    # Compute max length from training data
    max_length = compute_max_length(train_captions)
    print(f"  Max caption length: {max_length}")

    # ── Build or resume model ──
    if resume and os.path.exists(model_file):
        print(f"\nLoading existing checkpoint to resume training: {model_file}")
        try:
            model = load_model(
                model_file,
                custom_objects={"BahdanauAttention": BahdanauAttention}
            )
        except Exception as exc:
            print(f"  Warning: could not load existing model ({exc}). Rebuilding from scratch.")
            model = build_model(vocab_size, max_length)
    else:
        print("\nBuilding model...")
        model = build_model(vocab_size, max_length)

    model.summary()

    # ── Create tf.data.Dataset ──
    train_dataset = make_dataset(
        train_captions, features, tokenizer, max_length, vocab_size, BATCH_SIZE
    )
    val_dataset = make_dataset(
        val_captions, features, tokenizer, max_length, vocab_size, BATCH_SIZE
    )

    train_steps = count_steps(train_captions, features, tokenizer, BATCH_SIZE)
    val_steps = count_steps(val_captions, features, tokenizer, BATCH_SIZE)
    print(f"  Train steps/epoch: {train_steps}")
    print(f"  Val steps/epoch:   {val_steps}")

    # ── Callbacks ──
    os.makedirs(MODELS_DIR, exist_ok=True)
    os.makedirs(OUTPUTS_DIR, exist_ok=True)

    callbacks = [
        ModelCheckpoint(
            model_file, monitor="val_loss", save_best_only=True, verbose=1
        ),
        ReduceLROnPlateau(
            monitor="val_loss", factor=LR_FACTOR, patience=LR_PATIENCE,
            min_lr=1e-6, verbose=1
        ),
        EarlyStopping(
            monitor="val_loss", patience=EARLY_STOP_PATIENCE,
            restore_best_weights=True, verbose=1
        ),
    ]

    # ── Train ──
    print(f"\nTraining for {epochs} epochs...")
    history = model.fit(
        train_dataset,
        steps_per_epoch=train_steps,
        epochs=epochs,
        validation_data=val_dataset,
        validation_steps=val_steps,
        callbacks=callbacks,
        verbose=1,
    )

    # ── Save final model ──
    model.save(model_file)
    print(f"\nModel saved to: {model_file}")

    # ── Plot loss ──
    plot_loss(
        history.history, loss_plot,
        f"CaptionIQ Training Loss β€” {backbone.upper()}"
    )

    return model, history


def main():
    parser = argparse.ArgumentParser(description="Train CaptionIQ model")
    parser.add_argument(
        "--backbone", type=str, default="vgg19",
        choices=["vgg16", "vgg19", "both"],
        help="Which backbone features to use (default: vgg19)"
    )
    parser.add_argument(
        "--epochs", type=int, default=None,
        help=f"Number of epochs (default: {EPOCHS})"
    )
    parser.add_argument(
        "--resume", action=argparse.BooleanOptionalAction, default=True,
        help="Resume from existing model checkpoint if available (default: enabled)"
    )
    args = parser.parse_args()

    backbones = ["vgg16", "vgg19"] if args.backbone == "both" else [args.backbone]

    for backbone in backbones:
        train(backbone, args.epochs, args.resume)

    print("\nβœ“ Training complete!")


if __name__ == "__main__":
    main()