RebeccaNissan26 commited on
Commit
406bd4d
·
1 Parent(s): 3ac6e01

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +56 -0
README.md CHANGED
@@ -22,6 +22,62 @@ from huggingface_hub import from_pretrained_keras
22
  model = from_pretrained_keras("MIDSCapstoneTeam/ContrailSentinel", custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
23
  }}
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
  ## Training and evaluation data
27
 
 
22
  model = from_pretrained_keras("MIDSCapstoneTeam/ContrailSentinel", custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
23
  }}
24
 
25
+ {{
26
+ #Required imports and Huggingface authentication
27
+ import os
28
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
29
+ os.environ["SM_FRAMEWORK"] = "tf.keras"
30
+ import segmentation_models as sm
31
+ import tensorflow as tf
32
+ from huggingface_hub import from_pretrained_keras
33
+ from huggingface_hub import notebook_login
34
+ from PIL import Image
35
+ import numpy as np
36
+ import matplotlib.pyplot as plt
37
+
38
+
39
+ weights = [0.5,0.5] # hyper parameter
40
+
41
+ dice_loss = sm.losses.DiceLoss(class_weights = weights)
42
+ focal_loss = sm.losses.CategoricalFocalLoss()
43
+ TOTAL_LOSS_FACTOR = 5
44
+ total_loss = dice_loss + (TOTAL_LOSS_FACTOR * focal_loss)
45
+
46
+ def jaccard_coef(y_true, y_pred):
47
+ """
48
+ Defines custom jaccard coefficient metric
49
+ """
50
+
51
+ y_true_flatten = K.flatten(y_true)
52
+ y_pred_flatten = K.flatten(y_pred)
53
+ intersection = K.sum(y_true_flatten * y_pred_flatten)
54
+ final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
55
+ return final_coef_value
56
+
57
+ metrics = [tf.keras.metrics.MeanIoU(num_classes=2, sparse_y_true= False, sparse_y_pred=False, name="Mean IOU")]
58
+ notebook_login()
59
+
60
+ #Load model from Huggingface Hub
61
+ model = from_pretrained_keras("MIDSCapstoneTeam/ContrailSentinel", custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
62
+ model.compile(metrics=metrics)
63
+
64
+ #Inference -- User needs to specify the image path where label and ash images are stored
65
+ label = np.load({Image path} + 'human_pixel_masks.npy')
66
+ ash_image = np.load({Image path} + 'ash_image.npy')[...,4]
67
+ y_pred = model.predict(ash_image.reshape(1,256, 256, 3))
68
+ prediction = np.argmax(y_pred[0], axis=2).reshape(256,256,1)
69
+
70
+ fig, ax = plt.subplots(1, 2, figsize=(9, 5))
71
+ fig.tight_layout(pad=5.0)
72
+ ax[1].set_title("Contrail prediction")
73
+ ax[1].imshow(ash_image)
74
+ ax[1].imshow(prediction)
75
+ ax[1].axis('off')
76
+
77
+ ax[0].set_title("False colored satellite image")
78
+ ax[0].imshow(ash_image)
79
+ ax[0].axis('off')}}
80
+
81
 
82
  ## Training and evaluation data
83