Update app.py
Browse files
app.py
CHANGED
|
@@ -7,7 +7,7 @@ import keras.backend as K
|
|
| 7 |
|
| 8 |
from matplotlib import pyplot as plt
|
| 9 |
from PIL import Image
|
| 10 |
-
import keras
|
| 11 |
|
| 12 |
|
| 13 |
resized_shape = (768, 768, 3)
|
|
@@ -21,14 +21,13 @@ IMG_SCALING = (1, 1)
|
|
| 21 |
# gdown.download(url, output, quiet=False)
|
| 22 |
# return output
|
| 23 |
|
| 24 |
-
|
| 25 |
-
model_file = 'seg_unet_model.h5'
|
| 26 |
|
| 27 |
#Custom objects for model
|
| 28 |
|
| 29 |
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
|
| 30 |
-
targets =
|
| 31 |
-
inputs =
|
| 32 |
intersection = K.sum(targets * inputs)
|
| 33 |
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
|
| 34 |
inputs = K.clip(inputs, eps, 1.0 - eps)
|
|
@@ -38,20 +37,20 @@ def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
|
|
| 38 |
return combo
|
| 39 |
|
| 40 |
def dice_coef(y_true, y_pred, smooth=1):
|
| 41 |
-
y_pred =
|
| 42 |
-
y_true =
|
| 43 |
intersection = K.sum(y_true * y_pred, axis=[1,2,3])
|
| 44 |
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
|
| 45 |
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
|
| 46 |
|
| 47 |
def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
|
| 48 |
-
pt_1 =
|
| 49 |
-
pt_0 =
|
| 50 |
focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
|
| 51 |
return focal_loss_fixed
|
| 52 |
|
| 53 |
# Load the model
|
| 54 |
-
seg_model = keras.models.load_model(
|
| 55 |
|
| 56 |
# inputs = gr.inputs.Image(type="pil", label="Upload an image")
|
| 57 |
# image_output = gr.outputs.Image(type="pil", label="Output Image")
|
|
@@ -68,7 +67,7 @@ def gen_pred(img, model=seg_model):
|
|
| 68 |
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
|
| 69 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 70 |
img = img/255
|
| 71 |
-
img =
|
| 72 |
pred = model.predict(img)
|
| 73 |
pred = np.squeeze(pred, axis=0)
|
| 74 |
fig = plt.figure(figsize=(3, 3))
|
|
|
|
| 7 |
|
| 8 |
from matplotlib import pyplot as plt
|
| 9 |
from PIL import Image
|
| 10 |
+
from tensorflow import keras
|
| 11 |
|
| 12 |
|
| 13 |
resized_shape = (768, 768, 3)
|
|
|
|
| 21 |
# gdown.download(url, output, quiet=False)
|
| 22 |
# return output
|
| 23 |
|
| 24 |
+
model_file = "./seg_unet_model.h5"
|
|
|
|
| 25 |
|
| 26 |
#Custom objects for model
|
| 27 |
|
| 28 |
def Combo_loss(y_true, y_pred, eps=1e-9, smooth=1):
|
| 29 |
+
targets = tf.dtypes.cast(K.flatten(y_true), tf.float32)
|
| 30 |
+
inputs = tf.dtypes.cast(K.flatten(y_pred), tf.float32)
|
| 31 |
intersection = K.sum(targets * inputs)
|
| 32 |
dice = (2. * intersection + smooth) / (K.sum(targets) + K.sum(inputs) + smooth)
|
| 33 |
inputs = K.clip(inputs, eps, 1.0 - eps)
|
|
|
|
| 37 |
return combo
|
| 38 |
|
| 39 |
def dice_coef(y_true, y_pred, smooth=1):
|
| 40 |
+
y_pred = tf.dtypes.cast(y_pred, tf.int32)
|
| 41 |
+
y_true = tf.dtypes.cast(y_true, tf.int32)
|
| 42 |
intersection = K.sum(y_true * y_pred, axis=[1,2,3])
|
| 43 |
union = K.sum(y_true, axis=[1,2,3]) + K.sum(y_pred, axis=[1,2,3])
|
| 44 |
return K.mean((2 * intersection + smooth) / (union + smooth), axis=0)
|
| 45 |
|
| 46 |
def focal_loss_fixed(y_true, y_pred, gamma=2.0, alpha=0.25):
|
| 47 |
+
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
|
| 48 |
+
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
|
| 49 |
focal_loss_fixed = -K.mean(alpha * K.pow(1. - pt_1, gamma) * K.log(pt_1+K.epsilon())) - K.mean((1 - alpha) * K.pow(pt_0, gamma) * K.log(1. - pt_0 + K.epsilon()))
|
| 50 |
return focal_loss_fixed
|
| 51 |
|
| 52 |
# Load the model
|
| 53 |
+
seg_model = keras.models.load_model('seg_unet_model.h5', custom_objects={'Combo_loss': Combo_loss, 'focal_loss_fixed': focal_loss_fixed, 'dice_coef': dice_coef})
|
| 54 |
|
| 55 |
# inputs = gr.inputs.Image(type="pil", label="Upload an image")
|
| 56 |
# image_output = gr.outputs.Image(type="pil", label="Output Image")
|
|
|
|
| 67 |
img = img[::IMG_SCALING[0], ::IMG_SCALING[1]]
|
| 68 |
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 69 |
img = img/255
|
| 70 |
+
img = tf.expand_dims(img, axis=0)
|
| 71 |
pred = model.predict(img)
|
| 72 |
pred = np.squeeze(pred, axis=0)
|
| 73 |
fig = plt.figure(figsize=(3, 3))
|