RebeccaNissan26 commited on
Commit
3880b63
·
1 Parent(s): 9266a5b

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +57 -34
README.md CHANGED
@@ -5,62 +5,85 @@ library_name: keras
5
  ## Model description
6
 
7
  This model identifies contrails in satellite images. It takes pre-processed .npy files (images) as its inputs, and returns a "mask" image showing only the contrails overlayed on the same area.
8
- We used a TransUNet model architecture ...
9
 
10
- ## Intended uses & limitations
11
 
12
- Note, this is in progress -
13
 
14
- We hope that data scientists and researchers focused on reducing contrails (towards the goal of reducing global warming) will use this model to improve their work.
15
- There are current efforts underway to develop software that re-routes planes to avoid contrails. Researchers are building models to predict contrails based on atmospheric conditions and other factors, but they need a way to validate those predictions.
16
- That's where we come in. Let's say you have a model that suggests there should be contrails in an image (based on the time/location the picture was taken).
17
- Our model can find the contrails (or lack thereof) in your image without a human labeler, allowing you to validate whether your predictions were correct.
18
- This becomes valuable at scale, when you need to validate your model on many images - contrail detection is a tough task for humans and machines alike!
19
 
 
 
 
20
 
21
  ## How to Get Started with the Model
22
 
23
  Use the code below to get started with the model.
24
 
25
  ```
 
 
 
 
 
 
 
26
  from huggingface_hub import notebook_login
27
- notebook_login()
 
 
28
 
29
- from huggingface_hub import from_pretrained_keras
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  model = from_pretrained_keras("MIDSCapstoneTeam/ContrailSentinel", custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
32
- ```
 
 
 
 
 
 
33
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  ## Training and evaluation data
36
 
37
  (will add more info here)
38
  OpenContrails dataset [here](https://arxiv.org/abs/2304.02122)
 
39
 
40
  ## Training procedure
41
 
42
 
43
 
44
- ### Training hyperparameters
45
-
46
- The following hyperparameters were used during training:
47
-
48
- | Hyperparameters | Value |
49
- | :-- | :-- |
50
- | name | RMSprop |
51
- | weight_decay | None |
52
- | clipnorm | None |
53
- | global_clipnorm | None |
54
- | clipvalue | None |
55
- | use_ema | False |
56
- | ema_momentum | 0.99 |
57
- | ema_overwrite_frequency | 100 |
58
- | jit_compile | True |
59
- | is_legacy_optimizer | False |
60
- | learning_rate | 0.0010000000474974513 |
61
- | rho | 0.9 |
62
- | momentum | 0.0 |
63
- | epsilon | 1e-07 |
64
- | centered | False |
65
- | training_precision | float32 |
66
-
 
5
  ## Model description
6
 
7
  This model identifies contrails in satellite images. It takes pre-processed .npy files (images) as its inputs, and returns a "mask" image showing only the contrails overlayed on the same area.
8
+ We used a UNet model architecture ...
9
 
 
10
 
 
11
 
12
+ ## Intended uses & limitations
 
 
 
 
13
 
14
+ We hope that data scientists and researchers focused on reducing contrails (towards the goal of reducing global warming) will use this model to improve their work.
15
+ There are current efforts underway to develop software that re-routes planes to avoid contrails.
16
+ Researchers are building models to predict contrails based on atmospheric conditions and other factors, but they need a way to validate those predictions.
17
 
18
  ## How to Get Started with the Model
19
 
20
  Use the code below to get started with the model.
21
 
22
  ```
23
+ #Required imports and Huggingface authentication
24
+ import os
25
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
26
+ os.environ["SM_FRAMEWORK"] = "tf.keras"
27
+ import segmentation_models as sm
28
+ import tensorflow as tf
29
+ from huggingface_hub import from_pretrained_keras
30
  from huggingface_hub import notebook_login
31
+ from PIL import Image
32
+ import numpy as np
33
+ import matplotlib.pyplot as plt
34
 
 
35
 
36
+ weights = [0.5,0.5] # hyper parameter
37
+
38
+ dice_loss = sm.losses.DiceLoss(class_weights = weights)
39
+ focal_loss = sm.losses.CategoricalFocalLoss()
40
+ TOTAL_LOSS_FACTOR = 5
41
+ total_loss = dice_loss + (TOTAL_LOSS_FACTOR * focal_loss)
42
+
43
+ def jaccard_coef(y_true, y_pred):
44
+ """
45
+ Defines custom jaccard coefficient metric
46
+ """
47
+
48
+ y_true_flatten = K.flatten(y_true)
49
+ y_pred_flatten = K.flatten(y_pred)
50
+ intersection = K.sum(y_true_flatten * y_pred_flatten)
51
+ final_coef_value = (intersection + 1.0) / (K.sum(y_true_flatten) + K.sum(y_pred_flatten) - intersection + 1.0)
52
+ return final_coef_value
53
+
54
+ metrics = [tf.keras.metrics.MeanIoU(num_classes=2, sparse_y_true= False, sparse_y_pred=False, name="Mean IOU")]
55
+ notebook_login()
56
+
57
+ # Load model from Huggingface Hub
58
  model = from_pretrained_keras("MIDSCapstoneTeam/ContrailSentinel", custom_objects={'dice_loss_plus_5focal_loss': total_loss, 'jaccard_coef': jaccard_coef, 'IOU score' : sm.metrics.IOUScore(threshold=0.9, name="IOU score"), 'Dice Coeficient' : sm.metrics.FScore(threshold=0.6, name="Dice Coeficient")}, compile=False)
59
+ model.compile(metrics=metrics)
60
+
61
+ # Inference -- User needs to specify the image path where label and ash images are stored
62
+ label = np.load({Image path} + 'human_pixel_masks.npy')
63
+ ash_image = np.load({Image path} + 'ash_image.npy')[...,4]
64
+ y_pred = model.predict(ash_image.reshape(1,256, 256, 3))
65
+ prediction = np.argmax(y_pred[0], axis=2).reshape(256,256,1)
66
 
67
+ fig, ax = plt.subplots(1, 2, figsize=(9, 5))
68
+ fig.tight_layout(pad=5.0)
69
+ ax[1].set_title("Contrail prediction")
70
+ ax[1].imshow(ash_image)
71
+ ax[1].imshow(prediction)
72
+ ax[1].axis('off')
73
+
74
+ ax[0].set_title("False colored satellite image")
75
+ ax[0].imshow(ash_image)
76
+ ax[0].axis('off')
77
+
78
+ ```
79
 
80
  ## Training and evaluation data
81
 
82
  (will add more info here)
83
  OpenContrails dataset [here](https://arxiv.org/abs/2304.02122)
84
+ Can get images
85
 
86
  ## Training procedure
87
 
88
 
89