Update app.py
Browse files
app.py
CHANGED
|
@@ -190,6 +190,45 @@ class readDataset:
|
|
| 190 |
np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
|
| 191 |
np.array(test_sar), np.array(test_optic), np.array(test_masks)
|
| 192 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
# Streamlit App Title
|
| 194 |
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
|
| 195 |
|
|
@@ -224,7 +263,8 @@ if st.button("Run Inference"):
|
|
| 224 |
optic_images = dataset.normalizeImages(optic_images, 'i')
|
| 225 |
|
| 226 |
# Load model
|
| 227 |
-
model = load_model(model_path
|
|
|
|
| 228 |
|
| 229 |
# Predict
|
| 230 |
pred_masks = model.predict([optic_images, sar_images], verbose=0)
|
|
|
|
| 190 |
np.array(train_sar_aug), np.array(train_optic_aug), np.array(train_masks_aug),
|
| 191 |
np.array(test_sar), np.array(test_optic), np.array(test_masks)
|
| 192 |
)
|
| 193 |
+
|
| 194 |
+
@tf.keras.saving.register_keras_serializable()
|
| 195 |
+
def dice_score(y_true, y_pred, threshold=0.5, smooth=1.0):
|
| 196 |
+
#determine binary or multiclass segmentation
|
| 197 |
+
is_multiclass = y_true.shape[-1] > 1
|
| 198 |
+
|
| 199 |
+
if not is_multiclass:
|
| 200 |
+
# Binary segmentation
|
| 201 |
+
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
|
| 202 |
+
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
|
| 203 |
+
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
|
| 204 |
+
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
|
| 205 |
+
return score
|
| 206 |
+
else:
|
| 207 |
+
# Multiclass segmentation
|
| 208 |
+
num_classes = y_true.shape[-1]
|
| 209 |
+
score_per_class = []
|
| 210 |
+
|
| 211 |
+
for i in range(num_classes):
|
| 212 |
+
y_true_flat = tf.cast(tf.reshape(y_true, [-1]), dtype=tf.float32)
|
| 213 |
+
y_pred_flat = tf.cast(tf.reshape(y_pred >= threshold, [-1]), dtype=tf.float32)
|
| 214 |
+
intersection = tf.reduce_sum(y_true_flat * y_pred_flat)
|
| 215 |
+
score = (2. * intersection + smooth) / (tf.reduce_sum(y_true_flat) + tf.reduce_sum(y_pred_flat) + smooth)
|
| 216 |
+
score_per_class.append(score)
|
| 217 |
+
|
| 218 |
+
return tf.reduce_mean(score_per_class)
|
| 219 |
+
|
| 220 |
+
@tf.keras.saving.register_keras_serializable()
|
| 221 |
+
def dice_loss(y_true, y_pred):
|
| 222 |
+
dice = dice_score(y_true, y_pred)
|
| 223 |
+
loss = 1. - dice
|
| 224 |
+
return tf.cast(loss, dtype=tf.float32)
|
| 225 |
+
|
| 226 |
+
@tf.keras.saving.register_keras_serializable()
|
| 227 |
+
def cce_dice_loss(y_true, y_pred):
|
| 228 |
+
cce = tf.keras.losses.CategoricalCrossentropy()(y_true, y_pred)
|
| 229 |
+
dice = dice_loss(y_true, y_pred)
|
| 230 |
+
return tf.cast(cce, dtype=tf.float32) + dice
|
| 231 |
+
|
| 232 |
# Streamlit App Title
|
| 233 |
st.title("Satellite Mining Segmentation: SAR + Optic Image Inference")
|
| 234 |
|
|
|
|
| 263 |
optic_images = dataset.normalizeImages(optic_images, 'i')
|
| 264 |
|
| 265 |
# Load model
|
| 266 |
+
model = tf.keras.models.load_model(model_path,
|
| 267 |
+
custom_objects={"cce_dice_loss": cce_dice_loss, "dice_score": dice_score})
|
| 268 |
|
| 269 |
# Predict
|
| 270 |
pred_masks = model.predict([optic_images, sar_images], verbose=0)
|