File size: 2,726 Bytes
29c3d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6abaa
29c3d64
 
 
 
4f6abaa
 
 
 
 
 
29c3d64
 
 
 
 
 
 
 
 
 
 
 
 
c2fc8e6
29c3d64
 
 
 
 
 
 
 
 
 
 
4f6abaa
29c3d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import streamlit as st
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from model import load_generator
from io import BytesIO

st.set_page_config(page_title="Summer ↔ Winter CycleGAN", page_icon="πŸ”οΈ", layout="centered")
st.title("πŸ”οΈ Summer ↔ Winter Translation")
st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")

@st.cache_resource
def get_generators():
    device = "cpu"
    gen_a2b = load_generator("gen_a2b_fp16.pth", device)
    gen_b2a = load_generator("gen_b2a_fp16.pth", device)
    return gen_a2b, gen_b2a, device

gen_a2b, gen_b2a, device = get_generators()
st.success(f"Model loaded on **{device}**", icon="βœ…")

MEAN = (0.5, 0.5, 0.5)
STD  = (0.5, 0.5, 0.5)

to_tensor = T.Compose([
    T.Resize((256, 256)),
    T.ToTensor(),
    T.Normalize(MEAN, STD),
])

def to_pil(tensor, original_size=None):
    img = tensor.squeeze(0).cpu().float()
    for i, (m, s) in enumerate(zip(MEAN, STD)):
        img[i] = img[i] * s + m
    img = torch.clamp(img, 0, 1)
    
    pil_img = T.ToPILImage()(img)
    
    if original_size:
        pil_img = pil_img.resize(original_size, Image.LANCZOS)
    return pil_img

direction = st.radio(
    "Translation direction",
    ["β˜€οΈ Summer β†’ ❄️ Winter", "❄️ Winter β†’ β˜€οΈ Summer"],
    horizontal=True,
)

uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])

if uploaded is not None:
    try:
        raw = uploaded.read()
        input_img = Image.open(BytesIO(raw)).convert("RGB")
        original_size = input_img.size
        
        col1, col2 = st.columns(2)
        with col1:
            st.subheader("Input")
            st.image(input_img, use_container_width=True)

        with st.spinner("Translating..."):
            tensor = to_tensor(input_img).unsqueeze(0).to(device)
            generator = gen_a2b if "Summer" in direction.split("β†’")[0] else gen_b2a
            with torch.no_grad():
                output_tensor = generator(tensor)
            output_img = to_pil(output_tensor, original_size=original_size)

        with col2:
            st.subheader("Output")
            st.image(output_img, use_container_width=True)

        buf = BytesIO()
        output_img.save(buf, format="PNG")
        st.download_button("⬇️ Download result", buf.getvalue(), "translated.png", "image/png")

    except Exception as error:
        st.error(f"Error: {error}")
        st.exception(error)

st.markdown("---")
st.markdown("**Model:** CycleGAN ResNet-9 blocks (64 channels) Β· **Train / Test:** Yosemite / Alpine (Unsplash)")