b2bomber commited on
Commit
3a27096
·
verified ·
1 Parent(s): 7bc1095

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -78
app.py CHANGED
@@ -1,81 +1,128 @@
1
- import os, urllib.request
2
  import gradio as gr
3
- import numpy as np
4
  import cv2
5
- from PIL import Image
6
- import torch
7
- from gfpgan import GFPGANer
8
-
9
- # Download models if needed
10
- def download_models():
11
- os.makedirs("models", exist_ok=True)
12
- files = {
13
- "models/colorization_deploy_v2.prototxt":
14
- "https://raw.githubusercontent.com/richzhang/colorization/master/models/colorization_deploy_v2.prototxt",
15
- "models/colorization_release_v2.caffemodel":
16
- "http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/colorization_release_v2.caffemodel",
17
- "models/pts_in_hull.npy":
18
- "https://raw.githubusercontent.com/junyanz/interactive-deep-colorization/master/data/color_bins/pts_in_hull.npy"
19
- }
20
- for path, url in files.items():
21
- if not os.path.exists(path):
22
- print(f"🔽 Downloading {os.path.basename(path)} …")
23
- urllib.request.urlretrieve(url, path)
24
-
25
- download_models()
26
-
27
- # Load GFPGAN for face enhancement
28
- gfpganer = GFPGANer(model_path=None, upscale=2, arch='clean', channel_multiplier=2,
29
- bg_upsampler=None, device='cpu')
30
-
31
- # Setup colorization network
32
- net = cv2.dnn.readNetFromCaffe("models/colorization_deploy_v2.prototxt",
33
- "models/colorization_release_v2.caffemodel")
34
- pts = np.load("models/pts_in_hull.npy")
35
- pts = pts.transpose().reshape(2, 313, 1, 1).astype("float32")
36
- net.getLayer(net.getLayerId('class8_ab')).blobs = [pts]
37
- net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full((1,313),2.606,dtype="float32")]
38
-
39
- def detect_scratches(img):
40
- gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
41
- _, mask = cv2.threshold(gray, 235, 255, cv2.THRESH_BINARY)
42
- return cv2.dilate(mask, np.ones((3,3), np.uint8), iterations=1)
43
-
44
- def restore_photo(img, enhance_face, remove_scratch, colorize):
45
- img_np = np.array(img.convert("RGB"))
46
- img_bgr = cv2.cvtColor(img_np, cv2.COLOR_RGB2BGR)
47
-
48
- if remove_scratch:
49
- mask = detect_scratches(img_np)
50
- img_bgr = cv2.inpaint(img_bgr, mask, 3, cv2.INPAINT_TELEA)
51
-
52
- if enhance_face:
53
- _, img_bgr = gfpganer.enhance(img_bgr, has_aligned=False,
54
- only_center_face=False, paste_back=True)
55
-
56
- img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
57
-
58
- if colorize:
59
- lab = cv2.cvtColor(img_rgb.astype("float32")/255.0, cv2.COLOR_RGB2Lab)
60
- L = lab[:,:,0]
61
- L_rs = cv2.resize(L, (224,224)) - 50
62
- net.setInput(cv2.dnn.blobFromImage(L_rs))
63
- ab = net.forward()[0].transpose((1,2,0))
64
- ab_us = cv2.resize(ab, (img_rgb.shape[1], img_rgb.shape[0]))
65
- lab_out = np.concatenate([L[:,:,None], ab_us], axis=2)
66
- img_rgb = np.clip(cv2.cvtColor(lab_out, cv2.COLOR_Lab2RGB)*255, 0, 255).astype("uint8")
67
-
68
- return Image.fromarray(img_rgb)
69
-
70
- with gr.Blocks(title="AI Photo Restorer") as demo:
71
- gr.Markdown("### 🧓✨ AI Old Photo Restorer")
72
- inp = gr.Image(label="Upload Photo", type="pil")
73
- fe = gr.Checkbox("Enhance Faces", value=True)
74
- sr = gr.Checkbox("Remove Scratches", value=True)
75
- cz = gr.Checkbox("Colorize Image", value=True)
76
- btn = gr.Button("Restore")
77
- out = gr.Image(label="Restored Photo")
78
- btn.click(restore_photo, inputs=[inp, fe, sr, cz], outputs=out)
79
-
80
- if __name__=="__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  demo.launch()
 
 
1
  import gradio as gr
 
2
  import cv2
3
+ import numpy as np
4
+ import os
5
+ import urllib.request
6
+ import subprocess
7
+ import psutil
8
+ import tempfile
9
+
10
+ # ------------------ Config ------------------
11
+ MAX_DIM = 1024 # Resize large images to avoid OOM
12
+ device = 'cpu'
13
+
14
+ # ------------------ Model URLs ------------------
15
+ color_proto_url = "https://github.com/richzhang/colorization/raw/caffe/models/colorization_deploy_v2.prototxt"
16
+ color_model_url = "https://github.com/richzhang/colorization/raw/caffe/models/colorization_release_v2.caffemodel"
17
+ color_pts_url = "https://github.com/richzhang/colorization/raw/caffe/resources/pts_in_hull.npy"
18
+
19
+ # ------------------ Download Colorization Models ------------------
20
+ def download_if_missing(url, dest):
21
+ if not os.path.exists(dest):
22
+ print(f"Downloading {dest}...")
23
+ urllib.request.urlretrieve(url, dest)
24
+
25
+ color_dir = "models/colorization"
26
+ os.makedirs(color_dir, exist_ok=True)
27
+ download_if_missing(color_proto_url, f"{color_dir}/colorization_deploy_v2.prototxt")
28
+ download_if_missing(color_model_url, f"{color_dir}/colorization_release_v2.caffemodel")
29
+ download_if_missing(color_pts_url, f"{color_dir}/pts_in_hull.npy")
30
+
31
+ # ------------------ Load Colorization Net ------------------
32
+ net = cv2.dnn.readNetFromCaffe(
33
+ f"{color_dir}/colorization_deploy_v2.prototxt",
34
+ f"{color_dir}/colorization_release_v2.caffemodel"
35
+ )
36
+ pts_in_hull = np.load(f"{color_dir}/pts_in_hull.npy")
37
+ class8 = net.getLayerId("class8_ab")
38
+ conv8 = net.getLayerId("conv8_313_rh")
39
+ net.getLayer(class8).blobs = [pts_in_hull.transpose().reshape(2, 313, 1, 1)]
40
+ net.getLayer(conv8).blobs = [np.full([1, 313], 2.606, dtype=np.float32)]
41
+
42
+ # ------------------ Optional GFPGAN lazy load ------------------
43
+ gfpganer = None
44
+
45
+ def enhance_face(image_np):
46
+ global gfpganer
47
+ if gfpganer is None:
48
+ from gfpgan import GFPGANer
49
+ gfpganer = GFPGANer(
50
+ model_path=None, upscale=1, arch='clean', channel_multiplier=2,
51
+ bg_upsampler=None, device=device
52
+ )
53
+ cropped_faces, _, restored_img = gfpganer.enhance(image_np, has_aligned=False, only_center_face=False, paste_back=True)
54
+ return restored_img
55
+
56
+ # ------------------ Scratch Removal (OpenCV-based) ------------------
57
+ def remove_scratches(img):
58
+ gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
+ scratches = cv2.Laplacian(gray, cv2.CV_64F)
60
+ scratches = cv2.convertScaleAbs(scratches)
61
+ _, mask = cv2.threshold(scratches, 30, 255, cv2.THRESH_BINARY)
62
+ mask = cv2.dilate(mask, None, iterations=2)
63
+ inpainted = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
64
+ return inpainted
65
+
66
+ # ------------------ Colorization ------------------
67
+ def colorize_image(img):
68
+ h, w = img.shape[:2]
69
+ img_rs = cv2.resize(img, (224, 224))
70
+ lab = cv2.cvtColor(img_rs, cv2.COLOR_BGR2LAB)
71
+ l = lab[:, :, 0]
72
+ l -= 50
73
+
74
+ blob = cv2.dnn.blobFromImage(l)
75
+ net.setInput(blob)
76
+ ab = net.forward()[0, :, :, :].transpose((1, 2, 0))
77
+ ab = cv2.resize(ab, (w, h))
78
+
79
+ L = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)[:, :, 0]
80
+ lab_out = np.concatenate((L[:, :, np.newaxis], ab), axis=2)
81
+ bgr_out = cv2.cvtColor(lab_out.astype(np.uint8), cv2.COLOR_LAB2BGR)
82
+ return bgr_out
83
+
84
+ # ------------------ Main Pipeline ------------------
85
+ def restore_photo(image, enhance_face_flag, scratch_flag, color_flag):
86
+ print("Memory used (MB):", psutil.Process().memory_info().rss / 1024 / 1024)
87
+
88
+ img = image
89
+ img_np = np.array(img)
90
+
91
+ # Resize if needed
92
+ h, w = img_np.shape[:2]
93
+ if max(h, w) > MAX_DIM:
94
+ scaling = MAX_DIM / max(h, w)
95
+ img_np = cv2.resize(img_np, (int(w * scaling), int(h * scaling)))
96
+
97
+ # Step-by-step processing
98
+ if scratch_flag:
99
+ img_np = remove_scratches(img_np)
100
+ if enhance_face_flag:
101
+ img_np = enhance_face(img_np)
102
+ if color_flag:
103
+ img_np = colorize_image(img_np)
104
+
105
+ return img_np
106
+
107
+ # ------------------ Gradio UI ------------------
108
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
+ gr.Markdown("## 🧙‍♂️ AI Old Photo Restorer\nRestore, enhance, and colorize your old photos with AI on CPU!")
110
+
111
+ with gr.Row():
112
+ with gr.Column():
113
+ input_image = gr.Image(label="Upload Old Photo", type="numpy")
114
+ enhance_face = gr.Checkbox(label="🧑‍🦰 Enhance Faces", value=True)
115
+ remove_scratch = gr.Checkbox(label="🧽 Remove Scratches", value=True)
116
+ colorize = gr.Checkbox(label="🎨 Colorize Image", value=True)
117
+ run_btn = gr.Button("🚀 Restore")
118
+ with gr.Column():
119
+ output_image = gr.Image(label="Restored Photo")
120
+
121
+ run_btn.click(
122
+ fn=restore_photo,
123
+ inputs=[input_image, enhance_face, remove_scratch, colorize],
124
+ outputs=output_image
125
+ )
126
+
127
+ if __name__ == "__main__":
128
  demo.launch()