| from flask import Flask, request, jsonify, send_file, render_template_string |
| import os |
| import cv2 |
| from rembg import new_session, remove |
| from rembg.sessions import sessions_class |
| import base64 |
| import uuid |
| from flask_cors import CORS |
|
|
| app = Flask(__name__) |
| CORS(app) |
|
|
| |
| for session in sessions_class: |
| session.download_models() |
|
|
| def process_image(file_path, mask, model, x, y): |
| im = cv2.imread(file_path, cv2.IMREAD_COLOR) |
| input_path = f"temp_input_{uuid.uuid4().hex}.png" |
| output_path = f"temp_output_{uuid.uuid4().hex}.png" |
| cv2.imwrite(input_path, im) |
|
|
| with open(input_path, 'rb') as i: |
| with open(output_path, 'wb') as o: |
| input_data = i.read() |
| session = new_session(model) |
|
|
| output = remove( |
| input_data, |
| session=session, |
| **{"sam_prompt": [{"type": "point", "data": [x, y], "label": 1}]}, |
| only_mask=(mask == "Mask only") |
| ) |
| o.write(output) |
|
|
| |
| if os.path.exists(input_path): |
| os.remove(input_path) |
| |
| return output_path |
|
|
| @app.route('/api/process', methods=['POST']) |
| def api_process(): |
| if 'file' not in request.files: |
| return jsonify({'error': 'No file uploaded'}), 400 |
| |
| file = request.files['file'] |
| mask = request.form.get('mask', 'Default') |
| model = request.form.get('model', 'isnet-general-use') |
| x = request.form.get('x', None) |
| y = request.form.get('y', None) |
| |
| try: |
| x = float(x) if x is not None else None |
| y = float(y) if y is not None else None |
| except (TypeError, ValueError): |
| x = None |
| y = None |
| |
| |
| temp_input = f"temp_{uuid.uuid4().hex}.png" |
| file.save(temp_input) |
| |
| try: |
| output_path = process_image(temp_input, mask, model, x, y) |
| return send_file(output_path, mimetype='image/png') |
| except Exception as e: |
| return jsonify({'error': str(e)}), 500 |
| finally: |
| |
| if os.path.exists(temp_input): |
| os.remove(temp_input) |
|
|
| HTML_TEMPLATE = """ |
| <!DOCTYPE html> |
| <html> |
| <head> |
| <title>RemBG API</title> |
| <style> |
| body { |
| font-family: Arial, sans-serif; |
| max-width: 800px; |
| margin: 0 auto; |
| padding: 20px; |
| } |
| .container { |
| display: flex; |
| flex-direction: column; |
| gap: 20px; |
| } |
| .row { |
| display: flex; |
| gap: 20px; |
| } |
| .column { |
| flex: 1; |
| } |
| img { |
| max-width: 100%; |
| height: auto; |
| border: 1px solid #ddd; |
| } |
| .form-group { |
| margin-bottom: 15px; |
| } |
| label { |
| display: block; |
| margin-bottom: 5px; |
| font-weight: bold; |
| } |
| select, input, button { |
| width: 100%; |
| padding: 8px; |
| box-sizing: border-box; |
| } |
| button { |
| background-color: #4CAF50; |
| color: white; |
| border: none; |
| cursor: pointer; |
| padding: 10px; |
| } |
| button:hover { |
| background-color: #45a049; |
| } |
| #fetch-code { |
| width: 100%; |
| height: 150px; |
| font-family: monospace; |
| padding: 10px; |
| box-sizing: border-box; |
| background-color: #f5f5f5; |
| border: 1px solid #ddd; |
| } |
| .coords-input { |
| display: none; |
| } |
| </style> |
| </head> |
| <body> |
| <h1>RemBG API</h1> |
| <p>Upload an image to process with RemBG. Select options and click "Process Image".</p> |
| |
| <div class="container"> |
| <div class="row"> |
| <div class="column"> |
| <div class="form-group"> |
| <label for="file">Input Image:</label> |
| <input type="file" id="file" accept="image/*"> |
| </div> |
| <img id="input-image" src="" alt="Input image will appear here"> |
| </div> |
| <div class="column"> |
| <div class="form-group"> |
| <label>Output Image:</label> |
| <img id="output-image" src="" alt="Output image will appear here"> |
| </div> |
| </div> |
| </div> |
| |
| <div class="row"> |
| <div class="column"> |
| <div class="form-group"> |
| <label for="mask">Output Type:</label> |
| <select id="mask"> |
| <option value="Default">Default</option> |
| <option value="Mask only">Mask only</option> |
| </select> |
| </div> |
| </div> |
| <div class="column"> |
| <div class="form-group"> |
| <label for="model">Model Selection:</label> |
| <select id="model"> |
| <option value="u2net">u2net</option> |
| <option value="u2netp">u2netp</option> |
| <option value="u2net_human_seg">u2net_human_seg</option> |
| <option value="u2net_cloth_seg">u2net_cloth_seg</option> |
| <option value="silueta">silueta</option> |
| <option value="isnet-general-use" selected>isnet-general-use</option> |
| <option value="isnet-anime">isnet-anime</option> |
| <option value="sam">sam</option> |
| <option value="birefnet-general">birefnet-general</option> |
| <option value="birefnet-general-lite">birefnet-general-lite</option> |
| <option value="birefnet-portrait">birefnet-portrait</option> |
| <option value="birefnet-dis">birefnet-dis</option> |
| <option value="birefnet-hrsod">birefnet-hrsod</option> |
| <option value="birefnet-cod">birefnet-cod</option> |
| <option value="birefnet-massive">birefnet-massive</option> |
| </select> |
| </div> |
| </div> |
| </div> |
| |
| <div id="coords-section" style="display: none;"> |
| <h3>SAM Model Coordinates</h3> |
| <p>Click on the image to set coordinates (for SAM model only)</p> |
| <div class="row"> |
| <div class="column"> |
| <div class="form-group"> |
| <label for="x">X Coordinate:</label> |
| <input type="number" id="x" class="coords-input"> |
| </div> |
| </div> |
| <div class="column"> |
| <div class="form-group"> |
| <label for="y">Y Coordinate:</label> |
| <input type="number" id="y" class="coords-input"> |
| </div> |
| </div> |
| </div> |
| </div> |
| |
| <button id="process-btn">Process Image</button> |
| |
| <div class="form-group"> |
| <label for="fetch-code">Fetch Code:</label> |
| <textarea id="fetch-code" readonly></textarea> |
| </div> |
| </div> |
| |
| <script> |
| const fileInput = document.getElementById('file'); |
| const inputImage = document.getElementById('input-image'); |
| const outputImage = document.getElementById('output-image'); |
| const maskSelect = document.getElementById('mask'); |
| const modelSelect = document.getElementById('model'); |
| const xInput = document.getElementById('x'); |
| const yInput = document.getElementById('y'); |
| const coordsSection = document.getElementById('coords-section'); |
| const processBtn = document.getElementById('process-btn'); |
| const fetchCodeTextarea = document.getElementById('fetch-code'); |
| |
| // 画像プレビュー |
| fileInput.addEventListener('change', function(e) { |
| const file = e.target.files[0]; |
| if (file) { |
| const reader = new FileReader(); |
| reader.onload = function(event) { |
| inputImage.src = event.target.result; |
| updateFetchCode(); |
| }; |
| reader.readAsDataURL(file); |
| } |
| }); |
| |
| // モデル選択でSAMの場合は座標入力表示 |
| modelSelect.addEventListener('change', function() { |
| const isSam = modelSelect.value === 'sam'; |
| coordsSection.style.display = isSam ? 'block' : 'none'; |
| document.querySelectorAll('.coords-input').forEach(el => { |
| el.style.display = isSam ? 'block' : 'none'; |
| }); |
| updateFetchCode(); |
| }); |
| |
| // 画像クリックで座標取得 (SAMモデルのみ) |
| inputImage.addEventListener('click', function(e) { |
| if (modelSelect.value === 'sam') { |
| const rect = e.target.getBoundingClientRect(); |
| const x = e.clientX - rect.left; |
| const y = e.clientY - rect.top; |
| |
| xInput.value = Math.round(x); |
| yInput.value = Math.round(y); |
| updateFetchCode(); |
| } |
| }); |
| |
| // その他の入力変更時 |
| [maskSelect, xInput, yInput].forEach(el => { |
| el.addEventListener('change', updateFetchCode); |
| }); |
| |
| // 画像処理 |
| processBtn.addEventListener('click', async function() { |
| if (!fileInput.files || fileInput.files.length === 0) { |
| alert('Please select an image file'); |
| return; |
| } |
| |
| const formData = new FormData(); |
| formData.append('file', fileInput.files[0]); |
| formData.append('mask', maskSelect.value); |
| formData.append('model', modelSelect.value); |
| |
| if (modelSelect.value === 'sam' && xInput.value && yInput.value) { |
| formData.append('x', xInput.value); |
| formData.append('y', yInput.value); |
| } |
| |
| try { |
| const response = await fetch('/api/process', { |
| method: 'POST', |
| body: formData |
| }); |
| |
| if (!response.ok) { |
| const error = await response.json(); |
| throw new Error(error.error || 'Failed to process image'); |
| } |
| |
| const blob = await response.blob(); |
| outputImage.src = URL.createObjectURL(blob); |
| } catch (error) { |
| alert('Error: ' + error.message); |
| console.error(error); |
| } |
| }); |
| |
| // Fetchコード生成 |
| function updateFetchCode() { |
| const file = fileInput.files && fileInput.files[0]; |
| if (!file) { |
| fetchCodeTextarea.value = '// Select an image first'; |
| return; |
| } |
| |
| const mask = maskSelect.value; |
| const model = modelSelect.value; |
| const x = xInput.value; |
| const y = yInput.value; |
| |
| let code = `const formData = new FormData();\n`; |
| code += `formData.append('file', fileInput.files[0]);\n`; |
| code += `formData.append('mask', '${mask}');\n`; |
| code += `formData.append('model', '${model}');\n`; |
| |
| if (model === 'sam' && x && y) { |
| code += `formData.append('x', '${x}');\n`; |
| code += `formData.append('y', '${y}');\n`; |
| } |
| |
| code += `\n`; |
| code += `fetch('http://${window.location.host}/api/process', {\n`; |
| code += ` method: 'POST',\n`; |
| code += ` body: formData\n`; |
| code += `})\n`; |
| code += `.then(response => {\n`; |
| code += ` if (!response.ok) {\n`; |
| code += ` return response.json().then(err => { throw new Error(err.error); });\n`; |
| code += ` }\n`; |
| code += ` return response.blob();\n`; |
| code += `})\n`; |
| code += `.then(blob => {\n`; |
| code += ` // Handle the processed image blob\n`; |
| code += ` const imgUrl = URL.createObjectURL(blob);\n`; |
| code += ` document.getElementById('output-image').src = imgUrl;\n`; |
| code += `})\n`; |
| code += `.catch(error => {\n`; |
| code += ` console.error('Error:', error);\n`; |
| code += ` alert('Error: ' + error.message);\n`; |
| code += `});`; |
| |
| fetchCodeTextarea.value = code; |
| } |
| |
| // 初期化 |
| updateFetchCode(); |
| </script> |
| </body> |
| </html> |
| """ |
|
|
| @app.route('/') |
| def index(): |
| return render_template_string(HTML_TEMPLATE) |
|
|
| if __name__ == '__main__': |
| app.run(host='0.0.0.0', port=7860) |