saraayum commited on
Commit
e13598f
·
verified ·
1 Parent(s): 8c3bde0

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +2 -19
inference.py CHANGED
@@ -5,21 +5,13 @@ import numpy as np
5
  from PIL import Image
6
 
7
  IMG_SIZE = 256
8
-
9
- # ------------------------------
10
- # Load the generator model from Hugging Face Hub
11
- # ------------------------------
12
  model_path = hf_hub_download(
13
- repo_id="saraayum/cgan4colsar", # your Hugging Face repo
14
  filename="generator_epoch_180.keras"
15
  )
16
 
17
- # Safe model loading
18
  generator = load_model(model_path, compile=False, safe_mode=False)
19
 
20
- # ------------------------------
21
- # Preprocess the SAR image (normalize [0, 1])
22
- # ------------------------------
23
  def preprocess_image(image_path):
24
  img = tf.io.read_file(image_path)
25
  img = tf.image.decode_png(img, channels=1) # SAR = grayscale
@@ -28,9 +20,6 @@ def preprocess_image(image_path):
28
  img = tf.expand_dims(img, 0) # Add batch dimension
29
  return img
30
 
31
- # ------------------------------
32
- # Postprocess the generator output
33
- # ------------------------------
34
  def postprocess_output(output_tensor):
35
  # Convert output from [0,1] to [0,255]
36
  output_tensor = output_tensor[0] * 255.0
@@ -38,19 +27,13 @@ def postprocess_output(output_tensor):
38
  output_image = tf.cast(output_tensor, tf.uint8)
39
  return Image.fromarray(output_image.numpy())
40
 
41
- # ------------------------------
42
- # Predict function for inference
43
- # ------------------------------
44
  def predict(image_path, save_path="output.png"):
45
  sar_input = preprocess_image(image_path)
46
  gen_output = generator(sar_input, training=False)
47
  output_image = postprocess_output(gen_output)
48
  output_image.save(save_path)
49
- print(f"Colorized image saved as: {save_path}")
50
  return output_image
51
 
52
- # ------------------------------
53
- # Example usage (comment this on HF)
54
- # ------------------------------
55
  if __name__ == "__main__":
56
  predict("sample_sar.png", "predicted_colorized.png")
 
5
  from PIL import Image
6
 
7
  IMG_SIZE = 256
 
 
 
 
8
  model_path = hf_hub_download(
9
+ repo_id="saraayum/cgan4colsar",
10
  filename="generator_epoch_180.keras"
11
  )
12
 
 
13
  generator = load_model(model_path, compile=False, safe_mode=False)
14
 
 
 
 
15
  def preprocess_image(image_path):
16
  img = tf.io.read_file(image_path)
17
  img = tf.image.decode_png(img, channels=1) # SAR = grayscale
 
20
  img = tf.expand_dims(img, 0) # Add batch dimension
21
  return img
22
 
 
 
 
23
  def postprocess_output(output_tensor):
24
  # Convert output from [0,1] to [0,255]
25
  output_tensor = output_tensor[0] * 255.0
 
27
  output_image = tf.cast(output_tensor, tf.uint8)
28
  return Image.fromarray(output_image.numpy())
29
 
 
 
 
30
  def predict(image_path, save_path="output.png"):
31
  sar_input = preprocess_image(image_path)
32
  gen_output = generator(sar_input, training=False)
33
  output_image = postprocess_output(gen_output)
34
  output_image.save(save_path)
35
+ print(f"Colorized image saved as: {save_path}")
36
  return output_image
37
 
 
 
 
38
  if __name__ == "__main__":
39
  predict("sample_sar.png", "predicted_colorized.png")