import streamlit as st import torch import numpy as np from PIL import Image import torchvision.transforms as T from model import load_generator from io import BytesIO st.set_page_config(page_title="Summer โ†” Winter CycleGAN", page_icon="๐Ÿ”๏ธ", layout="centered") st.title("๐Ÿ”๏ธ Summer โ†” Winter Translation") st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.") @st.cache_resource def get_generators(): device = "cpu" gen_a2b = load_generator("gen_a2b_fp16.pth", device) gen_b2a = load_generator("gen_b2a_fp16.pth", device) return gen_a2b, gen_b2a, device gen_a2b, gen_b2a, device = get_generators() st.success(f"Model loaded on **{device}**", icon="โœ…") MEAN = (0.5, 0.5, 0.5) STD = (0.5, 0.5, 0.5) to_tensor = T.Compose([ T.Resize((256, 256)), T.ToTensor(), T.Normalize(MEAN, STD), ]) def to_pil(tensor, original_size=None): img = tensor.squeeze(0).cpu().float() for i, (m, s) in enumerate(zip(MEAN, STD)): img[i] = img[i] * s + m img = torch.clamp(img, 0, 1) pil_img = T.ToPILImage()(img) if original_size: pil_img = pil_img.resize(original_size, Image.LANCZOS) return pil_img direction = st.radio( "Translation direction", ["โ˜€๏ธ Summer โ†’ โ„๏ธ Winter", "โ„๏ธ Winter โ†’ โ˜€๏ธ Summer"], horizontal=True, ) uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"]) if uploaded is not None: try: raw = uploaded.read() input_img = Image.open(BytesIO(raw)).convert("RGB") original_size = input_img.size col1, col2 = st.columns(2) with col1: st.subheader("Input") st.image(input_img, use_container_width=True) with st.spinner("Translating..."): tensor = to_tensor(input_img).unsqueeze(0).to(device) generator = gen_a2b if "Summer" in direction.split("โ†’")[0] else gen_b2a with torch.no_grad(): output_tensor = generator(tensor) output_img = to_pil(output_tensor, original_size=original_size) with col2: st.subheader("Output") st.image(output_img, use_container_width=True) buf = BytesIO() output_img.save(buf, format="PNG") st.download_button("โฌ‡๏ธ Download result", buf.getvalue(), "translated.png", "image/png") except Exception as error: st.error(f"Error: {error}") st.exception(error) st.markdown("---") st.markdown("**Model:** CycleGAN ResNet-9 blocks (64 channels) ยท **Train / Test:** Yosemite / Alpine (Unsplash)")