File size: 12,826 Bytes
9a22bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
import os
import re
import cv2
import numpy as np
import random
import secrets
from flask import Flask, render_template, request, send_from_directory, jsonify, g
import uuid
import logging
import onnxruntime as ort
import time
import threading
from collections import defaultdict

# Configure basic logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = Flask(__name__)
app.config['UPLOAD_FOLDER'] = 'static/uploads'
app.config['RESULTS_FOLDER'] = 'static/results'
app.config['MODEL_FOLDER'] = 'models'
# Limit upload size to 16MB to prevent DoS
app.config['MAX_CONTENT_LENGTH'] = 16 * 1024 * 1024

ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg', 'webp'}
# Limit image dimensions to 25MP to prevent memory exhaustion (DoS)
MAX_IMAGE_PIXELS = 25 * 1000 * 1000

def allowed_file(filename):
    return '.' in filename and \
           filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS

def is_valid_image_content(file_stream, ext):
    """Verifies that the file content matches the claimed image extension using magic bytes."""
    header = file_stream.read(12)
    file_stream.seek(0) # Reset stream position after reading

    if not header:
        return False

    if ext in ['jpg', 'jpeg']:
        return header.startswith(b'\xff\xd8\xff')
    elif ext == 'png':
        return header.startswith(b'\x89PNG\r\n\x1a\n')
    elif ext == 'webp':
        return header.startswith(b'RIFF') and header[8:12] == b'WEBP'

    return False

@app.errorhandler(413)
def request_entity_too_large(error):
    return jsonify({'error': 'File is too large. Maximum size is 16MB.'}), 413

# Ensure directories exist
os.makedirs(app.config['UPLOAD_FOLDER'], exist_ok=True)
os.makedirs(app.config['RESULTS_FOLDER'], exist_ok=True)

# Model paths
MODEL_PATH = os.path.join(app.config['MODEL_FOLDER'], 'ddcolor_modelserver.onnx')

session = None
input_name_cached = None
output_name_cached = None
model_lock = threading.Lock()
rate_limit_lock = threading.Lock()
cleanup_lock = threading.Lock()

# Simple in-memory rate limiting
upload_counts = defaultdict(list)
RATE_LIMIT_LIMIT = 10
RATE_LIMIT_WINDOW = 60 # seconds

def cleanup_old_files():
    """Deletes files in uploads and results folders older than 12 hours to prevent disk exhaustion."""
    with cleanup_lock:
        now = time.time()
        max_age = 12 * 3600 # 12 hours

        folders = [app.config['UPLOAD_FOLDER'], app.config['RESULTS_FOLDER']]
        for folder in folders:
            try:
                # Use os.scandir for better performance during directory traversal
                with os.scandir(folder) as it:
                    for entry in it:
                        if entry.name == '.gitkeep' or not entry.is_file():
                            continue
                        try:
                            # entry.stat() is often cached on modern OSs during scandir
                            if entry.stat().st_mtime < now - max_age:
                                os.remove(entry.path)
                                logger.info(f"Cleaned up old file: {entry.path}")
                        except Exception as e:
                            logger.error(f"Error cleaning up file {entry.path}: {e}")
            except Exception as e:
                logger.error(f"Error scanning folder {folder}: {e}")

def is_rate_limited(ip):
    """Checks if an IP is exceeding the upload rate limit."""
    now = time.time()

    with rate_limit_lock:
        # Get and filter timestamps for this IP
        timestamps = upload_counts.get(ip, [])
        if timestamps:
            timestamps = [t for t in timestamps if now - t < RATE_LIMIT_WINDOW]

        if len(timestamps) >= RATE_LIMIT_LIMIT:
            upload_counts[ip] = timestamps
            return True

        # Add current timestamp and update the record
        timestamps.append(now)
        upload_counts[ip] = timestamps

        # Periodically (1% of requests) clean up to prevent memory and disk exhaustion
        # We use standard random module here for efficiency
        if random.random() < 0.01:
            # 1. Clean up rate limiter memory
            expired_ips = [k for k, v in upload_counts.items()
                           if not [t for t in v if now - t < RATE_LIMIT_WINDOW]]
            for e_ip in expired_ips:
                del upload_counts[e_ip]

            # 2. Clean up old files from disk (in background)
            if not cleanup_lock.locked():
                threading.Thread(target=cleanup_old_files, daemon=True).start()

    return False

def load_model():
    global session, input_name_cached, output_name_cached
    if not os.path.exists(MODEL_PATH):
        logger.warning(f"MISSING MODEL FILE: {MODEL_PATH}")
        return False

    try:
        logger.info("Loading DDColor ONNX model...")

        # Configure session options for performance
        sess_options = ort.SessionOptions()
        sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
        # Enable memory pattern optimization and CPU memory arena for faster inference and lower fragmentation
        sess_options.enable_mem_pattern = True
        sess_options.enable_cpu_mem_arena = True

        # Use CPU provider for consistency on HF Spaces
        session = ort.InferenceSession(
            MODEL_PATH,
            sess_options=sess_options,
            providers=['CPUExecutionProvider']
        )
        # Cache input and output names to avoid redundant lookups during processing
        input_name_cached = session.get_inputs()[0].name
        output_name_cached = session.get_outputs()[0].name

        logger.info("Model loaded successfully.")
        return True
    except Exception as e:
        logger.error(f"Error loading model: {e}")
        return False

# Load model on startup
model_loaded = load_model()

@app.before_request
def generate_nonce():
    g.nonce = secrets.token_hex(16)

@app.route('/')
def index():
    return render_template('index.html', nonce=g.nonce)

@app.route('/upload', methods=['POST'])
def upload_file():
    global model_loaded

    # Rate limit by IP
    # Note: We use remote_addr because trusting X-Forwarded-For without
    # a configured proxy is a security risk (spoofing).
    client_ip = request.remote_addr
    if is_rate_limited(client_ip):
        return jsonify({'error': 'Rate limit exceeded. Please wait a minute.'}), 429

    if not model_loaded:
        with model_lock:
            # Check again inside the lock
            if not model_loaded:
                if not load_model():
                    return jsonify({'error': 'Server model is not ready.'}), 500
                model_loaded = True
        
    if 'file' not in request.files:
        return jsonify({'error': 'No file part'}), 400
    file = request.files['file']
    if file.filename == '':
        return jsonify({'error': 'No selected file'}), 400
        
    if file and allowed_file(file.filename):
        ext = file.filename.rsplit('.', 1)[1].lower()

        # Security enhancement: Verify magic bytes to prevent spoofed extensions
        if not is_valid_image_content(file.stream, ext):
            return jsonify({'error': 'Invalid image content for the given extension.'}), 400

        unique_id = uuid.uuid4().hex
        filename = f"{unique_id}.{ext}"
        filepath = os.path.join(app.config['UPLOAD_FOLDER'], filename)
        file.save(filepath)
        
        try:
            colorized_filename = process_image(filepath, filename)
            return jsonify({
                'original_url': f"/static/uploads/{filename}",
                'colorized_url': f"/static/results/{colorized_filename}",
                'colorized_filename': colorized_filename
            })
        except Exception as e:
            logger.error(f"Error processing image: {e}")
            # Clean up the uploaded file if processing fails
            if os.path.exists(filepath):
                os.remove(filepath)
            return jsonify({'error': 'An error occurred during image processing'}), 500
    else:
        return jsonify({'error': 'File type not allowed. Supported: PNG, JPG, JPEG, WEBP'}), 400

def process_image(path, filename):
    # Load image with OpenCV (BGR format)
    img_bgr = cv2.imread(path)
    if img_bgr is None:
        raise ValueError("Could not read image")

    orig_height, orig_width = img_bgr.shape[:2]

    # Check if image dimensions exceed the maximum allowed pixels
    if orig_height * orig_width > MAX_IMAGE_PIXELS:
        raise ValueError(f"Image dimensions too large: {orig_width}x{orig_height}")

    # Preprocess: Use cv2.dnn.blobFromImage to resize to 512x512, convert to RGB,
    # normalize to [0, 1], and change to [1, 3, H, W] format in one highly optimized step.
    # Benchmarking shows this is ~5x faster than separate resize + blob conversion for large images.
    img_input = cv2.dnn.blobFromImage(img_bgr, 1.0/255.0, (512, 512), (0, 0, 0), swapRB=True, crop=False)
    
    # Run inference using cached input/output names for better performance
    outputs = session.run([output_name_cached], {input_name_cached: img_input})
    
    # Postprocess: [1, 3, H, W] -> [H, W, 3]
    output = outputs[0][0]
    
    # Scale and clip in-place while channel-first (CHW), then use cv2.merge for fast RGB->BGR + CHW->HWC conversion.
    # This is ~3x faster than np.transpose followed by cv2.cvtColor.
    # Using np.multiply with out parameter ensures the operation is performed in-place.
    np.multiply(output, 255.0, out=output)
    np.clip(output, 0, 255, out=output)
    output_uint8 = output.astype(np.uint8)
    
    # Convert RGB to BGR and CHW to HWC using cv2.merge.
    # Doing this before upscaling is more efficient.
    res_img_bgr_small = cv2.merge([output_uint8[2], output_uint8[1], output_uint8[0]])

    # Resize back to original dimensions. INTER_LINEAR is ~55% faster than INTER_CUBIC
    # with negligible loss in quality for AI-generated results.
    if orig_width == 512 and orig_height == 512:
        res_img_bgr = res_img_bgr_small
    else:
        res_img_bgr = cv2.resize(res_img_bgr_small, (orig_width, orig_height), interpolation=cv2.INTER_LINEAR)
    
    result_filename = f"colorized_{filename}"
    result_path = os.path.join(app.config['RESULTS_FOLDER'], result_filename)

    # Optimization: Use lower JPEG quality and faster PNG compression to speed up encoding
    # and reduce file size.
    params = []
    if result_filename.lower().endswith(('.jpg', '.jpeg')):
        params = [cv2.IMWRITE_JPEG_QUALITY, 90]
    elif result_filename.lower().endswith('.png'):
        params = [cv2.IMWRITE_PNG_COMPRESSION, 1]

    cv2.imwrite(result_path, res_img_bgr, params)
    
    return result_filename

@app.route('/download/<filename>')
def download_file(filename):
    # Strict filename validation to prevent path traversal and unauthorized access.
    # Pattern: colorized_[32-char hex UUID].[extension]
    if not re.match(r'^colorized_[a-f0-9]{32}\.(png|jpg|jpeg|webp)$', filename):
        return jsonify({'error': 'Invalid filename format'}), 400
    return send_from_directory(app.config['RESULTS_FOLDER'], filename, as_attachment=True)

@app.after_request
def add_security_headers(response):
    response.headers['X-Content-Type-Options'] = 'nosniff'
    response.headers['X-Frame-Options'] = 'SAMEORIGIN'
    response.headers['X-XSS-Protection'] = '1; mode=block'
    response.headers['Referrer-Policy'] = 'strict-origin-when-cross-origin'
    response.headers['Permissions-Policy'] = 'camera=(), microphone=(), geolocation=(), usb=()'
    response.headers['Strict-Transport-Security'] = 'max-age=31536000; includeSubDomains'
    response.headers['Server'] = ''

    # Content Security Policy: restrict resources to trusted sources
    # object-src 'none', base-uri 'none', form-action 'self', and frame-ancestors 'none' added for hardening
    # Nonce added to script-src and style-src to allow legitimate inline scripts/styles while preventing XSS.
    # Note: 'unsafe-inline' is kept for style-src because the Tailwind Play CDN dynamically injects styles.
    csp = (
        "default-src 'none'; "
        f"script-src 'self' 'nonce-{g.nonce}' cdn.tailwindcss.com; "
        f"style-src 'self' 'unsafe-inline' 'nonce-{g.nonce}' cdn.tailwindcss.com fonts.googleapis.com; "
        "img-src 'self' data:; "
        "font-src 'self' fonts.gstatic.com; "
        "connect-src 'self'; "
        "object-src 'none'; "
        "base-uri 'none'; "
        "form-action 'self'; "
        "frame-ancestors 'none';"
    )
    response.headers['Content-Security-Policy'] = csp
    return response

if __name__ == "__main__":
    # Use the port Hugging Face or Railway provides, or default to 7860
    port = int(os.environ.get("PORT", 7860))
    app.run(host="0.0.0.0", port=port)