Spaces:
Build error
Build error
Sophie98 commited on
Commit ·
ab7b996
1
Parent(s): cfa3c46
fix errors
Browse files- styleTransfer.py +10 -37
styleTransfer.py
CHANGED
|
@@ -11,8 +11,6 @@ from collections import OrderedDict
|
|
| 11 |
import tensorflow_hub as hub
|
| 12 |
import tensorflow as tf
|
| 13 |
|
| 14 |
-
from torchvision.utils import save_image
|
| 15 |
-
|
| 16 |
############################################# TRANSFORMER ############################################
|
| 17 |
|
| 18 |
vgg_path = 'vgg_normalised.pth'
|
|
@@ -20,16 +18,6 @@ decoder_path = 'decoder_iter_160000.pth'
|
|
| 20 |
Trans_path = 'transformer_iter_160000.pth'
|
| 21 |
embedding_path = 'embedding_iter_160000.pth'
|
| 22 |
|
| 23 |
-
def test_transform(size, crop):
|
| 24 |
-
transform_list = []
|
| 25 |
-
|
| 26 |
-
if size != 0:
|
| 27 |
-
transform_list.append(transforms.Resize(size))
|
| 28 |
-
if crop:
|
| 29 |
-
transform_list.append(transforms.CenterCrop(size))
|
| 30 |
-
transform_list.append(transforms.ToTensor())
|
| 31 |
-
transform = transforms.Compose(transform_list)
|
| 32 |
-
return transform
|
| 33 |
def style_transform(h,w):
|
| 34 |
k = (h,w)
|
| 35 |
size = int(np.max(k))
|
|
@@ -48,7 +36,6 @@ def content_transform():
|
|
| 48 |
# Advanced options
|
| 49 |
content_size=640
|
| 50 |
style_size=640
|
| 51 |
-
crop='store_true'
|
| 52 |
|
| 53 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 54 |
|
|
@@ -66,42 +53,28 @@ vgg.eval()
|
|
| 66 |
|
| 67 |
new_state_dict = OrderedDict()
|
| 68 |
state_dict = torch.load(decoder_path)
|
| 69 |
-
|
| 70 |
-
#namekey = k[7:] # remove `module.`
|
| 71 |
-
namekey = k
|
| 72 |
-
new_state_dict[namekey] = v
|
| 73 |
-
decoder.load_state_dict(new_state_dict)
|
| 74 |
|
| 75 |
new_state_dict = OrderedDict()
|
| 76 |
state_dict = torch.load(Trans_path)
|
| 77 |
-
|
| 78 |
-
#namekey = k[7:] # remove `module.`
|
| 79 |
-
namekey = k
|
| 80 |
-
new_state_dict[namekey] = v
|
| 81 |
-
Trans.load_state_dict(new_state_dict)
|
| 82 |
|
| 83 |
new_state_dict = OrderedDict()
|
| 84 |
state_dict = torch.load(embedding_path)
|
| 85 |
-
|
| 86 |
-
#namekey = k[7:] # remove `module.`
|
| 87 |
-
namekey = k
|
| 88 |
-
new_state_dict[namekey] = v
|
| 89 |
-
embedding.load_state_dict(new_state_dict)
|
| 90 |
|
| 91 |
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
|
| 92 |
network.eval()
|
| 93 |
|
|
|
|
|
|
|
|
|
|
| 94 |
def StyleTransformer(content_img: Image, style_img: Image):
|
| 95 |
|
| 96 |
network.to(device)
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
content_tf1 = content_transform()
|
| 101 |
-
content = content_tf1(content_img.convert("RGB"))
|
| 102 |
-
h,w,c=np.shape(content)
|
| 103 |
-
style_tf1 = style_transform(h,w)
|
| 104 |
-
style = style_tf1(style_img.convert("RGB"))
|
| 105 |
style = style.to(device).unsqueeze(0)
|
| 106 |
content = content.to(device).unsqueeze(0)
|
| 107 |
|
|
@@ -128,4 +101,4 @@ def StyleGAN(content_image, style_image):
|
|
| 128 |
def create_styledSofa(sofa:Image, style:Image):
|
| 129 |
#styled_sofa = StyleGAN(sofa,style)
|
| 130 |
styled_sofa = StyleTransformer(sofa,style)
|
| 131 |
-
return styled_sofa
|
|
|
|
| 11 |
import tensorflow_hub as hub
|
| 12 |
import tensorflow as tf
|
| 13 |
|
|
|
|
|
|
|
| 14 |
############################################# TRANSFORMER ############################################
|
| 15 |
|
| 16 |
vgg_path = 'vgg_normalised.pth'
|
|
|
|
| 18 |
Trans_path = 'transformer_iter_160000.pth'
|
| 19 |
embedding_path = 'embedding_iter_160000.pth'
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def style_transform(h,w):
|
| 22 |
k = (h,w)
|
| 23 |
size = int(np.max(k))
|
|
|
|
| 36 |
# Advanced options
|
| 37 |
content_size=640
|
| 38 |
style_size=640
|
|
|
|
| 39 |
|
| 40 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 41 |
|
|
|
|
| 53 |
|
| 54 |
new_state_dict = OrderedDict()
|
| 55 |
state_dict = torch.load(decoder_path)
|
| 56 |
+
decoder.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
|
| 58 |
new_state_dict = OrderedDict()
|
| 59 |
state_dict = torch.load(Trans_path)
|
| 60 |
+
Trans.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
|
| 62 |
new_state_dict = OrderedDict()
|
| 63 |
state_dict = torch.load(embedding_path)
|
| 64 |
+
embedding.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
|
| 67 |
network.eval()
|
| 68 |
|
| 69 |
+
content_tf = content_transform()
|
| 70 |
+
style_tf = style_transform(style_size,style_size)
|
| 71 |
+
|
| 72 |
def StyleTransformer(content_img: Image, style_img: Image):
|
| 73 |
|
| 74 |
network.to(device)
|
| 75 |
+
|
| 76 |
+
content = content_tf(content_img.convert("RGB"))
|
| 77 |
+
style = style_tf(style_img.convert("RGB"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
style = style.to(device).unsqueeze(0)
|
| 79 |
content = content.to(device).unsqueeze(0)
|
| 80 |
|
|
|
|
| 101 |
def create_styledSofa(sofa:Image, style:Image):
|
| 102 |
#styled_sofa = StyleGAN(sofa,style)
|
| 103 |
styled_sofa = StyleTransformer(sofa,style)
|
| 104 |
+
return styled_sofa
|