Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| import torchvision.transforms as tr | |
| from PIL import Image | |
| import numpy as np | |
| from model import CycleGAN, Discriminator, Generator | |
| mean = [0.5, 0.5, 0.5] | |
| std = [0.5, 0.5, 0.5] | |
| hyperparams = dict( | |
| crop_size=256, | |
| ) | |
| def get_transforms(mean, std, crop_size=64): | |
| val_transform = tr.Compose([ | |
| tr.Resize((crop_size, crop_size)), | |
| tr.ToTensor(), | |
| tr.Normalize(mean=mean, std=std) | |
| ]) | |
| def de_normalize(tensor): | |
| denorm = tr.Normalize( | |
| mean=[-m / s for m, s in zip(mean, std)], | |
| std=[1 / s for s in std] | |
| ) | |
| return denorm(tensor.clone()).clamp(0, 1) | |
| return val_transform, de_normalize | |
| val_transform, de_normalize = get_transforms(mean, std, **hyperparams) | |
| def load_model(): | |
| checkpoint = torch.load("cycle_gan_face.pt", map_location="cpu", weights_only = False) | |
| model = CycleGAN(Discriminator, Generator) | |
| model.load_state_dict(checkpoint['model_state_dict']) | |
| model.eval() | |
| return model | |
| model = load_model() | |
| # === Streamlit UI === | |
| st.title("Обработка изображений через PyTorch модель") | |
| uploaded_file_1 = st.file_uploader("Загрузите изображение белого человека", type=["jpg", "jpeg", "png"], key="file1") | |
| uploaded_file_2 = st.file_uploader("Загрузите изображение черного человека", type=["jpg", "jpeg", "png"], key="file2") | |
| selected = st.radio("Выберите изображение для обработки", ["Первое", "Второе"]) | |
| if uploaded_file_1 and uploaded_file_2: | |
| image1 = Image.open(uploaded_file_1).convert("RGB") | |
| image2 = Image.open(uploaded_file_2).convert("RGB") | |
| if selected == "Первое": | |
| selected_image = image1 | |
| tensor = val_transform(selected_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model.netG_B2A(tensor) | |
| else: | |
| selected_image = image2 | |
| tensor = val_transform(selected_image).unsqueeze(0) | |
| with torch.no_grad(): | |
| output = model.netG_A2B(tensor) | |
| result_image = de_normalize(output.squeeze(0)).permute(1, 2, 0).numpy() | |
| result_image = (result_image * 255).astype(np.uint8) | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(selected_image, caption="Оригинал", use_column_width=True) | |
| with col2: | |
| st.image(result_image, caption="Результат модели", use_column_width=True) | |
| else: | |
| st.info("Пожалуйста, загрузите оба изображения.") |