import streamlit as st import torch from torchvision import transforms from PIL import Image from model import CycleGAN st.set_page_config(page_title="CycleGAN Demo", layout="wide") st.title("CycleGAN: Трансфер стилей") @st.cache_resource def load_model(): device = torch.device("cpu") model = CycleGAN() checkpoint = torch.load("cyclegan_150_epochs.pt", map_location=device, weights_only=False) model.load_state_dict(checkpoint['model_state_dict']) model.eval() return model, device model, device = load_model() IMG_SIZE = 128 transform = transforms.Compose([ transforms.Resize((IMG_SIZE, IMG_SIZE), Image.BICUBIC), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) ]) def de_normalize(tensor): tensor = tensor.cpu().squeeze(0) tensor = tensor * 0.5 + 0.5 tensor = torch.clamp(tensor, 0, 1) return tensor.permute(1, 2, 0).numpy() col1, col2 = st.columns(2) with col1: st.header("Домен A ➡️ Домен B") file_a = st.file_uploader("Загрузить фото A", type=["jpg", "png", "jpeg"], key="a") if file_a: img_a = Image.open(file_a).convert("RGB") st.image(img_a, caption="Оригинал") if st.button("Преобразовать", key="btn_a"): with st.spinner("Генерация..."): tensor = transform(img_a).unsqueeze(0) with torch.no_grad(): res = model.G_A2B(tensor) # Перевод из A в B st.image(de_normalize(res), caption="Результат") with col2: st.header("Домен B ➡️ Домен A") file_b = st.file_uploader("Загрузить фото B", type=["jpg", "png", "jpeg"], key="b") if file_b: img_b = Image.open(file_b).convert("RGB") st.image(img_b, caption="Оригинал") if st.button("Преобразовать", key="btn_b"): with st.spinner("Генерация..."): tensor = transform(img_b).unsqueeze(0) with torch.no_grad(): res = model.G_B2A(tensor) # Перевод из B в A st.image(de_normalize(res), caption="Результат")