Upload 5 files
Browse fileschanged to 9 blocks
- app.py +77 -77
- gen_a2b_fp16.pth +2 -2
- gen_b2a_fp16.pth +2 -2
- model.py +2 -2
app.py
CHANGED
|
@@ -1,77 +1,77 @@
|
|
| 1 |
-
import streamlit as st
|
| 2 |
-
import torch
|
| 3 |
-
import numpy as np
|
| 4 |
-
from PIL import Image
|
| 5 |
-
import torchvision.transforms as T
|
| 6 |
-
from model import load_generator
|
| 7 |
-
from io import BytesIO
|
| 8 |
-
|
| 9 |
-
st.set_page_config(page_title="Summer β Winter CycleGAN", page_icon="ποΈ", layout="centered")
|
| 10 |
-
st.title("ποΈ Summer β Winter Translation")
|
| 11 |
-
st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")
|
| 12 |
-
|
| 13 |
-
@st.cache_resource
|
| 14 |
-
def get_generators():
|
| 15 |
-
device = "cpu"
|
| 16 |
-
gen_a2b = load_generator("gen_a2b_fp16.pth", device)
|
| 17 |
-
gen_b2a = load_generator("gen_b2a_fp16.pth", device)
|
| 18 |
-
return gen_a2b, gen_b2a, device
|
| 19 |
-
|
| 20 |
-
gen_a2b, gen_b2a, device = get_generators()
|
| 21 |
-
st.success(f"Model loaded on **{device}**", icon="β
")
|
| 22 |
-
|
| 23 |
-
MEAN = (0.5, 0.5, 0.5)
|
| 24 |
-
STD = (0.5, 0.5, 0.5)
|
| 25 |
-
|
| 26 |
-
to_tensor = T.Compose([
|
| 27 |
-
T.Resize((256, 256)),
|
| 28 |
-
T.ToTensor(),
|
| 29 |
-
T.Normalize(MEAN, STD),
|
| 30 |
-
])
|
| 31 |
-
|
| 32 |
-
def to_pil(tensor):
|
| 33 |
-
img = tensor.squeeze(0).cpu().float()
|
| 34 |
-
for i, (m, s) in enumerate(zip(MEAN, STD)):
|
| 35 |
-
img[i] = img[i] * s + m
|
| 36 |
-
img = torch.clamp(img, 0, 1)
|
| 37 |
-
return T.ToPILImage()(img)
|
| 38 |
-
|
| 39 |
-
direction = st.radio(
|
| 40 |
-
"Translation direction",
|
| 41 |
-
["βοΈ Summer β βοΈ Winter", "βοΈ Winter β βοΈ Summer"],
|
| 42 |
-
horizontal=True,
|
| 43 |
-
)
|
| 44 |
-
|
| 45 |
-
uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])
|
| 46 |
-
|
| 47 |
-
if uploaded is not None:
|
| 48 |
-
try:
|
| 49 |
-
raw = uploaded.read()
|
| 50 |
-
input_img = Image.open(BytesIO(raw)).convert("RGB")
|
| 51 |
-
|
| 52 |
-
col1, col2 = st.columns(2)
|
| 53 |
-
with col1:
|
| 54 |
-
st.subheader("Input")
|
| 55 |
-
st.image(input_img, use_container_width=True)
|
| 56 |
-
|
| 57 |
-
with st.spinner("Translating..."):
|
| 58 |
-
tensor = to_tensor(input_img).unsqueeze(0).to(device)
|
| 59 |
-
generator = gen_a2b if "Summer" in direction.split("β")[0] else gen_b2a
|
| 60 |
-
with torch.no_grad():
|
| 61 |
-
output_tensor = generator(tensor)
|
| 62 |
-
output_img = to_pil(output_tensor)
|
| 63 |
-
|
| 64 |
-
with col2:
|
| 65 |
-
st.subheader("Output")
|
| 66 |
-
st.image(output_img, use_container_width=True)
|
| 67 |
-
|
| 68 |
-
buf = BytesIO()
|
| 69 |
-
output_img.save(buf, format="PNG")
|
| 70 |
-
st.download_button("β¬οΈ Download result", buf.getvalue(), "translated.png", "image/png")
|
| 71 |
-
|
| 72 |
-
except Exception as error:
|
| 73 |
-
st.error(f"Error: {error}")
|
| 74 |
-
st.exception(error)
|
| 75 |
-
|
| 76 |
-
st.markdown("---")
|
| 77 |
-
st.markdown("**Model:** CycleGAN ResNet-
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import torch
|
| 3 |
+
import numpy as np
|
| 4 |
+
from PIL import Image
|
| 5 |
+
import torchvision.transforms as T
|
| 6 |
+
from model import load_generator
|
| 7 |
+
from io import BytesIO
|
| 8 |
+
|
| 9 |
+
st.set_page_config(page_title="Summer β Winter CycleGAN", page_icon="ποΈ", layout="centered")
|
| 10 |
+
st.title("ποΈ Summer β Winter Translation")
|
| 11 |
+
st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")
|
| 12 |
+
|
| 13 |
+
@st.cache_resource
|
| 14 |
+
def get_generators():
|
| 15 |
+
device = "cpu"
|
| 16 |
+
gen_a2b = load_generator("gen_a2b_fp16.pth", device)
|
| 17 |
+
gen_b2a = load_generator("gen_b2a_fp16.pth", device)
|
| 18 |
+
return gen_a2b, gen_b2a, device
|
| 19 |
+
|
| 20 |
+
gen_a2b, gen_b2a, device = get_generators()
|
| 21 |
+
st.success(f"Model loaded on **{device}**", icon="β
")
|
| 22 |
+
|
| 23 |
+
MEAN = (0.5, 0.5, 0.5)
|
| 24 |
+
STD = (0.5, 0.5, 0.5)
|
| 25 |
+
|
| 26 |
+
to_tensor = T.Compose([
|
| 27 |
+
T.Resize((256, 256)),
|
| 28 |
+
T.ToTensor(),
|
| 29 |
+
T.Normalize(MEAN, STD),
|
| 30 |
+
])
|
| 31 |
+
|
| 32 |
+
def to_pil(tensor):
|
| 33 |
+
img = tensor.squeeze(0).cpu().float()
|
| 34 |
+
for i, (m, s) in enumerate(zip(MEAN, STD)):
|
| 35 |
+
img[i] = img[i] * s + m
|
| 36 |
+
img = torch.clamp(img, 0, 1)
|
| 37 |
+
return T.ToPILImage()(img)
|
| 38 |
+
|
| 39 |
+
direction = st.radio(
|
| 40 |
+
"Translation direction",
|
| 41 |
+
["βοΈ Summer β βοΈ Winter", "βοΈ Winter β βοΈ Summer"],
|
| 42 |
+
horizontal=True,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])
|
| 46 |
+
|
| 47 |
+
if uploaded is not None:
|
| 48 |
+
try:
|
| 49 |
+
raw = uploaded.read()
|
| 50 |
+
input_img = Image.open(BytesIO(raw)).convert("RGB")
|
| 51 |
+
|
| 52 |
+
col1, col2 = st.columns(2)
|
| 53 |
+
with col1:
|
| 54 |
+
st.subheader("Input")
|
| 55 |
+
st.image(input_img, use_container_width=True)
|
| 56 |
+
|
| 57 |
+
with st.spinner("Translating..."):
|
| 58 |
+
tensor = to_tensor(input_img).unsqueeze(0).to(device)
|
| 59 |
+
generator = gen_a2b if "Summer" in direction.split("β")[0] else gen_b2a
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
output_tensor = generator(tensor)
|
| 62 |
+
output_img = to_pil(output_tensor)
|
| 63 |
+
|
| 64 |
+
with col2:
|
| 65 |
+
st.subheader("Output")
|
| 66 |
+
st.image(output_img, use_container_width=True)
|
| 67 |
+
|
| 68 |
+
buf = BytesIO()
|
| 69 |
+
output_img.save(buf, format="PNG")
|
| 70 |
+
st.download_button("β¬οΈ Download result", buf.getvalue(), "translated.png", "image/png")
|
| 71 |
+
|
| 72 |
+
except Exception as error:
|
| 73 |
+
st.error(f"Error: {error}")
|
| 74 |
+
st.exception(error)
|
| 75 |
+
|
| 76 |
+
st.markdown("---")
|
| 77 |
+
st.markdown("**Model:** CycleGAN ResNet-9 blocks (64 channels) Β· **Train / Test:** Yosemite / Alpine (Unsplash)")
|
gen_a2b_fp16.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:da5c03b451d035d13e78a7de49b0b41dee5d6c06c60b8d522f9d7652a376e64c
|
| 3 |
+
size 22771407
|
gen_b2a_fp16.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:68096562fc25dcecf71a6dd7f2f6299a4f263941eae27af31acd6bd3275d9b2b
|
| 3 |
+
size 22771407
|
model.py
CHANGED
|
@@ -18,7 +18,7 @@ class ResidualBlock(nn.Module):
|
|
| 18 |
return x + self.block(x)
|
| 19 |
|
| 20 |
class ResNetGenerator(nn.Module):
|
| 21 |
-
def __init__(self, in_channels=3, out_channels=3, n_filters=64, n_res_blocks=
|
| 22 |
super().__init__()
|
| 23 |
model = [
|
| 24 |
nn.ReflectionPad2d(3),
|
|
@@ -66,4 +66,4 @@ def load_generator(path, device="cpu"):
|
|
| 66 |
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 67 |
gen.load_state_dict(state_dict)
|
| 68 |
gen.to(device).eval()
|
| 69 |
-
return gen
|
|
|
|
| 18 |
return x + self.block(x)
|
| 19 |
|
| 20 |
class ResNetGenerator(nn.Module):
|
| 21 |
+
def __init__(self, in_channels=3, out_channels=3, n_filters=64, n_res_blocks=9):
|
| 22 |
super().__init__()
|
| 23 |
model = [
|
| 24 |
nn.ReflectionPad2d(3),
|
|
|
|
| 66 |
state_dict = {k: v.float() for k, v in state_dict.items()}
|
| 67 |
gen.load_state_dict(state_dict)
|
| 68 |
gen.to(device).eval()
|
| 69 |
+
return gen
|