Spaces:
Sleeping
Sleeping
| from flask import Flask, request, send_file, jsonify | |
| from flask_cors import CORS | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import torch | |
| from basicsr.archs.rrdbnet_arch import RRDBNet | |
| from realesrgan import RealESRGANer | |
| import time | |
| import os | |
| app = Flask(__name__) | |
| CORS(app) | |
| # Global variables | |
| upsampler = None | |
| # Initialize Real-ESRGAN upsampler at startup | |
| def initialize_enhancer(): | |
| global upsampler | |
| try: | |
| print("Starting model initialization...") | |
| # Configuration for RealESRGAN x4v3 | |
| model = RRDBNet( | |
| num_in_ch=3, | |
| num_out_ch=3, | |
| num_feat=64, | |
| num_block=6, # Critical parameter for x4v3 model | |
| num_grow_ch=32, | |
| scale=4 | |
| ) | |
| # Force CPU usage for Hugging Face compatibility | |
| device = torch.device('cpu') | |
| # Check if model weights file exists | |
| weights_path = 'weights/realesr-general-x4v3.pth' | |
| if not os.path.exists(weights_path): | |
| print(f"Model weights not found at {weights_path}") | |
| # Create the directory if it doesn't exist | |
| os.makedirs('weights', exist_ok=True) | |
| # You'd normally download weights here, but for this example | |
| # we'll just use a placeholder | |
| print("ERROR: Model weights file not found!") | |
| return False | |
| # Initialize the upsampler | |
| upsampler = RealESRGANer( | |
| scale=4, | |
| model_path=weights_path, | |
| model=model, | |
| tile=0, # Set to 0 for small images, increase for large images | |
| tile_pad=10, | |
| pre_pad=0, | |
| half=False, # CPU doesn't support half precision | |
| device=device | |
| ) | |
| # Run a small test to ensure everything is loaded | |
| test_img = np.zeros((64, 64, 3), dtype=np.uint8) | |
| upsampler.enhance(test_img, outscale=4, alpha_upsampler='realesrgan') | |
| print("Model initialization completed successfully") | |
| return True | |
| except Exception as e: | |
| error_msg = f"Model initialization failed: {str(e)}" | |
| print(error_msg) | |
| return False | |
| # Global init flag to track if we've attempted initialization | |
| init_attempted = False | |
| # Initialize model immediately at startup - this blocks until model is ready | |
| if not initialize_enhancer(): | |
| print("ERROR: Model failed to initialize. Server will continue running but enhancement won't work.") | |
| # We'll keep running but mark that we attempted initialization | |
| init_attempted = True | |
| def enhance_image(): | |
| global upsampler | |
| # Check if upsampler is ready | |
| if upsampler is None: | |
| return jsonify({'error': 'Enhancement model is not initialized.'}), 500 | |
| # Check if file was uploaded | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| try: | |
| # Read and validate image | |
| file = request.files['file'] | |
| if file.filename == '': | |
| return jsonify({'error': 'Empty file submitted'}), 400 | |
| # Process the image | |
| img = Image.open(file.stream).convert('RGB') | |
| img_array = np.array(img) | |
| # Check image size | |
| h, w = img_array.shape[:2] | |
| if h > 2000 or w > 2000: | |
| return jsonify({'error': 'Image too large. Maximum size is 2000x2000 pixels'}), 400 | |
| # Enhance image | |
| output, _ = upsampler.enhance( | |
| img_array, | |
| outscale=4, # 4x super-resolution | |
| alpha_upsampler='realesrgan' | |
| ) | |
| # Convert to JPEG bytes | |
| img_byte_arr = io.BytesIO() | |
| Image.fromarray(output).save(img_byte_arr, format='JPEG', quality=95) | |
| img_byte_arr.seek(0) | |
| return send_file(img_byte_arr, mimetype='image/jpeg') | |
| except Exception as e: | |
| return jsonify({'error': f'Processing error: {str(e)}'}), 500 | |
| def health_check(): | |
| global upsampler, init_attempted | |
| # If we have the upsampler, we're ready | |
| if upsampler is not None: | |
| status = 'ready' | |
| # If we tried to initialize and failed, report failure | |
| elif init_attempted: | |
| status = 'failed' | |
| # Otherwise we're still in an unknown state | |
| else: | |
| status = 'initializing' | |
| status_info = { | |
| 'status': status, | |
| 'timestamp': time.time() | |
| } | |
| return jsonify(status_info) | |
| def home(): | |
| global upsampler | |
| return jsonify({ | |
| 'message': 'Image Enhancement API', | |
| 'endpoints': { | |
| 'POST /enhance': 'Process images (4x upscale)', | |
| 'GET /health': 'Service status check' | |
| }, | |
| 'status': 'ready' if upsampler is not None else 'not ready' | |
| }) | |
| if __name__ == '__main__': | |
| # Start the Flask app | |
| app.run(host='0.0.0.0', port=5000) |