import tensorflow as tf from tensorflow.keras.models import load_model from huggingface_hub import hf_hub_download import numpy as np from PIL import Image # ------------------------------ # Load the generator model from Hugging Face Hub # ------------------------------ model_path = hf_hub_download( repo_id="saraayum/sar-colorization-generator", # replace with your actual username/repo filename="sar_colorization_generator_final.keras" ) # generator = load_model(model_path, compile=False) generator = load_model(model_path, compile=False, safe_mode=False) # ------------------------------ # Preprocess the SAR image # ------------------------------ def preprocess_image(image_path): img = tf.io.read_file(image_path) img = tf.image.decode_png(img, channels=1) # SAR = grayscale img = tf.image.resize(img, [256, 256]) img = tf.cast(img, tf.float32) / 127.5 - 1.0 # Normalize to [-1, 1] img = tf.expand_dims(img, 0) # Add batch dimension return img # ------------------------------ # Postprocess the generator output # ------------------------------ def postprocess_output(output_tensor): output_tensor = (output_tensor[0] + 1.0) * 127.5 # Convert back to [0,255] output_tensor = tf.clip_by_value(output_tensor, 0, 255) output_image = tf.cast(output_tensor, tf.uint8) return Image.fromarray(output_image.numpy()) # ------------------------------ # Predict function for inference # ------------------------------ def predict(image_path, save_path="output.png"): sar_input = preprocess_image(image_path) gen_output = generator(sar_input, training=False) output_image = postprocess_output(gen_output) output_image.save(save_path) print(f"Colorized image saved as: {save_path}") return output_image # ------------------------------ # Example usage (you can comment this out on HF) # ------------------------------ if __name__ == "__main__": predict("sample_sar.png", "predicted_colorized.png")