| import streamlit as st
|
| import torch
|
| import numpy as np
|
| from PIL import Image
|
| import torchvision.transforms as T
|
| from model import load_generator
|
| from io import BytesIO
|
|
|
| st.set_page_config(page_title="Summer β Winter CycleGAN", page_icon="ποΈ", layout="centered")
|
| st.title("ποΈ Summer β Winter Translation")
|
| st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")
|
|
|
| @st.cache_resource
|
| def get_generators():
|
| device = "cpu"
|
| gen_a2b = load_generator("gen_a2b_fp16.pth", device)
|
| gen_b2a = load_generator("gen_b2a_fp16.pth", device)
|
| return gen_a2b, gen_b2a, device
|
|
|
| gen_a2b, gen_b2a, device = get_generators()
|
| st.success(f"Model loaded on **{device}**", icon="β
")
|
|
|
| MEAN = (0.5, 0.5, 0.5)
|
| STD = (0.5, 0.5, 0.5)
|
|
|
| to_tensor = T.Compose([
|
| T.Resize((256, 256)),
|
| T.ToTensor(),
|
| T.Normalize(MEAN, STD),
|
| ])
|
|
|
| def to_pil(tensor, original_size=None):
|
| img = tensor.squeeze(0).cpu().float()
|
| for i, (m, s) in enumerate(zip(MEAN, STD)):
|
| img[i] = img[i] * s + m
|
| img = torch.clamp(img, 0, 1)
|
|
|
| pil_img = T.ToPILImage()(img)
|
|
|
| if original_size:
|
| pil_img = pil_img.resize(original_size, Image.LANCZOS)
|
| return pil_img
|
|
|
| direction = st.radio(
|
| "Translation direction",
|
| ["βοΈ Summer β βοΈ Winter", "βοΈ Winter β βοΈ Summer"],
|
| horizontal=True,
|
| )
|
|
|
| uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])
|
|
|
| if uploaded is not None:
|
| try:
|
| raw = uploaded.read()
|
| input_img = Image.open(BytesIO(raw)).convert("RGB")
|
| original_size = input_img.size
|
|
|
| col1, col2 = st.columns(2)
|
| with col1:
|
| st.subheader("Input")
|
| st.image(input_img, use_container_width=True)
|
|
|
| with st.spinner("Translating..."):
|
| tensor = to_tensor(input_img).unsqueeze(0).to(device)
|
| generator = gen_a2b if "Summer" in direction.split("β")[0] else gen_b2a
|
| with torch.no_grad():
|
| output_tensor = generator(tensor)
|
| output_img = to_pil(output_tensor, original_size=original_size)
|
|
|
| with col2:
|
| st.subheader("Output")
|
| st.image(output_img, use_container_width=True)
|
|
|
| buf = BytesIO()
|
| output_img.save(buf, format="PNG")
|
| st.download_button("β¬οΈ Download result", buf.getvalue(), "translated.png", "image/png")
|
|
|
| except Exception as error:
|
| st.error(f"Error: {error}")
|
| st.exception(error)
|
|
|
| st.markdown("---")
|
| st.markdown("**Model:** CycleGAN ResNet-9 blocks (64 channels) Β· **Train / Test:** Yosemite / Alpine (Unsplash)") |