Spaces:
Sleeping
Sleeping
Initial commit
Browse files- .DS_Store +0 -0
- app.py +223 -0
- checkpoints/checkpoint_epoch_015_20250808_154437.pt +3 -0
- examples/Places365_test_00000287.jpg +0 -0
- examples/Places365_test_00000314.jpg +0 -0
- model.py +236 -0
- requirements.txt +11 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
app.py
ADDED
|
@@ -0,0 +1,223 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# app.py — Gradio-native metrics, clean UI, CUDA/CPU only
|
| 2 |
+
|
| 3 |
+
import os, math, cv2, base64
|
| 4 |
+
import torch, numpy as np, gradio as gr
|
| 5 |
+
from PIL import Image
|
| 6 |
+
|
| 7 |
+
# Optional (fine if missing)
|
| 8 |
+
try:
|
| 9 |
+
import kornia.color as kc
|
| 10 |
+
except Exception:
|
| 11 |
+
kc = None
|
| 12 |
+
|
| 13 |
+
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
|
| 14 |
+
from skimage.metrics import structural_similarity as ssim_metric
|
| 15 |
+
|
| 16 |
+
# ---------------- Device & Model (no MPS) ----------------
|
| 17 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 18 |
+
|
| 19 |
+
from model import ViTUNetColorizer
|
| 20 |
+
CKPT = "checkpoints/checkpoint_epoch_015_20250808_154437.pt"
|
| 21 |
+
|
| 22 |
+
model = None
|
| 23 |
+
if os.path.exists(CKPT):
|
| 24 |
+
model = ViTUNetColorizer(vit_model_name="vit_tiny_patch16_224").to(device)
|
| 25 |
+
state = torch.load(CKPT, map_location=device)
|
| 26 |
+
sd = state.get("generator_state_dict", state)
|
| 27 |
+
model.load_state_dict(sd)
|
| 28 |
+
model.eval()
|
| 29 |
+
|
| 30 |
+
# ---------------- Utils ----------------
|
| 31 |
+
def is_grayscale(img: Image.Image) -> bool:
|
| 32 |
+
a = np.array(img)
|
| 33 |
+
if a.ndim == 2: return True
|
| 34 |
+
if a.ndim == 3 and a.shape[2] == 1: return True
|
| 35 |
+
if a.ndim == 3 and a.shape[2] == 3:
|
| 36 |
+
return np.allclose(a[...,0], a[...,1]) and np.allclose(a[...,1], a[...,2])
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
def to_L(rgb_np: np.ndarray):
|
| 40 |
+
# ViTUNetColorizer expects L in [0,1]
|
| 41 |
+
if kc is None:
|
| 42 |
+
gray = cv2.cvtColor(rgb_np, cv2.COLOR_RGB2GRAY).astype(np.float32)
|
| 43 |
+
L = gray / 100.0
|
| 44 |
+
return torch.from_numpy(L).unsqueeze(0).unsqueeze(0).float().to(device)
|
| 45 |
+
t = torch.from_numpy(rgb_np.astype(np.float32)/255.).permute(2,0,1).unsqueeze(0).to(device)
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
return kc.rgb_to_lab(t)[:,0:1]/100.0
|
| 48 |
+
|
| 49 |
+
def lab_to_rgb(L, ab):
|
| 50 |
+
if kc is None:
|
| 51 |
+
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)[0].permute(1,2,0).cpu().numpy()
|
| 52 |
+
lab = np.clip(lab, [0,-128,-128], [100,127,127]).astype(np.float32)
|
| 53 |
+
rgb = cv2.cvtColor(lab, cv2.COLOR_LAB2RGB)
|
| 54 |
+
return (np.clip(rgb,0,1)*255).astype(np.uint8)
|
| 55 |
+
lab = torch.cat([L*100.0, torch.clamp(ab, -1, 1)*110.0], dim=1)
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
rgb = kc.lab_to_rgb(lab)
|
| 58 |
+
return (torch.clamp(rgb,0,1)[0].permute(1,2,0).cpu().numpy()*255).astype(np.uint8)
|
| 59 |
+
|
| 60 |
+
def pad_to_multiple(img_np, m=16):
|
| 61 |
+
h,w = img_np.shape[:2]
|
| 62 |
+
ph, pw = math.ceil(h/m)*m, math.ceil(w/m)*m
|
| 63 |
+
return cv2.copyMakeBorder(img_np,0,ph-h,0,pw-w,cv2.BORDER_CONSTANT,value=(0,0,0)), (h,w)
|
| 64 |
+
|
| 65 |
+
def compute_metrics(pred, gt):
|
| 66 |
+
p = pred.astype(np.float32)/255.; g = gt.astype(np.float32)/255.
|
| 67 |
+
mae = float(np.mean(np.abs(p-g)))
|
| 68 |
+
psnr = float(psnr_metric(g, p, data_range=1.0))
|
| 69 |
+
try:
|
| 70 |
+
ssim = float(ssim_metric(g, p, channel_axis=2, data_range=1.0, win_size=7))
|
| 71 |
+
except TypeError:
|
| 72 |
+
ssim = float(ssim_metric(g, p, multichannel=True, data_range=1.0, win_size=7))
|
| 73 |
+
return round(mae,4), round(psnr,2), round(ssim,4)
|
| 74 |
+
|
| 75 |
+
# ---------------- Inference ----------------
|
| 76 |
+
def infer(image: Image.Image, want_metrics: bool, sizing_mode: str, show_L: bool):
|
| 77 |
+
if image is None:
|
| 78 |
+
return None, None, None, None, None, "", ""
|
| 79 |
+
if model is None:
|
| 80 |
+
return None, None, None, None, None, "", "<div>Checkpoint not found in /checkpoints.</div>"
|
| 81 |
+
|
| 82 |
+
pil = image.convert("RGB")
|
| 83 |
+
rgb = np.array(pil)
|
| 84 |
+
w,h = pil.size
|
| 85 |
+
was_color = not is_grayscale(pil)
|
| 86 |
+
|
| 87 |
+
if sizing_mode == "Pad to keep size":
|
| 88 |
+
proc, (oh, ow) = pad_to_multiple(rgb, 16); back = (ow, oh)
|
| 89 |
+
else:
|
| 90 |
+
proc = cv2.resize(rgb, (256,256), interpolation=cv2.INTER_CUBIC); back = (w,h)
|
| 91 |
+
|
| 92 |
+
L = to_L(proc)
|
| 93 |
+
with torch.no_grad():
|
| 94 |
+
ab = model(L)
|
| 95 |
+
out = lab_to_rgb(L, ab)
|
| 96 |
+
|
| 97 |
+
if sizing_mode == "Pad to keep size":
|
| 98 |
+
out = out[:back[1], :back[0]]
|
| 99 |
+
else:
|
| 100 |
+
out = cv2.resize(out, back, interpolation=cv2.INTER_CUBIC)
|
| 101 |
+
|
| 102 |
+
# Metrics (Gradio-native numbers)
|
| 103 |
+
mae = psnr = ssim = None
|
| 104 |
+
if want_metrics:
|
| 105 |
+
mae, psnr, ssim = compute_metrics(out, np.array(pil))
|
| 106 |
+
|
| 107 |
+
# Optional L preview
|
| 108 |
+
extra_html = ""
|
| 109 |
+
if show_L:
|
| 110 |
+
L01 = np.clip(L[0,0].detach().cpu().numpy(),0,1)
|
| 111 |
+
L_vis = (L01*255).astype(np.uint8)
|
| 112 |
+
L_vis = cv2.cvtColor(L_vis, cv2.COLOR_GRAY2RGB)
|
| 113 |
+
_, buf = cv2.imencode(".png", cv2.cvtColor(L_vis, cv2.COLOR_RGB2BGR))
|
| 114 |
+
L_b64 = "data:image/png;base64," + base64.b64encode(buf).decode()
|
| 115 |
+
extra_html += f"<div><b>L-channel</b><br/><img style='max-height:140px;border-radius:12px' src='{L_b64}'/></div>"
|
| 116 |
+
|
| 117 |
+
# Subtle notice only if needed
|
| 118 |
+
if was_color:
|
| 119 |
+
extra_html += "<div style='opacity:.8;margin-top:8px'>We used a grayscale version of your image for colorization.</div>"
|
| 120 |
+
|
| 121 |
+
# Compare slider (HTML only; easy to remove if you want 100% Gradio)
|
| 122 |
+
_, bo = cv2.imencode(".jpg", cv2.cvtColor(np.array(pil), cv2.COLOR_RGB2BGR))
|
| 123 |
+
_, bc = cv2.imencode(".jpg", cv2.cvtColor(out, cv2.COLOR_RGB2BGR))
|
| 124 |
+
so = "data:image/jpeg;base64," + base64.b64encode(bo).decode()
|
| 125 |
+
sc = "data:image/jpeg;base64," + base64.b64encode(bc).decode()
|
| 126 |
+
compare = f"""
|
| 127 |
+
<div style="position:relative;max-width:500px;margin:auto;border-radius:14px;overflow:hidden;box-shadow:0 8px 20px rgba(0,0,0,.2)">
|
| 128 |
+
<img src="{so}" style="width:100%;display:block"/>
|
| 129 |
+
<div id="cmpTop" style="position:absolute;top:0;left:0;height:100%;width:50%;overflow:hidden">
|
| 130 |
+
<img src="{sc}" style="width:100%;display:block"/>
|
| 131 |
+
</div>
|
| 132 |
+
<input id="cmpRange" type="range" min="0" max="100" value="50"
|
| 133 |
+
oninput="document.getElementById('cmpTop').style.width=this.value+'%';"
|
| 134 |
+
style="position:absolute;left:0;right:0;bottom:8px;width:60%;margin:auto"/>
|
| 135 |
+
</div>
|
| 136 |
+
"""
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
return Image.fromarray(np.array(pil)), Image.fromarray(out), mae, psnr, ssim, compare, extra_html
|
| 140 |
+
|
| 141 |
+
# ---------------- Theme (fallback-safe) ----------------
|
| 142 |
+
def make_theme():
|
| 143 |
+
try:
|
| 144 |
+
from gradio.themes.utils import colors, fonts, sizes
|
| 145 |
+
return gr.themes.Soft(
|
| 146 |
+
primary_hue=colors.indigo,
|
| 147 |
+
neutral_hue=colors.gray,
|
| 148 |
+
font=fonts.GoogleFont("Inter"),
|
| 149 |
+
).set(radius_size=sizes.radius_lg, spacing_size=sizes.spacing_md)
|
| 150 |
+
except Exception:
|
| 151 |
+
return gr.themes.Soft()
|
| 152 |
+
|
| 153 |
+
THEME = make_theme()
|
| 154 |
+
|
| 155 |
+
# ---------------- UI ----------------
|
| 156 |
+
with gr.Blocks(theme=THEME, title="Neural Colorizer") as demo:
|
| 157 |
+
gr.Markdown("# 🎨 Neural Colorizer")
|
| 158 |
+
|
| 159 |
+
with gr.Row():
|
| 160 |
+
with gr.Column(scale=5):
|
| 161 |
+
img_in = gr.Image(
|
| 162 |
+
label="Upload grayscale or color image",
|
| 163 |
+
type="pil",
|
| 164 |
+
image_mode="RGB",
|
| 165 |
+
height=320,
|
| 166 |
+
sources=["upload", "clipboard"]
|
| 167 |
+
)
|
| 168 |
+
with gr.Row():
|
| 169 |
+
sizing = gr.Radio(
|
| 170 |
+
["Resize to 256", "Pad to keep size"],
|
| 171 |
+
value="Resize to 256",
|
| 172 |
+
label="Sizing"
|
| 173 |
+
)
|
| 174 |
+
show_L = gr.Checkbox(label="Show L-channel", value=False)
|
| 175 |
+
show_m = gr.Checkbox(label="Show metrics", value=True)
|
| 176 |
+
with gr.Row():
|
| 177 |
+
run = gr.Button("Colorize")
|
| 178 |
+
clr = gr.Button("Clear")
|
| 179 |
+
|
| 180 |
+
examples = gr.Examples(
|
| 181 |
+
examples=[os.path.join("examples", f) for f in os.listdir("examples")] if os.path.exists("examples") else [],
|
| 182 |
+
inputs=img_in,
|
| 183 |
+
examples_per_page=8,
|
| 184 |
+
label=None
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
with gr.Column(scale=7):
|
| 188 |
+
with gr.Row():
|
| 189 |
+
orig = gr.Image(label="Original", interactive=False, height=300, show_download_button=True)
|
| 190 |
+
out = gr.Image(label="Result", interactive=False, height=300, show_download_button=True)
|
| 191 |
+
|
| 192 |
+
# Pure Gradio metric fields
|
| 193 |
+
with gr.Row():
|
| 194 |
+
mae_box = gr.Number(label="MAE", interactive=False, precision=4)
|
| 195 |
+
psnr_box = gr.Number(label="PSNR (dB)", interactive=False, precision=2)
|
| 196 |
+
ssim_box = gr.Number(label="SSIM", interactive=False, precision=4)
|
| 197 |
+
|
| 198 |
+
gr.Markdown("**Compare**")
|
| 199 |
+
compare = gr.HTML()
|
| 200 |
+
extras = gr.HTML()
|
| 201 |
+
|
| 202 |
+
def _go(image, want_metrics, sizing_mode, show_L):
|
| 203 |
+
o, c, mae, psnr, ssim, cmp_html, extra = infer(image, want_metrics, sizing_mode, show_L)
|
| 204 |
+
if not want_metrics:
|
| 205 |
+
mae = psnr = ssim = None
|
| 206 |
+
return o, c, mae, psnr, ssim, cmp_html, extra
|
| 207 |
+
|
| 208 |
+
run.click(
|
| 209 |
+
_go,
|
| 210 |
+
inputs=[img_in, show_m, sizing, show_L],
|
| 211 |
+
outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras]
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
def _clear():
|
| 215 |
+
return None, None, None, None, None, "", ""
|
| 216 |
+
clr.click(_clear, inputs=None, outputs=[orig, out, mae_box, psnr_box, ssim_box, compare, extras])
|
| 217 |
+
|
| 218 |
+
if __name__ == "__main__":
|
| 219 |
+
# No queue, no API panel
|
| 220 |
+
try:
|
| 221 |
+
demo.launch(show_api=False)
|
| 222 |
+
except TypeError:
|
| 223 |
+
demo.launch()
|
checkpoints/checkpoint_epoch_015_20250808_154437.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:160e9bdb21f474f0da5dd059866966e0fe74b3ae4008307f5e9b1e245b3019c1
|
| 3 |
+
size 84569969
|
examples/Places365_test_00000287.jpg
ADDED
|
examples/Places365_test_00000314.jpg
ADDED
|
model.py
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import timm
|
| 5 |
+
import json
|
| 6 |
+
from torch.nn.utils import spectral_norm
|
| 7 |
+
from torchinfo import summary
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class EncoderBlock(nn.Module):
|
| 11 |
+
def __init__(self, in_channels, out_channels):
|
| 12 |
+
super(EncoderBlock, self).__init__()
|
| 13 |
+
self.conv_block = nn.Sequential(
|
| 14 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 15 |
+
nn.GroupNorm(8, out_channels),
|
| 16 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 17 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 18 |
+
nn.GroupNorm(8, out_channels),
|
| 19 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 20 |
+
)
|
| 21 |
+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
| 22 |
+
|
| 23 |
+
def forward(self, x):
|
| 24 |
+
features = self.conv_block(x)
|
| 25 |
+
pooled = self.pool(features)
|
| 26 |
+
return pooled, features
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DecoderBlock(nn.Module):
|
| 30 |
+
def __init__(self, in_channels, skip_channels, out_channels):
|
| 31 |
+
super(DecoderBlock, self).__init__()
|
| 32 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
| 33 |
+
self.ag = AttentionGate(F_g=in_channels // 2, F_l=skip_channels, F_int=in_channels // 4)
|
| 34 |
+
|
| 35 |
+
conv_in_channels = in_channels // 2 + skip_channels
|
| 36 |
+
|
| 37 |
+
self.conv_block = nn.Sequential(
|
| 38 |
+
nn.Conv2d(conv_in_channels, out_channels, kernel_size=3, padding=1),
|
| 39 |
+
nn.GroupNorm(8, out_channels),
|
| 40 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 41 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1),
|
| 42 |
+
nn.GroupNorm(8, out_channels),
|
| 43 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
def forward(self, x, skip):
|
| 47 |
+
x = self.up(x)
|
| 48 |
+
skip = self.ag(x, skip)
|
| 49 |
+
x = torch.cat([x, skip], dim=1)
|
| 50 |
+
x = self.conv_block(x)
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
class AttentionGate(nn.Module):
|
| 54 |
+
def __init__(self, F_g, F_l, F_int):
|
| 55 |
+
super(AttentionGate, self).__init__()
|
| 56 |
+
|
| 57 |
+
self.W_g = nn.Sequential(
|
| 58 |
+
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
| 59 |
+
nn.GroupNorm(8, F_int),
|
| 60 |
+
)
|
| 61 |
+
self.W_x = nn.Sequential(
|
| 62 |
+
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
|
| 63 |
+
nn.GroupNorm(8, F_int),
|
| 64 |
+
)
|
| 65 |
+
self.psi = nn.Sequential(
|
| 66 |
+
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
|
| 67 |
+
nn.GroupNorm(1, 1),
|
| 68 |
+
nn.Sigmoid(),
|
| 69 |
+
)
|
| 70 |
+
self.relu = nn.LeakyReLU(0.01, inplace=True)
|
| 71 |
+
|
| 72 |
+
def forward(self, g, x):
|
| 73 |
+
g1 = self.W_g(g)
|
| 74 |
+
x1 = self.W_x(x)
|
| 75 |
+
psi = self.relu(g1 + x1)
|
| 76 |
+
psi = self.psi(psi)
|
| 77 |
+
return x * psi
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class ViTUNetColorizer(nn.Module):
|
| 81 |
+
def __init__(self, vit_model_name="vit_tiny_patch16_224", freeze_vit_epochs=10):
|
| 82 |
+
super(ViTUNetColorizer, self).__init__()
|
| 83 |
+
|
| 84 |
+
self.vit = timm.create_model(vit_model_name, pretrained=True, num_classes=0)
|
| 85 |
+
self.vit_embed_dim = self.vit.embed_dim
|
| 86 |
+
self.vit.head = nn.Identity()
|
| 87 |
+
|
| 88 |
+
self.enc1 = EncoderBlock(1, 16)
|
| 89 |
+
self.enc2 = EncoderBlock(16, 32)
|
| 90 |
+
self.enc3 = EncoderBlock(32, 64)
|
| 91 |
+
self.enc4 = EncoderBlock(64, 128)
|
| 92 |
+
|
| 93 |
+
self.bottleneck_processor = nn.Sequential(
|
| 94 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 95 |
+
nn.GroupNorm(8, 128),
|
| 96 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 97 |
+
nn.AdaptiveAvgPool2d((14, 14)),
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
self.fusion_layer = nn.Sequential(
|
| 101 |
+
nn.Conv2d(128 + self.vit_embed_dim, 128, kernel_size=1), # type: ignore
|
| 102 |
+
nn.GroupNorm(8, 128),
|
| 103 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 104 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
| 105 |
+
nn.GroupNorm(8, 128),
|
| 106 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
self.dec4 = DecoderBlock(128, 64, 64)
|
| 110 |
+
self.dec3 = DecoderBlock(64, 32, 32)
|
| 111 |
+
self.dec2 = DecoderBlock(32, 16, 16)
|
| 112 |
+
|
| 113 |
+
self.final_conv = nn.Sequential(
|
| 114 |
+
nn.Conv2d(16, 8, kernel_size=3, padding=1),
|
| 115 |
+
nn.GroupNorm(8, 8),
|
| 116 |
+
nn.LeakyReLU(0.01, inplace=True),
|
| 117 |
+
nn.Conv2d(8, 2, kernel_size=1),
|
| 118 |
+
nn.Tanh(),
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
self.freeze_vit_epochs = freeze_vit_epochs
|
| 122 |
+
self.current_epoch = 0
|
| 123 |
+
|
| 124 |
+
def extract_vit_features(self, x):
|
| 125 |
+
B = x.shape[0]
|
| 126 |
+
x_3ch = x.repeat(1, 3, 1, 1)
|
| 127 |
+
|
| 128 |
+
if x_3ch.shape[-1] != 224:
|
| 129 |
+
x_3ch = F.interpolate(
|
| 130 |
+
x_3ch, size=(224, 224), mode="bicubic", align_corners=False
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
x_vit = self.vit.patch_embed(x_3ch) # type: ignore
|
| 134 |
+
if hasattr(self.vit, 'pos_embed') and self.vit.pos_embed is not None:
|
| 135 |
+
x_vit = x_vit + self.vit.pos_embed[:, 1:, :] # type: ignore
|
| 136 |
+
x_vit = self.vit.pos_drop(x_vit) # type: ignore
|
| 137 |
+
|
| 138 |
+
for block in self.vit.blocks: # type: ignore
|
| 139 |
+
x_vit = block(x_vit)
|
| 140 |
+
|
| 141 |
+
x_vit = self.vit.norm(x_vit) # type: ignore
|
| 142 |
+
x_vit = x_vit.transpose(1, 2).reshape(B, self.vit_embed_dim, 14, 14)
|
| 143 |
+
|
| 144 |
+
return x_vit
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
|
| 148 |
+
x1, skip1 = self.enc1(x)
|
| 149 |
+
x2, skip2 = self.enc2(x1)
|
| 150 |
+
x3, skip3 = self.enc3(x2)
|
| 151 |
+
x4, skip4 = self.enc4(x3)
|
| 152 |
+
|
| 153 |
+
bottleneck = self.bottleneck_processor(x4)
|
| 154 |
+
vit_features = self.extract_vit_features(x)
|
| 155 |
+
fused = torch.cat([bottleneck, vit_features], dim=1)
|
| 156 |
+
fused = self.fusion_layer(fused)
|
| 157 |
+
|
| 158 |
+
fused = F.interpolate(fused, size=x3.shape[2:], mode="bilinear", align_corners=False)
|
| 159 |
+
|
| 160 |
+
d4 = self.dec4(fused, skip3)
|
| 161 |
+
d3 = self.dec3(d4, skip2)
|
| 162 |
+
d2 = self.dec2(d3, skip1)
|
| 163 |
+
|
| 164 |
+
out = self.final_conv(d2)
|
| 165 |
+
|
| 166 |
+
return out
|
| 167 |
+
|
| 168 |
+
def set_epoch(self, epoch):
|
| 169 |
+
self.current_epoch = epoch
|
| 170 |
+
requires_grad = epoch >= self.freeze_vit_epochs
|
| 171 |
+
for param in self.vit.parameters():
|
| 172 |
+
param.requires_grad = requires_grad
|
| 173 |
+
|
| 174 |
+
def get_param_groups(self, lr_decoder=1e-4, lr_vit=1e-5):
|
| 175 |
+
vit_params = []
|
| 176 |
+
decoder_params = []
|
| 177 |
+
for name, param in self.named_parameters():
|
| 178 |
+
if "vit" in name:
|
| 179 |
+
vit_params.append(param)
|
| 180 |
+
else:
|
| 181 |
+
decoder_params.append(param)
|
| 182 |
+
return [
|
| 183 |
+
{"params": decoder_params, "lr": lr_decoder},
|
| 184 |
+
{"params": vit_params, "lr": lr_vit},
|
| 185 |
+
]
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class PatchDiscriminator(nn.Module):
|
| 189 |
+
def __init__(self, in_channels=3, n_filters=64):
|
| 190 |
+
super(PatchDiscriminator, self).__init__()
|
| 191 |
+
|
| 192 |
+
def discriminator_block(in_filters, out_filters, stride=2):
|
| 193 |
+
return [
|
| 194 |
+
spectral_norm(
|
| 195 |
+
nn.Conv2d(
|
| 196 |
+
in_filters, out_filters, kernel_size=4, stride=stride, padding=1
|
| 197 |
+
)
|
| 198 |
+
),
|
| 199 |
+
nn.LeakyReLU(0.01, inplace=True)
|
| 200 |
+
]
|
| 201 |
+
|
| 202 |
+
self.model = nn.Sequential(
|
| 203 |
+
*discriminator_block(in_channels, n_filters),
|
| 204 |
+
*discriminator_block(n_filters, n_filters * 2),
|
| 205 |
+
*discriminator_block(n_filters * 2, n_filters * 4),
|
| 206 |
+
spectral_norm(nn.Conv2d(n_filters * 4, 1, kernel_size=4, padding=1))
|
| 207 |
+
)
|
| 208 |
+
self.apply(self._init_weights)
|
| 209 |
+
|
| 210 |
+
def _init_weights(self, m):
|
| 211 |
+
if isinstance(m, nn.Conv2d):
|
| 212 |
+
nn.init.normal_(m.weight, 0.0, 0.02)
|
| 213 |
+
if m.bias is not None:
|
| 214 |
+
nn.init.constant_(m.bias, 0)
|
| 215 |
+
|
| 216 |
+
def forward(self, L, ab):
|
| 217 |
+
img_input = torch.cat((L, ab), dim=1)
|
| 218 |
+
return self.model(img_input)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
if __name__ == "__main__":
|
| 222 |
+
try:
|
| 223 |
+
with open("hyperparameters.json", "r") as f:
|
| 224 |
+
hparams = json.load(f)
|
| 225 |
+
resolution = hparams.get("resolution", 256)
|
| 226 |
+
except FileNotFoundError:
|
| 227 |
+
resolution = 256
|
| 228 |
+
print("Using default resolution: 256x256")
|
| 229 |
+
|
| 230 |
+
generator = ViTUNetColorizer()
|
| 231 |
+
generator_input_size = (1, 1, resolution, resolution)
|
| 232 |
+
summary(generator, input_size=generator_input_size)
|
| 233 |
+
|
| 234 |
+
discriminator = PatchDiscriminator()
|
| 235 |
+
discriminator_input_size = [(1, 1, resolution, resolution), (1, 2, resolution, resolution)]
|
| 236 |
+
summary(discriminator, input_size=discriminator_input_size)
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio
|
| 2 |
+
torch
|
| 3 |
+
torchvision
|
| 4 |
+
torchinfo
|
| 5 |
+
numpy
|
| 6 |
+
opencv-python-headless
|
| 7 |
+
Pillow
|
| 8 |
+
scikit-image
|
| 9 |
+
kornia
|
| 10 |
+
matplotlib
|
| 11 |
+
timm
|