Spaces:
Sleeping
Sleeping
File size: 4,238 Bytes
290f366 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
CaptionIQ β Attention-Based CNN-LSTM Caption Generation Model
Uses Bahdanau (additive) attention over spatial CNN features for
image-specific caption generation.
"""
import tensorflow as tf
from tensorflow.keras.models import Model
from tensorflow.keras.layers import (
Input, Dense, LSTM, Embedding, Dropout, Concatenate, Layer
)
from tensorflow.keras.optimizers import Adam
import os, sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config import (
EMBED_DIM, LSTM_UNITS, DROPOUT_RATE, FEATURE_DIM,
FEATURE_LOCATIONS, ATTENTION_DIM, LEARNING_RATE,
)
class BahdanauAttention(Layer):
"""
Bahdanau (additive) attention over spatial image features.
Given spatial features (batch, 49, 512) and LSTM hidden state (batch, 512):
score = V * tanh(W1 * features + W2 * hidden)
weights = softmax(score)
context = sum(weights * features)
This lets the model focus on different image regions for each word.
"""
def __init__(self, units, **kwargs):
super().__init__(**kwargs)
self.units = units
self.W1 = Dense(units, name="att_features")
self.W2 = Dense(units, name="att_hidden")
self.V = Dense(1, name="att_score")
def call(self, inputs):
features, hidden = inputs
# features: (batch, locations, feature_dim)
# hidden: (batch, lstm_units)
hidden_expanded = tf.expand_dims(hidden, 1) # (batch, 1, lstm_units)
score = tf.nn.tanh(
self.W1(features) + self.W2(hidden_expanded)
) # (batch, locations, attention_dim)
attention_weights = tf.nn.softmax(
self.V(score), axis=1
) # (batch, locations, 1)
context = tf.reduce_sum(
attention_weights * features, axis=1
) # (batch, feature_dim)
return context
def get_config(self):
config = super().get_config()
config.update({"units": self.units})
return config
def build_model(vocab_size: int, max_length: int) -> Model:
"""
Build the attention-based CNN-LSTM image captioning model.
Architecture:
Image: (49, 512) spatial features from VGG block5_pool
Caption: (max_length,) β Embedding(256) β Dropout(0.3) β LSTM(512)
Attention: query=LSTM hidden, keys=spatial features β context (512,)
Merge: Concatenate(context, LSTM) β Dense(512) β Dropout(0.3) β Dense(vocab)
Args:
vocab_size: Vocabulary size (including padding index 0)
max_length: Maximum caption length in tokens
Returns:
Compiled Keras Model
"""
# ββ Image spatial features ββ
image_input = Input(
shape=(FEATURE_LOCATIONS, FEATURE_DIM), name="image_input"
)
# ββ Caption sequence branch ββ
caption_input = Input(shape=(max_length,), name="caption_input")
caption_embed = Embedding(
vocab_size, EMBED_DIM, mask_zero=True, name="caption_embedding"
)(caption_input)
caption_drop = Dropout(DROPOUT_RATE, name="caption_dropout")(caption_embed)
caption_lstm = LSTM(LSTM_UNITS, name="caption_lstm")(caption_drop)
# ββ Attention over spatial features ββ
context = BahdanauAttention(
ATTENTION_DIM, name="attention"
)([image_input, caption_lstm])
# ββ Merge context + LSTM output ββ
merged = Concatenate(name="merge")([context, caption_lstm])
dense1 = Dense(LSTM_UNITS, activation="relu", name="dense_relu")(merged)
dense_drop = Dropout(DROPOUT_RATE, name="dense_dropout")(dense1)
output = Dense(vocab_size, activation="softmax", name="output")(dense_drop)
# ββ Build and compile ββ
model = Model(
inputs=[image_input, caption_input], outputs=output, name="CaptionIQ"
)
model.compile(
loss="categorical_crossentropy",
optimizer=Adam(learning_rate=LEARNING_RATE),
)
return model
def print_model_summary(vocab_size: int = 5000, max_length: int = 34):
"""Utility to print the model architecture."""
model = build_model(vocab_size, max_length)
model.summary()
return model
if __name__ == "__main__":
print_model_summary()
|