File size: 4,192 Bytes
f804ab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3979116
 
 
f804ab9
3979116
 
 
 
f804ab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3979116
 
 
 
 
f804ab9
 
 
 
 
 
3979116
f804ab9
 
 
 
 
3979116
f804ab9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3979116
f804ab9
 
3979116
 
 
 
 
 
 
 
 
f804ab9
3979116
 
 
 
f804ab9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import cv2
import numpy as np
from flask import Flask, request, render_template, url_for
from werkzeug.utils import secure_filename

app = Flask(__name__, static_folder='static')

# --- Directory Setup ---
BASE_DIR = os.path.abspath(os.path.dirname(__file__))
STATIC_FOLDER = os.path.join(BASE_DIR, 'static')
UPLOAD_FOLDER = os.path.join(STATIC_FOLDER, 'uploads')
OUTPUT_FOLDER = os.path.join(STATIC_FOLDER, 'outputs')
MODEL_FOLDER = os.path.join(BASE_DIR, 'model')

os.makedirs(UPLOAD_FOLDER, exist_ok=True)
os.makedirs(OUTPUT_FOLDER, exist_ok=True)

# --- Model paths ---
protoPath = os.path.join(MODEL_FOLDER, 'colorization_deploy_v2.prototxt')
modelPath = os.path.join(MODEL_FOLDER, 'colorization_release_v2.caffemodel')
hullPath = os.path.join(MODEL_FOLDER, 'pts_in_hull.npy')

COLORIZATION_MODEL_AVAILABLE = False
net = None

# --- Load model ---
try:
    assert os.path.exists(protoPath), f"Missing proto file: {protoPath}"
    assert os.path.exists(modelPath), f"Missing model file: {modelPath}"
    assert os.path.exists(hullPath), f"Missing hull file: {hullPath}"

    print("Loading colorization model...")
    net = cv2.dnn.readNetFromCaffe(protoPath, modelPath)
    pts_in_hull = np.load(hullPath)

    if pts_in_hull.shape != (313, 2):
        raise ValueError(f"pts_in_hull shape invalid: {pts_in_hull.shape}")

    pts = pts_in_hull.transpose().reshape(2, 313, 1, 1).astype(np.float32)
    net.getLayer(net.getLayerId('class8_ab')).blobs = [pts]
    net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, dtype=np.float32)]
    COLORIZATION_MODEL_AVAILABLE = True
    print("Colorization model loaded successfully.")
except Exception as e:
    print(f"Failed to load colorization model: {e}")

MAX_DIMENSION = 2048

def resize_img(img):
    h, w = img.shape[:2]
    if max(h, w) > MAX_DIMENSION:
        scale = MAX_DIMENSION / max(h, w)
        return cv2.resize(img, (int(w * scale), int(h * scale)))
    return img

def adjust_brightness_contrast(img, brightness=0, contrast=20):
    alpha = 1 + contrast / 100.0
    return cv2.convertScaleAbs(img, alpha=alpha, beta=brightness)

def colorize_image_local(input_path, output_path):
    if not COLORIZATION_MODEL_AVAILABLE:
        print("Colorization model not available.")
        return False

    img = cv2.imread(input_path)
    if img is None:
        print(f"Failed to read image: {input_path}")
        return False

    img = resize_img(img)
    h, w = img.shape[:2]

    img_rgb = img.astype("float32") / 255.0
    lab = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2LAB)
    L = lab[:, :, 0]
    L_resized = cv2.resize(L, (224, 224)) - 50  # mean-centering

    try:
        net.setInput(cv2.dnn.blobFromImage(L_resized))
        ab = net.forward()[0].transpose(1, 2, 0)
    except Exception as e:
        print(f"DNN forward pass failed: {e}")
        return False

    ab = cv2.resize(ab, (w, h))
    out_lab = np.concatenate((L[:, :, np.newaxis], ab), axis=2)
    out_bgr = cv2.cvtColor(out_lab, cv2.COLOR_Lab2BGR)
    out_bgr = np.clip(out_bgr * 255, 0, 255).astype("uint8")
    out_bgr = adjust_brightness_contrast(out_bgr, brightness=-10, contrast=15)
    cv2.imwrite(output_path, out_bgr)
    return True

@app.route('/', methods=['GET', 'POST'])
def index():
    original_url = enhanced_url = None
    if request.method == 'POST':
        file = request.files.get('image')
        if not file or not file.filename:
            return "No file uploaded", 400

        filename = secure_filename(file.filename)
        orig_path = os.path.join(UPLOAD_FOLDER, filename)
        file.save(orig_path)

        output_name = f"output_{filename}"
        output_path = os.path.join(OUTPUT_FOLDER, output_name)

        if not colorize_image_local(orig_path, output_path):
            return "Colorization failed", 400

        original_url = url_for('static', filename=f'uploads/{filename}')
        enhanced_url = url_for('static', filename=f'outputs/{output_name}')

    return render_template('index.html', original_url=original_url, enhanced_url=enhanced_url)

if __name__ == '__main__':
    port = int(os.environ.get("PORT", 7860))
    app.run(host='0.0.0.0', port=port, debug=False)