File size: 3,629 Bytes
177588a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import numpy as np
from   keras.models              import load_model
from   keras.preprocessing.image import load_img, img_to_array
from   matplotlib                import pyplot


######################
# Configuration
######################
RESNET_PATH     = "path_to_resnet50_model.h5"
CGAN_PATH       = "path_to_cgan_model.h5"
POST_CGAN_PATH  = "path_to_postprocess_cgan_model.h5"  # <--- NEW

DATA_PATH       = "path_to_test_dir"
OUTPUT_PATH     = "path_to_output_dir"

HEIGHT, WIDTH   = 256, 256
TARGET_SIZE     = (HEIGHT, WIDTH)
BATCH_SIZE      = 32

os.makedirs(OUTPUT_PATH, exist_ok=True)


# Load the models
resnet_model = load_model(RESNET_PATH)
print("Resnet50 loaded successfully!")

cgan_model = load_model(CGAN_PATH)
print("cGAN loaded successfully!")

post_cgan_model = load_model(POST_CGAN_PATH)
print("Post-processing cGAN loaded successfully!")


######################
#
######################
def load_and_preprocess(img_path: str, model: str = "resnet") -> np.ndarray:
    """

    Desc:

        Load an image from disk and preprocess it for input into a deep learning model.

    Args:

        img_path (str): Path to the image file.

        model (str): The model type to preprocess for. 

                     "resnet" uses scaling to [0,1], other models use [-1,1] normalization.

    Returns:

        np.ndarray: Preprocessed image ready for model input.

    """
    img = load_img(img_path, target_size=TARGET_SIZE)
    img_array = img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)

    if model == "resnet":
        return img_array / 255.0
    # for "cgan" and "post_cgan" we assume [-1, 1] normalization
    return (img_array - 127.5) / 127.5
    

######################
#
######################
def plot_generated_image(gen_image: np.ndarray, filename: str) -> None:
    """

    Save a generated image to disk after rescaling it from [-1, 1] to [0, 1].

    Args:

        gen_image (np.ndarray): The generated image array, expected shape (1, H, W, C).

        filename (str): The filename to save the image as (including extension, e.g., "image.png").

    Returns:

        None

    """
    # Scale from [-1,1] to [0,1]
    gen_image = (gen_image + 1) / 2.0

    # Save the generated image
    output_filename = os.path.join(OUTPUT_PATH, filename)
    pyplot.imsave(output_filename, gen_image[0])


all_ctr    = 0
spiral_ctr = 0

# === Loop through images ===
for filename in os.listdir(DATA_PATH):
    if not filename.lower().endswith(('.jpg', '.jpeg', '.png')):
        continue

    img_path = os.path.join(DATA_PATH, filename)
    all_ctr += 1

    # Step 1: Classify with ResNet50
    resnet_input = load_and_preprocess(img_path, model="resnet")
    resnet_preds = resnet_model.predict(resnet_input, verbose=0)

    predicted_class = np.argmax(resnet_preds, axis=1)[0]
    if predicted_class == 1:  # Spiral galaxy

        if resnet_preds[0][1] > 0.65:    # Confidence threshold

            # Step 2: Process with first cGAN (skeletonization)
            cgan_input   = load_and_preprocess(img_path, model="cgan")
            cgan_output  = cgan_model.predict(cgan_input, verbose=0)

            # Step 3: Post-process with second cGAN (smoothing/connecting lines)
            post_output  = post_cgan_model.predict(cgan_output, verbose=0)

            # Step 4: Save final post-processed output
            plot_generated_image(post_output, filename)
            spiral_ctr += 1

    print(f"Found '{spiral_ctr}' spiral galaxies in '{all_ctr}' images.")