init
Browse files- app.py +19 -5
- ugatit_test.py +6 -6
app.py
CHANGED
|
@@ -15,7 +15,7 @@ import numpy as np
|
|
| 15 |
import PIL.Image
|
| 16 |
|
| 17 |
from io import BytesIO
|
| 18 |
-
|
| 19 |
|
| 20 |
ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
|
| 21 |
TITLE = 'taki0112/UGATIT'
|
|
@@ -26,6 +26,9 @@ ARTICLE = """
|
|
| 26 |
|
| 27 |
"""
|
| 28 |
|
|
|
|
|
|
|
|
|
|
| 29 |
def parse_args() -> argparse.Namespace:
|
| 30 |
parser = argparse.ArgumentParser()
|
| 31 |
parser.add_argument('--device', type=str, default='cpu')
|
|
@@ -41,13 +44,22 @@ def parse_args() -> argparse.Namespace:
|
|
| 41 |
return parser.parse_args()
|
| 42 |
|
| 43 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
def run(
|
| 46 |
-
image
|
|
|
|
| 47 |
) -> tuple[PIL.Image.Image]:
|
| 48 |
-
|
| 49 |
|
| 50 |
-
|
|
|
|
|
|
|
| 51 |
|
| 52 |
|
| 53 |
def main():
|
|
@@ -55,7 +67,9 @@ def main():
|
|
| 55 |
|
| 56 |
args = parse_args()
|
| 57 |
|
| 58 |
-
|
|
|
|
|
|
|
| 59 |
func = functools.update_wrapper(func, run)
|
| 60 |
|
| 61 |
|
|
|
|
| 15 |
import PIL.Image
|
| 16 |
|
| 17 |
from io import BytesIO
|
| 18 |
+
import ugatit_test
|
| 19 |
|
| 20 |
ORIGINAL_REPO_URL = 'https://github.com/taki0112/UGATIT'
|
| 21 |
TITLE = 'taki0112/UGATIT'
|
|
|
|
| 26 |
|
| 27 |
"""
|
| 28 |
|
| 29 |
+
|
| 30 |
+
MODEL_REPO = 'hylee/UGATIT_model'
|
| 31 |
+
|
| 32 |
def parse_args() -> argparse.Namespace:
|
| 33 |
parser = argparse.ArgumentParser()
|
| 34 |
parser.add_argument('--device', type=str, default='cpu')
|
|
|
|
| 44 |
return parser.parse_args()
|
| 45 |
|
| 46 |
|
| 47 |
+
def load_checkpoint():
|
| 48 |
+
checkpoint_path = huggingface_hub.hf_hub_download(MODEL_REPO,
|
| 49 |
+
'UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing/checkpoint',
|
| 50 |
+
cache_dir='UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing')
|
| 51 |
+
print(checkpoint_path)
|
| 52 |
+
return 'UGATIT_selfie2anime_lsgan_4resblock_6dis_1_1_10_10_1000_sn_smoothing'
|
| 53 |
+
|
| 54 |
|
| 55 |
def run(
|
| 56 |
+
image,
|
| 57 |
+
checkpoint_dir: str,
|
| 58 |
) -> tuple[PIL.Image.Image]:
|
|
|
|
| 59 |
|
| 60 |
+
result = ugatit_test.main_test(image.name, checkpoint_dir)
|
| 61 |
+
|
| 62 |
+
return PIL.Image.open(result)
|
| 63 |
|
| 64 |
|
| 65 |
def main():
|
|
|
|
| 67 |
|
| 68 |
args = parse_args()
|
| 69 |
|
| 70 |
+
checkpoint_dir = load_checkpoint()
|
| 71 |
+
|
| 72 |
+
func = functools.partial(run, checkpoint_dir=checkpoint_dir)
|
| 73 |
func = functools.update_wrapper(func, run)
|
| 74 |
|
| 75 |
|
ugatit_test.py
CHANGED
|
@@ -8,7 +8,7 @@ from ugatit.utils import *
|
|
| 8 |
|
| 9 |
class UgatitTest:
|
| 10 |
|
| 11 |
-
def __init__(self, sess):
|
| 12 |
self.light = False
|
| 13 |
|
| 14 |
if self.light:
|
|
@@ -18,7 +18,7 @@ class UgatitTest:
|
|
| 18 |
|
| 19 |
self.sess = sess
|
| 20 |
self.phase = 'test'
|
| 21 |
-
self.checkpoint_dir =
|
| 22 |
self.result_dir = 'results'
|
| 23 |
self.log_dir = 'logs'
|
| 24 |
self.dataset_name = 'selfie2anime'
|
|
@@ -57,8 +57,8 @@ class UgatitTest:
|
|
| 57 |
self.img_size = 256
|
| 58 |
self.img_ch = 3
|
| 59 |
|
| 60 |
-
self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
|
| 61 |
-
check_folder(self.sample_dir)
|
| 62 |
|
| 63 |
# self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
|
| 64 |
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
|
|
@@ -350,12 +350,12 @@ class UgatitTest:
|
|
| 350 |
|
| 351 |
|
| 352 |
gan = None
|
| 353 |
-
def main_test(img_path):
|
| 354 |
# open session
|
| 355 |
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
| 356 |
global gan
|
| 357 |
if gan is None:
|
| 358 |
-
gan = UgatitTest(sess)
|
| 359 |
# build graph
|
| 360 |
gan.build_model()
|
| 361 |
# show network architecture
|
|
|
|
| 8 |
|
| 9 |
class UgatitTest:
|
| 10 |
|
| 11 |
+
def __init__(self, sess, checkpoint_dir):
|
| 12 |
self.light = False
|
| 13 |
|
| 14 |
if self.light:
|
|
|
|
| 18 |
|
| 19 |
self.sess = sess
|
| 20 |
self.phase = 'test'
|
| 21 |
+
self.checkpoint_dir = checkpoint_dir
|
| 22 |
self.result_dir = 'results'
|
| 23 |
self.log_dir = 'logs'
|
| 24 |
self.dataset_name = 'selfie2anime'
|
|
|
|
| 57 |
self.img_size = 256
|
| 58 |
self.img_ch = 3
|
| 59 |
|
| 60 |
+
#self.sample_dir = os.path.join('/home/hylee/cartoon/UGATIT/samples', self.model_dir)
|
| 61 |
+
#check_folder(self.sample_dir)
|
| 62 |
|
| 63 |
# self.trainA, self.trainB = prepare_data(dataset_name=self.dataset_name, size=self.img_size
|
| 64 |
self.trainA_dataset = glob('./dataset/{}/*.*'.format(self.dataset_name + '/trainA'))
|
|
|
|
| 350 |
|
| 351 |
|
| 352 |
gan = None
|
| 353 |
+
def main_test(img_path, checkpoint_dir):
|
| 354 |
# open session
|
| 355 |
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
|
| 356 |
global gan
|
| 357 |
if gan is None:
|
| 358 |
+
gan = UgatitTest(sess, checkpoint_dir)
|
| 359 |
# build graph
|
| 360 |
gan.build_model()
|
| 361 |
# show network architecture
|