ebgoldstein commited on
Commit
c705fd9
·
1 Parent(s): 5523ece

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -0
app.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ import matplotlib.pyplot as plt
5
+ from PIL import Image
6
+ import io
7
+
8
+ #from SegZoo
9
+ def standardize(img):
10
+ #standardization using adjusted standard deviation
11
+
12
+ N = np.shape(img)[0] * np.shape(img)[1]
13
+ s = np.maximum(np.std(img), 1.0/np.sqrt(N))
14
+ m = np.mean(img)
15
+ img = (img - m) / s
16
+ del m, s, N
17
+ #
18
+ if np.ndim(img)==2:
19
+ img = np.dstack((img,img,img))
20
+ return img
21
+
22
+ #load model
23
+ filepath = './model/FRF_jan22_remap'
24
+ model = tf.keras.models.load_model(filepath, compile = True)
25
+ model.compile
26
+
27
+ #segmentation
28
+ def FRFsegment(input_img):
29
+ #img = tf.keras.preprocessing.image.load_img(input_img,target_size = (512, 512))
30
+ #img = tf.keras.preprocessing.image.img_to_array(img)
31
+ img = standardize(input_img)
32
+ img = np.expand_dims(img,axis=0)
33
+
34
+ est_label = model.predict(img)
35
+
36
+
37
+ # est_label2 = np.flipud(model.predict((np.flipud(img)), batch_size=1))
38
+ # est_label3 = np.fliplr(model.predict((np.fliplr(img)), batch_size=1))
39
+ # est_label4 = np.flipud(np.fliplr(model.predict((np.flipud(np.fliplr(img))))))
40
+
41
+ # #soft voting - sum the softmax scores to return the new TTA estimated softmax scores
42
+ # pred = est_label + est_label2 + est_label3 + est_label4
43
+
44
+ pred = est_label
45
+
46
+ # print(pred.shape)
47
+ mask = np.argmax(np.squeeze(pred, axis=0),-1)
48
+ # print(np.amax(mask))
49
+ # print(np.amin(mask))
50
+
51
+ #overlay plot
52
+ p = plt.imshow(input_img,cmap='gray')
53
+ p = plt.imshow(mask, alpha=0.4)
54
+ p = plt.axis("off")
55
+ return plt
56
+
57
+ # #overlay plot to PIL
58
+ # p = plt.imshow(input_img,cmap='gray')
59
+ # p = plt.imshow(mask, alpha=0.6)
60
+ # p = plt.axis("off")
61
+ # buf = io.BytesIO()
62
+ # fig = plt.gcf()
63
+ # fig.savefig(buf)
64
+ # buf.seek(0)
65
+ # img = Image.open(buf)
66
+ # return img
67
+
68
+ # #PIL
69
+ # #img = Image.fromarray(np.uint8(mask*(255/5)))
70
+
71
+ # return img
72
+
73
+ #FRFSegapp = gr.Interface(FRFsegment, gr.inputs.Image(shape=(512, 512)), outputs=gr.outputs.Image('plot'))
74
+ FRFSegapp = gr.Interface(FRFsegment, gr.inputs.Image(shape=(512, 512)), "image")
75
+
76
+ FRFSegapp.launch()