Spaces:
Runtime error
Runtime error
| import os | |
| import cv2 | |
| import torch | |
| from flask import Flask, request, jsonify, send_file, render_template_string | |
| from basicsr.archs.srvgg_arch import SRVGGNetCompact | |
| from gfpgan.utils import GFPGANer | |
| from realesrgan.utils import RealESRGANer | |
| import tempfile | |
| import uuid | |
| app = Flask(__name__) | |
| # Initialize models | |
| model = SRVGGNetCompact(num_in_ch=3, num_out_ch=3, num_feat=64, num_conv=32, upscale=4, act_type='prelu') | |
| model_path = 'realesr-general-x4v3.pth' | |
| half = True if torch.cuda.is_available() else False | |
| upsampler = RealESRGANer(scale=4, model_path=model_path, model=model, tile=0, tile_pad=10, pre_pad=0, half=half) | |
| # Ensure output directory exists | |
| os.makedirs('output', exist_ok=True) | |
| # Download weights if not exists | |
| def download_weights(): | |
| weights = { | |
| 'realesr-general-x4v3.pth': 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth', | |
| 'GFPGANv1.2.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.2.pth', | |
| 'GFPGANv1.3.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth', | |
| 'GFPGANv1.4.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', | |
| 'RestoreFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/RestoreFormer.pth', | |
| 'CodeFormer.pth': 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.4/CodeFormer.pth' | |
| } | |
| for weight_file, url in weights.items(): | |
| if not os.path.exists(weight_file): | |
| os.system(f"wget {url} -O {weight_file}") | |
| download_weights() | |
| def process_image(img_path, version, scale, weight=0.5): | |
| try: | |
| extension = os.path.splitext(os.path.basename(str(img_path)))[1] | |
| img = cv2.imread(img_path, cv2.IMREAD_UNCHANGED) | |
| if len(img.shape) == 3 and img.shape[2] == 4: | |
| img_mode = 'RGBA' | |
| elif len(img.shape) == 2: | |
| img_mode = None | |
| img = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) | |
| else: | |
| img_mode = None | |
| h, w = img.shape[0:2] | |
| if h < 300: | |
| img = cv2.resize(img, (w * 2, h * 2), interpolation=cv2.INTER_LANCZOS4) | |
| if version == 'v1.2': | |
| face_enhancer = GFPGANer( | |
| model_path='GFPGANv1.2.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'v1.3': | |
| face_enhancer = GFPGANer( | |
| model_path='GFPGANv1.3.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'v1.4': | |
| face_enhancer = GFPGANer( | |
| model_path='GFPGANv1.4.pth', upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'RestoreFormer': | |
| face_enhancer = GFPGANer( | |
| model_path='RestoreFormer.pth', upscale=2, arch='RestoreFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'CodeFormer': | |
| face_enhancer = GFPGANer( | |
| model_path='CodeFormer.pth', upscale=2, arch='CodeFormer', channel_multiplier=2, bg_upsampler=upsampler) | |
| elif version == 'RealESR-General-x4v3': | |
| face_enhancer = GFPGANer( | |
| model_path='realesr-general-x4v3.pth', upscale=2, arch='realesr-general', channel_multiplier=2, bg_upsampler=upsampler) | |
| try: | |
| if version == 'CodeFormer': | |
| _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True, weight=weight) | |
| else: | |
| _, _, output = face_enhancer.enhance(img, has_aligned=False, only_center_face=False, paste_back=True) | |
| except RuntimeError as error: | |
| print('Error', error) | |
| raise Exception(f"Enhancement error: {str(error)}") | |
| try: | |
| if scale != 2: | |
| interpolation = cv2.INTER_AREA if scale < 2 else cv2.INTER_LANCZOS4 | |
| h, w = img.shape[0:2] | |
| output = cv2.resize(output, (int(w * scale / 2), int(h * scale / 2)), interpolation=interpolation) | |
| except Exception as error: | |
| print('wrong scale input.', error) | |
| # Save to temporary file | |
| output_filename = f"output_{uuid.uuid4().hex}.jpg" | |
| output_path = os.path.join('output', output_filename) | |
| if img_mode == 'RGBA': | |
| cv2.imwrite(output_path, output, [int(cv2.IMWRITE_PNG_COMPRESSION), 9]) | |
| else: | |
| cv2.imwrite(output_path, output, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) | |
| return output_path | |
| except Exception as error: | |
| print('Global exception', error) | |
| raise Exception(f"Processing error: {str(error)}") | |
| def index(): | |
| return render_template_string(''' | |
| <!DOCTYPE html> | |
| <html> | |
| <head> | |
| <title>Image Upscaling & Restoration API</title> | |
| <style> | |
| body { font-family: Arial, sans-serif; max-width: 800px; margin: 0 auto; padding: 20px; } | |
| .container { border: 1px solid #ddd; padding: 20px; border-radius: 5px; } | |
| .form-group { margin-bottom: 15px; } | |
| label { display: block; margin-bottom: 5px; } | |
| input, select { width: 100%; padding: 8px; box-sizing: border-box; } | |
| button { background-color: #4CAF50; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; } | |
| button:hover { background-color: #45a049; } | |
| #result { margin-top: 20px; } | |
| #preview { max-width: 100%; margin-top: 10px; } | |
| #apiUsage { background-color: #f5f5f5; padding: 15px; border-radius: 5px; margin-top: 20px; font-family: monospace; white-space: pre-wrap; } | |
| #apiUsage h3 { margin-top: 0; } | |
| #formDataPreview { max-height: 200px; overflow-y: auto; margin-bottom: 10px; } | |
| .code-block { background-color: #f8f8f8; padding: 10px; border-radius: 4px; border-left: 3px solid #4CAF50; } | |
| .comment { color: #666; font-style: italic; } | |
| .loader { | |
| width: 48px; | |
| height: 48px; | |
| border: 5px solid #4CAF50; | |
| border-bottom-color: transparent; | |
| border-radius: 50%; | |
| display: inline-block; | |
| box-sizing: border-box; | |
| animation: rotation 1s linear infinite; | |
| margin: 20px auto; | |
| display: none; /* 初期状態では非表示 */ | |
| } | |
| @keyframes rotation { | |
| 0% { | |
| transform: rotate(0deg); | |
| } | |
| 100% { | |
| transform: rotate(360deg); | |
| } | |
| } | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Image Upscaling & Restoration API</h1> | |
| <div class="container"> | |
| <form id="uploadForm" enctype="multipart/form-data"> | |
| <div class="form-group"> | |
| <label for="file">Upload Image:</label> | |
| <input type="file" id="file" name="file" required> | |
| </div> | |
| <div class="form-group"> | |
| <label for="version">Version:</label> | |
| <select id="version" name="version"> | |
| <option value="v1.2">GFPGANv1.2</option> | |
| <option value="v1.3">GFPGANv1.3</option> | |
| <option value="v1.4" selected>GFPGANv1.4</option> | |
| <option value="RestoreFormer">RestoreFormer</option> | |
| <option value="CodeFormer">CodeFormer</option> | |
| <option value="RealESR-General-x4v3">RealESR-General-x4v3</option> | |
| </select> | |
| </div> | |
| <div class="form-group"> | |
| <label for="scale">Rescaling factor:</label> | |
| <input type="number" id="scale" name="scale" value="2" step="0.1" min="1" max="4" required> | |
| </div> | |
| <div class="form-group" id="weightGroup" style="display: none;"> | |
| <label for="weight">CodeFormer Weight (0-1):</label> | |
| <input type="number" id="weight" name="weight" value="0.5" step="0.1" min="0" max="1"> | |
| </div> | |
| <button type="submit" id="submitButton">Process Image</button> | |
| </form> | |
| <div id="loading" class="loader"></div> | |
| <div id="result"> | |
| <h3>Result:</h3> | |
| <div id="outputContainer" style="display: none;"> | |
| <img id="preview" src="" alt="Processed Image"> | |
| <a id="downloadLink" href="#" download>Download Image</a> | |
| </div> | |
| </div> | |
| <div id="apiUsage"> | |
| <h3>API Usage:</h3> | |
| <div id="fetchCode" class="code-block"> | |
| // JavaScript fetch code will appear here | |
| </div> | |
| </div> | |
| </div> | |
| <script> | |
| // CodeFormerが選択された時にweightパラメータを表示 | |
| document.getElementById('version').addEventListener('change', function() { | |
| const weightGroup = document.getElementById('weightGroup'); | |
| if (this.value === 'CodeFormer') { | |
| weightGroup.style.display = 'block'; | |
| } else { | |
| weightGroup.style.display = 'none'; | |
| } | |
| updateApiUsage(); | |
| }); | |
| // フォームの変更を監視してAPI使用例を更新 | |
| function updateApiUsage() { | |
| const fileInput = document.getElementById('file'); | |
| const version = document.getElementById('version').value; | |
| const scale = document.getElementById('scale').value; | |
| const weight = document.getElementById('weight').value; | |
| // 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない) | |
| const baseUrl = window.location.origin; | |
| const apiUrl = baseUrl + '/api/restore'; | |
| // ファイルのプレビュー用文字列を準備 | |
| let filePreview = '"img-dataURL"'; | |
| if (fileInput.files.length > 0) { | |
| const file = fileInput.files[0]; | |
| const reader = new FileReader(); | |
| reader.onload = function(e) { | |
| const dataURL = e.target.result; | |
| if (dataURL.length > 40) { | |
| filePreview = `"${dataURL.substring(0, 20)}...${dataURL.substring(dataURL.length - 20)}"`; | |
| } else { | |
| filePreview = `"${dataURL}"`; | |
| } | |
| updateFetchCode(apiUrl, version, scale, weight, filePreview); | |
| }; | |
| reader.readAsDataURL(file); | |
| } else { | |
| updateFetchCode(apiUrl, version, scale, weight, filePreview); | |
| } | |
| } | |
| function updateFetchCode(apiUrl, version, scale, weight, filePreview) { | |
| const fetchCodeDiv = document.getElementById('fetchCode'); | |
| let code = `// JavaScript fetch example: | |
| const formData = new FormData(); | |
| formData.append('file', ${filePreview}); | |
| formData.append('version', '${version}'); | |
| formData.append('scale', ${scale});`; | |
| if (version === 'CodeFormer') { | |
| code += ` | |
| formData.append('weight', ${weight});`; | |
| } | |
| code += ` | |
| fetch('${apiUrl}', { | |
| method: 'POST', | |
| body: formData | |
| }) | |
| .then(response => { | |
| if (!response.ok) { | |
| return response.json().then(err => { throw new Error(err.error || 'Unknown error'); }); | |
| } | |
| return response.blob(); | |
| }) | |
| .then(blob => { | |
| // Process the returned image blob | |
| const url = URL.createObjectURL(blob); | |
| console.log('Image processed successfully', url); | |
| // Example: document.getElementById('resultImage').src = url; | |
| }) | |
| .catch(error => { | |
| console.error('Error:', error.message); | |
| });`; | |
| fetchCodeDiv.innerHTML = code; | |
| } | |
| // フォーム要素の変更を監視 | |
| document.getElementById('file').addEventListener('change', updateApiUsage); | |
| document.getElementById('version').addEventListener('change', updateApiUsage); | |
| document.getElementById('scale').addEventListener('input', updateApiUsage); | |
| document.getElementById('weight').addEventListener('input', updateApiUsage); | |
| // 初期表示 | |
| updateApiUsage(); | |
| document.getElementById('uploadForm').addEventListener('submit', function(e) { | |
| e.preventDefault(); | |
| // ボタンを無効化し、ローディングを表示 | |
| const submitButton = document.getElementById('submitButton'); | |
| const loadingElement = document.getElementById('loading'); | |
| submitButton.disabled = true; | |
| loadingElement.style.display = 'block'; | |
| const formData = new FormData(); | |
| formData.append('file', document.getElementById('file').files[0]); | |
| formData.append('version', document.getElementById('version').value); | |
| formData.append('scale', document.getElementById('scale').value); | |
| // CodeFormerが選択されている場合はweightも追加 | |
| if (document.getElementById('version').value === 'CodeFormer') { | |
| formData.append('weight', document.getElementById('weight').value); | |
| } | |
| // 現在のURLからベースURLを取得(パス、パラメータ、ハッシュを含めない) | |
| const baseUrl = window.location.origin; | |
| const apiUrl = baseUrl + '/api/restore'; | |
| fetch(apiUrl, { | |
| method: 'POST', | |
| body: formData | |
| }) | |
| .then(response => { | |
| if (!response.ok) { | |
| return response.json().then(err => { throw new Error(err.error || 'Unknown error'); }); | |
| } | |
| return response.blob(); | |
| }) | |
| .then(blob => { | |
| const url = URL.createObjectURL(blob); | |
| const preview = document.getElementById('preview'); | |
| const downloadLink = document.getElementById('downloadLink'); | |
| const outputContainer = document.getElementById('outputContainer'); | |
| preview.src = url; | |
| downloadLink.href = url; | |
| downloadLink.download = 'restored_' + document.getElementById('file').files[0].name; | |
| outputContainer.style.display = 'block'; | |
| }) | |
| .catch(error => { | |
| alert('Error: ' + error.message); | |
| }) | |
| .finally(() => { | |
| // 処理が終わったらローディングを非表示にし、ボタンを再有効化 | |
| loadingElement.style.display = 'none'; | |
| submitButton.disabled = false; | |
| }); | |
| }); | |
| </script> | |
| </body> | |
| </html> | |
| ''') | |
| def api_restore(): | |
| if 'file' not in request.files: | |
| return jsonify({'error': 'No file uploaded'}), 400 | |
| file = request.files['file'] | |
| version = request.form.get('version', 'v1.4') | |
| scale = float(request.form.get('scale', 2)) | |
| weight = float(request.form.get('weight', 0.5)) if version == 'CodeFormer' else None | |
| if file.filename == '': | |
| return jsonify({'error': 'No selected file'}), 400 | |
| try: | |
| # Save uploaded file to temp location | |
| temp_dir = tempfile.mkdtemp() | |
| input_path = os.path.join(temp_dir, file.filename) | |
| file.save(input_path) | |
| # Process image | |
| output_path = process_image(input_path, version, scale, weight) | |
| # Return the processed image | |
| return send_file(output_path, mimetype='image/jpeg') | |
| except Exception as e: | |
| return jsonify({'error': str(e)}), 500 | |
| finally: | |
| # Clean up temp files | |
| if 'input_path' in locals() and os.path.exists(input_path): | |
| os.remove(input_path) | |
| if 'temp_dir' in locals() and os.path.exists(temp_dir): | |
| os.rmdir(temp_dir) | |
| if __name__ == '__main__': | |
| app.run(host='0.0.0.0', port=7860, debug=True) |