haiderakt commited on
Commit
ff3e89b
·
1 Parent(s): d7858e9

Deploy Streamlit app to Hugging Face Space

Browse files
README.md CHANGED
@@ -1,20 +1,36 @@
 
 
 
 
1
  ---
2
- title: Colorization
3
- emoji: 🚀
4
- colorFrom: red
5
- colorTo: red
6
- sdk: docker
7
- app_port: 8501
8
- tags:
9
- - streamlit
10
- pinned: false
11
- short_description: A Streamlit app that colorizes black-and-white images
12
- license: apache-2.0
 
 
 
 
 
 
 
 
 
13
  ---
14
 
15
- # Welcome to Streamlit!
16
 
17
- Edit `/src/streamlit_app.py` to customize this app to your heart's desire. :heart:
 
 
 
18
 
19
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
20
- forums](https://discuss.streamlit.io).
 
1
+ # 🎨 Image Colorization & Post-Processing Tool
2
+
3
+ This project provides a **Streamlit-based web app** for automatic image **colorization** and **basic post-processing** (sharpening, blurring, undo, saving). It's built using a pretrained deep learning model from [Zhang et al. (2017)](https://richzhang.github.io/colorization/).
4
+
5
  ---
6
+
7
+ ## 🚀 Features
8
+
9
+ - ✅ Upload grayscale image
10
+ - 🎨 Auto-colorize using `siggraph17` model
11
+ - 🧪 Sharpen or blur the result
12
+ - ↩️ Undo changes (1-step)
13
+ - 💾 Save processed image
14
+
15
+ ---
16
+
17
+ ## 🛠️ Technologies Used
18
+
19
+ - Python
20
+ - Streamlit
21
+ - PyTorch
22
+ - OpenCV
23
+ - PIL (Pillow)
24
+ - `colorizers` (Zhang's pretrained models)
25
+
26
  ---
27
 
28
+ ## 📦 Installation
29
 
30
+ ```bash
31
+ # Clone the repository
32
+ git clone https://github.com/haiderakt/image-colorization.git
33
+ cd image-colorization
34
 
35
+ # Install dependencies
36
+ pip install -r requirements.txt
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ import torch
4
+ import streamlit as st
5
+ from PIL import Image
6
+ import colorizers
7
+
8
+ # Load pretrained colorization model
9
+ model = colorizers.siggraph17(pretrained=True).eval()
10
+
11
+ # Session state init
12
+ if 'processed_image' not in st.session_state:
13
+ st.session_state.processed_image = None
14
+ if 'original_image' not in st.session_state:
15
+ st.session_state.original_image = None
16
+ if 'history' not in st.session_state:
17
+ st.session_state.history = []
18
+
19
+ # Convert OpenCV image to PIL
20
+ def display_image_cv2(image):
21
+ rgb_img = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
22
+ return Image.fromarray(rgb_img)
23
+
24
+ # Colorization logic
25
+ def colouring_image(file, model):
26
+ img = cv2.imdecode(np.frombuffer(file.read(), np.uint8), cv2.IMREAD_GRAYSCALE)
27
+ original = cv2.cvtColor(cv2.resize(img, (256, 256)), cv2.COLOR_GRAY2BGR)
28
+
29
+ img = cv2.resize(img, (256, 256)) / 255.0 * 100
30
+ img_tensor = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).float()
31
+
32
+ with torch.no_grad():
33
+ ab = model(img_tensor).cpu().numpy()[0].transpose((1, 2, 0))
34
+
35
+ lab = np.concatenate((img[:, :, np.newaxis], ab), axis=2)
36
+ bgr = cv2.cvtColor(lab.astype(np.float32), cv2.COLOR_Lab2BGR)
37
+ bgr = np.clip(bgr * 255, 0, 255).astype(np.uint8)
38
+
39
+ return bgr, original
40
+
41
+ # UI Setup
42
+ st.set_page_config(page_title="Image Colorizer", layout="wide")
43
+ st.title("🎨 Image Colorization and Post-Processing Tool")
44
+
45
+ uploaded_file = st.file_uploader("Upload a grayscale image", type=["jpg", "jpeg", "png", "bmp"])
46
+
47
+ if uploaded_file:
48
+ colorized, original = colouring_image(uploaded_file, model)
49
+ st.session_state.processed_image = colorized.copy()
50
+ st.session_state.original_image = original
51
+ st.session_state.history = [colorized.copy()]
52
+
53
+ st.subheader("Preview:")
54
+ col1, col2 = st.columns(2)
55
+ with col1:
56
+ st.image(display_image_cv2(original), caption="Original Image", use_container_width=True)
57
+ with col2:
58
+ st.image(display_image_cv2(colorized), caption="Colorized Image", use_container_width=True)
59
+
60
+ st.markdown("---")
61
+
62
+ # Button row
63
+ colA, colB, colC, colD = st.columns(4)
64
+
65
+ with colA:
66
+ if st.button("🔪 Sharpen"):
67
+ kernel = np.array([[0, -1, 0], [-1, 5, -1], [0, -1, 0]])
68
+ sharpened = cv2.filter2D(st.session_state.processed_image, -1, kernel)
69
+ st.session_state.history.append(st.session_state.processed_image.copy())
70
+ st.session_state.processed_image = sharpened
71
+ st.image(display_image_cv2(sharpened), caption="Sharpened Image", use_container_width=True)
72
+
73
+ with colB:
74
+ if st.button("💧 Blur"):
75
+ blurred = cv2.GaussianBlur(st.session_state.processed_image, (15, 15), 0)
76
+ st.session_state.history.append(st.session_state.processed_image.copy())
77
+ st.session_state.processed_image = blurred
78
+ st.image(display_image_cv2(blurred), caption="Blurred Image", use_container_width=True)
79
+
80
+ with colC:
81
+ if st.button("↩️ Undo"):
82
+ if len(st.session_state.history) > 1:
83
+ st.session_state.history.pop()
84
+ st.session_state.processed_image = st.session_state.history[-1]
85
+ st.image(display_image_cv2(st.session_state.processed_image), caption="Undo Applied", use_container_width=True)
86
+ else:
87
+ st.warning("Nothing to undo.")
88
+
89
+ with colD:
90
+ if st.session_state.processed_image is not None:
91
+ buf = cv2.imencode(".jpg", st.session_state.processed_image)[1].tobytes()
92
+ st.download_button(label="💾 Save Image", data=buf, file_name="colorized.jpg", mime="image/jpeg")
colorizers/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+
2
+ from .base_color import *
3
+ from .eccv16 import *
4
+ from .siggraph17 import *
5
+ from .util import *
6
+
colorizers/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (287 Bytes). View file
 
colorizers/__pycache__/__init__.cpython-37.pyc ADDED
Binary file (285 Bytes). View file
 
colorizers/__pycache__/base_color.cpython-312.pyc ADDED
Binary file (1.56 kB). View file
 
colorizers/__pycache__/base_color.cpython-37.pyc ADDED
Binary file (1.24 kB). View file
 
colorizers/__pycache__/eccv16.cpython-312.pyc ADDED
Binary file (7.32 kB). View file
 
colorizers/__pycache__/eccv16.cpython-37.pyc ADDED
Binary file (3.26 kB). View file
 
colorizers/__pycache__/siggraph17.cpython-312.pyc ADDED
Binary file (10.5 kB). View file
 
colorizers/__pycache__/siggraph17.cpython-37.pyc ADDED
Binary file (4.36 kB). View file
 
colorizers/__pycache__/util.cpython-312.pyc ADDED
Binary file (2.73 kB). View file
 
colorizers/__pycache__/util.cpython-37.pyc ADDED
Binary file (1.71 kB). View file
 
colorizers/base_color.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from torch import nn
4
+
5
+ class BaseColor(nn.Module):
6
+ def __init__(self):
7
+ super(BaseColor, self).__init__()
8
+
9
+ self.l_cent = 50.
10
+ self.l_norm = 100.
11
+ self.ab_norm = 110.
12
+
13
+ def normalize_l(self, in_l):
14
+ return (in_l-self.l_cent)/self.l_norm
15
+
16
+ def unnormalize_l(self, in_l):
17
+ return in_l*self.l_norm + self.l_cent
18
+
19
+ def normalize_ab(self, in_ab):
20
+ return in_ab/self.ab_norm
21
+
22
+ def unnormalize_ab(self, in_ab):
23
+ return in_ab*self.ab_norm
24
+
colorizers/eccv16.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from IPython import embed
6
+
7
+ from .base_color import *
8
+
9
+ class ECCVGenerator(BaseColor):
10
+ def __init__(self, norm_layer=nn.BatchNorm2d):
11
+ super(ECCVGenerator, self).__init__()
12
+
13
+ model1=[nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1, bias=True),]
16
+ model1+=[nn.ReLU(True),]
17
+ model1+=[norm_layer(64),]
18
+
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+
25
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
26
+ model3+=[nn.ReLU(True),]
27
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[norm_layer(256),]
32
+
33
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
34
+ model4+=[nn.ReLU(True),]
35
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
36
+ model4+=[nn.ReLU(True),]
37
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[norm_layer(512),]
40
+
41
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
42
+ model5+=[nn.ReLU(True),]
43
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
44
+ model5+=[nn.ReLU(True),]
45
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
46
+ model5+=[nn.ReLU(True),]
47
+ model5+=[norm_layer(512),]
48
+
49
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
50
+ model6+=[nn.ReLU(True),]
51
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
52
+ model6+=[nn.ReLU(True),]
53
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
54
+ model6+=[nn.ReLU(True),]
55
+ model6+=[norm_layer(512),]
56
+
57
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
58
+ model7+=[nn.ReLU(True),]
59
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
60
+ model7+=[nn.ReLU(True),]
61
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
62
+ model7+=[nn.ReLU(True),]
63
+ model7+=[norm_layer(512),]
64
+
65
+ model8=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True),]
66
+ model8+=[nn.ReLU(True),]
67
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
68
+ model8+=[nn.ReLU(True),]
69
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
70
+ model8+=[nn.ReLU(True),]
71
+
72
+ model8+=[nn.Conv2d(256, 313, kernel_size=1, stride=1, padding=0, bias=True),]
73
+
74
+ self.model1 = nn.Sequential(*model1)
75
+ self.model2 = nn.Sequential(*model2)
76
+ self.model3 = nn.Sequential(*model3)
77
+ self.model4 = nn.Sequential(*model4)
78
+ self.model5 = nn.Sequential(*model5)
79
+ self.model6 = nn.Sequential(*model6)
80
+ self.model7 = nn.Sequential(*model7)
81
+ self.model8 = nn.Sequential(*model8)
82
+
83
+ self.softmax = nn.Softmax(dim=1)
84
+ self.model_out = nn.Conv2d(313, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=False)
85
+ self.upsample4 = nn.Upsample(scale_factor=4, mode='bilinear')
86
+
87
+ def forward(self, input_l):
88
+ conv1_2 = self.model1(self.normalize_l(input_l))
89
+ conv2_2 = self.model2(conv1_2)
90
+ conv3_3 = self.model3(conv2_2)
91
+ conv4_3 = self.model4(conv3_3)
92
+ conv5_3 = self.model5(conv4_3)
93
+ conv6_3 = self.model6(conv5_3)
94
+ conv7_3 = self.model7(conv6_3)
95
+ conv8_3 = self.model8(conv7_3)
96
+ out_reg = self.model_out(self.softmax(conv8_3))
97
+
98
+ return self.unnormalize_ab(self.upsample4(out_reg))
99
+
100
+ def eccv16(pretrained=True):
101
+ model = ECCVGenerator()
102
+ if(pretrained):
103
+ import torch.utils.model_zoo as model_zoo
104
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/colorization_release_v2-9b330a0b.pth',map_location='cpu',check_hash=True))
105
+ return model
colorizers/siggraph17.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .base_color import *
5
+
6
+ class SIGGRAPHGenerator(BaseColor):
7
+ def __init__(self, norm_layer=nn.BatchNorm2d, classes=529):
8
+ super(SIGGRAPHGenerator, self).__init__()
9
+
10
+ # Conv1
11
+ model1=[nn.Conv2d(4, 64, kernel_size=3, stride=1, padding=1, bias=True),]
12
+ model1+=[nn.ReLU(True),]
13
+ model1+=[nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1, bias=True),]
14
+ model1+=[nn.ReLU(True),]
15
+ model1+=[norm_layer(64),]
16
+ # add a subsampling operation
17
+
18
+ # Conv2
19
+ model2=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
20
+ model2+=[nn.ReLU(True),]
21
+ model2+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
22
+ model2+=[nn.ReLU(True),]
23
+ model2+=[norm_layer(128),]
24
+ # add a subsampling layer operation
25
+
26
+ # Conv3
27
+ model3=[nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=True),]
28
+ model3+=[nn.ReLU(True),]
29
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
30
+ model3+=[nn.ReLU(True),]
31
+ model3+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
32
+ model3+=[nn.ReLU(True),]
33
+ model3+=[norm_layer(256),]
34
+ # add a subsampling layer operation
35
+
36
+ # Conv4
37
+ model4=[nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1, bias=True),]
38
+ model4+=[nn.ReLU(True),]
39
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
40
+ model4+=[nn.ReLU(True),]
41
+ model4+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
42
+ model4+=[nn.ReLU(True),]
43
+ model4+=[norm_layer(512),]
44
+
45
+ # Conv5
46
+ model5=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
47
+ model5+=[nn.ReLU(True),]
48
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
49
+ model5+=[nn.ReLU(True),]
50
+ model5+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
51
+ model5+=[nn.ReLU(True),]
52
+ model5+=[norm_layer(512),]
53
+
54
+ # Conv6
55
+ model6=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
56
+ model6+=[nn.ReLU(True),]
57
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
58
+ model6+=[nn.ReLU(True),]
59
+ model6+=[nn.Conv2d(512, 512, kernel_size=3, dilation=2, stride=1, padding=2, bias=True),]
60
+ model6+=[nn.ReLU(True),]
61
+ model6+=[norm_layer(512),]
62
+
63
+ # Conv7
64
+ model7=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
65
+ model7+=[nn.ReLU(True),]
66
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
67
+ model7+=[nn.ReLU(True),]
68
+ model7+=[nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1, bias=True),]
69
+ model7+=[nn.ReLU(True),]
70
+ model7+=[norm_layer(512),]
71
+
72
+ # Conv7
73
+ model8up=[nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=True)]
74
+ model3short8=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
75
+
76
+ model8=[nn.ReLU(True),]
77
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
78
+ model8+=[nn.ReLU(True),]
79
+ model8+=[nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1, bias=True),]
80
+ model8+=[nn.ReLU(True),]
81
+ model8+=[norm_layer(256),]
82
+
83
+ # Conv9
84
+ model9up=[nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=True),]
85
+ model2short9=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
86
+ # add the two feature maps above
87
+
88
+ model9=[nn.ReLU(True),]
89
+ model9+=[nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1, bias=True),]
90
+ model9+=[nn.ReLU(True),]
91
+ model9+=[norm_layer(128),]
92
+
93
+ # Conv10
94
+ model10up=[nn.ConvTranspose2d(128, 128, kernel_size=4, stride=2, padding=1, bias=True),]
95
+ model1short10=[nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=True),]
96
+ # add the two feature maps above
97
+
98
+ model10=[nn.ReLU(True),]
99
+ model10+=[nn.Conv2d(128, 128, kernel_size=3, dilation=1, stride=1, padding=1, bias=True),]
100
+ model10+=[nn.LeakyReLU(negative_slope=.2),]
101
+
102
+ # classification output
103
+ model_class=[nn.Conv2d(256, classes, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
104
+
105
+ # regression output
106
+ model_out=[nn.Conv2d(128, 2, kernel_size=1, padding=0, dilation=1, stride=1, bias=True),]
107
+ model_out+=[nn.Tanh()]
108
+
109
+ self.model1 = nn.Sequential(*model1)
110
+ self.model2 = nn.Sequential(*model2)
111
+ self.model3 = nn.Sequential(*model3)
112
+ self.model4 = nn.Sequential(*model4)
113
+ self.model5 = nn.Sequential(*model5)
114
+ self.model6 = nn.Sequential(*model6)
115
+ self.model7 = nn.Sequential(*model7)
116
+ self.model8up = nn.Sequential(*model8up)
117
+ self.model8 = nn.Sequential(*model8)
118
+ self.model9up = nn.Sequential(*model9up)
119
+ self.model9 = nn.Sequential(*model9)
120
+ self.model10up = nn.Sequential(*model10up)
121
+ self.model10 = nn.Sequential(*model10)
122
+ self.model3short8 = nn.Sequential(*model3short8)
123
+ self.model2short9 = nn.Sequential(*model2short9)
124
+ self.model1short10 = nn.Sequential(*model1short10)
125
+
126
+ self.model_class = nn.Sequential(*model_class)
127
+ self.model_out = nn.Sequential(*model_out)
128
+
129
+ self.upsample4 = nn.Sequential(*[nn.Upsample(scale_factor=4, mode='bilinear'),])
130
+ self.softmax = nn.Sequential(*[nn.Softmax(dim=1),])
131
+
132
+ def forward(self, input_A, input_B=None, mask_B=None):
133
+ if(input_B is None):
134
+ input_B = torch.cat((input_A*0, input_A*0), dim=1)
135
+ if(mask_B is None):
136
+ mask_B = input_A*0
137
+
138
+ conv1_2 = self.model1(torch.cat((self.normalize_l(input_A),self.normalize_ab(input_B),mask_B),dim=1))
139
+ conv2_2 = self.model2(conv1_2[:,:,::2,::2])
140
+ conv3_3 = self.model3(conv2_2[:,:,::2,::2])
141
+ conv4_3 = self.model4(conv3_3[:,:,::2,::2])
142
+ conv5_3 = self.model5(conv4_3)
143
+ conv6_3 = self.model6(conv5_3)
144
+ conv7_3 = self.model7(conv6_3)
145
+
146
+ conv8_up = self.model8up(conv7_3) + self.model3short8(conv3_3)
147
+ conv8_3 = self.model8(conv8_up)
148
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
149
+ conv9_3 = self.model9(conv9_up)
150
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
151
+ conv10_2 = self.model10(conv10_up)
152
+ out_reg = self.model_out(conv10_2)
153
+
154
+ conv9_up = self.model9up(conv8_3) + self.model2short9(conv2_2)
155
+ conv9_3 = self.model9(conv9_up)
156
+ conv10_up = self.model10up(conv9_3) + self.model1short10(conv1_2)
157
+ conv10_2 = self.model10(conv10_up)
158
+ out_reg = self.model_out(conv10_2)
159
+
160
+ return self.unnormalize_ab(out_reg)
161
+
162
+ def siggraph17(pretrained=True):
163
+ model = SIGGRAPHGenerator()
164
+ if(pretrained):
165
+ import torch.utils.model_zoo as model_zoo
166
+ model.load_state_dict(model_zoo.load_url('https://colorizers.s3.us-east-2.amazonaws.com/siggraph17-df00044c.pth',map_location='cpu',check_hash=True))
167
+ return model
168
+
colorizers/util.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from PIL import Image
3
+ import numpy as np
4
+ from skimage import color
5
+ import torch
6
+ import torch.nn.functional as F
7
+ from IPython import embed
8
+
9
+ def load_img(img_path):
10
+ out_np = np.asarray(Image.open(img_path))
11
+ if(out_np.ndim==2):
12
+ out_np = np.tile(out_np[:,:,None],3)
13
+ return out_np
14
+
15
+ def resize_img(img, HW=(256,256), resample=3):
16
+ return np.asarray(Image.fromarray(img).resize((HW[1],HW[0]), resample=resample))
17
+
18
+ def preprocess_img(img_rgb_orig, HW=(256,256), resample=3):
19
+ # return original size L and resized L as torch Tensors
20
+ img_rgb_rs = resize_img(img_rgb_orig, HW=HW, resample=resample)
21
+
22
+ img_lab_orig = color.rgb2lab(img_rgb_orig)
23
+ img_lab_rs = color.rgb2lab(img_rgb_rs)
24
+
25
+ img_l_orig = img_lab_orig[:,:,0]
26
+ img_l_rs = img_lab_rs[:,:,0]
27
+
28
+ tens_orig_l = torch.Tensor(img_l_orig)[None,None,:,:]
29
+ tens_rs_l = torch.Tensor(img_l_rs)[None,None,:,:]
30
+
31
+ return (tens_orig_l, tens_rs_l)
32
+
33
+ def postprocess_tens(tens_orig_l, out_ab, mode='bilinear'):
34
+ # tens_orig_l 1 x 1 x H_orig x W_orig
35
+ # out_ab 1 x 2 x H x W
36
+
37
+ HW_orig = tens_orig_l.shape[2:]
38
+ HW = out_ab.shape[2:]
39
+
40
+ # call resize function if needed
41
+ if(HW_orig[0]!=HW[0] or HW_orig[1]!=HW[1]):
42
+ out_ab_orig = F.interpolate(out_ab, size=HW_orig, mode='bilinear')
43
+ else:
44
+ out_ab_orig = out_ab
45
+
46
+ out_lab_orig = torch.cat((tens_orig_l, out_ab_orig), dim=1)
47
+ return color.lab2rgb(out_lab_orig.data.cpu().numpy()[0,...].transpose((1,2,0)))
requirements.txt CHANGED
@@ -1,3 +1,7 @@
1
- altair
2
- pandas
 
 
 
 
3
  streamlit
 
1
+ torch
2
+ scikit-image
3
+ numpy
4
+ matplotlib
5
+ argparse
6
+ Pillow
7
  streamlit