File size: 2,726 Bytes
29c3d64 4f6abaa 29c3d64 4f6abaa 29c3d64 c2fc8e6 29c3d64 4f6abaa 29c3d64 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 | import streamlit as st
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as T
from model import load_generator
from io import BytesIO
st.set_page_config(page_title="Summer β Winter CycleGAN", page_icon="ποΈ", layout="centered")
st.title("ποΈ Summer β Winter Translation")
st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")
@st.cache_resource
def get_generators():
device = "cpu"
gen_a2b = load_generator("gen_a2b_fp16.pth", device)
gen_b2a = load_generator("gen_b2a_fp16.pth", device)
return gen_a2b, gen_b2a, device
gen_a2b, gen_b2a, device = get_generators()
st.success(f"Model loaded on **{device}**", icon="β
")
MEAN = (0.5, 0.5, 0.5)
STD = (0.5, 0.5, 0.5)
to_tensor = T.Compose([
T.Resize((256, 256)),
T.ToTensor(),
T.Normalize(MEAN, STD),
])
def to_pil(tensor, original_size=None):
img = tensor.squeeze(0).cpu().float()
for i, (m, s) in enumerate(zip(MEAN, STD)):
img[i] = img[i] * s + m
img = torch.clamp(img, 0, 1)
pil_img = T.ToPILImage()(img)
if original_size:
pil_img = pil_img.resize(original_size, Image.LANCZOS)
return pil_img
direction = st.radio(
"Translation direction",
["βοΈ Summer β βοΈ Winter", "βοΈ Winter β βοΈ Summer"],
horizontal=True,
)
uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])
if uploaded is not None:
try:
raw = uploaded.read()
input_img = Image.open(BytesIO(raw)).convert("RGB")
original_size = input_img.size
col1, col2 = st.columns(2)
with col1:
st.subheader("Input")
st.image(input_img, use_container_width=True)
with st.spinner("Translating..."):
tensor = to_tensor(input_img).unsqueeze(0).to(device)
generator = gen_a2b if "Summer" in direction.split("β")[0] else gen_b2a
with torch.no_grad():
output_tensor = generator(tensor)
output_img = to_pil(output_tensor, original_size=original_size)
with col2:
st.subheader("Output")
st.image(output_img, use_container_width=True)
buf = BytesIO()
output_img.save(buf, format="PNG")
st.download_button("β¬οΈ Download result", buf.getvalue(), "translated.png", "image/png")
except Exception as error:
st.error(f"Error: {error}")
st.exception(error)
st.markdown("---")
st.markdown("**Model:** CycleGAN ResNet-9 blocks (64 channels) Β· **Train / Test:** Yosemite / Alpine (Unsplash)") |