Spaces:
Sleeping
Sleeping
| import tensorflow as tf | |
| from tensorflow.keras.models import load_model | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| from PIL import Image | |
| # ------------------------------ | |
| # Load the generator model from Hugging Face Hub | |
| # ------------------------------ | |
| model_path = hf_hub_download( | |
| repo_id="saraayum/sar-colorization-generator", # replace with your actual username/repo | |
| filename="sar_colorization_generator_final.keras" | |
| ) | |
| # generator = load_model(model_path, compile=False) | |
| generator = load_model(model_path, compile=False, safe_mode=False) | |
| # ------------------------------ | |
| # Preprocess the SAR image | |
| # ------------------------------ | |
| def preprocess_image(image_path): | |
| img = tf.io.read_file(image_path) | |
| img = tf.image.decode_png(img, channels=1) # SAR = grayscale | |
| img = tf.image.resize(img, [256, 256]) | |
| img = tf.cast(img, tf.float32) / 127.5 - 1.0 # Normalize to [-1, 1] | |
| img = tf.expand_dims(img, 0) # Add batch dimension | |
| return img | |
| # ------------------------------ | |
| # Postprocess the generator output | |
| # ------------------------------ | |
| def postprocess_output(output_tensor): | |
| output_tensor = (output_tensor[0] + 1.0) * 127.5 # Convert back to [0,255] | |
| output_tensor = tf.clip_by_value(output_tensor, 0, 255) | |
| output_image = tf.cast(output_tensor, tf.uint8) | |
| return Image.fromarray(output_image.numpy()) | |
| # ------------------------------ | |
| # Predict function for inference | |
| # ------------------------------ | |
| def predict(image_path, save_path="output.png"): | |
| sar_input = preprocess_image(image_path) | |
| gen_output = generator(sar_input, training=False) | |
| output_image = postprocess_output(gen_output) | |
| output_image.save(save_path) | |
| print(f"Colorized image saved as: {save_path}") | |
| return output_image | |
| # ------------------------------ | |
| # Example usage (you can comment this out on HF) | |
| # ------------------------------ | |
| if __name__ == "__main__": | |
| predict("sample_sar.png", "predicted_colorized.png") | |