soiz1 commited on
Commit
075fc85
·
verified ·
1 Parent(s): 9ce2677

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +345 -86
app.py CHANGED
@@ -1,109 +1,368 @@
1
- import gradio as gr
2
  import os
3
  import cv2
4
  from rembg import new_session, remove
5
  from rembg.sessions import sessions_class
 
 
6
 
7
- def inference(file, mask, model, x, y):
8
- im = cv2.imread(file, cv2.IMREAD_COLOR)
9
- input_path = "input.png"
10
- output_path = "output.png"
 
 
 
 
 
 
11
  cv2.imwrite(input_path, im)
12
 
13
  with open(input_path, 'rb') as i:
14
  with open(output_path, 'wb') as o:
15
- input = i.read()
16
  session = new_session(model)
17
 
18
  output = remove(
19
- input,
20
  session=session,
21
- **{ "sam_prompt": [{"type": "point", "data": [x, y], "label": 1}] },
22
  only_mask=(mask == "Mask only")
23
  )
24
  o.write(output)
25
 
 
 
 
 
26
  return output_path
27
 
28
- title = "RemBG"
29
- description = "Gradio demo for **[RemBG](https://github.com/danielgatis/rembg)**. To use it, simply upload your image, select a model, click Process, and wait."
30
- badge = """
31
- <div style="position: fixed; left: 50%; text-align: center;">
32
- <a href="https://github.com/danielgatis/rembg" target="_blank" style="text-decoration: none;">
33
- <img src="https://img.shields.io/badge/RemBG-Github-blue" alt="RemBG Github" />
34
- </a>
35
- </div>
36
- """
37
- def get_coords(evt: gr.SelectData) -> tuple:
38
- return evt.index[0], evt.index[1]
39
-
40
- def show_coords(model: str):
41
- visible = model == "sam"
42
- return gr.update(visible=visible), gr.update(visible=visible), gr.update(visible=visible)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
- for session in sessions_class:
45
- session.download_models()
46
-
47
- with gr.Blocks() as app:
48
- gr.Markdown(f"# {title}")
49
- gr.Markdown(description)
50
-
51
- with gr.Row():
52
- inputs = gr.Image(type="filepath", label="Input Image")
53
- outputs = gr.Image(type="filepath", label="Output Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
- with gr.Row():
56
- mask_option = gr.Radio(
57
- ["Default", "Mask only"],
58
- value="Default",
59
- label="Output Type"
60
- )
61
- model_selector = gr.Dropdown(
62
- [
63
- "u2net",
64
- "u2netp",
65
- "u2net_human_seg",
66
- "u2net_cloth_seg",
67
- "silueta",
68
- "isnet-general-use",
69
- "isnet-anime",
70
- "sam",
71
- "birefnet-general",
72
- "birefnet-general-lite",
73
- "birefnet-portrait",
74
- "birefnet-dis",
75
- "birefnet-hrsod",
76
- "birefnet-cod",
77
- "birefnet-massive"
78
- ],
79
- value="isnet-general-use",
80
- label="Model Selection"
81
- )
82
-
83
- extra = gr.Markdown("## Click on the image to capture coordinates (for SAM model)", visible=False)
84
-
85
- x = gr.Number(label="Mouse X Coordinate", visible=False)
86
- y = gr.Number(label="Mouse Y Coordinate", visible=False)
87
-
88
- model_selector.change(show_coords, inputs=model_selector, outputs=[x, y, extra])
89
- inputs.select(get_coords, None, [x, y])
90
-
91
-
92
- gr.Button("Process Image").click(
93
- inference,
94
- inputs=[inputs, mask_option, model_selector, x, y],
95
- outputs=outputs
96
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- gr.Examples(
99
- examples=[
100
- ["lion.png", "Default", "u2net", None, None],
101
- ["girl.jpg", "Default", "u2net", None, None],
102
- ["anime-girl.jpg", "Default", "isnet-anime", None, None]
103
- ],
104
- inputs=[inputs, mask_option, model_selector, x, y],
105
- outputs=outputs
106
- )
107
- gr.HTML(badge)
108
 
109
- app.launch(share=True)
 
 
1
+ from flask import Flask, request, jsonify, send_file, render_template_string
2
  import os
3
  import cv2
4
  from rembg import new_session, remove
5
  from rembg.sessions import sessions_class
6
+ import base64
7
+ import uuid
8
 
9
+ app = Flask(__name__)
10
+
11
+ # セッションの初期化
12
+ for session in sessions_class:
13
+ session.download_models()
14
+
15
+ def process_image(file_path, mask, model, x, y):
16
+ im = cv2.imread(file_path, cv2.IMREAD_COLOR)
17
+ input_path = f"temp_input_{uuid.uuid4().hex}.png"
18
+ output_path = f"temp_output_{uuid.uuid4().hex}.png"
19
  cv2.imwrite(input_path, im)
20
 
21
  with open(input_path, 'rb') as i:
22
  with open(output_path, 'wb') as o:
23
+ input_data = i.read()
24
  session = new_session(model)
25
 
26
  output = remove(
27
+ input_data,
28
  session=session,
29
+ **{"sam_prompt": [{"type": "point", "data": [x, y], "label": 1}]},
30
  only_mask=(mask == "Mask only")
31
  )
32
  o.write(output)
33
 
34
+ # 一時ファイルを削除
35
+ if os.path.exists(input_path):
36
+ os.remove(input_path)
37
+
38
  return output_path
39
 
40
+ @app.route('/api/process', methods=['POST'])
41
+ def api_process():
42
+ if 'file' not in request.files:
43
+ return jsonify({'error': 'No file uploaded'}), 400
44
+
45
+ file = request.files['file']
46
+ mask = request.form.get('mask', 'Default')
47
+ model = request.form.get('model', 'isnet-general-use')
48
+ x = request.form.get('x', None)
49
+ y = request.form.get('y', None)
50
+
51
+ try:
52
+ x = float(x) if x is not None else None
53
+ y = float(y) if y is not None else None
54
+ except (TypeError, ValueError):
55
+ x = None
56
+ y = None
57
+
58
+ # 一時ファイルに保存
59
+ temp_input = f"temp_{uuid.uuid4().hex}.png"
60
+ file.save(temp_input)
61
+
62
+ try:
63
+ output_path = process_image(temp_input, mask, model, x, y)
64
+ return send_file(output_path, mimetype='image/png')
65
+ except Exception as e:
66
+ return jsonify({'error': str(e)}), 500
67
+ finally:
68
+ # 一時ファイルを削除
69
+ if os.path.exists(temp_input):
70
+ os.remove(temp_input)
71
 
72
+ HTML_TEMPLATE = """
73
+ <!DOCTYPE html>
74
+ <html>
75
+ <head>
76
+ <title>RemBG API</title>
77
+ <style>
78
+ body {
79
+ font-family: Arial, sans-serif;
80
+ max-width: 800px;
81
+ margin: 0 auto;
82
+ padding: 20px;
83
+ }
84
+ .container {
85
+ display: flex;
86
+ flex-direction: column;
87
+ gap: 20px;
88
+ }
89
+ .row {
90
+ display: flex;
91
+ gap: 20px;
92
+ }
93
+ .column {
94
+ flex: 1;
95
+ }
96
+ img {
97
+ max-width: 100%;
98
+ height: auto;
99
+ border: 1px solid #ddd;
100
+ }
101
+ .form-group {
102
+ margin-bottom: 15px;
103
+ }
104
+ label {
105
+ display: block;
106
+ margin-bottom: 5px;
107
+ font-weight: bold;
108
+ }
109
+ select, input, button {
110
+ width: 100%;
111
+ padding: 8px;
112
+ box-sizing: border-box;
113
+ }
114
+ button {
115
+ background-color: #4CAF50;
116
+ color: white;
117
+ border: none;
118
+ cursor: pointer;
119
+ padding: 10px;
120
+ }
121
+ button:hover {
122
+ background-color: #45a049;
123
+ }
124
+ #fetch-code {
125
+ width: 100%;
126
+ height: 150px;
127
+ font-family: monospace;
128
+ padding: 10px;
129
+ box-sizing: border-box;
130
+ background-color: #f5f5f5;
131
+ border: 1px solid #ddd;
132
+ }
133
+ .coords-input {
134
+ display: none;
135
+ }
136
+ </style>
137
+ </head>
138
+ <body>
139
+ <h1>RemBG API</h1>
140
+ <p>Upload an image to process with RemBG. Select options and click "Process Image".</p>
141
+
142
+ <div class="container">
143
+ <div class="row">
144
+ <div class="column">
145
+ <div class="form-group">
146
+ <label for="file">Input Image:</label>
147
+ <input type="file" id="file" accept="image/*">
148
+ </div>
149
+ <img id="input-image" src="" alt="Input image will appear here">
150
+ </div>
151
+ <div class="column">
152
+ <div class="form-group">
153
+ <label>Output Image:</label>
154
+ <img id="output-image" src="" alt="Output image will appear here">
155
+ </div>
156
+ </div>
157
+ </div>
158
 
159
+ <div class="row">
160
+ <div class="column">
161
+ <div class="form-group">
162
+ <label for="mask">Output Type:</label>
163
+ <select id="mask">
164
+ <option value="Default">Default</option>
165
+ <option value="Mask only">Mask only</option>
166
+ </select>
167
+ </div>
168
+ </div>
169
+ <div class="column">
170
+ <div class="form-group">
171
+ <label for="model">Model Selection:</label>
172
+ <select id="model">
173
+ <option value="u2net">u2net</option>
174
+ <option value="u2netp">u2netp</option>
175
+ <option value="u2net_human_seg">u2net_human_seg</option>
176
+ <option value="u2net_cloth_seg">u2net_cloth_seg</option>
177
+ <option value="silueta">silueta</option>
178
+ <option value="isnet-general-use" selected>isnet-general-use</option>
179
+ <option value="isnet-anime">isnet-anime</option>
180
+ <option value="sam">sam</option>
181
+ <option value="birefnet-general">birefnet-general</option>
182
+ <option value="birefnet-general-lite">birefnet-general-lite</option>
183
+ <option value="birefnet-portrait">birefnet-portrait</option>
184
+ <option value="birefnet-dis">birefnet-dis</option>
185
+ <option value="birefnet-hrsod">birefnet-hrsod</option>
186
+ <option value="birefnet-cod">birefnet-cod</option>
187
+ <option value="birefnet-massive">birefnet-massive</option>
188
+ </select>
189
+ </div>
190
+ </div>
191
+ </div>
192
+
193
+ <div id="coords-section" style="display: none;">
194
+ <h3>SAM Model Coordinates</h3>
195
+ <p>Click on the image to set coordinates (for SAM model only)</p>
196
+ <div class="row">
197
+ <div class="column">
198
+ <div class="form-group">
199
+ <label for="x">X Coordinate:</label>
200
+ <input type="number" id="x" class="coords-input">
201
+ </div>
202
+ </div>
203
+ <div class="column">
204
+ <div class="form-group">
205
+ <label for="y">Y Coordinate:</label>
206
+ <input type="number" id="y" class="coords-input">
207
+ </div>
208
+ </div>
209
+ </div>
210
+ </div>
211
+
212
+ <button id="process-btn">Process Image</button>
213
+
214
+ <div class="form-group">
215
+ <label for="fetch-code">Fetch Code:</label>
216
+ <textarea id="fetch-code" readonly></textarea>
217
+ </div>
218
+ </div>
219
+
220
+ <script>
221
+ const fileInput = document.getElementById('file');
222
+ const inputImage = document.getElementById('input-image');
223
+ const outputImage = document.getElementById('output-image');
224
+ const maskSelect = document.getElementById('mask');
225
+ const modelSelect = document.getElementById('model');
226
+ const xInput = document.getElementById('x');
227
+ const yInput = document.getElementById('y');
228
+ const coordsSection = document.getElementById('coords-section');
229
+ const processBtn = document.getElementById('process-btn');
230
+ const fetchCodeTextarea = document.getElementById('fetch-code');
231
+
232
+ // 画像プレビュー
233
+ fileInput.addEventListener('change', function(e) {
234
+ const file = e.target.files[0];
235
+ if (file) {
236
+ const reader = new FileReader();
237
+ reader.onload = function(event) {
238
+ inputImage.src = event.target.result;
239
+ updateFetchCode();
240
+ };
241
+ reader.readAsDataURL(file);
242
+ }
243
+ });
244
+
245
+ // モデル選択でSAMの場合は座標入力表示
246
+ modelSelect.addEventListener('change', function() {
247
+ const isSam = modelSelect.value === 'sam';
248
+ coordsSection.style.display = isSam ? 'block' : 'none';
249
+ document.querySelectorAll('.coords-input').forEach(el => {
250
+ el.style.display = isSam ? 'block' : 'none';
251
+ });
252
+ updateFetchCode();
253
+ });
254
+
255
+ // 画像クリックで座標取得 (SAMモデルのみ)
256
+ inputImage.addEventListener('click', function(e) {
257
+ if (modelSelect.value === 'sam') {
258
+ const rect = e.target.getBoundingClientRect();
259
+ const x = e.clientX - rect.left;
260
+ const y = e.clientY - rect.top;
261
+
262
+ xInput.value = Math.round(x);
263
+ yInput.value = Math.round(y);
264
+ updateFetchCode();
265
+ }
266
+ });
267
+
268
+ // その他の入力変更時
269
+ [maskSelect, xInput, yInput].forEach(el => {
270
+ el.addEventListener('change', updateFetchCode);
271
+ });
272
+
273
+ // 画像処理
274
+ processBtn.addEventListener('click', async function() {
275
+ if (!fileInput.files || fileInput.files.length === 0) {
276
+ alert('Please select an image file');
277
+ return;
278
+ }
279
+
280
+ const formData = new FormData();
281
+ formData.append('file', fileInput.files[0]);
282
+ formData.append('mask', maskSelect.value);
283
+ formData.append('model', modelSelect.value);
284
+
285
+ if (modelSelect.value === 'sam' && xInput.value && yInput.value) {
286
+ formData.append('x', xInput.value);
287
+ formData.append('y', yInput.value);
288
+ }
289
+
290
+ try {
291
+ const response = await fetch('/api/process', {
292
+ method: 'POST',
293
+ body: formData
294
+ });
295
+
296
+ if (!response.ok) {
297
+ const error = await response.json();
298
+ throw new Error(error.error || 'Failed to process image');
299
+ }
300
+
301
+ const blob = await response.blob();
302
+ outputImage.src = URL.createObjectURL(blob);
303
+ } catch (error) {
304
+ alert('Error: ' + error.message);
305
+ console.error(error);
306
+ }
307
+ });
308
+
309
+ // Fetchコード生成
310
+ function updateFetchCode() {
311
+ const file = fileInput.files && fileInput.files[0];
312
+ if (!file) {
313
+ fetchCodeTextarea.value = '// Select an image first';
314
+ return;
315
+ }
316
+
317
+ const mask = maskSelect.value;
318
+ const model = modelSelect.value;
319
+ const x = xInput.value;
320
+ const y = yInput.value;
321
+
322
+ let code = `const formData = new FormData();\n`;
323
+ code += `formData.append('file', fileInput.files[0]);\n`;
324
+ code += `formData.append('mask', '${mask}');\n`;
325
+ code += `formData.append('model', '${model}');\n`;
326
+
327
+ if (model === 'sam' && x && y) {
328
+ code += `formData.append('x', '${x}');\n`;
329
+ code += `formData.append('y', '${y}');\n`;
330
+ }
331
+
332
+ code += `\n`;
333
+ code += `fetch('http://${window.location.host}/api/process', {\n`;
334
+ code += ` method: 'POST',\n`;
335
+ code += ` body: formData\n`;
336
+ code += `})\n`;
337
+ code += `.then(response => {\n`;
338
+ code += ` if (!response.ok) {\n`;
339
+ code += ` return response.json().then(err => { throw new Error(err.error); });\n`;
340
+ code += ` }\n`;
341
+ code += ` return response.blob();\n`;
342
+ code += `})\n`;
343
+ code += `.then(blob => {\n`;
344
+ code += ` // Handle the processed image blob\n`;
345
+ code += ` const imgUrl = URL.createObjectURL(blob);\n`;
346
+ code += ` document.getElementById('output-image').src = imgUrl;\n`;
347
+ code += `})\n`;
348
+ code += `.catch(error => {\n`;
349
+ code += ` console.error('Error:', error);\n`;
350
+ code += ` alert('Error: ' + error.message);\n`;
351
+ code += `});`;
352
+
353
+ fetchCodeTextarea.value = code;
354
+ }
355
+
356
+ // 初期化
357
+ updateFetchCode();
358
+ </script>
359
+ </body>
360
+ </html>
361
+ """
362
 
363
+ @app.route('/')
364
+ def index():
365
+ return render_template_string(HTML_TEMPLATE)
 
 
 
 
 
 
 
366
 
367
+ if __name__ == '__main__':
368
+ app.run(host='0.0.0.0', port=7860)