CycleGAN / app.py
Roman79's picture
Upload app.py
c2fc8e6 verified
import streamlit as st
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from model import load_generator
from io import BytesIO
st.set_page_config(page_title="Summer ↔ Winter CycleGAN", page_icon="πŸ”οΈ", layout="centered")
st.title("πŸ”οΈ Summer ↔ Winter Translation")
st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")
@st.cache_resource
def get_generators():
device = "cpu"
gen_a2b = load_generator("gen_a2b_fp16.pth", device)
gen_b2a = load_generator("gen_b2a_fp16.pth", device)
return gen_a2b, gen_b2a, device
gen_a2b, gen_b2a, device = get_generators()
st.success(f"Model loaded on **{device}**", icon="βœ…")
MEAN = (0.5, 0.5, 0.5)
STD = (0.5, 0.5, 0.5)
to_tensor = T.Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(MEAN, STD),
])
def to_pil(tensor, original_size=None):
img = tensor.squeeze(0).cpu().float()
for i, (m, s) in enumerate(zip(MEAN, STD)):
img[i] = img[i] * s + m
img = torch.clamp(img, 0, 1)
pil_img = T.ToPILImage()(img)
if original_size:
pil_img = pil_img.resize(original_size, Image.LANCZOS)
return pil_img
direction = st.radio(
"Translation direction",
["β˜€οΈ Summer β†’ ❄️ Winter", "❄️ Winter β†’ β˜€οΈ Summer"],
horizontal=True,
)
uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])
if uploaded is not None:
try:
raw = uploaded.read()
input_img = Image.open(BytesIO(raw)).convert("RGB")
original_size = input_img.size
col1, col2 = st.columns(2)
with col1:
st.subheader("Input")
st.image(input_img, use_container_width=True)
with st.spinner("Translating..."):
tensor = to_tensor(input_img).unsqueeze(0).to(device)
generator = gen_a2b if "Summer" in direction.split("β†’")[0] else gen_b2a
with torch.no_grad():
output_tensor = generator(tensor)
output_img = to_pil(output_tensor, original_size=original_size)
with col2:
st.subheader("Output")
st.image(output_img, use_container_width=True)
buf = BytesIO()
output_img.save(buf, format="PNG")
st.download_button("⬇️ Download result", buf.getvalue(), "translated.png", "image/png")
except Exception as error:
st.error(f"Error: {error}")
st.exception(error)
st.markdown("---")
st.markdown("**Model:** CycleGAN ResNet-9 blocks (64 channels) Β· **Train / Test:** Yosemite / Alpine (Unsplash)")