b2bomber commited on
Commit
a840e7e
·
verified ·
1 Parent(s): 1d58d2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -102
app.py CHANGED
@@ -1,104 +1,46 @@
1
- import os
2
- import cv2
3
- import numpy as np
4
- import urllib.request
5
- import gradio as gr
6
-
7
- # ------------------ Working Model URLs from Hugging Face ------------------
8
- color_proto_url = "https://huggingface.co/datasets/amitverma-ai/old-photo-files/resolve/main/colorization_deploy_v2.prototxt"
9
- color_model_url = "https://huggingface.co/datasets/amitverma-ai/old-photo-files/resolve/main/colorization_release_v2.caffemodel"
10
- color_pts_url = "https://huggingface.co/datasets/amitverma-ai/old-photo-files/resolve/main/pts_in_hull.npy"
11
-
12
-
13
- # ------------------ Utility Functions ------------------
14
-
15
- def download_if_missing(url, dest):
16
- if not os.path.exists(dest):
17
- print(f"Downloading {dest}...")
18
- urllib.request.urlretrieve(url, dest)
19
-
20
- # ------------------ Setup Models ------------------
21
-
22
- model_dir = "models/colorization"
23
- os.makedirs(model_dir, exist_ok=True)
24
-
25
- proto_path = os.path.join(model_dir, "colorization_deploy_v2.prototxt")
26
- model_path = os.path.join(model_dir, "colorization_release_v2.caffemodel")
27
- pts_path = os.path.join(model_dir, "pts_in_hull.npy")
28
-
29
- download_if_missing(color_proto_url, proto_path)
30
- download_if_missing(color_model_url, model_path)
31
- download_if_missing(color_pts_url, pts_path)
32
-
33
- # Load model
34
- net = cv2.dnn.readNetFromCaffe(proto_path, model_path)
35
- pts = np.load(pts_path)
36
-
37
- class8 = net.getLayerId("class8_ab")
38
- conv8 = net.getLayerId("conv8_313_rh")
39
- pts = pts.transpose().reshape(2, 313, 1, 1)
40
- net.getLayer(class8).blobs = [pts.astype(np.float32)]
41
- net.getLayer(conv8).blobs = [np.full([1, 313], 2.606, dtype="float32")]
42
-
43
- # ------------------ Main Processing Function ------------------
44
-
45
- def restore_old_photo(image, face_enhance=True, colorize=True, scratch_remove=True):
46
- try:
47
- original = image.copy()
48
-
49
- # Scratch removal using median blur (lightweight alternative)
50
- if scratch_remove:
51
- image = cv2.medianBlur(image, 3)
52
-
53
- # Face enhancement: simulate sharpening with unsharp masking
54
- if face_enhance:
55
- blur = cv2.GaussianBlur(image, (0, 0), 3)
56
- image = cv2.addWeighted(image, 1.5, blur, -0.5, 0)
57
-
58
- # Colorization
59
- if colorize:
60
- h, w = image.shape[:2]
61
- img_rgb = (image.astype("float32") / 255.0)
62
- img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2LAB)
63
- l_channel = img_lab[:, :, 0]
64
-
65
- net_input = cv2.resize(l_channel, (224, 224))
66
- net_input -= 50
67
- net.setInput(cv2.dnn.blobFromImage(net_input))
68
- ab_dec = net.forward()[0, :, :, :].transpose((1, 2, 0))
69
- ab_dec_us = cv2.resize(ab_dec, (w, h))
70
-
71
- lab_output = np.concatenate((l_channel[:, :, np.newaxis], ab_dec_us), axis=2)
72
- bgr_output = cv2.cvtColor(lab_output.astype("float32"), cv2.COLOR_LAB2BGR)
73
- bgr_output = np.clip(bgr_output * 255, 0, 255).astype("uint8")
74
- image = bgr_output
75
-
76
- return image
77
-
78
- except Exception as e:
79
- print(f"Error during restoration: {e}")
80
- return original
81
-
82
- # ------------------ Gradio UI ------------------
83
-
84
- with gr.Blocks(title="AI Old Photo Restorer") as demo:
85
- gr.Markdown("## 🧓🎨 AI Old Photo Restorer\nUpload old B/W or damaged photos and restore them with colorization, scratch removal, and face enhancement.")
86
-
87
- with gr.Row():
88
- with gr.Column():
89
- input_image = gr.Image(label="📷 Upload Old Photo", type="numpy")
90
- face_toggle = gr.Checkbox(label="👤 Face Enhancement", value=True)
91
- colorize_toggle = gr.Checkbox(label="🎨 Colorization", value=True)
92
- scratch_toggle = gr.Checkbox(label="🩹 Scratch Removal", value=True)
93
- run_button = gr.Button("✨ Restore Photo", variant="primary")
94
-
95
- with gr.Column():
96
- output_image = gr.Image(label="🧼 Restored Photo")
97
-
98
- run_button.click(
99
- fn=restore_old_photo,
100
- inputs=[input_image, face_toggle, colorize_toggle, scratch_toggle],
101
- outputs=output_image
102
- )
103
 
104
  demo.launch()
 
1
+ import os, urllib.request, cv2, numpy as np, gradio as gr
2
+
3
+ urls = {
4
+ "proto": "https://raw.githubusercontent.com/richzhang/colorization/master/colorization/models/colorization_deploy_v2.prototxt",
5
+ "model": "http://eecs.berkeley.edu/~rich.zhang/projects/2016_colorization/files/demo_v2/colorization_release_v2.caffemodel",
6
+ "pts": "https://raw.githubusercontent.com/richzhang/colorization/master/colorization/resources/pts_in_hull.npy"
7
+ }
8
+ os.makedirs("models", exist_ok=True)
9
+ for name, url in urls.items():
10
+ path = f"models/{name}"
11
+ if not os.path.exists(path):
12
+ print(f"Downloading {name}...")
13
+ urllib.request.urlretrieve(url, path)
14
+
15
+ net = cv2.dnn.readNetFromCaffe("models/proto", "models/model")
16
+ pts = np.load("models/pts")
17
+ for layer in ["class8_ab","conv8_313_rh"]:
18
+ net.getLayer(net.getLayerId(layer)).blobs = [pts.transpose().reshape(2,313,1,1).astype(np.float32)] if "class8" in layer else [np.full([1,313],2.606,dtype=np.float32)]
19
+
20
+ def colorize(img):
21
+ h, w = img.shape[:2]
22
+ L = cv2.cvtColor(img.astype("float32")/255, cv2.COLOR_BGR2LAB)[:,:,0]
23
+ net.setInput(cv2.dnn.blobFromImage(cv2.resize(L,(224,224))-50))
24
+ ab = net.forward()[0].transpose((1,2,0))
25
+ ab = cv2.resize(ab,(w,h))
26
+ lab = np.concatenate([L[:,:,None],ab],axis=2)
27
+ return cv2.cvtColor(np.clip(lab,0,255).astype("uint8"), cv2.COLOR_LAB2BGR)
28
+
29
+ def restore(img, face=True, scratch=True, color=True):
30
+ img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
31
+ if scratch:
32
+ img = cv2.medianBlur(img,3)
33
+ if face:
34
+ blur = cv2.GaussianBlur(img,(0,0),3)
35
+ img = cv2.addWeighted(img,1.5,blur,-0.5,0)
36
+ if color:
37
+ img = colorize(img)
38
+ return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
39
+
40
+ with gr.Blocks() as demo:
41
+ inp = gr.Image(type="numpy")
42
+ opts = [gr.Checkbox(label=l, value=True) for l in ["Face Enhancement","Scratch Removal","Colorization"]]
43
+ out = gr.Image()
44
+ gr.Button("Restore").click(restore, [inp]+opts, out)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
  demo.launch()