Cycle_gan_hw / app.py
DeniSSio's picture
Update app.py
c830076 verified
import streamlit as st
import torch
import torchvision.transforms as tr
from PIL import Image
import numpy as np
from model import CycleGAN, Discriminator, Generator
mean = [0.5, 0.5, 0.5]
std = [0.5, 0.5, 0.5]
hyperparams = dict(
crop_size=256,
)
def get_transforms(mean, std, crop_size=64):
val_transform = tr.Compose([
tr.Resize((crop_size, crop_size)),
tr.ToTensor(),
tr.Normalize(mean=mean, std=std)
])
def de_normalize(tensor):
denorm = tr.Normalize(
mean=[-m / s for m, s in zip(mean, std)],
std=[1 / s for s in std]
)
return denorm(tensor.clone()).clamp(0, 1)
return val_transform, de_normalize
val_transform, de_normalize = get_transforms(mean, std, **hyperparams)
@st.cache_resource
def load_model():
checkpoint = torch.load("cycle_gan_face.pt", map_location="cpu", weights_only = False)
model = CycleGAN(Discriminator, Generator)
model.load_state_dict(checkpoint['model_state_dict'])
model.eval()
return model
model = load_model()
# === Streamlit UI ===
st.title("Обработка изображений через PyTorch модель")
uploaded_file_1 = st.file_uploader("Загрузите изображение белого человека", type=["jpg", "jpeg", "png"], key="file1")
uploaded_file_2 = st.file_uploader("Загрузите изображение черного человека", type=["jpg", "jpeg", "png"], key="file2")
selected = st.radio("Выберите изображение для обработки", ["Первое", "Второе"])
if uploaded_file_1 and uploaded_file_2:
image1 = Image.open(uploaded_file_1).convert("RGB")
image2 = Image.open(uploaded_file_2).convert("RGB")
if selected == "Первое":
selected_image = image1
tensor = val_transform(selected_image).unsqueeze(0)
with torch.no_grad():
output = model.netG_B2A(tensor)
else:
selected_image = image2
tensor = val_transform(selected_image).unsqueeze(0)
with torch.no_grad():
output = model.netG_A2B(tensor)
result_image = de_normalize(output.squeeze(0)).permute(1, 2, 0).numpy()
result_image = (result_image * 255).astype(np.uint8)
col1, col2 = st.columns(2)
with col1:
st.image(selected_image, caption="Оригинал", use_column_width=True)
with col2:
st.image(result_image, caption="Результат модели", use_column_width=True)
else:
st.info("Пожалуйста, загрузите оба изображения.")