| import os |
| import re |
| import cv2 |
| import numpy as np |
| import random |
| import secrets |
| from flask import Flask, render_template, request, send_from_directory, jsonify, g |
| import uuid |
| import logging |
| import onnxruntime as ort |
| import time |
| import threading |
| from collections import defaultdict |
|
|
| |
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
| app = Flask(__name__) |
| app.config['UPLOAD_FOLDER'] = 'static/uploads' |
| app.config['RESULTS_FOLDER'] = 'static/results' |
| app.config['MODEL_FOLDER'] = 'models' |
| |
| app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 |
|
|
| ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'} |
| |
| MAX_IMAGE_PIXELS = 25 * 1000 * 1000 |
|
|
| def allowed_file(filename): |
| return '.' in filename and \ |
| filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
| def is_valid_image_content(file_stream, ext): |
| """Verifies that the file content matches the claimed image extension using magic bytes.""" |
| header = file_stream.read(12) |
| file_stream.seek(0) |
|
|
| if not header: |
| return False |
|
|
| if ext in ['jpg', 'jpeg']: |
| return header.startswith(b'\xff\xd8\xff') |
| elif ext == 'png': |
| return header.startswith(b'\x89PNG\r\n\x1a\n') |
| elif ext == 'webp': |
| return header.startswith(b'RIFF') and header[8:12] == b'WEBP' |
|
|
| return False |
|
|
| @app.errorhandler(413) |
| def request_entity_too_large(error): |
| return jsonify({'error': 'File is too large. Maximum size is 16MB.'}), 413 |
|
|
| |
| os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) |
| os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True) |
|
|
| |
| MODEL_PATH = os.path.join(app.config['MODEL_FOLDER'], 'ddcolor_modelserver.onnx') |
|
|
| session = None |
| input_name_cached = None |
| output_name_cached = None |
| model_lock = threading.Lock() |
| rate_limit_lock = threading.Lock() |
| cleanup_lock = threading.Lock() |
|
|
| |
| upload_counts = defaultdict(list) |
| RATE_LIMIT_LIMIT = 10 |
| RATE_LIMIT_WINDOW = 60 |
|
|
| def cleanup_old_files(): |
| """Deletes files in uploads and results folders older than 12 hours to prevent disk exhaustion.""" |
| with cleanup_lock: |
| now = time.time() |
| max_age = 12 * 3600 |
|
|
| folders = [app.config['UPLOAD_FOLDER'], app.config['RESULTS_FOLDER']] |
| for folder in folders: |
| try: |
| |
| with os.scandir(folder) as it: |
| for entry in it: |
| if entry.name == '.gitkeep' or not entry.is_file(): |
| continue |
| try: |
| |
| if entry.stat().st_mtime < now - max_age: |
| os.remove(entry.path) |
| logger.info(f"Cleaned up old file: {entry.path}") |
| except Exception as e: |
| logger.error(f"Error cleaning up file {entry.path}: {e}") |
| except Exception as e: |
| logger.error(f"Error scanning folder {folder}: {e}") |
|
|
| def is_rate_limited(ip): |
| """Checks if an IP is exceeding the upload rate limit.""" |
| now = time.time() |
|
|
| with rate_limit_lock: |
| |
| timestamps = upload_counts.get(ip, []) |
| if timestamps: |
| timestamps = [t for t in timestamps if now - t < RATE_LIMIT_WINDOW] |
|
|
| if len(timestamps) >= RATE_LIMIT_LIMIT: |
| upload_counts[ip] = timestamps |
| return True |
|
|
| |
| timestamps.append(now) |
| upload_counts[ip] = timestamps |
|
|
| |
| |
| if random.random() < 0.01: |
| |
| expired_ips = [k for k, v in upload_counts.items() |
| if not [t for t in v if now - t < RATE_LIMIT_WINDOW]] |
| for e_ip in expired_ips: |
| del upload_counts[e_ip] |
|
|
| |
| if not cleanup_lock.locked(): |
| threading.Thread(target=cleanup_old_files, daemon=True).start() |
|
|
| return False |
|
|
| def load_model(): |
| global session, input_name_cached, output_name_cached |
| if not os.path.exists(MODEL_PATH): |
| logger.warning(f"MISSING MODEL FILE: {MODEL_PATH}") |
| return False |
|
|
| try: |
| logger.info("Loading DDColor ONNX model...") |
|
|
| |
| sess_options = ort.SessionOptions() |
| sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL |
| |
| sess_options.enable_mem_pattern = True |
| sess_options.enable_cpu_mem_arena = True |
|
|
| |
| session = ort.InferenceSession( |
| MODEL_PATH, |
| sess_options=sess_options, |
| providers=['CPUExecutionProvider'] |
| ) |
| |
| input_name_cached = session.get_inputs()[0].name |
| output_name_cached = session.get_outputs()[0].name |
|
|
| logger.info("Model loaded successfully.") |
| return True |
| except Exception as e: |
| logger.error(f"Error loading model: {e}") |
| return False |
|
|
| |
| model_loaded = load_model() |
|
|
| @app.before_request |
| def generate_nonce(): |
| g.nonce = secrets.token_hex(16) |
|
|
| @app.route('/') |
| def index(): |
| return render_template('index.html', nonce=g.nonce) |
|
|
| @app.route('/upload', methods=['POST']) |
| def upload_file(): |
| global model_loaded |
|
|
| |
| |
| |
| client_ip = request.remote_addr |
| if is_rate_limited(client_ip): |
| return jsonify({'error': 'Rate limit exceeded. Please wait a minute.'}), 429 |
|
|
| if not model_loaded: |
| with model_lock: |
| |
| if not model_loaded: |
| if not load_model(): |
| return jsonify({'error': 'Server model is not ready.'}), 500 |
| model_loaded = True |
| |
| if 'file' not in request.files: |
| return jsonify({'error': 'No file part'}), 400 |
| file = request.files['file'] |
| if file.filename == '': |
| return jsonify({'error': 'No selected file'}), 400 |
| |
| if file and allowed_file(file.filename): |
| ext = file.filename.rsplit('.', 1)[1].lower() |
|
|
| |
| if not is_valid_image_content(file.stream, ext): |
| return jsonify({'error': 'Invalid image content for the given extension.'}), 400 |
|
|
| unique_id = uuid.uuid4().hex |
| filename = f"{unique_id}.{ext}" |
| filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) |
| file.save(filepath) |
| |
| try: |
| colorized_filename = process_image(filepath, filename) |
| return jsonify({ |
| 'original_url': f"/static/uploads/{filename}", |
| 'colorized_url': f"/static/results/{colorized_filename}", |
| 'colorized_filename': colorized_filename |
| }) |
| except Exception as e: |
| logger.error(f"Error processing image: {e}") |
| |
| if os.path.exists(filepath): |
| os.remove(filepath) |
| return jsonify({'error': 'An error occurred during image processing'}), 500 |
| else: |
| return jsonify({'error': 'File type not allowed. Supported: PNG, JPG, JPEG, WEBP'}), 400 |
|
|
| def process_image(path, filename): |
| |
| img_bgr = cv2.imread(path) |
| if img_bgr is None: |
| raise ValueError("Could not read image") |
|
|
| orig_height, orig_width = img_bgr.shape[:2] |
|
|
| |
| if orig_height * orig_width > MAX_IMAGE_PIXELS: |
| raise ValueError(f"Image dimensions too large: {orig_width}x{orig_height}") |
|
|
| |
| |
| |
| img_input = cv2.dnn.blobFromImage(img_bgr, 1.0/255.0, (512, 512), (0, 0, 0), swapRB=True, crop=False) |
| |
| |
| outputs = session.run([output_name_cached], {input_name_cached: img_input}) |
| |
| |
| output = outputs[0][0] |
| |
| |
| |
| |
| np.multiply(output, 255.0, out=output) |
| np.clip(output, 0, 255, out=output) |
| output_uint8 = output.astype(np.uint8) |
| |
| |
| |
| res_img_bgr_small = cv2.merge([output_uint8[2], output_uint8[1], output_uint8[0]]) |
|
|
| |
| |
| if orig_width == 512 and orig_height == 512: |
| res_img_bgr = res_img_bgr_small |
| else: |
| res_img_bgr = cv2.resize(res_img_bgr_small, (orig_width, orig_height), interpolation=cv2.INTER_LINEAR) |
| |
| result_filename = f"colorized_{filename}" |
| result_path = os.path.join(app.config['RESULTS_FOLDER'], result_filename) |
|
|
| |
| |
| params = [] |
| if result_filename.lower().endswith(('.jpg', '.jpeg')): |
| params = [cv2.IMWRITE_JPEG_QUALITY, 90] |
| elif result_filename.lower().endswith('.png'): |
| params = [cv2.IMWRITE_PNG_COMPRESSION, 1] |
|
|
| cv2.imwrite(result_path, res_img_bgr, params) |
| |
| return result_filename |
|
|
| @app.route('/download/<filename>') |
| def download_file(filename): |
| |
| |
| if not re.match(r'^colorized_[a-f0-9]{32}\.(png|jpg|jpeg|webp)$', filename): |
| return jsonify({'error': 'Invalid filename format'}), 400 |
| return send_from_directory(app.config['RESULTS_FOLDER'], filename, as_attachment=True) |
|
|
| @app.after_request |
| def add_security_headers(response): |
| response.headers['X-Content-Type-Options'] = 'nosniff' |
| response.headers['X-Frame-Options'] = 'SAMEORIGIN' |
| response.headers['X-XSS-Protection'] = '1; mode=block' |
| response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin' |
| response.headers['Permissions-Policy'] = 'camera=(), microphone=(), geolocation=(), usb=()' |
| response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains' |
| response.headers['Server'] = '' |
|
|
| |
| |
| |
| |
| csp = ( |
| "default-src 'none'; " |
| f"script-src 'self' 'nonce-{g.nonce}' cdn.tailwindcss.com; " |
| f"style-src 'self' 'unsafe-inline' 'nonce-{g.nonce}' cdn.tailwindcss.com fonts.googleapis.com; " |
| "img-src 'self' data:; " |
| "font-src 'self' fonts.gstatic.com; " |
| "connect-src 'self'; " |
| "object-src 'none'; " |
| "base-uri 'none'; " |
| "form-action 'self'; " |
| "frame-ancestors 'none';" |
| ) |
| response.headers['Content-Security-Policy'] = csp |
| return response |
|
|
| if __name__ == "__main__": |
| |
| port = int(os.environ.get("PORT", 7860)) |
| app.run(host="0.0.0.0", port=port) |
|
|