ombui commited on
Commit
b45b065
·
1 Parent(s): c39a2f0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -1,6 +1,3 @@
1
- import os
2
- os.environ["SM_FRAMEWORK"] = "tf.keras"
3
-
4
  import os
5
  import cv2
6
  from PIL import Image
@@ -27,7 +24,9 @@ dice_loss = sm.losses.DiceLoss(class_weights = weights)
27
  focal_loss = sm.losses.CategoricalFocalLoss()
28
  total_loss = dice_loss + (1 * focal_loss)
29
 
30
- satellite_model = load_model('C:/Users/sa/Desktop/Model_Training/satellite_segmentation_full.h5', custom_objects={'dice_loss_plus_1focal_loss': total_loss})
 
 
31
 
32
 
33
  def process_input_image(image_source):
@@ -44,8 +43,8 @@ my_app = gr.Blocks()
44
 
45
  with my_app:
46
  gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
47
- with gr.Tabs():
48
- with gr.TabItem ("Select your image"):
49
  with gr.Row():
50
  with gr.Column():
51
  img_source = gr.Image(label="Please select source Image", shape=(256, 256))
@@ -54,14 +53,15 @@ with my_app:
54
  output_label = gr.Label(label="Image Info")
55
  img_output = gr.Image(label="Image Output")
56
  source_image_loader.click(
57
- process_input_image,
58
- [
59
- img_source
60
- ],
61
- [
62
- output_label,
63
- img_output
64
- ]
65
  )
66
 
 
67
  my_app.launch(debug=True)
 
 
 
 
1
  import os
2
  import cv2
3
  from PIL import Image
 
24
  focal_loss = sm.losses.CategoricalFocalLoss()
25
  total_loss = dice_loss + (1 * focal_loss)
26
 
27
+ satellite_model = load_model('C:/Users/sa/Desktop/Model_Training/satellite_segmentation_full.h5',
28
+ custom_objects=({'dice_loss_plus_1focal_loss': total_loss,
29
+ 'jaccard_coef': jaccard_coef}))
30
 
31
 
32
  def process_input_image(image_source):
 
43
 
44
  with my_app:
45
  gr.Markdown("Statellite Image Segmentation Application UI with Gradio")
46
+ with gr.Tabs():
47
+ with gr.TabItem("Select your image"):
48
  with gr.Row():
49
  with gr.Column():
50
  img_source = gr.Image(label="Please select source Image", shape=(256, 256))
 
53
  output_label = gr.Label(label="Image Info")
54
  img_output = gr.Image(label="Image Output")
55
  source_image_loader.click(
56
+ process_input_image,
57
+ [
58
+ img_source
59
+ ],
60
+ [
61
+ output_label,
62
+ img_output
63
+ ]
64
  )
65
 
66
+
67
  my_app.launch(debug=True)