Spaces:
Sleeping
Sleeping
File size: 4,398 Bytes
290f366 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | """
CaptionIQ — Image Feature Extraction
Extract spatial feature maps from VGG16 and VGG19 (`block5_pool`).
Save features as pickle files for training.
"""
import os
import pickle
import argparse
import numpy as np
from tqdm import tqdm
from tensorflow.keras.applications.vgg16 import VGG16, preprocess_input as vgg16_preprocess
from tensorflow.keras.applications.vgg19 import VGG19, preprocess_input as vgg19_preprocess
from tensorflow.keras.preprocessing.image import load_img, img_to_array
from tensorflow.keras.models import Model
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from src.config import (
FLICKR_IMAGES_DIR, IMAGE_SIZE,
VGG16_FEATURES_FILE, VGG19_FEATURES_FILE,
)
def build_feature_extractor(backbone: str = "vgg16") -> tuple:
"""
Build a feature extractor from a pre-trained VGG model.
Outputs `block5_pool` spatial features (7x7x512).
Args:
backbone: "vgg16" or "vgg19"
Returns:
(model, preprocess_fn) tuple
"""
if backbone == "vgg16":
base_model = VGG16(weights="imagenet")
preprocess_fn = vgg16_preprocess
elif backbone == "vgg19":
base_model = VGG19(weights="imagenet")
preprocess_fn = vgg19_preprocess
else:
raise ValueError(f"Unknown backbone: {backbone}. Use 'vgg16' or 'vgg19'.")
# Use block5_pool for spatial features (7x7x512) instead of fc2 (4096)
model = Model(
inputs=base_model.input,
outputs=base_model.get_layer("block5_pool").output
)
print(f"\n{backbone.upper()} feature extractor loaded")
print(f" Output shape: {model.output_shape} (spatial features)")
return model, preprocess_fn
def extract_features(model, preprocess_fn, images_dir: str) -> dict:
"""
Extract features for all images in a directory.
Returns:
dict mapping filename → numpy array of shape (49, 512)
"""
features = {}
image_files = [
f for f in os.listdir(images_dir)
if f.lower().endswith((".jpg", ".jpeg", ".png"))
]
print(f"Extracting features for {len(image_files)} images...")
for fname in tqdm(image_files, desc="Extracting"):
filepath = os.path.join(images_dir, fname)
try:
# Load and preprocess image
image = load_img(filepath, target_size=(IMAGE_SIZE, IMAGE_SIZE))
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
image = preprocess_fn(image)
# Extract spatial feature map and reshape to (49, 512)
feature = model.predict(image, verbose=0)[0] # (7, 7, 512)
h, w, c = feature.shape
features[fname] = feature.reshape(h * w, c) # (49, 512)
except Exception as e:
print(f" Warning: Failed to process {fname}: {e}")
print(f"Extracted features for {len(features)} images")
return features
def save_features(features: dict, filepath: str):
"""Save features dict to pickle file."""
with open(filepath, "wb") as f:
pickle.dump(features, f)
size_mb = os.path.getsize(filepath) / (1024 * 1024)
print(f"Features saved to: {filepath} ({size_mb:.1f} MB)")
def main():
"""Extract features using VGG16 and/or VGG19."""
parser = argparse.ArgumentParser(description="Extract VGG features from images")
parser.add_argument(
"--backbone", type=str, default="vgg19",
choices=["vgg16", "vgg19", "both"],
help="Which backbone to use for extraction (default: vgg19)"
)
args = parser.parse_args()
if not os.path.exists(FLICKR_IMAGES_DIR):
print(f"Error: Image directory not found: {FLICKR_IMAGES_DIR}")
print("Please run preprocess.py first to download the dataset.")
return
backbones = ["vgg16", "vgg19"] if args.backbone == "both" else [args.backbone]
for backbone in backbones:
print("\n" + "=" * 60)
print(f" Extracting {backbone.upper()} features")
print("=" * 60)
model, preprocess_fn = build_feature_extractor(backbone)
features = extract_features(model, preprocess_fn, FLICKR_IMAGES_DIR)
output_file = VGG16_FEATURES_FILE if backbone == "vgg16" else VGG19_FEATURES_FILE
save_features(features, output_file)
print("\n✓ Feature extraction complete!")
if __name__ == "__main__":
main()
|