hash-map's picture
files
b466b8b verified
raw
history blame
4.99 kB
import os
import random
import numpy as np
import tensorflow as tf
from pathlib import Path
from tqdm import tqdm
import cv2 # for reading images from disk
# -------------------------------------------------
# CONFIGURATION
# -------------------------------------------------
GENERATOR_PATH = "generator_final.h5" # <-- your model
VIS_FOLDER = "data/train/visible" # <-- folder with visible images
IR_FOLDER = "data/train/infrared" # <-- (optional) ground-truth IR
SAVE_DIR = "output/train_results" # where to save side-by-side
NUM_SAMPLES = 10
IMG_SIZE = (256, 256) # adjust to your model's input size
SEED = 42
# -------------------------------------------------
# 1. Load the generator
# -------------------------------------------------
print(f"Loading generator from {GENERATOR_PATH} ...")
generator = tf.keras.models.load_model(GENERATOR_PATH, compile=False)
print("Generator loaded successfully.")
# -------------------------------------------------
# 2. Helper: preprocess image (resize + normalize to [-1, 1])
# -------------------------------------------------
def load_and_preprocess_image(img_path, target_size=IMG_SIZE):
img = cv2.imread(img_path)
if img is None:
raise FileNotFoundError(f"Image not found: {img_path}")
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
img = cv2.resize(img, target_size)
img = img.astype(np.float32) / 127.5 - 1.0 # → [-1, 1]
img = np.expand_dims(img, axis=0) # (1, H, W, 3)
return img
# -------------------------------------------------
# 3. Helper: convert [-1,1] tensor → uint8 image
# -------------------------------------------------
def to_uint8(tensor):
tensor = np.clip(tensor, -1.0, 1.0)
tensor = (tensor + 1.0) * 127.5
return np.clip(tensor, 0, 255).astype(np.uint8)
# -------------------------------------------------
# 4. Main function
# -------------------------------------------------
def generate_from_folder(
vis_folder, ir_folder=None, save_dir=SAVE_DIR, num_samples=NUM_SAMPLES
):
os.makedirs(save_dir, exist_ok=True)
random.seed(SEED)
vis_paths = sorted(Path(vis_folder).glob("*.*"))
vis_paths = [p for p in vis_paths if p.suffix.lower() in {".png", ".jpg", ".jpeg", ".bmp"}]
if len(vis_paths) == 0:
raise ValueError(f"No images found in {vis_folder}")
# Sample random images
sample_paths = random.sample(vis_paths, min(num_samples, len(vis_paths)))
print(f"Generating {len(sample_paths)} random side-by-side images...")
for idx, vis_path in enumerate(tqdm(sample_paths)):
# Load visible image
vis_tensor = load_and_preprocess_image(str(vis_path))
# Generate IR
pred_tensor = generator(vis_tensor, training=False) # (1, H, W, C)
pred_img = to_uint8(pred_tensor[0].numpy()) # (H, W, 3)
# Optional: load ground-truth IR (same filename)
ir_img = None
if ir_folder:
ir_path = Path(ir_folder) / vis_path.name
if ir_path.exists():
ir_tensor = load_and_preprocess_image(str(ir_path))
ir_tensor = generator.predict(ir_tensor) # not needed, just load raw
# Actually just read and convert
ir_raw = cv2.imread(str(ir_path))
ir_raw = cv2.cvtColor(ir_raw, cv2.COLOR_BGR2RGB)
ir_raw = cv2.resize(ir_raw, IMG_SIZE)
ir_img = ir_raw
else:
print(f"Warning: IR not found for {vis_path.name}, using black placeholder.")
ir_img = np.zeros((IMG_SIZE[1], IMG_SIZE[0], 3), dtype=np.uint8)
# If no IR folder, show only generated
if ir_img is None:
# Show: [Black | Generated]
left = np.zeros_like(pred_img)
row = np.concatenate([left, pred_img], axis=1)
title = "Generated Only"
else:
# Show: [Ground Truth IR | Generated]
row = np.concatenate([ir_img, pred_img], axis=1)
title = "GT | Generated"
# Save
save_path = os.path.join(save_dir, f"sample_{idx:02d}_{vis_path.stem}.png")
cv2.imwrite(save_path, cv2.cvtColor(row, cv2.COLOR_RGB2BGR))
print(f"All {len(sample_paths)} images saved to {save_dir}")
# -------------------------------------------------
# 5. RUN
# -------------------------------------------------
if __name__ == "__main__":
# Case 1: You have ground-truth IR images (same filename)
generate_from_folder(
vis_folder=VIS_FOLDER,
ir_folder=IR_FOLDER, # set to None if you don't have GT
num_samples=NUM_SAMPLES
)
# Case 2: Only visible images → show generated only
# generate_from_folder(vis_folder=VIS_FOLDER, ir_folder=None)