b2bomber commited on
Commit
3d48cb6
·
verified ·
1 Parent(s): 02a0656

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +80 -105
app.py CHANGED
@@ -1,128 +1,103 @@
1
- import gradio as gr
2
  import cv2
3
  import numpy as np
4
- import os
5
  import urllib.request
6
- import subprocess
7
- import psutil
8
- import tempfile
9
 
10
- # ------------------ Config ------------------
11
- MAX_DIM = 1024 # Resize large images to avoid OOM
12
- device = 'cpu'
 
13
 
14
- # ------------------ Model URLs ------------------
15
- color_proto_url = "https://github.com/richzhang/colorization/raw/caffe/models/colorization_deploy_v2.prototxt"
16
- color_model_url = "https://github.com/richzhang/colorization/raw/caffe/models/colorization_release_v2.caffemodel"
17
- color_pts_url = "https://github.com/richzhang/colorization/raw/caffe/resources/pts_in_hull.npy"
18
 
19
- # ------------------ Download Colorization Models ------------------
20
  def download_if_missing(url, dest):
21
  if not os.path.exists(dest):
22
  print(f"Downloading {dest}...")
23
  urllib.request.urlretrieve(url, dest)
24
 
25
- color_dir = "models/colorization"
26
- os.makedirs(color_dir, exist_ok=True)
27
- download_if_missing(color_proto_url, f"{color_dir}/colorization_deploy_v2.prototxt")
28
- download_if_missing(color_model_url, f"{color_dir}/colorization_release_v2.caffemodel")
29
- download_if_missing(color_pts_url, f"{color_dir}/pts_in_hull.npy")
30
-
31
- # ------------------ Load Colorization Net ------------------
32
- net = cv2.dnn.readNetFromCaffe(
33
- f"{color_dir}/colorization_deploy_v2.prototxt",
34
- f"{color_dir}/colorization_release_v2.caffemodel"
35
- )
36
- pts_in_hull = np.load(f"{color_dir}/pts_in_hull.npy")
 
 
 
 
 
37
  class8 = net.getLayerId("class8_ab")
38
  conv8 = net.getLayerId("conv8_313_rh")
39
- net.getLayer(class8).blobs = [pts_in_hull.transpose().reshape(2, 313, 1, 1)]
40
- net.getLayer(conv8).blobs = [np.full([1, 313], 2.606, dtype=np.float32)]
41
-
42
- # ------------------ Optional GFPGAN lazy load ------------------
43
- gfpganer = None
44
-
45
- def enhance_face(image_np):
46
- global gfpganer
47
- if gfpganer is None:
48
- from gfpgan import GFPGANer
49
- gfpganer = GFPGANer(
50
- model_path=None, upscale=1, arch='clean', channel_multiplier=2,
51
- bg_upsampler=None, device=device
52
- )
53
- cropped_faces, _, restored_img = gfpganer.enhance(image_np, has_aligned=False, only_center_face=False, paste_back=True)
54
- return restored_img
55
-
56
- # ------------------ Scratch Removal (OpenCV-based) ------------------
57
- def remove_scratches(img):
58
- gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
59
- scratches = cv2.Laplacian(gray, cv2.CV_64F)
60
- scratches = cv2.convertScaleAbs(scratches)
61
- _, mask = cv2.threshold(scratches, 30, 255, cv2.THRESH_BINARY)
62
- mask = cv2.dilate(mask, None, iterations=2)
63
- inpainted = cv2.inpaint(img, mask, 3, cv2.INPAINT_TELEA)
64
- return inpainted
65
-
66
- # ------------------ Colorization ------------------
67
- def colorize_image(img):
68
- h, w = img.shape[:2]
69
- img_rs = cv2.resize(img, (224, 224))
70
- lab = cv2.cvtColor(img_rs, cv2.COLOR_BGR2LAB)
71
- l = lab[:, :, 0]
72
- l -= 50
73
-
74
- blob = cv2.dnn.blobFromImage(l)
75
- net.setInput(blob)
76
- ab = net.forward()[0, :, :, :].transpose((1, 2, 0))
77
- ab = cv2.resize(ab, (w, h))
78
-
79
- L = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)[:, :, 0]
80
- lab_out = np.concatenate((L[:, :, np.newaxis], ab), axis=2)
81
- bgr_out = cv2.cvtColor(lab_out.astype(np.uint8), cv2.COLOR_LAB2BGR)
82
- return bgr_out
83
-
84
- # ------------------ Main Pipeline ------------------
85
- def restore_photo(image, enhance_face_flag, scratch_flag, color_flag):
86
- print("Memory used (MB):", psutil.Process().memory_info().rss / 1024 / 1024)
87
-
88
- img = image
89
- img_np = np.array(img)
90
-
91
- # Resize if needed
92
- h, w = img_np.shape[:2]
93
- if max(h, w) > MAX_DIM:
94
- scaling = MAX_DIM / max(h, w)
95
- img_np = cv2.resize(img_np, (int(w * scaling), int(h * scaling)))
96
-
97
- # Step-by-step processing
98
- if scratch_flag:
99
- img_np = remove_scratches(img_np)
100
- if enhance_face_flag:
101
- img_np = enhance_face(img_np)
102
- if color_flag:
103
- img_np = colorize_image(img_np)
104
-
105
- return img_np
106
 
107
  # ------------------ Gradio UI ------------------
108
- with gr.Blocks(theme=gr.themes.Soft()) as demo:
109
- gr.Markdown("## 🧙‍♂️ AI Old Photo Restorer\nRestore, enhance, and colorize your old photos with AI on CPU!")
 
110
 
111
  with gr.Row():
112
  with gr.Column():
113
- input_image = gr.Image(label="Upload Old Photo", type="numpy")
114
- enhance_face = gr.Checkbox(label="🧑‍🦰 Enhance Faces", value=True)
115
- remove_scratch = gr.Checkbox(label="🧽 Remove Scratches", value=True)
116
- colorize = gr.Checkbox(label="🎨 Colorize Image", value=True)
117
- run_btn = gr.Button("🚀 Restore")
 
118
  with gr.Column():
119
- output_image = gr.Image(label="Restored Photo")
120
 
121
- run_btn.click(
122
- fn=restore_photo,
123
- inputs=[input_image, enhance_face, remove_scratch, colorize],
124
  outputs=output_image
125
  )
126
 
127
- if __name__ == "__main__":
128
- demo.launch()
 
1
+ import os
2
  import cv2
3
  import numpy as np
 
4
  import urllib.request
5
+ import gradio as gr
 
 
6
 
7
+ # ------------------ Working Model URLs from Hugging Face ------------------
8
+ color_proto_url = "https://huggingface.co/akhaliq/old-photo-restoration/resolve/main/colorization_deploy_v2.prototxt"
9
+ color_model_url = "https://huggingface.co/akhaliq/old-photo-restoration/resolve/main/colorization_release_v2.caffemodel"
10
+ color_pts_url = "https://huggingface.co/akhaliq/old-photo-restoration/resolve/main/pts_in_hull.npy"
11
 
12
+ # ------------------ Utility Functions ------------------
 
 
 
13
 
 
14
  def download_if_missing(url, dest):
15
  if not os.path.exists(dest):
16
  print(f"Downloading {dest}...")
17
  urllib.request.urlretrieve(url, dest)
18
 
19
+ # ------------------ Setup Models ------------------
20
+
21
+ model_dir = "models/colorization"
22
+ os.makedirs(model_dir, exist_ok=True)
23
+
24
+ proto_path = os.path.join(model_dir, "colorization_deploy_v2.prototxt")
25
+ model_path = os.path.join(model_dir, "colorization_release_v2.caffemodel")
26
+ pts_path = os.path.join(model_dir, "pts_in_hull.npy")
27
+
28
+ download_if_missing(color_proto_url, proto_path)
29
+ download_if_missing(color_model_url, model_path)
30
+ download_if_missing(color_pts_url, pts_path)
31
+
32
+ # Load model
33
+ net = cv2.dnn.readNetFromCaffe(proto_path, model_path)
34
+ pts = np.load(pts_path)
35
+
36
  class8 = net.getLayerId("class8_ab")
37
  conv8 = net.getLayerId("conv8_313_rh")
38
+ pts = pts.transpose().reshape(2, 313, 1, 1)
39
+ net.getLayer(class8).blobs = [pts.astype(np.float32)]
40
+ net.getLayer(conv8).blobs = [np.full([1, 313], 2.606, dtype="float32")]
41
+
42
+ # ------------------ Main Processing Function ------------------
43
+
44
+ def restore_old_photo(image, face_enhance=True, colorize=True, scratch_remove=True):
45
+ try:
46
+ original = image.copy()
47
+
48
+ # Scratch removal using median blur (lightweight alternative)
49
+ if scratch_remove:
50
+ image = cv2.medianBlur(image, 3)
51
+
52
+ # Face enhancement: simulate sharpening with unsharp masking
53
+ if face_enhance:
54
+ blur = cv2.GaussianBlur(image, (0, 0), 3)
55
+ image = cv2.addWeighted(image, 1.5, blur, -0.5, 0)
56
+
57
+ # Colorization
58
+ if colorize:
59
+ h, w = image.shape[:2]
60
+ img_rgb = (image.astype("float32") / 255.0)
61
+ img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2LAB)
62
+ l_channel = img_lab[:, :, 0]
63
+
64
+ net_input = cv2.resize(l_channel, (224, 224))
65
+ net_input -= 50
66
+ net.setInput(cv2.dnn.blobFromImage(net_input))
67
+ ab_dec = net.forward()[0, :, :, :].transpose((1, 2, 0))
68
+ ab_dec_us = cv2.resize(ab_dec, (w, h))
69
+
70
+ lab_output = np.concatenate((l_channel[:, :, np.newaxis], ab_dec_us), axis=2)
71
+ bgr_output = cv2.cvtColor(lab_output.astype("float32"), cv2.COLOR_LAB2BGR)
72
+ bgr_output = np.clip(bgr_output * 255, 0, 255).astype("uint8")
73
+ image = bgr_output
74
+
75
+ return image
76
+
77
+ except Exception as e:
78
+ print(f"Error during restoration: {e}")
79
+ return original
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  # ------------------ Gradio UI ------------------
82
+
83
+ with gr.Blocks(title="AI Old Photo Restorer") as demo:
84
+ gr.Markdown("## 🧓🎨 AI Old Photo Restorer\nUpload old B/W or damaged photos and restore them with colorization, scratch removal, and face enhancement.")
85
 
86
  with gr.Row():
87
  with gr.Column():
88
+ input_image = gr.Image(label="📷 Upload Old Photo", type="numpy")
89
+ face_toggle = gr.Checkbox(label="👤 Face Enhancement", value=True)
90
+ colorize_toggle = gr.Checkbox(label="🎨 Colorization", value=True)
91
+ scratch_toggle = gr.Checkbox(label="🩹 Scratch Removal", value=True)
92
+ run_button = gr.Button(" Restore Photo", variant="primary")
93
+
94
  with gr.Column():
95
+ output_image = gr.Image(label="🧼 Restored Photo")
96
 
97
+ run_button.click(
98
+ fn=restore_old_photo,
99
+ inputs=[input_image, face_toggle, colorize_toggle, scratch_toggle],
100
  outputs=output_image
101
  )
102
 
103
+ demo.launch()