Spaces:
Sleeping
Sleeping
Commit ·
727349a
1
Parent(s): ffa182f
fix
Browse files- app.py +14 -7
- cycle_gan.py +45 -0
app.py
CHANGED
|
@@ -3,7 +3,7 @@ import torch
|
|
| 3 |
from torchvision import transforms as tr
|
| 4 |
from PIL import Image
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
-
from cycle_gan import CycleGAN, create_model_and_optimizer
|
| 7 |
|
| 8 |
@st.cache_resource # кэширование
|
| 9 |
def load_model():
|
|
@@ -30,16 +30,23 @@ uploaded_file = st.file_uploader(f"Upload your {query}", type=["png", "jpg", "jp
|
|
| 30 |
if uploaded_file is not None:
|
| 31 |
image = Image.open(uploaded_file).convert("RGB")
|
| 32 |
image_size = image.size
|
| 33 |
-
image = image.resize((256,256))
|
| 34 |
-
image = tr.
|
| 35 |
-
|
| 36 |
model.eval()
|
| 37 |
with torch.no_grad():
|
| 38 |
if style == 'ColoredToSketch':
|
| 39 |
-
|
|
|
|
|
|
|
|
|
|
| 40 |
else:
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
resized = generation.resize(image_size, Image.Resampling.LANCZOS)
|
| 44 |
st.image(resized, caption="Your result!", use_column_width=True)
|
| 45 |
|
|
|
|
| 3 |
from torchvision import transforms as tr
|
| 4 |
from PIL import Image
|
| 5 |
from huggingface_hub import hf_hub_download
|
| 6 |
+
from cycle_gan import CycleGAN, create_model_and_optimizer, val_transform_a, val_transform_b, de_normalize_a, de_normalize_b
|
| 7 |
|
| 8 |
@st.cache_resource # кэширование
|
| 9 |
def load_model():
|
|
|
|
| 30 |
if uploaded_file is not None:
|
| 31 |
image = Image.open(uploaded_file).convert("RGB")
|
| 32 |
image_size = image.size
|
| 33 |
+
# image = image.resize((256,256))
|
| 34 |
+
image = tr.CenterCrop(min(image.size))(image)
|
| 35 |
+
|
| 36 |
model.eval()
|
| 37 |
with torch.no_grad():
|
| 38 |
if style == 'ColoredToSketch':
|
| 39 |
+
image = val_transform_a(image)
|
| 40 |
+
image = image.to(dtype = torch.float, device='cpu')
|
| 41 |
+
generation = model.GA(image.view(1, *image.shape)).detach().squeeze()
|
| 42 |
+
generation = de_normalize_b(generation)
|
| 43 |
else:
|
| 44 |
+
st.markdown(str(image.shape))
|
| 45 |
+
image = val_transform_b(image)
|
| 46 |
+
image = image.to(dtype = torch.float, device='cpu')
|
| 47 |
+
print(val_transform_b)
|
| 48 |
+
generation = model.GB(image.unsqueeze(0)).squeeze().detach()
|
| 49 |
+
generation = de_normalize_a(generation)
|
| 50 |
resized = generation.resize(image_size, Image.Resampling.LANCZOS)
|
| 51 |
st.image(resized, caption="Your result!", use_column_width=True)
|
| 52 |
|
cycle_gan.py
CHANGED
|
@@ -1,9 +1,54 @@
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
from torch import nn
|
|
|
|
|
|
|
| 4 |
import functools
|
| 5 |
import itertools
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
class Discriminator(nn.Module):
|
| 8 |
def __init__(self, c_in, c_out_init=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
| 9 |
super().__init__()
|
|
|
|
| 1 |
|
| 2 |
import torch
|
| 3 |
from torch import nn
|
| 4 |
+
from torchvision import transforms as tr
|
| 5 |
+
import numpy as np
|
| 6 |
import functools
|
| 7 |
import itertools
|
| 8 |
|
| 9 |
+
channel_mean_a, channel_std_a = np.array([0.87415955, 0.84342639 ,0.8385736 ]), np.array([0.21790985, 0.24519696, 0.24330734])
|
| 10 |
+
channel_mean_b, channel_std_b = np.array([0.95326045, 0.95326045, 0.95326045]), np.array([0.13081782, 0.13081782, 0.13081782])
|
| 11 |
+
|
| 12 |
+
def get_transforms(mean, std, crop_size):
|
| 13 |
+
train_transform = tr.Compose([
|
| 14 |
+
tr.ToPILImage(),
|
| 15 |
+
tr.RandomCrop(crop_size),
|
| 16 |
+
tr.RandomHorizontalFlip(0.5),
|
| 17 |
+
tr.RandomVerticalFlip(0.5),
|
| 18 |
+
# tr.ColorJitter(brightness=0.05, contrast=0.1, saturation=0.1, hue=0.1),
|
| 19 |
+
# tr.RandomRotation(30),
|
| 20 |
+
# tr.RandomVerticalFlip(0.5),
|
| 21 |
+
tr.Resize(256),
|
| 22 |
+
tr.ToTensor(),
|
| 23 |
+
tr.Normalize(mean, std)
|
| 24 |
+
])
|
| 25 |
+
|
| 26 |
+
val_transform = tr.Compose([
|
| 27 |
+
tr.ToPILImage(),
|
| 28 |
+
tr.Resize(256),
|
| 29 |
+
tr.ToTensor(),
|
| 30 |
+
tr.Normalize(mean, std)
|
| 31 |
+
])
|
| 32 |
+
|
| 33 |
+
def de_normalize(img, normalized=True):
|
| 34 |
+
# print(type(img))
|
| 35 |
+
res = img.cpu()*std.reshape(-1,1,1) + mean.reshape(-1,1,1)
|
| 36 |
+
return tr.ToPILImage()(res)
|
| 37 |
+
# img = img.detach().cpu().numpy().transpose((1, 2, 0))
|
| 38 |
+
# return img
|
| 39 |
+
# return img * std + mean
|
| 40 |
+
|
| 41 |
+
return train_transform, val_transform, de_normalize
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Ваши гиперпараметры
|
| 45 |
+
hyperparams = dict(
|
| 46 |
+
[('crop_size',224)]
|
| 47 |
+
)
|
| 48 |
+
# transform-ы для A и B
|
| 49 |
+
train_transform_a, val_transform_a, de_normalize_a = get_transforms(channel_mean_a, channel_std_a, **hyperparams)
|
| 50 |
+
train_transform_b, val_transform_b, de_normalize_b = get_transforms(channel_mean_b, channel_std_b, **hyperparams)
|
| 51 |
+
|
| 52 |
class Discriminator(nn.Module):
|
| 53 |
def __init__(self, c_in, c_out_init=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
| 54 |
super().__init__()
|