Roman79 commited on
Commit
29c3d64
Β·
verified Β·
1 Parent(s): 8481d5d

Upload 5 files

Browse files

changed to 9 blocks

Files changed (4) hide show
  1. app.py +77 -77
  2. gen_a2b_fp16.pth +2 -2
  3. gen_b2a_fp16.pth +2 -2
  4. model.py +2 -2
app.py CHANGED
@@ -1,77 +1,77 @@
1
- import streamlit as st
2
- import torch
3
- import numpy as np
4
- from PIL import Image
5
- import torchvision.transforms as T
6
- from model import load_generator
7
- from io import BytesIO
8
-
9
- st.set_page_config(page_title="Summer ↔ Winter CycleGAN", page_icon="πŸ”οΈ", layout="centered")
10
- st.title("πŸ”οΈ Summer ↔ Winter Translation")
11
- st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")
12
-
13
- @st.cache_resource
14
- def get_generators():
15
- device = "cpu"
16
- gen_a2b = load_generator("gen_a2b_fp16.pth", device)
17
- gen_b2a = load_generator("gen_b2a_fp16.pth", device)
18
- return gen_a2b, gen_b2a, device
19
-
20
- gen_a2b, gen_b2a, device = get_generators()
21
- st.success(f"Model loaded on **{device}**", icon="βœ…")
22
-
23
- MEAN = (0.5, 0.5, 0.5)
24
- STD = (0.5, 0.5, 0.5)
25
-
26
- to_tensor = T.Compose([
27
- T.Resize((256, 256)),
28
- T.ToTensor(),
29
- T.Normalize(MEAN, STD),
30
- ])
31
-
32
- def to_pil(tensor):
33
- img = tensor.squeeze(0).cpu().float()
34
- for i, (m, s) in enumerate(zip(MEAN, STD)):
35
- img[i] = img[i] * s + m
36
- img = torch.clamp(img, 0, 1)
37
- return T.ToPILImage()(img)
38
-
39
- direction = st.radio(
40
- "Translation direction",
41
- ["β˜€οΈ Summer β†’ ❄️ Winter", "❄️ Winter β†’ β˜€οΈ Summer"],
42
- horizontal=True,
43
- )
44
-
45
- uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])
46
-
47
- if uploaded is not None:
48
- try:
49
- raw = uploaded.read()
50
- input_img = Image.open(BytesIO(raw)).convert("RGB")
51
-
52
- col1, col2 = st.columns(2)
53
- with col1:
54
- st.subheader("Input")
55
- st.image(input_img, use_container_width=True)
56
-
57
- with st.spinner("Translating..."):
58
- tensor = to_tensor(input_img).unsqueeze(0).to(device)
59
- generator = gen_a2b if "Summer" in direction.split("β†’")[0] else gen_b2a
60
- with torch.no_grad():
61
- output_tensor = generator(tensor)
62
- output_img = to_pil(output_tensor)
63
-
64
- with col2:
65
- st.subheader("Output")
66
- st.image(output_img, use_container_width=True)
67
-
68
- buf = BytesIO()
69
- output_img.save(buf, format="PNG")
70
- st.download_button("⬇️ Download result", buf.getvalue(), "translated.png", "image/png")
71
-
72
- except Exception as error:
73
- st.error(f"Error: {error}")
74
- st.exception(error)
75
-
76
- st.markdown("---")
77
- st.markdown("**Model:** CycleGAN ResNet-6 Β· **Train:** Yosemite + Alpine (Unsplash)")
 
1
+ import streamlit as st
2
+ import torch
3
+ import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as T
6
+ from model import load_generator
7
+ from io import BytesIO
8
+
9
+ st.set_page_config(page_title="Summer ↔ Winter CycleGAN", page_icon="πŸ”οΈ", layout="centered")
10
+ st.title("πŸ”οΈ Summer ↔ Winter Translation")
11
+ st.markdown("Upload a landscape photo and convert it between **summer** and **winter**.")
12
+
13
+ @st.cache_resource
14
+ def get_generators():
15
+ device = "cpu"
16
+ gen_a2b = load_generator("gen_a2b_fp16.pth", device)
17
+ gen_b2a = load_generator("gen_b2a_fp16.pth", device)
18
+ return gen_a2b, gen_b2a, device
19
+
20
+ gen_a2b, gen_b2a, device = get_generators()
21
+ st.success(f"Model loaded on **{device}**", icon="βœ…")
22
+
23
+ MEAN = (0.5, 0.5, 0.5)
24
+ STD = (0.5, 0.5, 0.5)
25
+
26
+ to_tensor = T.Compose([
27
+ T.Resize((256, 256)),
28
+ T.ToTensor(),
29
+ T.Normalize(MEAN, STD),
30
+ ])
31
+
32
+ def to_pil(tensor):
33
+ img = tensor.squeeze(0).cpu().float()
34
+ for i, (m, s) in enumerate(zip(MEAN, STD)):
35
+ img[i] = img[i] * s + m
36
+ img = torch.clamp(img, 0, 1)
37
+ return T.ToPILImage()(img)
38
+
39
+ direction = st.radio(
40
+ "Translation direction",
41
+ ["β˜€οΈ Summer β†’ ❄️ Winter", "❄️ Winter β†’ β˜€οΈ Summer"],
42
+ horizontal=True,
43
+ )
44
+
45
+ uploaded = st.file_uploader("Upload landscape photo (JPG/PNG)", type=["jpg", "jpeg", "png"])
46
+
47
+ if uploaded is not None:
48
+ try:
49
+ raw = uploaded.read()
50
+ input_img = Image.open(BytesIO(raw)).convert("RGB")
51
+
52
+ col1, col2 = st.columns(2)
53
+ with col1:
54
+ st.subheader("Input")
55
+ st.image(input_img, use_container_width=True)
56
+
57
+ with st.spinner("Translating..."):
58
+ tensor = to_tensor(input_img).unsqueeze(0).to(device)
59
+ generator = gen_a2b if "Summer" in direction.split("β†’")[0] else gen_b2a
60
+ with torch.no_grad():
61
+ output_tensor = generator(tensor)
62
+ output_img = to_pil(output_tensor)
63
+
64
+ with col2:
65
+ st.subheader("Output")
66
+ st.image(output_img, use_container_width=True)
67
+
68
+ buf = BytesIO()
69
+ output_img.save(buf, format="PNG")
70
+ st.download_button("⬇️ Download result", buf.getvalue(), "translated.png", "image/png")
71
+
72
+ except Exception as error:
73
+ st.error(f"Error: {error}")
74
+ st.exception(error)
75
+
76
+ st.markdown("---")
77
+ st.markdown("**Model:** CycleGAN ResNet-9 blocks (64 channels) Β· **Train / Test:** Yosemite / Alpine (Unsplash)")
gen_a2b_fp16.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:34cd9e43b8d797b91fe295fb6cb81569d88eb1246a2838e23681395e85035fc0
3
- size 15686583
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:da5c03b451d035d13e78a7de49b0b41dee5d6c06c60b8d522f9d7652a376e64c
3
+ size 22771407
gen_b2a_fp16.pth CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:2ef6e8fa18dac09a8df1452e82616c64f77d476c472df794e082ed381abd6f00
3
- size 15686583
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:68096562fc25dcecf71a6dd7f2f6299a4f263941eae27af31acd6bd3275d9b2b
3
+ size 22771407
model.py CHANGED
@@ -18,7 +18,7 @@ class ResidualBlock(nn.Module):
18
  return x + self.block(x)
19
 
20
  class ResNetGenerator(nn.Module):
21
- def __init__(self, in_channels=3, out_channels=3, n_filters=64, n_res_blocks=6):
22
  super().__init__()
23
  model = [
24
  nn.ReflectionPad2d(3),
@@ -66,4 +66,4 @@ def load_generator(path, device="cpu"):
66
  state_dict = {k: v.float() for k, v in state_dict.items()}
67
  gen.load_state_dict(state_dict)
68
  gen.to(device).eval()
69
- return gen
 
18
  return x + self.block(x)
19
 
20
  class ResNetGenerator(nn.Module):
21
+ def __init__(self, in_channels=3, out_channels=3, n_filters=64, n_res_blocks=9):
22
  super().__init__()
23
  model = [
24
  nn.ReflectionPad2d(3),
 
66
  state_dict = {k: v.float() for k, v in state_dict.items()}
67
  gen.load_state_dict(state_dict)
68
  gen.to(device).eval()
69
+ return gen