imageup / app.py
GitHub Actions
Deploy to Hugging Face Spaces
9a22bb5
import os
import re
import cv2
import numpy as np
import random
import secrets
from flask import Flask, render_template, request, send_from_directory, jsonify, g
import uuid
import logging
import onnxruntime as ort
import time
import threading
from collections import defaultdict
# Configure basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['RESULTS_FOLDER'] = 'static/results'
app.config['MODEL_FOLDER'] = 'models'
# Limit upload size to 16MB to prevent DoS
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'}
# Limit image dimensions to 25MP to prevent memory exhaustion (DoS)
MAX_IMAGE_PIXELS = 25 * 1000 * 1000
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def is_valid_image_content(file_stream, ext):
"""Verifies that the file content matches the claimed image extension using magic bytes."""
header = file_stream.read(12)
file_stream.seek(0) # Reset stream position after reading
if not header:
return False
if ext in ['jpg', 'jpeg']:
return header.startswith(b'\xff\xd8\xff')
elif ext == 'png':
return header.startswith(b'\x89PNG\r\n\x1a\n')
elif ext == 'webp':
return header.startswith(b'RIFF') and header[8:12] == b'WEBP'
return False
@app.errorhandler(413)
def request_entity_too_large(error):
return jsonify({'error': 'File is too large. Maximum size is 16MB.'}), 413
# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)
# Model paths
MODEL_PATH = os.path.join(app.config['MODEL_FOLDER'], 'ddcolor_modelserver.onnx')
session = None
input_name_cached = None
output_name_cached = None
model_lock = threading.Lock()
rate_limit_lock = threading.Lock()
cleanup_lock = threading.Lock()
# Simple in-memory rate limiting
upload_counts = defaultdict(list)
RATE_LIMIT_LIMIT = 10
RATE_LIMIT_WINDOW = 60 # seconds
def cleanup_old_files():
"""Deletes files in uploads and results folders older than 12 hours to prevent disk exhaustion."""
with cleanup_lock:
now = time.time()
max_age = 12 * 3600 # 12 hours
folders = [app.config['UPLOAD_FOLDER'], app.config['RESULTS_FOLDER']]
for folder in folders:
try:
# Use os.scandir for better performance during directory traversal
with os.scandir(folder) as it:
for entry in it:
if entry.name == '.gitkeep' or not entry.is_file():
continue
try:
# entry.stat() is often cached on modern OSs during scandir
if entry.stat().st_mtime < now - max_age:
os.remove(entry.path)
logger.info(f"Cleaned up old file: {entry.path}")
except Exception as e:
logger.error(f"Error cleaning up file {entry.path}: {e}")
except Exception as e:
logger.error(f"Error scanning folder {folder}: {e}")
def is_rate_limited(ip):
"""Checks if an IP is exceeding the upload rate limit."""
now = time.time()
with rate_limit_lock:
# Get and filter timestamps for this IP
timestamps = upload_counts.get(ip, [])
if timestamps:
timestamps = [t for t in timestamps if now - t < RATE_LIMIT_WINDOW]
if len(timestamps) >= RATE_LIMIT_LIMIT:
upload_counts[ip] = timestamps
return True
# Add current timestamp and update the record
timestamps.append(now)
upload_counts[ip] = timestamps
# Periodically (1% of requests) clean up to prevent memory and disk exhaustion
# We use standard random module here for efficiency
if random.random() < 0.01:
# 1. Clean up rate limiter memory
expired_ips = [k for k, v in upload_counts.items()
if not [t for t in v if now - t < RATE_LIMIT_WINDOW]]
for e_ip in expired_ips:
del upload_counts[e_ip]
# 2. Clean up old files from disk (in background)
if not cleanup_lock.locked():
threading.Thread(target=cleanup_old_files, daemon=True).start()
return False
def load_model():
global session, input_name_cached, output_name_cached
if not os.path.exists(MODEL_PATH):
logger.warning(f"MISSING MODEL FILE: {MODEL_PATH}")
return False
try:
logger.info("Loading DDColor ONNX model...")
# Configure session options for performance
sess_options = ort.SessionOptions()
sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
# Enable memory pattern optimization and CPU memory arena for faster inference and lower fragmentation
sess_options.enable_mem_pattern = True
sess_options.enable_cpu_mem_arena = True
# Use CPU provider for consistency on HF Spaces
session = ort.InferenceSession(
MODEL_PATH,
sess_options=sess_options,
providers=['CPUExecutionProvider']
)
# Cache input and output names to avoid redundant lookups during processing
input_name_cached = session.get_inputs()[0].name
output_name_cached = session.get_outputs()[0].name
logger.info("Model loaded successfully.")
return True
except Exception as e:
logger.error(f"Error loading model: {e}")
return False
# Load model on startup
model_loaded = load_model()
@app.before_request
def generate_nonce():
g.nonce = secrets.token_hex(16)
@app.route('/')
def index():
return render_template('index.html', nonce=g.nonce)
@app.route('/upload', methods=['POST'])
def upload_file():
global model_loaded
# Rate limit by IP
# Note: We use remote_addr because trusting X-Forwarded-For without
# a configured proxy is a security risk (spoofing).
client_ip = request.remote_addr
if is_rate_limited(client_ip):
return jsonify({'error': 'Rate limit exceeded. Please wait a minute.'}), 429
if not model_loaded:
with model_lock:
# Check again inside the lock
if not model_loaded:
if not load_model():
return jsonify({'error': 'Server model is not ready.'}), 500
model_loaded = True
if 'file' not in request.files:
return jsonify({'error': 'No file part'}), 400
file = request.files['file']
if file.filename == '':
return jsonify({'error': 'No selected file'}), 400
if file and allowed_file(file.filename):
ext = file.filename.rsplit('.', 1)[1].lower()
# Security enhancement: Verify magic bytes to prevent spoofed extensions
if not is_valid_image_content(file.stream, ext):
return jsonify({'error': 'Invalid image content for the given extension.'}), 400
unique_id = uuid.uuid4().hex
filename = f"{unique_id}.{ext}"
filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
file.save(filepath)
try:
colorized_filename = process_image(filepath, filename)
return jsonify({
'original_url': f"/static/uploads/{filename}",
'colorized_url': f"/static/results/{colorized_filename}",
'colorized_filename': colorized_filename
})
except Exception as e:
logger.error(f"Error processing image: {e}")
# Clean up the uploaded file if processing fails
if os.path.exists(filepath):
os.remove(filepath)
return jsonify({'error': 'An error occurred during image processing'}), 500
else:
return jsonify({'error': 'File type not allowed. Supported: PNG, JPG, JPEG, WEBP'}), 400
def process_image(path, filename):
# Load image with OpenCV (BGR format)
img_bgr = cv2.imread(path)
if img_bgr is None:
raise ValueError("Could not read image")
orig_height, orig_width = img_bgr.shape[:2]
# Check if image dimensions exceed the maximum allowed pixels
if orig_height * orig_width > MAX_IMAGE_PIXELS:
raise ValueError(f"Image dimensions too large: {orig_width}x{orig_height}")
# Preprocess: Use cv2.dnn.blobFromImage to resize to 512x512, convert to RGB,
# normalize to [0, 1], and change to [1, 3, H, W] format in one highly optimized step.
# Benchmarking shows this is ~5x faster than separate resize + blob conversion for large images.
img_input = cv2.dnn.blobFromImage(img_bgr, 1.0/255.0, (512, 512), (0, 0, 0), swapRB=True, crop=False)
# Run inference using cached input/output names for better performance
outputs = session.run([output_name_cached], {input_name_cached: img_input})
# Postprocess: [1, 3, H, W] -> [H, W, 3]
output = outputs[0][0]
# Scale and clip in-place while channel-first (CHW), then use cv2.merge for fast RGB->BGR + CHW->HWC conversion.
# This is ~3x faster than np.transpose followed by cv2.cvtColor.
# Using np.multiply with out parameter ensures the operation is performed in-place.
np.multiply(output, 255.0, out=output)
np.clip(output, 0, 255, out=output)
output_uint8 = output.astype(np.uint8)
# Convert RGB to BGR and CHW to HWC using cv2.merge.
# Doing this before upscaling is more efficient.
res_img_bgr_small = cv2.merge([output_uint8[2], output_uint8[1], output_uint8[0]])
# Resize back to original dimensions. INTER_LINEAR is ~55% faster than INTER_CUBIC
# with negligible loss in quality for AI-generated results.
if orig_width == 512 and orig_height == 512:
res_img_bgr = res_img_bgr_small
else:
res_img_bgr = cv2.resize(res_img_bgr_small, (orig_width, orig_height), interpolation=cv2.INTER_LINEAR)
result_filename = f"colorized_{filename}"
result_path = os.path.join(app.config['RESULTS_FOLDER'], result_filename)
# Optimization: Use lower JPEG quality and faster PNG compression to speed up encoding
# and reduce file size.
params = []
if result_filename.lower().endswith(('.jpg', '.jpeg')):
params = [cv2.IMWRITE_JPEG_QUALITY, 90]
elif result_filename.lower().endswith('.png'):
params = [cv2.IMWRITE_PNG_COMPRESSION, 1]
cv2.imwrite(result_path, res_img_bgr, params)
return result_filename
@app.route('/download/<filename>')
def download_file(filename):
# Strict filename validation to prevent path traversal and unauthorized access.
# Pattern: colorized_[32-char hex UUID].[extension]
if not re.match(r'^colorized_[a-f0-9]{32}\.(png|jpg|jpeg|webp)$', filename):
return jsonify({'error': 'Invalid filename format'}), 400
return send_from_directory(app.config['RESULTS_FOLDER'], filename, as_attachment=True)
@app.after_request
def add_security_headers(response):
response.headers['X-Content-Type-Options'] = 'nosniff'
response.headers['X-Frame-Options'] = 'SAMEORIGIN'
response.headers['X-XSS-Protection'] = '1; mode=block'
response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
response.headers['Permissions-Policy'] = 'camera=(), microphone=(), geolocation=(), usb=()'
response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
response.headers['Server'] = ''
# Content Security Policy: restrict resources to trusted sources
# object-src 'none', base-uri 'none', form-action 'self', and frame-ancestors 'none' added for hardening
# Nonce added to script-src and style-src to allow legitimate inline scripts/styles while preventing XSS.
# Note: 'unsafe-inline' is kept for style-src because the Tailwind Play CDN dynamically injects styles.
csp = (
"default-src 'none'; "
f"script-src 'self' 'nonce-{g.nonce}' cdn.tailwindcss.com; "
f"style-src 'self' 'unsafe-inline' 'nonce-{g.nonce}' cdn.tailwindcss.com fonts.googleapis.com; "
"img-src 'self' data:; "
"font-src 'self' fonts.gstatic.com; "
"connect-src 'self'; "
"object-src 'none'; "
"base-uri 'none'; "
"form-action 'self'; "
"frame-ancestors 'none';"
)
response.headers['Content-Security-Policy'] = csp
return response
if __name__ == "__main__":
# Use the port Hugging Face or Railway provides, or default to 7860
port = int(os.environ.get("PORT", 7860))
app.run(host="0.0.0.0", port=port)