1plus1equal3 commited on
Commit
1ff6b95
·
1 Parent(s): 5cd8f70

Initial commit

Browse files
Files changed (2) hide show
  1. app.py +94 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Import all necessary libs
2
+ import tensorflow as tf
3
+ import keras
4
+ import numpy as np
5
+ import cv2
6
+ import gradio as gr
7
+ import os
8
+ import subprocess
9
+ import gdown
10
+
11
+ # Download model
12
+ # Pix2pix model
13
+ model_url = 'https://drive.google.com/drive/folders/1jOxiKyf8n7fwNZfgeZyUrNJ90LEmIL3S?usp=sharing'
14
+ os.makedirs('/content/pix2pix', exist_ok=True)
15
+ subprocess.run(['gdown', '--fuzzy', model_url, '-O', '/content/pix2pix', '--folder'], check=True)
16
+
17
+ # WD-Net
18
+ model_url = 'https://drive.google.com/file/d/1M8EOE4Ej8oS4_0BHCEwExxu5CMFS5HQZ/view?usp=sharing'
19
+ os.makedirs('/content/WD-Net', exist_ok=True)
20
+ subprocess.run(['gdown', '--fuzzy', model_url, '-O', '/content/WD-Net/model.zip'], check=True)
21
+ subprocess.run(['unzip', '/content/WD-Net/model.zip', '-d', '/content/WD-Net'], check=True)
22
+
23
+ # MS-UNet
24
+ model_url = 'https://drive.google.com/file/d/1-0_bEWTItkILbCJQ4ViEBGg0zJPaIcC1/view?usp=sharing'
25
+ os.makedirs('/content/MS-UNet', exist_ok=True)
26
+ subprocess.run(['gdown', '--fuzzy', model_url, '-O', '/content/MS-UNet/unet.keras'], check=True)
27
+
28
+ # Load model
29
+ # Load Pix2Pix
30
+ pix2pix_path = '/model/wt_generator_best.keras'
31
+ Pix2Pix = keras.saving.load_model(pix2pix_path)
32
+
33
+ # Load MS-UNet
34
+ unet_path = '/model/unet.keras'
35
+ MS_UNet = keras.saving.load_model(unet_path)
36
+
37
+ # Load WD-Net
38
+ @keras.saving.register_keras_serializable(package="Clip")
39
+ class Clip(keras.layers.Layer):
40
+ def __init__(self, **kwargs):
41
+ super().__init__(**kwargs)
42
+ def call(self, input):
43
+ return tf.clip_by_value(input, 0, 1)
44
+
45
+ gen_path = '/model/generator_epoch:10.keras'
46
+ WD_Net = keras.saving.load_model(gen_path)
47
+
48
+ # Define infer function
49
+ def infer(img, model='WD-Net'):
50
+ # Image original shape
51
+ org_shape = img.shape
52
+ # Choose model
53
+ if model == 'WD-Net':
54
+ generator = WD_Net
55
+ img = tf.image.resize(img, [256, 256], method='area')
56
+ # Normalize image and return
57
+ img = tf.cast(img, tf.float32) / 255.
58
+ img = tf.expand_dims(img, axis=0)
59
+ rm_wt = generator.predict(img, verbose=0)
60
+ rm_wt = rm_wt['I'][0]
61
+ rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
62
+ out_img = (rm_wt * 255).astype(np.uint8)
63
+ elif model == 'MS-UNet':
64
+ generator = MS_UNet
65
+ img = tf.image.resize(img, [256, 256], method='area')
66
+ # Normalize image and return
67
+ img = (tf.cast(img, tf.float32) - 127.5) / 127.5
68
+ img = tf.expand_dims(img, axis=0)
69
+ rm_wt = generator.predict(img, verbose=0)
70
+ rm_wt = rm_wt[0]
71
+ rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
72
+ out_img = ((rm_wt + 1) / 2 * 255).astype(np.uint8)
73
+ elif model == 'Pix2Pix':
74
+ generator = Pix2Pix
75
+ img = tf.image.resize(img, [256, 256], method='area')
76
+ # Normalize image and return
77
+ img = tf.cast(img, tf.float32) / 255.
78
+ img = tf.expand_dims(img, axis=0)
79
+ rm_wt = generator.predict(img, verbose=0)
80
+ rm_wt = rm_wt[0]
81
+ rm_wt = cv2.resize(rm_wt, (org_shape[1], org_shape[0]))
82
+ out_img = (rm_wt * 255).astype(np.uint8)
83
+ return out_img
84
+
85
+ # Main gradio code
86
+ model_list = ['WD-Net', 'MS-UNet', 'Pix2Pix']
87
+
88
+ demo = gr.Interface(
89
+ fn=infer,
90
+ inputs=[gr.Image(), gr.Dropdown(model_list)],
91
+ outputs=gr.Image(),
92
+ )
93
+
94
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gdown
2
+ tensorflow
3
+ matplotlib
4
+ opencv-python
5
+ numpy
6
+ pandas
7
+ pillow