hann1010 commited on
Commit
8b5a088
·
verified ·
1 Parent(s): 25e98dd

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +466 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,466 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from PIL import Image
4
+ from torchvision import transforms
5
+ from transformers import AutoModelForImageSegmentation
6
+ import json
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+ import io
11
+ import base64
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ # ========== KONFIGURASI ==========
17
+ MODEL_ID = "mohantesting/remove_background"
18
+ MODEL_PATH = "./models/remove_background"
19
+ # ==================================
20
+
21
+ model = None
22
+ device = None
23
+ transform_image = None
24
+
25
+ def check_model_exists(path):
26
+ """Cek apakah model sudah ada"""
27
+ if not os.path.exists(path):
28
+ return False
29
+
30
+ required_files = ["config.json"]
31
+ for file in required_files:
32
+ if not os.path.exists(os.path.join(path, file)):
33
+ return False
34
+
35
+ has_weights = False
36
+ for root, dirs, files in os.walk(path):
37
+ for file in files:
38
+ if file.endswith((".bin", ".safetensors")):
39
+ has_weights = True
40
+ break
41
+ if has_weights:
42
+ break
43
+
44
+ return has_weights
45
+
46
+ def get_folder_size(folder_path):
47
+ """Hitung total ukuran folder"""
48
+ total_size = 0
49
+ for dirpath, dirnames, filenames in os.walk(folder_path):
50
+ for filename in filenames:
51
+ filepath = os.path.join(dirpath, filename)
52
+ if os.path.isfile(filepath):
53
+ total_size += os.path.getsize(filepath)
54
+ return total_size
55
+
56
+ def download_model():
57
+ """Download model jika belum ada"""
58
+ logger.info("="*60)
59
+ logger.info("CHECKING BACKGROUND REMOVAL MODEL...")
60
+ logger.info("="*60)
61
+
62
+ if check_model_exists(MODEL_PATH):
63
+ logger.info("✓ Model sudah ada di local!")
64
+ logger.info(f"✓ Location: {MODEL_PATH}")
65
+
66
+ size_bytes = get_folder_size(MODEL_PATH)
67
+ size_mb = size_bytes / (1024 * 1024)
68
+ logger.info(f"✓ Size: {size_mb:.2f} MB")
69
+ logger.info("✓ Skipping download...\n")
70
+ return True
71
+
72
+ logger.info("✗ Model tidak ditemukan. Mulai download...")
73
+ logger.info(f"Model ID: {MODEL_ID}")
74
+ logger.info(f"Save to: {MODEL_PATH}")
75
+ logger.info("-" * 60)
76
+
77
+ try:
78
+ os.makedirs(MODEL_PATH, exist_ok=True)
79
+
80
+ logger.info("Downloading background removal model...")
81
+ model_download = AutoModelForImageSegmentation.from_pretrained(
82
+ MODEL_ID,
83
+ trust_remote_code=True
84
+ )
85
+ model_download.save_pretrained(MODEL_PATH)
86
+
87
+ logger.info("✓ Model downloaded\n")
88
+
89
+ size_bytes = get_folder_size(MODEL_PATH)
90
+ size_mb = size_bytes / (1024 * 1024)
91
+ logger.info(f"✓ Total size: {size_mb:.2f} MB")
92
+ logger.info(f"✓ Model saved at: {MODEL_PATH}\n")
93
+
94
+ del model_download
95
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
96
+
97
+ return True
98
+
99
+ except Exception as e:
100
+ logger.error(f"✗ Error downloading model: {str(e)}")
101
+ import traceback
102
+ traceback.print_exc()
103
+ return False
104
+
105
+ def load_model():
106
+ """Load model ke memory"""
107
+ global model, device, transform_image
108
+
109
+ logger.info("="*60)
110
+ logger.info("LOADING MODEL INTO MEMORY...")
111
+ logger.info("="*60)
112
+
113
+ try:
114
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
115
+ logger.info(f"Device: {device}")
116
+
117
+ logger.info("Loading model from local...")
118
+ model = AutoModelForImageSegmentation.from_pretrained(
119
+ MODEL_PATH,
120
+ trust_remote_code=True
121
+ ).eval().to(device)
122
+
123
+ # Setup transform
124
+ image_size = (1024, 1024)
125
+ transform_image = transforms.Compose([
126
+ transforms.Resize(image_size),
127
+ transforms.ToTensor(),
128
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
129
+ ])
130
+
131
+ logger.info("="*60)
132
+ logger.info("✓ MODEL READY!")
133
+ logger.info(f" Model: {MODEL_ID}")
134
+ logger.info(f" Device: {device}")
135
+ logger.info(f" Image Size: {image_size}")
136
+ logger.info("="*60 + "\n")
137
+
138
+ return True
139
+
140
+ except Exception as e:
141
+ logger.error(f"✗ Failed to load model: {str(e)}")
142
+ import traceback
143
+ traceback.print_exc()
144
+ return False
145
+
146
+ # ========== STARTUP SEQUENCE ==========
147
+ logger.info("Starting application...")
148
+
149
+ if not download_model():
150
+ raise Exception("Failed to download model")
151
+
152
+ if not load_model():
153
+ raise Exception("Failed to load model into memory")
154
+
155
+ # ========================================
156
+
157
+ def remove_background(input_image):
158
+ """Remove background dari image"""
159
+
160
+ try:
161
+ if model is None or transform_image is None:
162
+ return None, None, json.dumps({
163
+ "success": False,
164
+ "error": "Model belum siap"
165
+ }, indent=2, ensure_ascii=False)
166
+
167
+ if input_image is None:
168
+ return None, None, json.dumps({
169
+ "success": False,
170
+ "error": "Image tidak boleh kosong"
171
+ }, indent=2, ensure_ascii=False)
172
+
173
+ # Convert to PIL Image
174
+ if not isinstance(input_image, Image.Image):
175
+ input_image = Image.fromarray(input_image).convert("RGB")
176
+ else:
177
+ input_image = input_image.convert("RGB")
178
+
179
+ logger.info(f"Processing image... Size: {input_image.width}x{input_image.height}")
180
+
181
+ # Transform image
182
+ input_tensor = transform_image(input_image).unsqueeze(0).to(device)
183
+
184
+ # Prediction
185
+ with torch.no_grad():
186
+ preds = model(input_tensor)[-1].sigmoid().cpu()
187
+
188
+ pred = preds[0].squeeze()
189
+ pred_pil = transforms.ToPILImage()(pred)
190
+ mask = pred_pil.resize(input_image.size)
191
+
192
+ # Create output with alpha channel
193
+ output_image = input_image.copy()
194
+ output_image.putalpha(mask)
195
+
196
+ logger.info(f"✓ Background removed. Output: {output_image.width}x{output_image.height}")
197
+
198
+ # JSON result
199
+ result = {
200
+ "success": True,
201
+ "input_size": f"{input_image.width}x{input_image.height}",
202
+ "output_size": f"{output_image.width}x{output_image.height}",
203
+ "output_format": "PNG with alpha channel",
204
+ "model": MODEL_ID,
205
+ "device": device
206
+ }
207
+
208
+ return output_image, mask, json.dumps(result, indent=2, ensure_ascii=False)
209
+
210
+ except Exception as e:
211
+ logger.error(f"Error removing background: {str(e)}", exc_info=True)
212
+ return None, None, json.dumps({
213
+ "success": False,
214
+ "error": str(e)
215
+ }, indent=2, ensure_ascii=False)
216
+
217
+ def get_model_info():
218
+ """Return model info sebagai JSON"""
219
+ try:
220
+ info = {
221
+ "model_name": "Background Removal Model",
222
+ "model_id": MODEL_ID,
223
+ "model_path": MODEL_PATH,
224
+ "model_type": "Image Segmentation",
225
+ "device": device if device else "unknown",
226
+ "model_loaded": model is not None,
227
+ "image_processing_size": "1024x1024",
228
+ "output_format": "PNG with transparency (alpha channel)",
229
+ "capabilities": [
230
+ "Automatic background removal",
231
+ "High-quality segmentation",
232
+ "Preserve original image resolution",
233
+ "Generate alpha mask"
234
+ ],
235
+ "use_cases": [
236
+ "Product photography",
237
+ "Portrait editing",
238
+ "E-commerce images",
239
+ "Graphic design",
240
+ "Social media content"
241
+ ]
242
+ }
243
+
244
+ return json.dumps(info, indent=2, ensure_ascii=False)
245
+ except Exception as e:
246
+ return json.dumps({"error": str(e)}, indent=2, ensure_ascii=False)
247
+
248
+ # Custom CSS
249
+ custom_css = """
250
+ #output_json {
251
+ font-family: 'Courier New', monospace;
252
+ font-size: 14px;
253
+ }
254
+ .gradio-container {
255
+ max-width: 1600px !important;
256
+ }
257
+ """
258
+
259
+ # Gradio Interface
260
+ with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo:
261
+ gr.Markdown("""
262
+ # 🎨 Background Removal API
263
+ AI-powered automatic background removal using image segmentation
264
+ """)
265
+
266
+ with gr.Tabs():
267
+ # Tab Background Removal
268
+ with gr.Tab("✂️ Remove Background"):
269
+ with gr.Row():
270
+ with gr.Column(scale=1):
271
+ input_image = gr.Image(
272
+ label="📸 Input Image",
273
+ type="pil",
274
+ height=450
275
+ )
276
+
277
+ remove_btn = gr.Button(
278
+ "✂️ Remove Background",
279
+ variant="primary",
280
+ size="lg"
281
+ )
282
+
283
+ with gr.Column(scale=1):
284
+ output_image = gr.Image(
285
+ label="🖼️ Output (No Background)",
286
+ type="pil",
287
+ height=450
288
+ )
289
+
290
+ output_mask = gr.Image(
291
+ label="🎭 Alpha Mask",
292
+ type="pil",
293
+ height=200
294
+ )
295
+
296
+ output_json = gr.Code(
297
+ label="📄 JSON Output",
298
+ language="json",
299
+ lines=10,
300
+ elem_id="output_json"
301
+ )
302
+
303
+ remove_btn.click(
304
+ fn=remove_background,
305
+ inputs=[input_image],
306
+ outputs=[output_image, output_mask, output_json]
307
+ )
308
+
309
+ # Tab Model Info
310
+ with gr.Tab("ℹ️ Model Info"):
311
+ model_info_output = gr.Code(
312
+ label="Model Information",
313
+ language="json",
314
+ lines=30
315
+ )
316
+ info_btn = gr.Button("🔍 Get Model Info", variant="secondary")
317
+
318
+ info_btn.click(
319
+ fn=get_model_info,
320
+ inputs=[],
321
+ outputs=model_info_output
322
+ )
323
+
324
+ # Tab API Documentation
325
+ with gr.Tab("📚 API Usage"):
326
+ gr.Markdown("""
327
+ ## 🚀 API Usage Guide
328
+
329
+ ### 1. Python Example
330
+ ```python
331
+ import requests
332
+ import base64
333
+ from PIL import Image
334
+ from io import BytesIO
335
+ import json
336
+
337
+ # Load and encode image
338
+ with open("input.jpg", "rb") as f:
339
+ img_data = base64.b64encode(f.read()).decode()
340
+
341
+ url = "https://YOUR-SPACE-URL/api/predict"
342
+
343
+ payload = {
344
+ "data": [f"data:image/jpeg;base64,{img_data}"]
345
+ }
346
+
347
+ response = requests.post(url, json=payload)
348
+ result = response.json()
349
+
350
+ # Get output image (PNG with transparency)
351
+ output_image_data = result['data'][0]
352
+ output_json = json.loads(result['data'][2])
353
+
354
+ # Decode and save
355
+ img_bytes = base64.b64decode(output_image_data.split(',')[1])
356
+ img = Image.open(BytesIO(img_bytes))
357
+ img.save('output_no_bg.png')
358
+
359
+ print(json.dumps(output_json, indent=2))
360
+ ```
361
+
362
+ ### 2. Response Format
363
+ ```json
364
+ {
365
+ "success": true,
366
+ "input_size": "1200x1600",
367
+ "output_size": "1200x1600",
368
+ "output_format": "PNG with alpha channel",
369
+ "model": "mohantesting/remove_background",
370
+ "device": "cuda"
371
+ }
372
+ ```
373
+
374
+ ### 3. Output Format
375
+
376
+ - **Format**: PNG with transparency (alpha channel)
377
+ - **Resolution**: Same as input image
378
+ - **Background**: Completely transparent
379
+ - **Quality**: High-quality segmentation
380
+
381
+ ### 4. Best Practices
382
+
383
+ ✅ **DO:**
384
+ - Use high-resolution images for better results
385
+ - Ensure good contrast between subject and background
386
+ - Use well-lit images
387
+ - Save output as PNG to preserve transparency
388
+
389
+ ❌ **DON'T:**
390
+ - Don't use extremely large images (>4K) - may cause OOM
391
+ - Don't expect perfect results on complex backgrounds
392
+ - Don't save as JPEG (loses transparency)
393
+
394
+ ### 5. Use Cases
395
+
396
+ **E-commerce Product Photos:**
397
+ ```python
398
+ # Remove background from product image
399
+ result = remove_background('product.jpg')
400
+ result.save('product_no_bg.png')
401
+ ```
402
+
403
+ **Portrait Photography:**
404
+ ```python
405
+ # Remove background from portrait
406
+ result = remove_background('portrait.jpg')
407
+ # Can now composite on different backgrounds
408
+ ```
409
+
410
+ **Graphic Design:**
411
+ ```python
412
+ # Create cutouts for design work
413
+ result = remove_background('subject.jpg')
414
+ # Use in Photoshop, Canva, etc.
415
+ ```
416
+
417
+ ### 6. Integration Example
418
+
419
+ **Batch Processing:**
420
+ ```python
421
+ import os
422
+ from pathlib import Path
423
+
424
+ input_dir = 'input_images'
425
+ output_dir = 'output_images'
426
+
427
+ os.makedirs(output_dir, exist_ok=True)
428
+
429
+ for img_file in Path(input_dir).glob('*.jpg'):
430
+ result = remove_background(str(img_file))
431
+ output_path = Path(output_dir) / f"{img_file.stem}_no_bg.png"
432
+ result.save(output_path)
433
+ print(f"Processed: {img_file.name}")
434
+ ```
435
+
436
+ ### 7. Performance
437
+
438
+ - **Processing Time**: ~1-3 seconds per image (GPU)
439
+ - **Max Resolution**: Recommended up to 2048x2048
440
+ - **Model Size**: ~180MB
441
+ - **GPU Memory**: ~2GB recommended
442
+
443
+ ---
444
+
445
+ **Model:** mohantesting/remove_background
446
+ **Type:** Image Segmentation
447
+ **Framework:** PyTorch + Transformers
448
+ """)
449
+
450
+ gr.Markdown("""
451
+ ---
452
+ ### 💡 Tips for Best Results
453
+
454
+ - Use images with **clear subject-background separation**
455
+ - **Good lighting** improves accuracy
456
+ - **Higher resolution** = better edge quality
457
+ - Save as **PNG** to preserve transparency
458
+ """)
459
+
460
+ # Launch
461
+ if __name__ == "__main__":
462
+ demo.launch(
463
+ server_name="0.0.0.0",
464
+ server_port=7860,
465
+ share=False
466
+ )
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ transformers
5
+ Pillow
6
+ accelerate