File size: 12,826 Bytes
9a22bb5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 | import os
import re
import cv2
import numpy as np
import random
import secrets
from flask import Flask, render_template, request, send_from_directory, jsonify, g
import uuid
import logging
import onnxruntime as ort
import time
import threading
from collections import defaultdict
# Configure basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['RESULTS_FOLDER'] = 'static/results'
app.config['MODEL_FOLDER'] = 'models'
# Limit upload size to 16MB to prevent DoS
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'}
# Limit image dimensions to 25MP to prevent memory exhaustion (DoS)
MAX_IMAGE_PIXELS = 25 * 1000 * 1000
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_valid_image_content(file_stream, ext):
"""Verifies that the file content matches the claimed image extension using magic bytes."""
header = file_stream.read(12)
file_stream.seek(0) # Reset stream position after reading
if not header:
return False
if ext in ['jpg', 'jpeg']:
return header.startswith(b'\xff\xd8\xff')
elif ext == 'png':
return header.startswith(b'\x89PNG\r\n\x1a\n')
elif ext == 'webp':
return header.startswith(b'RIFF') and header[8:12] == b'WEBP'
return False
@app.errorhandler(413)
def request_entity_too_large(error):
return jsonify({'error': 'File is too large. Maximum size is 16MB.'}), 413
# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
# Model paths
MODEL_PATH = os.path.join(app.config['MODEL_FOLDER'], 'ddcolor_modelserver.onnx')
session = None
input_name_cached = None
output_name_cached = None
model_lock = threading.Lock()
rate_limit_lock = threading.Lock()
cleanup_lock = threading.Lock()
# Simple in-memory rate limiting
upload_counts = defaultdict(list)
RATE_LIMIT_LIMIT = 10
RATE_LIMIT_WINDOW = 60 # seconds
def cleanup_old_files():
"""Deletes files in uploads and results folders older than 12 hours to prevent disk exhaustion."""
with cleanup_lock:
now = time.time()
max_age = 12 * 3600 # 12 hours
folders = [app.config['UPLOAD_FOLDER'], app.config['RESULTS_FOLDER']]
for folder in folders:
try:
# Use os.scandir for better performance during directory traversal
with os.scandir(folder) as it:
for entry in it:
if entry.name == '.gitkeep' or not entry.is_file():
continue
try:
# entry.stat() is often cached on modern OSs during scandir
if entry.stat().st_mtime < now - max_age:
os.remove(entry.path)
logger.info(f"Cleaned up old file: {entry.path}")
except Exception as e:
logger.error(f"Error cleaning up file {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning folder {folder}: {e}")
def is_rate_limited(ip):
"""Checks if an IP is exceeding the upload rate limit."""
now = time.time()
with rate_limit_lock:
# Get and filter timestamps for this IP
timestamps = upload_counts.get(ip, [])
if timestamps:
timestamps = [t for t in timestamps if now - t < RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_LIMIT:
upload_counts[ip] = timestamps
return True
# Add current timestamp and update the record
timestamps.append(now)
upload_counts[ip] = timestamps
# Periodically (1% of requests) clean up to prevent memory and disk exhaustion
# We use standard random module here for efficiency
if random.random() < 0.01:
# 1. Clean up rate limiter memory
expired_ips = [k for k, v in upload_counts.items()
if not [t for t in v if now - t < RATE_LIMIT_WINDOW]]
for e_ip in expired_ips:
del upload_counts[e_ip]
# 2. Clean up old files from disk (in background)
if not cleanup_lock.locked():
threading.Thread(target=cleanup_old_files, daemon=True).start()
return False
def load_model():
global session, input_name_cached, output_name_cached
if not os.path.exists(MODEL_PATH):
logger.warning(f"MISSING MODEL FILE: {MODEL_PATH}")
return False
try:
logger.info("Loading DDColor ONNX model...")
# Configure session options for performance
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Enable memory pattern optimization and CPU memory arena for faster inference and lower fragmentation
sess_options.enable_mem_pattern = True
sess_options.enable_cpu_mem_arena = True
# Use CPU provider for consistency on HF Spaces
session = ort.InferenceSession(
MODEL_PATH,
sess_options=sess_options,
providers=['CPUExecutionProvider']
)
# Cache input and output names to avoid redundant lookups during processing
input_name_cached = session.get_inputs()[0].name
output_name_cached = session.get_outputs()[0].name
logger.info("Model loaded successfully.")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
# Load model on startup
model_loaded = load_model()
@app.before_request
def generate_nonce():
g.nonce = secrets.token_hex(16)
@app.route('/')
def index():
return render_template('index.html', nonce=g.nonce)
@app.route('/upload', methods=['POST'])
def upload_file():
global model_loaded
# Rate limit by IP
# Note: We use remote_addr because trusting X-Forwarded-For without
# a configured proxy is a security risk (spoofing).
client_ip = request.remote_addr
if is_rate_limited(client_ip):
return jsonify({'error': 'Rate limit exceeded. Please wait a minute.'}), 429
if not model_loaded:
with model_lock:
# Check again inside the lock
if not model_loaded:
if not load_model():
return jsonify({'error': 'Server model is not ready.'}), 500
model_loaded = True
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
ext = file.filename.rsplit('.', 1)[1].lower()
# Security enhancement: Verify magic bytes to prevent spoofed extensions
if not is_valid_image_content(file.stream, ext):
return jsonify({'error': 'Invalid image content for the given extension.'}), 400
unique_id = uuid.uuid4().hex
filename = f"{unique_id}.{ext}"
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
colorized_filename = process_image(filepath, filename)
return jsonify({
'original_url': f"/static/uploads/{filename}",
'colorized_url': f"/static/results/{colorized_filename}",
'colorized_filename': colorized_filename
})
except Exception as e:
logger.error(f"Error processing image: {e}")
# Clean up the uploaded file if processing fails
if os.path.exists(filepath):
os.remove(filepath)
return jsonify({'error': 'An error occurred during image processing'}), 500
else:
return jsonify({'error': 'File type not allowed. Supported: PNG, JPG, JPEG, WEBP'}), 400
def process_image(path, filename):
# Load image with OpenCV (BGR format)
img_bgr = cv2.imread(path)
if img_bgr is None:
raise ValueError("Could not read image")
orig_height, orig_width = img_bgr.shape[:2]
# Check if image dimensions exceed the maximum allowed pixels
if orig_height * orig_width > MAX_IMAGE_PIXELS:
raise ValueError(f"Image dimensions too large: {orig_width}x{orig_height}")
# Preprocess: Use cv2.dnn.blobFromImage to resize to 512x512, convert to RGB,
# normalize to [0, 1], and change to [1, 3, H, W] format in one highly optimized step.
# Benchmarking shows this is ~5x faster than separate resize + blob conversion for large images.
img_input = cv2.dnn.blobFromImage(img_bgr, 1.0/255.0, (512, 512), (0, 0, 0), swapRB=True, crop=False)
# Run inference using cached input/output names for better performance
outputs = session.run([output_name_cached], {input_name_cached: img_input})
# Postprocess: [1, 3, H, W] -> [H, W, 3]
output = outputs[0][0]
# Scale and clip in-place while channel-first (CHW), then use cv2.merge for fast RGB->BGR + CHW->HWC conversion.
# This is ~3x faster than np.transpose followed by cv2.cvtColor.
# Using np.multiply with out parameter ensures the operation is performed in-place.
np.multiply(output, 255.0, out=output)
np.clip(output, 0, 255, out=output)
output_uint8 = output.astype(np.uint8)
# Convert RGB to BGR and CHW to HWC using cv2.merge.
# Doing this before upscaling is more efficient.
res_img_bgr_small = cv2.merge([output_uint8[2], output_uint8[1], output_uint8[0]])
# Resize back to original dimensions. INTER_LINEAR is ~55% faster than INTER_CUBIC
# with negligible loss in quality for AI-generated results.
if orig_width == 512 and orig_height == 512:
res_img_bgr = res_img_bgr_small
else:
res_img_bgr = cv2.resize(res_img_bgr_small, (orig_width, orig_height), interpolation=cv2.INTER_LINEAR)
result_filename = f"colorized_{filename}"
result_path = os.path.join(app.config['RESULTS_FOLDER'], result_filename)
# Optimization: Use lower JPEG quality and faster PNG compression to speed up encoding
# and reduce file size.
params = []
if result_filename.lower().endswith(('.jpg', '.jpeg')):
params = [cv2.IMWRITE_JPEG_QUALITY, 90]
elif result_filename.lower().endswith('.png'):
params = [cv2.IMWRITE_PNG_COMPRESSION, 1]
cv2.imwrite(result_path, res_img_bgr, params)
return result_filename
@app.route('/download/<filename>')
def download_file(filename):
# Strict filename validation to prevent path traversal and unauthorized access.
# Pattern: colorized_[32-char hex UUID].[extension]
if not re.match(r'^colorized_[a-f0-9]{32}\.(png|jpg|jpeg|webp)$', filename):
return jsonify({'error': 'Invalid filename format'}), 400
return send_from_directory(app.config['RESULTS_FOLDER'], filename, as_attachment=True)
@app.after_request
def add_security_headers(response):
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
response.headers['X-XSS-Protection'] = '1; mode=block'
response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
response.headers['Permissions-Policy'] = 'camera=(), microphone=(), geolocation=(), usb=()'
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
response.headers['Server'] = ''
# Content Security Policy: restrict resources to trusted sources
# object-src 'none', base-uri 'none', form-action 'self', and frame-ancestors 'none' added for hardening
# Nonce added to script-src and style-src to allow legitimate inline scripts/styles while preventing XSS.
# Note: 'unsafe-inline' is kept for style-src because the Tailwind Play CDN dynamically injects styles.
csp = (
"default-src 'none'; "
f"script-src 'self' 'nonce-{g.nonce}' cdn.tailwindcss.com; "
f"style-src 'self' 'unsafe-inline' 'nonce-{g.nonce}' cdn.tailwindcss.com fonts.googleapis.com; "
"img-src 'self' data:; "
"font-src 'self' fonts.gstatic.com; "
"connect-src 'self'; "
"object-src 'none'; "
"base-uri 'none'; "
"form-action 'self'; "
"frame-ancestors 'none';"
)
response.headers['Content-Security-Policy'] = csp
return response
if __name__ == "__main__":
# Use the port Hugging Face or Railway provides, or default to 7860
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)
|