import os import re import cv2 import numpy as np import random import secrets from flask import Flask, render_template, request, send_from_directory, jsonify, g import uuid import logging import onnxruntime as ort import time import threading from collections import defaultdict # Configure basic logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = Flask(__name__) app.config['UPLOAD_FOLDER'] = 'static/uploads' app.config['RESULTS_FOLDER'] = 'static/results' app.config['MODEL_FOLDER'] = 'models' # Limit upload size to 16MB to prevent DoS app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024 ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'} # Limit image dimensions to 25MP to prevent memory exhaustion (DoS) MAX_IMAGE_PIXELS = 25 * 1000 * 1000 def allowed_file(filename): return '.' in filename and \ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS def is_valid_image_content(file_stream, ext): """Verifies that the file content matches the claimed image extension using magic bytes.""" header = file_stream.read(12) file_stream.seek(0) # Reset stream position after reading if not header: return False if ext in ['jpg', 'jpeg']: return header.startswith(b'\xff\xd8\xff') elif ext == 'png': return header.startswith(b'\x89PNG\r\n\x1a\n') elif ext == 'webp': return header.startswith(b'RIFF') and header[8:12] == b'WEBP' return False @app.errorhandler(413) def request_entity_too_large(error): return jsonify({'error': 'File is too large. Maximum size is 16MB.'}), 413 # Ensure directories exist os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True) os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True) # Model paths MODEL_PATH = os.path.join(app.config['MODEL_FOLDER'], 'ddcolor_modelserver.onnx') session = None input_name_cached = None output_name_cached = None model_lock = threading.Lock() rate_limit_lock = threading.Lock() cleanup_lock = threading.Lock() # Simple in-memory rate limiting upload_counts = defaultdict(list) RATE_LIMIT_LIMIT = 10 RATE_LIMIT_WINDOW = 60 # seconds def cleanup_old_files(): """Deletes files in uploads and results folders older than 12 hours to prevent disk exhaustion.""" with cleanup_lock: now = time.time() max_age = 12 * 3600 # 12 hours folders = [app.config['UPLOAD_FOLDER'], app.config['RESULTS_FOLDER']] for folder in folders: try: # Use os.scandir for better performance during directory traversal with os.scandir(folder) as it: for entry in it: if entry.name == '.gitkeep' or not entry.is_file(): continue try: # entry.stat() is often cached on modern OSs during scandir if entry.stat().st_mtime < now - max_age: os.remove(entry.path) logger.info(f"Cleaned up old file: {entry.path}") except Exception as e: logger.error(f"Error cleaning up file {entry.path}: {e}") except Exception as e: logger.error(f"Error scanning folder {folder}: {e}") def is_rate_limited(ip): """Checks if an IP is exceeding the upload rate limit.""" now = time.time() with rate_limit_lock: # Get and filter timestamps for this IP timestamps = upload_counts.get(ip, []) if timestamps: timestamps = [t for t in timestamps if now - t < RATE_LIMIT_WINDOW] if len(timestamps) >= RATE_LIMIT_LIMIT: upload_counts[ip] = timestamps return True # Add current timestamp and update the record timestamps.append(now) upload_counts[ip] = timestamps # Periodically (1% of requests) clean up to prevent memory and disk exhaustion # We use standard random module here for efficiency if random.random() < 0.01: # 1. Clean up rate limiter memory expired_ips = [k for k, v in upload_counts.items() if not [t for t in v if now - t < RATE_LIMIT_WINDOW]] for e_ip in expired_ips: del upload_counts[e_ip] # 2. Clean up old files from disk (in background) if not cleanup_lock.locked(): threading.Thread(target=cleanup_old_files, daemon=True).start() return False def load_model(): global session, input_name_cached, output_name_cached if not os.path.exists(MODEL_PATH): logger.warning(f"MISSING MODEL FILE: {MODEL_PATH}") return False try: logger.info("Loading DDColor ONNX model...") # Configure session options for performance sess_options = ort.SessionOptions() sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL # Enable memory pattern optimization and CPU memory arena for faster inference and lower fragmentation sess_options.enable_mem_pattern = True sess_options.enable_cpu_mem_arena = True # Use CPU provider for consistency on HF Spaces session = ort.InferenceSession( MODEL_PATH, sess_options=sess_options, providers=['CPUExecutionProvider'] ) # Cache input and output names to avoid redundant lookups during processing input_name_cached = session.get_inputs()[0].name output_name_cached = session.get_outputs()[0].name logger.info("Model loaded successfully.") return True except Exception as e: logger.error(f"Error loading model: {e}") return False # Load model on startup model_loaded = load_model() @app.before_request def generate_nonce(): g.nonce = secrets.token_hex(16) @app.route('/') def index(): return render_template('index.html', nonce=g.nonce) @app.route('/upload', methods=['POST']) def upload_file(): global model_loaded # Rate limit by IP # Note: We use remote_addr because trusting X-Forwarded-For without # a configured proxy is a security risk (spoofing). client_ip = request.remote_addr if is_rate_limited(client_ip): return jsonify({'error': 'Rate limit exceeded. Please wait a minute.'}), 429 if not model_loaded: with model_lock: # Check again inside the lock if not model_loaded: if not load_model(): return jsonify({'error': 'Server model is not ready.'}), 500 model_loaded = True if 'file' not in request.files: return jsonify({'error': 'No file part'}), 400 file = request.files['file'] if file.filename == '': return jsonify({'error': 'No selected file'}), 400 if file and allowed_file(file.filename): ext = file.filename.rsplit('.', 1)[1].lower() # Security enhancement: Verify magic bytes to prevent spoofed extensions if not is_valid_image_content(file.stream, ext): return jsonify({'error': 'Invalid image content for the given extension.'}), 400 unique_id = uuid.uuid4().hex filename = f"{unique_id}.{ext}" filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename) file.save(filepath) try: colorized_filename = process_image(filepath, filename) return jsonify({ 'original_url': f"/static/uploads/{filename}", 'colorized_url': f"/static/results/{colorized_filename}", 'colorized_filename': colorized_filename }) except Exception as e: logger.error(f"Error processing image: {e}") # Clean up the uploaded file if processing fails if os.path.exists(filepath): os.remove(filepath) return jsonify({'error': 'An error occurred during image processing'}), 500 else: return jsonify({'error': 'File type not allowed. Supported: PNG, JPG, JPEG, WEBP'}), 400 def process_image(path, filename): # Load image with OpenCV (BGR format) img_bgr = cv2.imread(path) if img_bgr is None: raise ValueError("Could not read image") orig_height, orig_width = img_bgr.shape[:2] # Check if image dimensions exceed the maximum allowed pixels if orig_height * orig_width > MAX_IMAGE_PIXELS: raise ValueError(f"Image dimensions too large: {orig_width}x{orig_height}") # Preprocess: Use cv2.dnn.blobFromImage to resize to 512x512, convert to RGB, # normalize to [0, 1], and change to [1, 3, H, W] format in one highly optimized step. # Benchmarking shows this is ~5x faster than separate resize + blob conversion for large images. img_input = cv2.dnn.blobFromImage(img_bgr, 1.0/255.0, (512, 512), (0, 0, 0), swapRB=True, crop=False) # Run inference using cached input/output names for better performance outputs = session.run([output_name_cached], {input_name_cached: img_input}) # Postprocess: [1, 3, H, W] -> [H, W, 3] output = outputs[0][0] # Scale and clip in-place while channel-first (CHW), then use cv2.merge for fast RGB->BGR + CHW->HWC conversion. # This is ~3x faster than np.transpose followed by cv2.cvtColor. # Using np.multiply with out parameter ensures the operation is performed in-place. np.multiply(output, 255.0, out=output) np.clip(output, 0, 255, out=output) output_uint8 = output.astype(np.uint8) # Convert RGB to BGR and CHW to HWC using cv2.merge. # Doing this before upscaling is more efficient. res_img_bgr_small = cv2.merge([output_uint8[2], output_uint8[1], output_uint8[0]]) # Resize back to original dimensions. INTER_LINEAR is ~55% faster than INTER_CUBIC # with negligible loss in quality for AI-generated results. if orig_width == 512 and orig_height == 512: res_img_bgr = res_img_bgr_small else: res_img_bgr = cv2.resize(res_img_bgr_small, (orig_width, orig_height), interpolation=cv2.INTER_LINEAR) result_filename = f"colorized_{filename}" result_path = os.path.join(app.config['RESULTS_FOLDER'], result_filename) # Optimization: Use lower JPEG quality and faster PNG compression to speed up encoding # and reduce file size. params = [] if result_filename.lower().endswith(('.jpg', '.jpeg')): params = [cv2.IMWRITE_JPEG_QUALITY, 90] elif result_filename.lower().endswith('.png'): params = [cv2.IMWRITE_PNG_COMPRESSION, 1] cv2.imwrite(result_path, res_img_bgr, params) return result_filename @app.route('/download/') def download_file(filename): # Strict filename validation to prevent path traversal and unauthorized access. # Pattern: colorized_[32-char hex UUID].[extension] if not re.match(r'^colorized_[a-f0-9]{32}\.(png|jpg|jpeg|webp)$', filename): return jsonify({'error': 'Invalid filename format'}), 400 return send_from_directory(app.config['RESULTS_FOLDER'], filename, as_attachment=True) @app.after_request def add_security_headers(response): response.headers['X-Content-Type-Options'] = 'nosniff' response.headers['X-Frame-Options'] = 'SAMEORIGIN' response.headers['X-XSS-Protection'] = '1; mode=block' response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin' response.headers['Permissions-Policy'] = 'camera=(), microphone=(), geolocation=(), usb=()' response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains' response.headers['Server'] = '' # Content Security Policy: restrict resources to trusted sources # object-src 'none', base-uri 'none', form-action 'self', and frame-ancestors 'none' added for hardening # Nonce added to script-src and style-src to allow legitimate inline scripts/styles while preventing XSS. # Note: 'unsafe-inline' is kept for style-src because the Tailwind Play CDN dynamically injects styles. csp = ( "default-src 'none'; " f"script-src 'self' 'nonce-{g.nonce}' cdn.tailwindcss.com; " f"style-src 'self' 'unsafe-inline' 'nonce-{g.nonce}' cdn.tailwindcss.com fonts.googleapis.com; " "img-src 'self' data:; " "font-src 'self' fonts.gstatic.com; " "connect-src 'self'; " "object-src 'none'; " "base-uri 'none'; " "form-action 'self'; " "frame-ancestors 'none';" ) response.headers['Content-Security-Policy'] = csp return response if __name__ == "__main__": # Use the port Hugging Face or Railway provides, or default to 7860 port = int(os.environ.get("PORT", 7860)) app.run(host="0.0.0.0", port=port)