File size: 2,248 Bytes
2f8fc6b
 
 
 
55be43d
2f8fc6b
 
55be43d
 
2f8fc6b
 
 
55be43d
2f8fc6b
 
55be43d
35a9c6d
2f8fc6b
 
 
 
 
 
55be43d
2f8fc6b
 
 
 
 
 
 
 
 
 
 
 
 
55be43d
2f8fc6b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import torch
from torchvision import transforms
from PIL import Image
from model import CycleGAN 

st.set_page_config(page_title="CycleGAN Demo", layout="wide")
st.title("CycleGAN: Трансфер стилей")


@st.cache_resource
def load_model():

    device = torch.device("cpu") 
    model = CycleGAN()

    checkpoint = torch.load("cyclegan_150_epochs.pt", map_location=device, weights_only=False)
    model.load_state_dict(checkpoint['model_state_dict'])
    model.eval()
    return model, device

model, device = load_model()


IMG_SIZE = 128 
transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE), Image.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

def de_normalize(tensor):
    tensor = tensor.cpu().squeeze(0)
    tensor = tensor * 0.5 + 0.5
    tensor = torch.clamp(tensor, 0, 1)
    return tensor.permute(1, 2, 0).numpy()


col1, col2 = st.columns(2)

with col1:
    st.header("Домен A ➡️ Домен B")
    file_a = st.file_uploader("Загрузить фото A", type=["jpg", "png", "jpeg"], key="a")
    if file_a:
        img_a = Image.open(file_a).convert("RGB")
        st.image(img_a, caption="Оригинал")
        if st.button("Преобразовать", key="btn_a"):
            with st.spinner("Генерация..."):
                tensor = transform(img_a).unsqueeze(0)
                with torch.no_grad():
                    res = model.G_A2B(tensor) # Перевод из A в B
                st.image(de_normalize(res), caption="Результат")

with col2:
    st.header("Домен B ➡️ Домен A")
    file_b = st.file_uploader("Загрузить фото B", type=["jpg", "png", "jpeg"], key="b")
    if file_b:
        img_b = Image.open(file_b).convert("RGB")
        st.image(img_b, caption="Оригинал")
        if st.button("Преобразовать", key="btn_b"):
            with st.spinner("Генерация..."):
                tensor = transform(img_b).unsqueeze(0)
                with torch.no_grad():
                    res = model.G_B2A(tensor) # Перевод из B в A
                st.image(de_normalize(res), caption="Результат")