galaxy-simplifier / inference.py
erukude's picture
Upload folder using huggingface_hub
177588a verified
import os
import numpy as np
from keras.models import load_model
from keras.preprocessing.image import load_img, img_to_array
from matplotlib import pyplot
######################
# Configuration
######################
RESNET_PATH = "path_to_resnet50_model.h5"
CGAN_PATH = "path_to_cgan_model.h5"
POST_CGAN_PATH = "path_to_postprocess_cgan_model.h5" # <--- NEW
DATA_PATH = "path_to_test_dir"
OUTPUT_PATH = "path_to_output_dir"
HEIGHT, WIDTH = 256, 256
TARGET_SIZE = (HEIGHT, WIDTH)
BATCH_SIZE = 32
os.makedirs(OUTPUT_PATH, exist_ok=True)
# Load the models
resnet_model = load_model(RESNET_PATH)
print("Resnet50 loaded successfully!")
cgan_model = load_model(CGAN_PATH)
print("cGAN loaded successfully!")
post_cgan_model = load_model(POST_CGAN_PATH)
print("Post-processing cGAN loaded successfully!")
######################
#
######################
def load_and_preprocess(img_path: str, model: str = "resnet") -> np.ndarray:
"""
Desc:
Load an image from disk and preprocess it for input into a deep learning model.
Args:
img_path (str): Path to the image file.
model (str): The model type to preprocess for.
"resnet" uses scaling to [0,1], other models use [-1,1] normalization.
Returns:
np.ndarray: Preprocessed image ready for model input.
"""
img = load_img(img_path, target_size=TARGET_SIZE)
img_array = img_to_array(img)
img_array = np.expand_dims(img_array, axis=0)
if model == "resnet":
return img_array / 255.0
# for "cgan" and "post_cgan" we assume [-1, 1] normalization
return (img_array - 127.5) / 127.5
######################
#
######################
def plot_generated_image(gen_image: np.ndarray, filename: str) -> None:
"""
Save a generated image to disk after rescaling it from [-1, 1] to [0, 1].
Args:
gen_image (np.ndarray): The generated image array, expected shape (1, H, W, C).
filename (str): The filename to save the image as (including extension, e.g., "image.png").
Returns:
None
"""
# Scale from [-1,1] to [0,1]
gen_image = (gen_image + 1) / 2.0
# Save the generated image
output_filename = os.path.join(OUTPUT_PATH, filename)
pyplot.imsave(output_filename, gen_image[0])
all_ctr = 0
spiral_ctr = 0
# === Loop through images ===
for filename in os.listdir(DATA_PATH):
if not filename.lower().endswith(('.jpg', '.jpeg', '.png')):
continue
img_path = os.path.join(DATA_PATH, filename)
all_ctr += 1
# Step 1: Classify with ResNet50
resnet_input = load_and_preprocess(img_path, model="resnet")
resnet_preds = resnet_model.predict(resnet_input, verbose=0)
predicted_class = np.argmax(resnet_preds, axis=1)[0]
if predicted_class == 1: # Spiral galaxy
if resnet_preds[0][1] > 0.65: # Confidence threshold
# Step 2: Process with first cGAN (skeletonization)
cgan_input = load_and_preprocess(img_path, model="cgan")
cgan_output = cgan_model.predict(cgan_input, verbose=0)
# Step 3: Post-process with second cGAN (smoothing/connecting lines)
post_output = post_cgan_model.predict(cgan_output, verbose=0)
# Step 4: Save final post-processed output
plot_generated_image(post_output, filename)
spiral_ctr += 1
print(f"Found '{spiral_ctr}' spiral galaxies in '{all_ctr}' images.")