SARImageColourisation / inference.py
saraayum's picture
Update inference.py
b6f4361 verified
import tensorflow as tf
from tensorflow.keras.models import load_model
from huggingface_hub import hf_hub_download
import numpy as np
from PIL import Image
# ------------------------------
# Load the generator model from Hugging Face Hub
# ------------------------------
model_path = hf_hub_download(
repo_id="saraayum/sar-colorization-generator", # replace with your actual username/repo
filename="sar_colorization_generator_final.keras"
)
# generator = load_model(model_path, compile=False)
generator = load_model(model_path, compile=False, safe_mode=False)
# ------------------------------
# Preprocess the SAR image
# ------------------------------
def preprocess_image(image_path):
img = tf.io.read_file(image_path)
img = tf.image.decode_png(img, channels=1) # SAR = grayscale
img = tf.image.resize(img, [256, 256])
img = tf.cast(img, tf.float32) / 127.5 - 1.0 # Normalize to [-1, 1]
img = tf.expand_dims(img, 0) # Add batch dimension
return img
# ------------------------------
# Postprocess the generator output
# ------------------------------
def postprocess_output(output_tensor):
output_tensor = (output_tensor[0] + 1.0) * 127.5 # Convert back to [0,255]
output_tensor = tf.clip_by_value(output_tensor, 0, 255)
output_image = tf.cast(output_tensor, tf.uint8)
return Image.fromarray(output_image.numpy())
# ------------------------------
# Predict function for inference
# ------------------------------
def predict(image_path, save_path="output.png"):
sar_input = preprocess_image(image_path)
gen_output = generator(sar_input, training=False)
output_image = postprocess_output(gen_output)
output_image.save(save_path)
print(f"Colorized image saved as: {save_path}")
return output_image
# ------------------------------
# Example usage (you can comment this out on HF)
# ------------------------------
if __name__ == "__main__":
predict("sample_sar.png", "predicted_colorized.png")