Spaces:
Sleeping
Sleeping
| import os | |
| import cv2 | |
| import numpy as np | |
| from flask import Flask, request, render_template, url_for | |
| from werkzeug.utils import secure_filename | |
| app = Flask(__name__, static_folder='static') | |
| # --- Directory Setup --- | |
| BASE_DIR = os.path.abspath(os.path.dirname(__file__)) | |
| STATIC_FOLDER = os.path.join(BASE_DIR, 'static') | |
| UPLOAD_FOLDER = os.path.join(STATIC_FOLDER, 'uploads') | |
| OUTPUT_FOLDER = os.path.join(STATIC_FOLDER, 'outputs') | |
| MODEL_FOLDER = os.path.join(BASE_DIR, 'model') | |
| os.makedirs(UPLOAD_FOLDER, exist_ok=True) | |
| os.makedirs(OUTPUT_FOLDER, exist_ok=True) | |
| # --- Model paths --- | |
| protoPath = os.path.join(MODEL_FOLDER, 'colorization_deploy_v2.prototxt') | |
| modelPath = os.path.join(MODEL_FOLDER, 'colorization_release_v2.caffemodel') | |
| hullPath = os.path.join(MODEL_FOLDER, 'pts_in_hull.npy') | |
| COLORIZATION_MODEL_AVAILABLE = False | |
| net = None | |
| # --- Load model --- | |
| try: | |
| assert os.path.exists(protoPath), f"Missing proto file: {protoPath}" | |
| assert os.path.exists(modelPath), f"Missing model file: {modelPath}" | |
| assert os.path.exists(hullPath), f"Missing hull file: {hullPath}" | |
| print("Loading colorization model...") | |
| net = cv2.dnn.readNetFromCaffe(protoPath, modelPath) | |
| pts_in_hull = np.load(hullPath) | |
| if pts_in_hull.shape != (313, 2): | |
| raise ValueError(f"pts_in_hull shape invalid: {pts_in_hull.shape}") | |
| pts = pts_in_hull.transpose().reshape(2, 313, 1, 1).astype(np.float32) | |
| net.getLayer(net.getLayerId('class8_ab')).blobs = [pts] | |
| net.getLayer(net.getLayerId('conv8_313_rh')).blobs = [np.full([1, 313], 2.606, dtype=np.float32)] | |
| COLORIZATION_MODEL_AVAILABLE = True | |
| print("Colorization model loaded successfully.") | |
| except Exception as e: | |
| print(f"Failed to load colorization model: {e}") | |
| MAX_DIMENSION = 2048 | |
| def resize_img(img): | |
| h, w = img.shape[:2] | |
| if max(h, w) > MAX_DIMENSION: | |
| scale = MAX_DIMENSION / max(h, w) | |
| return cv2.resize(img, (int(w * scale), int(h * scale))) | |
| return img | |
| def adjust_brightness_contrast(img, brightness=0, contrast=20): | |
| alpha = 1 + contrast / 100.0 | |
| return cv2.convertScaleAbs(img, alpha=alpha, beta=brightness) | |
| def colorize_image_local(input_path, output_path): | |
| if not COLORIZATION_MODEL_AVAILABLE: | |
| print("Colorization model not available.") | |
| return False | |
| img = cv2.imread(input_path) | |
| if img is None: | |
| print(f"Failed to read image: {input_path}") | |
| return False | |
| img = resize_img(img) | |
| h, w = img.shape[:2] | |
| img_rgb = img.astype("float32") / 255.0 | |
| lab = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2LAB) | |
| L = lab[:, :, 0] | |
| L_resized = cv2.resize(L, (224, 224)) - 50 # mean-centering | |
| try: | |
| net.setInput(cv2.dnn.blobFromImage(L_resized)) | |
| ab = net.forward()[0].transpose(1, 2, 0) | |
| except Exception as e: | |
| print(f"DNN forward pass failed: {e}") | |
| return False | |
| ab = cv2.resize(ab, (w, h)) | |
| out_lab = np.concatenate((L[:, :, np.newaxis], ab), axis=2) | |
| out_bgr = cv2.cvtColor(out_lab, cv2.COLOR_Lab2BGR) | |
| out_bgr = np.clip(out_bgr * 255, 0, 255).astype("uint8") | |
| out_bgr = adjust_brightness_contrast(out_bgr, brightness=-10, contrast=15) | |
| cv2.imwrite(output_path, out_bgr) | |
| return True | |
| def index(): | |
| original_url = enhanced_url = None | |
| if request.method == 'POST': | |
| file = request.files.get('image') | |
| if not file or not file.filename: | |
| return "No file uploaded", 400 | |
| filename = secure_filename(file.filename) | |
| orig_path = os.path.join(UPLOAD_FOLDER, filename) | |
| file.save(orig_path) | |
| output_name = f"output_{filename}" | |
| output_path = os.path.join(OUTPUT_FOLDER, output_name) | |
| if not colorize_image_local(orig_path, output_path): | |
| return "Colorization failed", 400 | |
| original_url = url_for('static', filename=f'uploads/{filename}') | |
| enhanced_url = url_for('static', filename=f'outputs/{output_name}') | |
| return render_template('index.html', original_url=original_url, enhanced_url=enhanced_url) | |
| if __name__ == '__main__': | |
| port = int(os.environ.get("PORT", 7860)) | |
| app.run(host='0.0.0.0', port=port, debug=False) | |