| import streamlit as st |
| import torch |
| from torchvision import transforms |
| from PIL import Image |
| from model import CycleGAN |
|
|
| st.set_page_config(page_title="CycleGAN Demo", layout="wide") |
| st.title("CycleGAN: Трансфер стилей") |
|
|
|
|
| @st.cache_resource |
| def load_model(): |
|
|
| device = torch.device("cpu") |
| model = CycleGAN() |
|
|
| checkpoint = torch.load("cyclegan_150_epochs.pt", map_location=device, weights_only=False) |
| model.load_state_dict(checkpoint['model_state_dict']) |
| model.eval() |
| return model, device |
|
|
| model, device = load_model() |
|
|
|
|
| IMG_SIZE = 128 |
| transform = transforms.Compose([ |
| transforms.Resize((IMG_SIZE, IMG_SIZE), Image.BICUBIC), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) |
| ]) |
|
|
| def de_normalize(tensor): |
| tensor = tensor.cpu().squeeze(0) |
| tensor = tensor * 0.5 + 0.5 |
| tensor = torch.clamp(tensor, 0, 1) |
| return tensor.permute(1, 2, 0).numpy() |
|
|
|
|
| col1, col2 = st.columns(2) |
|
|
| with col1: |
| st.header("Домен A ➡️ Домен B") |
| file_a = st.file_uploader("Загрузить фото A", type=["jpg", "png", "jpeg"], key="a") |
| if file_a: |
| img_a = Image.open(file_a).convert("RGB") |
| st.image(img_a, caption="Оригинал") |
| if st.button("Преобразовать", key="btn_a"): |
| with st.spinner("Генерация..."): |
| tensor = transform(img_a).unsqueeze(0) |
| with torch.no_grad(): |
| res = model.G_A2B(tensor) |
| st.image(de_normalize(res), caption="Результат") |
|
|
| with col2: |
| st.header("Домен B ➡️ Домен A") |
| file_b = st.file_uploader("Загрузить фото B", type=["jpg", "png", "jpeg"], key="b") |
| if file_b: |
| img_b = Image.open(file_b).convert("RGB") |
| st.image(img_b, caption="Оригинал") |
| if st.button("Преобразовать", key="btn_b"): |
| with st.spinner("Генерация..."): |
| tensor = transform(img_b).unsqueeze(0) |
| with torch.no_grad(): |
| res = model.G_B2A(tensor) |
| st.image(de_normalize(res), caption="Результат") |
|
|