DeniSSio commited on
Commit
10f64e7
·
verified ·
1 Parent(s): 6948f31

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +83 -0
  2. model.py +111 -0
  3. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchvision.transforms as tr
4
+ from PIL import Image
5
+ import numpy as np
6
+ from model import CycleGAN, Discriminator, Generator
7
+
8
+ # === Настройки ===
9
+ mean = [0.5, 0.5, 0.5]
10
+ std = [0.5, 0.5, 0.5]
11
+
12
+ hyperparams = dict(
13
+ crop_size=256,
14
+ )
15
+
16
+ def get_transforms(mean, std, crop_size=64):
17
+
18
+ val_transform = tr.Compose([
19
+ tr.Resize((crop_size, crop_size)),
20
+ tr.ToTensor(),
21
+ tr.Normalize(mean=mean, std=std)
22
+ ])
23
+
24
+ def de_normalize(tensor):
25
+ denorm = tr.Normalize(
26
+ mean=[-m / s for m, s in zip(mean, std)],
27
+ std=[1 / s for s in std]
28
+ )
29
+ return denorm(tensor.clone()).clamp(0, 1)
30
+
31
+ return val_transform, de_normalize
32
+
33
+
34
+ val_transform, de_normalize = get_transforms(mean, std, **hyperparams)
35
+
36
+
37
+ # === Загрузка модели ===
38
+ @st.cache_resource
39
+ def load_model():
40
+ checkpoint = torch.load("cycle_gan_face.pt", map_location="cpu", weights_only = False)
41
+ model = CycleGAN(Discriminator, Generator)
42
+ model.load_state_dict(checkpoint['model_state_dict'])
43
+ model.eval()
44
+ return model
45
+
46
+
47
+ model = load_model()
48
+
49
+ # === Streamlit UI ===
50
+ st.title("Обработка изображений через PyTorch модель")
51
+
52
+ uploaded_file_1 = st.file_uploader("Загрузите изображение белого человека", type=["jpg", "jpeg", "png"], key="file1")
53
+ uploaded_file_2 = st.file_uploader("Загрузите изображение черного человека", type=["jpg", "jpeg", "png"], key="file2")
54
+
55
+ selected = st.radio("Выберите изображение для обработки", ["Первое", "Второе"])
56
+
57
+ if uploaded_file_1 and uploaded_file_2:
58
+ image1 = Image.open(uploaded_file_1).convert("RGB")
59
+ image2 = Image.open(uploaded_file_2).convert("RGB")
60
+
61
+ if selected == "Первое":
62
+ selected_image = image1
63
+ st.image(selected_image, caption="Выбранное изображение", use_column_width=True)
64
+ tensor = val_transform(selected_image).unsqueeze(0) # B x C x H x W
65
+
66
+ with torch.no_grad():
67
+ output = model.netG_A2B(tensor)
68
+ else:
69
+ selected_image = image2
70
+ st.image(selected_image, caption="Выбранное изображение", use_column_width=True)
71
+ tensor = val_transform(selected_image).unsqueeze(0) # B x C x H x W
72
+
73
+ with torch.no_grad():
74
+ output = model.netG_B2A(tensor)
75
+
76
+
77
+ # Де-нормализуем и показываем результат
78
+ result_image = de_normalize(output.squeeze(0)).permute(1, 2, 0).numpy()
79
+ result_image = (result_image * 255).astype(np.uint8)
80
+
81
+ st.image(result_image, caption="Результат модели", use_column_width=True)
82
+ else:
83
+ st.info("Пожалуйста, загрузите оба изображения.")
model.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class Discriminator(nn.Module):
7
+ def __init__(self, dropout_prob=0.3):
8
+ super(Discriminator, self).__init__()
9
+
10
+ self.main = nn.Sequential(
11
+ nn.Conv2d(3, 64, 4, stride=2, padding=1),
12
+ nn.LeakyReLU(0.2, inplace=True),
13
+ nn.Dropout2d(p=dropout_prob),
14
+
15
+ nn.Conv2d(64, 128, 4, stride=2, padding=1),
16
+ nn.InstanceNorm2d(128),
17
+ nn.LeakyReLU(0.2, inplace=True),
18
+ nn.Dropout2d(p=dropout_prob),
19
+
20
+ nn.Conv2d(128, 256, 4, stride=2, padding=1),
21
+ nn.InstanceNorm2d(256),
22
+ nn.LeakyReLU(0.2, inplace=True),
23
+ nn.Dropout2d(p=dropout_prob),
24
+
25
+ nn.Conv2d(256, 512, 4, padding=1),
26
+ nn.InstanceNorm2d(512),
27
+ nn.LeakyReLU(0.2, inplace=True),
28
+ nn.Dropout2d(p=dropout_prob),
29
+
30
+ nn.Conv2d(512, 1, 4, padding=1),
31
+ )
32
+
33
+ def forward(self, x):
34
+ x = self.main(x)
35
+ x = F.avg_pool2d(x, x.size()[2:])
36
+ x = torch.flatten(x, 1)
37
+ x = torch.sigmoid(x)
38
+ return x
39
+
40
+
41
+ class Generator(nn.Module):
42
+ def __init__(self):
43
+ super(Generator, self).__init__()
44
+ self.main = nn.Sequential(
45
+ # Initial convolution block
46
+ nn.ReflectionPad2d(3),
47
+ nn.Conv2d(3, 64, 7),
48
+ nn.InstanceNorm2d(64),
49
+ nn.ReLU(inplace=True),
50
+
51
+ # Downsampling
52
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
53
+ nn.InstanceNorm2d(128),
54
+ nn.ReLU(inplace=True),
55
+ nn.Conv2d(128, 256, 3, stride=2, padding=1),
56
+ nn.InstanceNorm2d(256),
57
+ nn.ReLU(inplace=True),
58
+
59
+ # Residual blocks
60
+ ResidualBlock(256),
61
+ ResidualBlock(256),
62
+ ResidualBlock(256),
63
+ ResidualBlock(256),
64
+ ResidualBlock(256),
65
+ ResidualBlock(256),
66
+ ResidualBlock(256),
67
+ ResidualBlock(256),
68
+ ResidualBlock(256),
69
+
70
+ # Upsampling
71
+ nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
72
+ nn.InstanceNorm2d(128),
73
+ nn.ReLU(inplace=True),
74
+ nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
75
+ nn.InstanceNorm2d(64),
76
+ nn.ReLU(inplace=True),
77
+
78
+ # Output layer
79
+ nn.ReflectionPad2d(3),
80
+ nn.Conv2d(64, 3, 7),
81
+ nn.Tanh()
82
+ )
83
+
84
+ def forward(self, x):
85
+ return self.main(x)
86
+
87
+
88
+ class ResidualBlock(nn.Module):
89
+ def __init__(self, in_channels):
90
+ super(ResidualBlock, self).__init__()
91
+
92
+ self.res = nn.Sequential(nn.ReflectionPad2d(1),
93
+ nn.Conv2d(in_channels, in_channels, 3),
94
+ nn.InstanceNorm2d(in_channels),
95
+ nn.ReLU(inplace=True),
96
+ nn.ReflectionPad2d(1),
97
+ nn.Conv2d(in_channels, in_channels, 3),
98
+ nn.InstanceNorm2d(in_channels))
99
+
100
+ def forward(self, x):
101
+ return x + self.res(x)
102
+
103
+
104
+ class CycleGAN(nn.Module):
105
+ def __init__(self, descriminator, generator):
106
+ super(CycleGAN, self).__init__()
107
+
108
+ self.netG_A2B = generator()
109
+ self.netG_B2A = generator()
110
+ self.netD_A = descriminator()
111
+ self.netD_B = descriminator()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ torch==2.6.0
2
+ torchvision==0.21.0
3
+ streamlit==1.44.1