astro189 commited on
Commit
d0d04c3
·
1 Parent(s): 32b46ce

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +123 -0
app.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ import tensorflow_hub as hub
3
+ import matplotlib.pyplot as plt
4
+ import os
5
+ from PIL import Image
6
+ import numpy as np
7
+ import math
8
+ import functools
9
+ from matplotlib import gridspec
10
+ import gradio
11
+ import pickle
12
+
13
+ def load_img(img):
14
+ # max_dim = 256
15
+ # img = tf.image.convert_image_dtype(img, tf.float32)
16
+
17
+ # shape = tf.cast(np.shape(img)[:-1], tf.float32)
18
+ # long_dim = max(shape)
19
+ # scale = max_dim / long_dim
20
+ # new_shape = tf.cast(shape * scale, tf.int32)
21
+ # img=tf.convert_to_tensor(img)
22
+ # img = tf.image.convert_image_dtype(img, tf.float32)
23
+ # img = tf.image.resize(img, (256,256))
24
+ # img = img[tf.newaxis, :]
25
+ max_dim = 256
26
+ img = tf.image.convert_image_dtype(img, tf.float32)
27
+
28
+ shape = tf.cast(np.shape(img)[:-1], tf.float32)
29
+ long_dim = max(shape)
30
+ scale = max_dim / long_dim
31
+
32
+ new_shape = tf.cast(shape * scale, tf.int32)
33
+
34
+ img = tf.image.resize(img, new_shape)
35
+ img = img[tf.newaxis, :]
36
+ return img
37
+
38
+ def crop_center(image):
39
+ """Returns a cropped square image."""
40
+ shape = image.shape
41
+ new_shape = min(shape[1], shape[2])
42
+ offset_y = max(shape[1] - shape[2], 0) // 2
43
+ offset_x = max(shape[2] - shape[1], 0) // 2
44
+ image=tf.image.crop_to_bounding_box(
45
+ image, offset_y, offset_x, new_shape, new_shape)
46
+ return image
47
+
48
+ @functools.lru_cache(maxsize=None)
49
+ def load_image(img, image_size=(256, 256), preserve_aspect_ratio=True):
50
+ """Loads and preprocesses images."""
51
+ # Cache image file locally.
52
+ # image_path = tf.keras.utils.get_file(os.path.basename(image_url)[-128:], image_url)
53
+ # Load and convert to float32 numpy array, add batch dimension, and normalize to range [0, 1].
54
+ # img = tf.io.decode_image(
55
+ # tf.io.read_file(image_url),
56
+ # channels=3, dtype=tf.float32)[tf.newaxis, ...]
57
+ max_dim = 256
58
+ img = tf.image.convert_image_dtype(img, tf.float32)
59
+
60
+ shape = tf.cast(np.shape(img)[:-1], tf.float32)
61
+ long_dim = max(shape)
62
+ scale = max_dim / long_dim
63
+ new_shape = tf.cast(shape * scale, tf.int32)
64
+
65
+ #img = crop_center(img)
66
+ img = tf.image.resize(img, new_shape, preserve_aspect_ratio=True)
67
+ img = img[tf.newaxis, :]
68
+ return img
69
+
70
+ def show_n(images, titles=('',)):
71
+ n = len(images)
72
+ image_sizes = [image.shape[1] for image in images]
73
+ w = (image_sizes[0] * 6) // 320
74
+ plt.figure(figsize=(w * n, w))
75
+ gs = gridspec.GridSpec(1, n, width_ratios=image_sizes)
76
+ for i in range(n):
77
+ plt.subplot(gs[i])
78
+ plt.imshow(images[i][0], aspect='equal')
79
+ plt.axis('off')
80
+ plt.title(titles[i] if len(titles) > i else '')
81
+ plt.show()
82
+
83
+
84
+
85
+ def load_content_style_img(style_image,content_image):
86
+ style_image=np.array(style_image)
87
+ content_image=np.array(content_image)
88
+ width,height=content_image.shape[1],content_image.shape[0]
89
+ content_image = load_img(content_image)
90
+ style_image = load_img(style_image)
91
+ #content_image = crop_center(content_image)
92
+ content_image = tf.image.resize(content_image, (width,height), preserve_aspect_ratio=True)
93
+ style_image = crop_center(style_image)
94
+ style_image = tf.image.resize(style_image, (256,256), preserve_aspect_ratio=True)
95
+ style_image = tf.nn.avg_pool(style_image, ksize=[3,3], strides=[1,1], padding='SAME')
96
+ return style_image,content_image
97
+
98
+ # style_image,content_image=load_content_style_img(style,content)
99
+
100
+ # display([content_image, style_image])
101
+ #show_n([content_image, style_image], ['Content image', 'Style image'])
102
+
103
+ hub_handle = 'https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2'
104
+ hub_module = hub.load(hub_handle)
105
+
106
+ def tensor_to_image(tensor):
107
+ tensor = tensor*255
108
+ tensor = np.array(tensor, dtype=np.uint8)
109
+ if np.ndim(tensor)>3:
110
+ assert tensor.shape[0] == 1
111
+ tensor = tensor[0]
112
+ return Image.fromarray(tensor)
113
+
114
+ stylized_image=0
115
+ def train(style,content):
116
+ style_image,content_image=load_content_style_img(style,content)
117
+ outputs = hub_module(tf.constant(content_image), tf.constant(style_image))
118
+ stylized_image = outputs[0]
119
+ stylized_image=tensor_to_image(stylized_image)
120
+ return stylized_image
121
+
122
+ gr=gradio.Interface(fn=train, inputs=['image','image'], outputs='image')
123
+ gr.launch(share=True)