igor-saprygin commited on
Commit
727349a
·
1 Parent(s): ffa182f
Files changed (2) hide show
  1. app.py +14 -7
  2. cycle_gan.py +45 -0
app.py CHANGED
@@ -3,7 +3,7 @@ import torch
3
  from torchvision import transforms as tr
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
- from cycle_gan import CycleGAN, create_model_and_optimizer
7
 
8
  @st.cache_resource # кэширование
9
  def load_model():
@@ -30,16 +30,23 @@ uploaded_file = st.file_uploader(f"Upload your {query}", type=["png", "jpg", "jp
30
  if uploaded_file is not None:
31
  image = Image.open(uploaded_file).convert("RGB")
32
  image_size = image.size
33
- image = image.resize((256,256))
34
- image = tr.ToTensor()(image)
35
- image = image.to(dtype = torch.float, device='cpu')
36
  model.eval()
37
  with torch.no_grad():
38
  if style == 'ColoredToSketch':
39
- generation = model.GA(image.view(1, *image.shape)).squeeze()
 
 
 
40
  else:
41
- generation = model.GB(image.view(1, *image.shape)).squeeze()
42
- generation = tr.ToPILImage()(generation)
 
 
 
 
43
  resized = generation.resize(image_size, Image.Resampling.LANCZOS)
44
  st.image(resized, caption="Your result!", use_column_width=True)
45
 
 
3
  from torchvision import transforms as tr
4
  from PIL import Image
5
  from huggingface_hub import hf_hub_download
6
+ from cycle_gan import CycleGAN, create_model_and_optimizer, val_transform_a, val_transform_b, de_normalize_a, de_normalize_b
7
 
8
  @st.cache_resource # кэширование
9
  def load_model():
 
30
  if uploaded_file is not None:
31
  image = Image.open(uploaded_file).convert("RGB")
32
  image_size = image.size
33
+ # image = image.resize((256,256))
34
+ image = tr.CenterCrop(min(image.size))(image)
35
+
36
  model.eval()
37
  with torch.no_grad():
38
  if style == 'ColoredToSketch':
39
+ image = val_transform_a(image)
40
+ image = image.to(dtype = torch.float, device='cpu')
41
+ generation = model.GA(image.view(1, *image.shape)).detach().squeeze()
42
+ generation = de_normalize_b(generation)
43
  else:
44
+ st.markdown(str(image.shape))
45
+ image = val_transform_b(image)
46
+ image = image.to(dtype = torch.float, device='cpu')
47
+ print(val_transform_b)
48
+ generation = model.GB(image.unsqueeze(0)).squeeze().detach()
49
+ generation = de_normalize_a(generation)
50
  resized = generation.resize(image_size, Image.Resampling.LANCZOS)
51
  st.image(resized, caption="Your result!", use_column_width=True)
52
 
cycle_gan.py CHANGED
@@ -1,9 +1,54 @@
1
 
2
  import torch
3
  from torch import nn
 
 
4
  import functools
5
  import itertools
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  class Discriminator(nn.Module):
8
  def __init__(self, c_in, c_out_init=64, n_layers=3, norm_layer=nn.BatchNorm2d):
9
  super().__init__()
 
1
 
2
  import torch
3
  from torch import nn
4
+ from torchvision import transforms as tr
5
+ import numpy as np
6
  import functools
7
  import itertools
8
 
9
+ channel_mean_a, channel_std_a = np.array([0.87415955, 0.84342639 ,0.8385736 ]), np.array([0.21790985, 0.24519696, 0.24330734])
10
+ channel_mean_b, channel_std_b = np.array([0.95326045, 0.95326045, 0.95326045]), np.array([0.13081782, 0.13081782, 0.13081782])
11
+
12
+ def get_transforms(mean, std, crop_size):
13
+ train_transform = tr.Compose([
14
+ tr.ToPILImage(),
15
+ tr.RandomCrop(crop_size),
16
+ tr.RandomHorizontalFlip(0.5),
17
+ tr.RandomVerticalFlip(0.5),
18
+ # tr.ColorJitter(brightness=0.05, contrast=0.1, saturation=0.1, hue=0.1),
19
+ # tr.RandomRotation(30),
20
+ # tr.RandomVerticalFlip(0.5),
21
+ tr.Resize(256),
22
+ tr.ToTensor(),
23
+ tr.Normalize(mean, std)
24
+ ])
25
+
26
+ val_transform = tr.Compose([
27
+ tr.ToPILImage(),
28
+ tr.Resize(256),
29
+ tr.ToTensor(),
30
+ tr.Normalize(mean, std)
31
+ ])
32
+
33
+ def de_normalize(img, normalized=True):
34
+ # print(type(img))
35
+ res = img.cpu()*std.reshape(-1,1,1) + mean.reshape(-1,1,1)
36
+ return tr.ToPILImage()(res)
37
+ # img = img.detach().cpu().numpy().transpose((1, 2, 0))
38
+ # return img
39
+ # return img * std + mean
40
+
41
+ return train_transform, val_transform, de_normalize
42
+
43
+
44
+ # Ваши гиперпараметры
45
+ hyperparams = dict(
46
+ [('crop_size',224)]
47
+ )
48
+ # transform-ы для A и B
49
+ train_transform_a, val_transform_a, de_normalize_a = get_transforms(channel_mean_a, channel_std_a, **hyperparams)
50
+ train_transform_b, val_transform_b, de_normalize_b = get_transforms(channel_mean_b, channel_std_b, **hyperparams)
51
+
52
  class Discriminator(nn.Module):
53
  def __init__(self, c_in, c_out_init=64, n_layers=3, norm_layer=nn.BatchNorm2d):
54
  super().__init__()