ahadhassan commited on
Commit
29e9dd7
·
verified ·
1 Parent(s): fc5324e

Update ndvi_predictor.py

Browse files
Files changed (1) hide show
  1. ndvi_predictor.py +58 -10
ndvi_predictor.py CHANGED
@@ -2,7 +2,6 @@
2
  import os
3
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
4
  os.environ["SM_FRAMEWORK"] = "tf.keras"
5
-
6
  import segmentation_models as sm
7
  import tensorflow as tf
8
  import numpy as np
@@ -12,25 +11,45 @@ from PIL import Image
12
  import io
13
 
14
  def load_model(model_path):
 
15
  return tf.keras.models.load_model(model_path, compile=False)
16
 
17
  def normalize_rgb(rgb):
 
18
  rgb_norm = rgb.copy().astype(np.float32)
 
 
 
 
 
19
  for b in range(3):
20
  band = rgb_norm[:, :, b]
21
  min_val, max_val = np.percentile(band, [1, 99])
22
  if min_val < max_val:
23
  rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
 
24
  return rgb_norm
25
 
26
  def predict_ndvi(model, rgb_np):
 
 
 
 
 
 
 
 
 
 
27
  height, width = rgb_np.shape[:2]
28
  tile_size = 512
29
  stride = int(tile_size * 0.7)
30
-
 
31
  ndvi_pred = np.zeros((height, width), dtype=np.float32)
32
  weight_map = np.zeros((height, width), dtype=np.float32)
33
-
 
34
  if height < tile_size or width < tile_size:
35
  pad_height = max(0, tile_size - height)
36
  pad_width = max(0, tile_size - width)
@@ -39,38 +58,67 @@ def predict_ndvi(model, rgb_np):
39
  else:
40
  rgb_padded = rgb_np
41
  height_padded, width_padded = height, width
42
-
 
43
  for i in range(0, height_padded - tile_size + 1, stride):
44
  for j in range(0, width_padded - tile_size + 1, stride):
 
45
  tile = rgb_padded[i:i+tile_size, j:j+tile_size, :]
 
 
46
  y, x = np.mgrid[0:tile_size, 0:tile_size]
47
  weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
48
  weights = np.clip(weights, 0, 50) / 50
 
 
49
  tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0]
 
 
50
  valid_height = min(tile_size, height - i)
51
  valid_width = min(tile_size, width - j)
52
- ndvi_pred[i:i+valid_height, j:j+valid_width] += tile_pred[:valid_height, :valid_width] * weights[:valid_height, :valid_width]
 
 
 
 
53
  weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width]
54
-
 
55
  mask = weight_map > 0
56
  ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
 
57
  return ndvi_pred
58
 
59
  def create_visualization(rgb, ndvi):
 
 
 
 
 
 
 
 
 
 
60
  fig, axes = plt.subplots(1, 2, figsize=(12, 6))
 
 
61
  rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
62
  axes[0].imshow(rgb_disp)
63
  axes[0].set_title("RGB Input")
64
  axes[0].axis("off")
65
-
 
66
  im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
67
  axes[1].set_title("Predicted NDVI")
68
  axes[1].axis("off")
69
  fig.colorbar(im, ax=axes[1])
70
-
 
71
  buf = io.BytesIO()
72
  plt.tight_layout()
73
- plt.savefig(buf, format="png")
74
  plt.close(fig)
75
  buf.seek(0)
76
- return buf
 
 
2
  import os
3
  os.environ["TF_ENABLE_ONEDNN_OPTS"] = "0"
4
  os.environ["SM_FRAMEWORK"] = "tf.keras"
 
5
  import segmentation_models as sm
6
  import tensorflow as tf
7
  import numpy as np
 
11
  import io
12
 
13
  def load_model(model_path):
14
+ """Load NDVI prediction model"""
15
  return tf.keras.models.load_model(model_path, compile=False)
16
 
17
  def normalize_rgb(rgb):
18
+ """Normalize RGB image to [0, 1] range using percentile normalization"""
19
  rgb_norm = rgb.copy().astype(np.float32)
20
+
21
+ # Handle different input ranges
22
+ if rgb.max() > 1:
23
+ rgb_norm = rgb_norm / 255.0
24
+
25
  for b in range(3):
26
  band = rgb_norm[:, :, b]
27
  min_val, max_val = np.percentile(band, [1, 99])
28
  if min_val < max_val:
29
  rgb_norm[:, :, b] = np.clip((band - min_val) / (max_val - min_val), 0, 1)
30
+
31
  return rgb_norm
32
 
33
  def predict_ndvi(model, rgb_np):
34
+ """
35
+ Predict NDVI from RGB image using tiled approach for large images
36
+
37
+ Args:
38
+ model: Loaded NDVI prediction model
39
+ rgb_np: RGB image as numpy array (H, W, 3) normalized to [0, 1]
40
+
41
+ Returns:
42
+ ndvi_pred: Predicted NDVI as numpy array (H, W) in range [-1, 1]
43
+ """
44
  height, width = rgb_np.shape[:2]
45
  tile_size = 512
46
  stride = int(tile_size * 0.7)
47
+
48
+ # Initialize output arrays
49
  ndvi_pred = np.zeros((height, width), dtype=np.float32)
50
  weight_map = np.zeros((height, width), dtype=np.float32)
51
+
52
+ # Handle small images by padding
53
  if height < tile_size or width < tile_size:
54
  pad_height = max(0, tile_size - height)
55
  pad_width = max(0, tile_size - width)
 
58
  else:
59
  rgb_padded = rgb_np
60
  height_padded, width_padded = height, width
61
+
62
+ # Process image tiles
63
  for i in range(0, height_padded - tile_size + 1, stride):
64
  for j in range(0, width_padded - tile_size + 1, stride):
65
+ # Extract tile
66
  tile = rgb_padded[i:i+tile_size, j:j+tile_size, :]
67
+
68
+ # Create distance-based weights for blending
69
  y, x = np.mgrid[0:tile_size, 0:tile_size]
70
  weights = np.minimum(np.minimum(x, tile_size - x - 1), np.minimum(y, tile_size - y - 1))
71
  weights = np.clip(weights, 0, 50) / 50
72
+
73
+ # Predict NDVI for tile
74
  tile_pred = model.predict(np.expand_dims(tile, axis=0), verbose=0)[0, :, :, 0]
75
+
76
+ # Determine valid region (handle edge cases)
77
  valid_height = min(tile_size, height - i)
78
  valid_width = min(tile_size, width - j)
79
+
80
+ # Accumulate weighted predictions
81
+ ndvi_pred[i:i+valid_height, j:j+valid_width] += (
82
+ tile_pred[:valid_height, :valid_width] * weights[:valid_height, :valid_width]
83
+ )
84
  weight_map[i:i+valid_height, j:j+valid_width] += weights[:valid_height, :valid_width]
85
+
86
+ # Normalize by weights
87
  mask = weight_map > 0
88
  ndvi_pred[mask] = ndvi_pred[mask] / weight_map[mask]
89
+
90
  return ndvi_pred
91
 
92
  def create_visualization(rgb, ndvi):
93
+ """
94
+ Create visualization of RGB input and predicted NDVI
95
+
96
+ Args:
97
+ rgb: RGB image array
98
+ ndvi: NDVI prediction array
99
+
100
+ Returns:
101
+ buf: BytesIO buffer containing the visualization as PNG
102
+ """
103
  fig, axes = plt.subplots(1, 2, figsize=(12, 6))
104
+
105
+ # Display RGB image
106
  rgb_disp = np.clip(rgb / 255 if rgb.max() > 1 else rgb, 0, 1)
107
  axes[0].imshow(rgb_disp)
108
  axes[0].set_title("RGB Input")
109
  axes[0].axis("off")
110
+
111
+ # Display NDVI with color map
112
  im = axes[1].imshow(ndvi, cmap='RdYlGn', vmin=-1, vmax=1)
113
  axes[1].set_title("Predicted NDVI")
114
  axes[1].axis("off")
115
  fig.colorbar(im, ax=axes[1])
116
+
117
+ # Save to buffer
118
  buf = io.BytesIO()
119
  plt.tight_layout()
120
+ plt.savefig(buf, format="png", dpi=150, bbox_inches='tight')
121
  plt.close(fig)
122
  buf.seek(0)
123
+
124
+ return buf