ebgoldstein commited on
Commit
9da561f
·
verified ·
1 Parent(s): 3373fee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -42
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import gradio as gr
2
  import numpy as np
3
  import tensorflow as tf
@@ -5,74 +6,67 @@ from skimage.io import imsave
5
  from skimage.transform import resize
6
  import matplotlib.pyplot as plt
7
 
8
- #from SegZoo
9
- def standardize(img):
10
- #standardization using adjusted standard deviation
11
 
 
 
 
12
  N = np.shape(img)[0] * np.shape(img)[1]
13
- s = np.maximum(np.std(img), 1.0/np.sqrt(N))
14
  m = np.mean(img)
15
  img = (img - m) / s
16
- del m, s, N
17
- #
18
- if np.ndim(img)==2:
19
- img = np.dstack((img,img,img))
20
  return img
21
 
22
- #load model
23
  filepath = './saved_model'
24
- model = tf.keras.models.load_model(filepath, compile = True)
25
- model.compile
26
 
27
- #segmentation
28
  def FRFsegment(input_img):
29
-
30
- dims=(512,512)
31
- w = input_img.shape[0]
32
- h = input_img.shape[1]
33
- print(w)
34
- print(h)
35
 
36
  img = standardize(input_img)
37
- img = resize(img, dims, preserve_range=True, clip=True)
38
- img = np.expand_dims(img,axis=0)
39
-
40
- est_label = model.predict(img)
41
-
42
- # # Test Time AUgmentation
43
- # est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1))
44
- # est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1))
45
- # est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img))))))
46
 
47
- # #soft voting - sum the softmax scores to return the new TTA estimated softmax scores
48
- # pred = est_label + est_label2 + est_label3 + est_label4
49
- # est_label = pred
50
-
51
- mask = np.argmax(np.squeeze(est_label, axis=0),-1)
52
  pred = resize(mask, (w, h), preserve_range=True, clip=True)
53
 
54
  # Convert pred to uint8
55
  pred_uint8 = (pred / np.max(pred) * 255).astype(np.uint8)
56
 
57
  imsave("label.png", pred_uint8)
58
-
59
- #overlay plot
60
  plt.clf()
61
- plt.imshow(input_img,cmap='gray')
62
  plt.imshow(pred, alpha=0.4)
63
  plt.axis("off")
64
  plt.margins(x=0, y=0)
65
  plt.savefig("overlay.png", dpi=300, bbox_inches="tight")
66
-
67
- return plt, "label.png", "overlay.png"
68
-
69
- out1 = gr.outputs.File()
70
- out2 = gr.outputs.File()
71
 
 
72
 
 
73
  title = "Segment beach imagery taken from a tower in Duck, NC, USA"
74
  description = "This model segments beach imagery into 4 classes: vegetation, sand, coarse sand, and background (water + sky + buildings + people)"
75
- examples = [['examples/FRF_c1_snap_20191112160000.jpg'],['examples/FRF_c1_snap_20170101.jpg']]
 
 
 
 
 
 
 
 
 
76
 
 
77
 
78
- FRFSegapp = gr.Interface(FRFsegment, gr.inputs.Image(), ['plot',out1, out2], examples=examples, title = title, description = description).launch()
 
1
+ import os
2
  import gradio as gr
3
  import numpy as np
4
  import tensorflow as tf
 
6
  from skimage.transform import resize
7
  import matplotlib.pyplot as plt
8
 
9
+ # Suppress TensorFlow warnings
10
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
 
11
 
12
+ # Standardize function
13
+ def standardize(img):
14
+ # Standardization using adjusted standard deviation
15
  N = np.shape(img)[0] * np.shape(img)[1]
16
+ s = np.maximum(np.std(img), 1.0 / np.sqrt(N))
17
  m = np.mean(img)
18
  img = (img - m) / s
19
+ if np.ndim(img) == 2:
20
+ img = np.dstack((img, img, img))
 
 
21
  return img
22
 
23
+ # Load model
24
  filepath = './saved_model'
25
+ model = tf.keras.layers.TFSMLayer(filepath, call_endpoint='serving_default')
 
26
 
27
+ # Segmentation function
28
  def FRFsegment(input_img):
29
+ dims = (512, 512)
30
+ w, h = input_img.shape[:2]
 
 
 
 
31
 
32
  img = standardize(input_img)
33
+ img = resize(img, dims, preserve_range=True, clip=True)
34
+ img = np.expand_dims(img, axis=0)
35
+
36
+ est_label = model(img)
 
 
 
 
 
37
 
38
+ # Mask creation
39
+ mask = np.argmax(np.squeeze(est_label, axis=0), -1)
 
 
 
40
  pred = resize(mask, (w, h), preserve_range=True, clip=True)
41
 
42
  # Convert pred to uint8
43
  pred_uint8 = (pred / np.max(pred) * 255).astype(np.uint8)
44
 
45
  imsave("label.png", pred_uint8)
46
+
47
+ # Overlay plot
48
  plt.clf()
49
+ plt.imshow(input_img, cmap='gray')
50
  plt.imshow(pred, alpha=0.4)
51
  plt.axis("off")
52
  plt.margins(x=0, y=0)
53
  plt.savefig("overlay.png", dpi=300, bbox_inches="tight")
 
 
 
 
 
54
 
55
+ return "label.png", "overlay.png"
56
 
57
+ # Gradio Interface
58
  title = "Segment beach imagery taken from a tower in Duck, NC, USA"
59
  description = "This model segments beach imagery into 4 classes: vegetation, sand, coarse sand, and background (water + sky + buildings + people)"
60
+ examples = [['examples/FRF_c1_snap_20191112160000.jpg'], ['examples/FRF_c1_snap_20170101.jpg']]
61
+
62
+ FRFSegapp = gr.Interface(
63
+ fn=FRFsegment,
64
+ inputs=gr.Image(type="numpy"),
65
+ outputs=[gr.File(), gr.File()],
66
+ examples=examples,
67
+ title=title,
68
+ description=description
69
+ )
70
 
71
+ FRFSegapp.launch()
72