RebeccaNissan26 commited on
Commit
a5a9935
·
1 Parent(s): f2276de

running app

Browse files
Files changed (1) hide show
  1. app.py +42 -16
app.py CHANGED
@@ -41,22 +41,48 @@ model = load_model('unet_contrails_model.keras', custom_objects={'dice_loss_plus
41
  model.compile(metrics=metrics)
42
 
43
 
44
- label = np.load('images/human_pixel_masks.npy')
45
- ash_image = np.load('images/ash_image.npy')[...,4]
46
- y_pred = model.predict(ash_image.reshape(1,256, 256, 3))
47
- prediction = np.argmax(y_pred[0], axis=2).reshape(256,256,1)
48
- intersection = label & prediction
49
- false_negative = label - intersection
50
- false_possitive = prediction - intersection
51
- color_prediction = np.stack([false_negative*.7, intersection*.7, false_possitive*.7], axis=2).reshape(256,256,3)
52
- print(intersection)
53
-
54
-
55
- def greet(name):
56
- return "Hello " + name + "!!"
57
-
58
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
59
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  # first we need to load the model from somewhere - probably by using model.load on a keras file (which could be saved in our huggingface space repo)
 
41
  model.compile(metrics=metrics)
42
 
43
 
44
+
45
+
46
+
47
+ # def greet(name):
48
+ # return "Hello " + name + "!!"
49
+
50
+ # iface = gr.Interface(fn=greet, inputs="text", outputs="text")
51
+ # iface.launch()
52
+
53
+ '#images/human_pixel_masks.npy'
54
+ def predict(ash_image, model=model):
55
+ #label = np.load(label_image)
56
+ ash_image = np.load(ash_image)[...,4]
57
+ y_pred = model.predict(ash_image.reshape(1,256, 256, 3))
58
+ prediction = np.argmax(y_pred[0], axis=2).reshape(256,256,1)
59
+ #intersection = label & prediction
60
+ #false_negative = label - intersection
61
+ #false_possitive = prediction - intersection
62
+ #color_prediction = np.stack([false_negative*.7, intersection*.7, false_possitive*.7], axis=2).reshape(256,256,3)
63
+ seg_info = [(prediction, 'contrails')]
64
+ return(ash_image, seg_info)
65
+
66
+
67
+
68
+ if __name__ == "__main__":
69
+ class2hexcolor = {"contrails": "#007fff"}
70
+
71
+ with gr.Blocks(title="Contrail Predictions") as demo:
72
+ gr.Markdown("""<h1><center>Predict Contrails in Satellite Images</center></h1>""")
73
+ with gr.Row():
74
+ img_input = gr.Image(type="pil", height=256, width=256, label="Input image")
75
+ img_output = gr.AnnotatedImage(label="Predictions", height=256, width=256, color_map=class2hexcolor)
76
+
77
+ section_btn = gr.Button("Generate Predictions")
78
+ section_btn.click(partial(predict, model=model), img_input, img_output)
79
+
80
+ images_dir = glob(os.path.join(os.getcwd(), "images") + os.sep + "*.png")
81
+ examples = [i for i in np.random.choice(images_dir, size=1, replace=False)]
82
+ gr.Examples(examples=examples, inputs=img_input, outputs=img_output)
83
+
84
+ demo.launch()
85
+
86
 
87
 
88
  # first we need to load the model from somewhere - probably by using model.load on a keras file (which could be saved in our huggingface space repo)