File size: 4,398 Bytes
290f366
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
"""
CaptionIQ — Image Feature Extraction
Extract spatial feature maps from VGG16 and VGG19 (`block5_pool`).
Save features as pickle files for training.
"""

import os
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg16_preprocess
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input as vgg19_preprocess
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model

import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config import (
    FLICKR_IMAGES_DIR, IMAGE_SIZE,
    VGG16_FEATURES_FILE, VGG19_FEATURES_FILE,
)


def build_feature_extractor(backbone: str = "vgg16") -> tuple:
    """
    Build a feature extractor from a pre-trained VGG model.
    Outputs `block5_pool` spatial features (7x7x512).

    Args:
        backbone: "vgg16" or "vgg19"

    Returns:
        (model, preprocess_fn) tuple
    """
    if backbone == "vgg16":
        base_model = VGG16(weights="imagenet")
        preprocess_fn = vgg16_preprocess
    elif backbone == "vgg19":
        base_model = VGG19(weights="imagenet")
        preprocess_fn = vgg19_preprocess
    else:
        raise ValueError(f"Unknown backbone: {backbone}. Use 'vgg16' or 'vgg19'.")

    # Use block5_pool for spatial features (7x7x512) instead of fc2 (4096)
    model = Model(
        inputs=base_model.input,
        outputs=base_model.get_layer("block5_pool").output
    )
    print(f"\n{backbone.upper()} feature extractor loaded")
    print(f"  Output shape: {model.output_shape}  (spatial features)")
    return model, preprocess_fn


def extract_features(model, preprocess_fn, images_dir: str) -> dict:
    """
    Extract features for all images in a directory.

    Returns:
        dict mapping filename → numpy array of shape (49, 512)
    """
    features = {}
    image_files = [
        f for f in os.listdir(images_dir)
        if f.lower().endswith((".jpg", ".jpeg", ".png"))
    ]

    print(f"Extracting features for {len(image_files)} images...")
    for fname in tqdm(image_files, desc="Extracting"):
        filepath = os.path.join(images_dir, fname)
        try:
            # Load and preprocess image
            image = load_img(filepath, target_size=(IMAGE_SIZE, IMAGE_SIZE))
            image = img_to_array(image)
            image = np.expand_dims(image, axis=0)
            image = preprocess_fn(image)

            # Extract spatial feature map and reshape to (49, 512)
            feature = model.predict(image, verbose=0)[0]  # (7, 7, 512)
            h, w, c = feature.shape
            features[fname] = feature.reshape(h * w, c)  # (49, 512)
        except Exception as e:
            print(f"  Warning: Failed to process {fname}: {e}")

    print(f"Extracted features for {len(features)} images")
    return features


def save_features(features: dict, filepath: str):
    """Save features dict to pickle file."""
    with open(filepath, "wb") as f:
        pickle.dump(features, f)
    size_mb = os.path.getsize(filepath) / (1024 * 1024)
    print(f"Features saved to: {filepath} ({size_mb:.1f} MB)")


def main():
    """Extract features using VGG16 and/or VGG19."""
    parser = argparse.ArgumentParser(description="Extract VGG features from images")
    parser.add_argument(
        "--backbone", type=str, default="vgg19",
        choices=["vgg16", "vgg19", "both"],
        help="Which backbone to use for extraction (default: vgg19)"
    )
    args = parser.parse_args()

    if not os.path.exists(FLICKR_IMAGES_DIR):
        print(f"Error: Image directory not found: {FLICKR_IMAGES_DIR}")
        print("Please run preprocess.py first to download the dataset.")
        return

    backbones = ["vgg16", "vgg19"] if args.backbone == "both" else [args.backbone]

    for backbone in backbones:
        print("\n" + "=" * 60)
        print(f"  Extracting {backbone.upper()} features")
        print("=" * 60)

        model, preprocess_fn = build_feature_extractor(backbone)
        features = extract_features(model, preprocess_fn, FLICKR_IMAGES_DIR)

        output_file = VGG16_FEATURES_FILE if backbone == "vgg16" else VGG19_FEATURES_FILE
        save_features(features, output_file)

    print("\n✓ Feature extraction complete!")


if __name__ == "__main__":
    main()