import gradio as gr import torch from PIL import Image from torchvision import transforms from transformers import AutoModelForImageSegmentation import json import logging import os from pathlib import Path import sys from datetime import datetime logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # ========== KONFIGURASI ========== MODEL_ID = "mohantesting/remove_background" MODEL_PATH = "./models/remove_background" MAX_IMAGE_SIZE = 2048 PROCESSING_SIZE = (1024, 1024) # ================================== model = None device = None transform_image = None stats = { "total_processed": 0, "total_errors": 0, "start_time": datetime.now() } def check_model_exists(path): """Cek apakah model sudah ada""" if not os.path.exists(path): return False required_files = ["config.json"] for file in required_files: if not os.path.exists(os.path.join(path, file)): return False has_weights = False for root, dirs, files in os.walk(path): for file in files: if file.endswith((".bin", ".safetensors")): has_weights = True break if has_weights: break return has_weights def get_folder_size(folder_path): """Hitung total ukuran folder""" total_size = 0 for dirpath, dirnames, filenames in os.walk(folder_path): for filename in filenames: filepath = os.path.join(dirpath, filename) if os.path.isfile(filepath): total_size += os.path.getsize(filepath) return total_size def download_model(): """Download model jika belum ada""" logger.info("="*60) logger.info("CHECKING BACKGROUND REMOVAL MODEL...") logger.info("="*60) if check_model_exists(MODEL_PATH): logger.info("✓ Model sudah ada di local!") logger.info(f"✓ Location: {MODEL_PATH}") size_bytes = get_folder_size(MODEL_PATH) size_mb = size_bytes / (1024 * 1024) logger.info(f"✓ Size: {size_mb:.2f} MB") logger.info("✓ Skipping download...\n") return True logger.info("✗ Model tidak ditemukan. Mulai download...") logger.info(f"Model ID: {MODEL_ID}") logger.info(f"Save to: {MODEL_PATH}") logger.info("-" * 60) try: os.makedirs(MODEL_PATH, exist_ok=True) logger.info("Downloading background removal model...") # Download langsung tanpa save - kita akan load langsung dari HF # karena model ini menggunakan custom code (BiRefNet) logger.info("✓ Model akan di-load langsung dari HuggingFace\n") return True except Exception as e: logger.error(f"✗ Error: {str(e)}") import traceback traceback.print_exc() return False def load_model(): """Load model ke memory""" global model, device, transform_image logger.info("="*60) logger.info("LOADING MODEL INTO MEMORY...") logger.info("="*60) try: device = 'cuda' if torch.cuda.is_available() else 'cpu' logger.info(f"Device: {device}") # Coba load dari local dulu, kalau gagal load dari HuggingFace try: if check_model_exists(MODEL_PATH): logger.info("Attempting to load from local...") # Add local path to sys.path for custom modules if MODEL_PATH not in sys.path: sys.path.insert(0, MODEL_PATH) model = AutoModelForImageSegmentation.from_pretrained( MODEL_PATH, trust_remote_code=True, local_files_only=True ) logger.info("✓ Loaded from local cache") else: raise FileNotFoundError("Local model not found") except Exception as e: logger.info(f"Local load failed: {str(e)}") logger.info("Loading from HuggingFace Hub...") model = AutoModelForImageSegmentation.from_pretrained( MODEL_ID, trust_remote_code=True ) logger.info("✓ Loaded from HuggingFace Hub") # Save untuk next time try: logger.info("Saving model to local cache...") model.save_pretrained(MODEL_PATH) logger.info(f"✓ Model saved to {MODEL_PATH}") except Exception as save_err: logger.warning(f"Could not save model: {save_err}") model.eval().to(device) # Setup transform transform_image = transforms.Compose([ transforms.Resize(PROCESSING_SIZE), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) logger.info("="*60) logger.info("✓ MODEL READY!") logger.info(f" Model: {MODEL_ID}") logger.info(f" Device: {device}") logger.info(f" Processing Size: {PROCESSING_SIZE}") logger.info("="*60 + "\n") return True except Exception as e: logger.error(f"✗ Failed to load model: {str(e)}") import traceback traceback.print_exc() return False # ========== STARTUP SEQUENCE ========== logger.info("\n" + "="*60) logger.info(f" APPLICATION STARTUP - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") logger.info("="*60 + "\n") if not download_model(): raise Exception("Failed to prepare model") if not load_model(): raise Exception("Failed to load model into memory") # ======================================== def remove_background(input_image): """Remove background dari image""" try: if model is None or transform_image is None: return None, None, json.dumps({ "success": False, "error": "Model belum siap" }, indent=2, ensure_ascii=False) if input_image is None: return None, None, json.dumps({ "success": False, "error": "Image tidak boleh kosong" }, indent=2, ensure_ascii=False) # Convert to PIL Image if not isinstance(input_image, Image.Image): input_image = Image.fromarray(input_image).convert("RGB") else: input_image = input_image.convert("RGB") original_size = input_image.size logger.info(f"Processing image... Size: {original_size[0]}x{original_size[1]}") # Check if image is too large max_dim = max(original_size) if max_dim > MAX_IMAGE_SIZE: scale = MAX_IMAGE_SIZE / max_dim new_size = (int(original_size[0] * scale), int(original_size[1] * scale)) logger.info(f"Resizing large image to {new_size[0]}x{new_size[1]}") input_image = input_image.resize(new_size, Image.Resampling.LANCZOS) # Transform image input_tensor = transform_image(input_image).unsqueeze(0).to(device) # Prediction with torch.no_grad(): preds = model(input_tensor)[-1].sigmoid().cpu() pred = preds[0].squeeze() pred_pil = transforms.ToPILImage()(pred) mask = pred_pil.resize(input_image.size, Image.Resampling.LANCZOS) # Create output with alpha channel output_image = input_image.copy() output_image.putalpha(mask) # Update stats stats["total_processed"] += 1 logger.info(f"✓ Background removed. Output: {output_image.width}x{output_image.height}") # JSON result result = { "success": True, "input_size": f"{input_image.width}x{input_image.height}", "output_size": f"{output_image.width}x{output_image.height}", "output_format": "PNG with alpha channel", "model": MODEL_ID, "device": device, "processing_time": "~1-3 seconds" } return output_image, mask, json.dumps(result, indent=2, ensure_ascii=False) except Exception as e: stats["total_errors"] += 1 logger.error(f"Error removing background: {str(e)}", exc_info=True) return None, None, json.dumps({ "success": False, "error": str(e) }, indent=2, ensure_ascii=False) def get_model_info(): """Return model info sebagai JSON""" try: uptime = datetime.now() - stats["start_time"] info = { "model_name": "Background Removal Model (BiRefNet)", "model_id": MODEL_ID, "model_path": MODEL_PATH, "model_type": "Image Segmentation", "architecture": "BiRefNet (Bilateral Reference Network)", "device": device if device else "unknown", "model_loaded": model is not None, "processing_size": f"{PROCESSING_SIZE[0]}x{PROCESSING_SIZE[1]}", "max_input_size": f"{MAX_IMAGE_SIZE}x{MAX_IMAGE_SIZE}", "output_format": "PNG with transparency (alpha channel)", "statistics": { "total_processed": stats["total_processed"], "total_errors": stats["total_errors"], "uptime": str(uptime).split('.')[0], "success_rate": f"{((stats['total_processed'] - stats['total_errors']) / max(stats['total_processed'], 1) * 100):.1f}%" }, "capabilities": [ "Automatic background removal", "High-quality segmentation", "Preserve original image resolution", "Generate alpha mask", "Handles complex backgrounds" ], "use_cases": [ "Product photography", "Portrait editing", "E-commerce images", "Graphic design", "Social media content", "Profile pictures" ], "technical_details": { "framework": "PyTorch + Transformers", "trust_remote_code": True, "normalization": "ImageNet stats", "model_size": "~840MB" } } return json.dumps(info, indent=2, ensure_ascii=False) except Exception as e: return json.dumps({"error": str(e)}, indent=2, ensure_ascii=False) # Custom CSS custom_css = """ #output_json { font-family: 'Courier New', monospace; font-size: 14px; } .gradio-container { max-width: 1600px !important; } .tab-nav button { font-size: 16px; font-weight: 500; } """ # Gradio Interface with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🎨 Background Removal API ### AI-powered automatic background removal using BiRefNet Remove backgrounds from images with high-quality segmentation """) with gr.Tabs(): # Tab Background Removal with gr.Tab("✂️ Remove Background"): with gr.Row(): with gr.Column(scale=1): input_image = gr.Image( label="📸 Input Image", type="pil", height=450 ) with gr.Row(): remove_btn = gr.Button( "✂️ Remove Background", variant="primary", size="lg", scale=2 ) clear_btn = gr.ClearButton( components=[input_image], value="🗑️ Clear", size="lg", scale=1 ) with gr.Column(scale=1): output_image = gr.Image( label="🖼️ Output (No Background)", type="pil", height=450 ) output_mask = gr.Image( label="🎭 Alpha Mask", type="pil", height=200 ) output_json = gr.Code( label="📄 JSON Response", language="json", lines=10, elem_id="output_json" ) gr.Markdown(""" ### 💡 Tips for Best Results - Use images with **clear subject-background separation** - **Good lighting** improves accuracy - **Higher resolution** = better edge quality - Images are automatically resized if too large (max 2048px) - Save as **PNG** to preserve transparency """) remove_btn.click( fn=remove_background, inputs=[input_image], outputs=[output_image, output_mask, output_json] ) # Tab Model Info with gr.Tab("ℹ️ Model Info"): model_info_output = gr.Code( label="Model Information & Statistics", language="json", lines=35 ) info_btn = gr.Button("🔍 Get Model Info & Stats", variant="secondary", size="lg") gr.Markdown(""" ### About BiRefNet **BiRefNet** (Bilateral Reference Network) is a state-of-the-art image segmentation model specifically designed for high-quality background removal. It uses bilateral reference mechanisms to achieve precise object segmentation with clean edges. **Key Features:** - Advanced bilateral architecture for precise segmentation - Handles complex backgrounds and fine details - Preserves hair, fur, and transparent objects - Production-ready quality """) info_btn.click( fn=get_model_info, inputs=[], outputs=model_info_output ) # Tab API Documentation with gr.Tab("📚 API Usage"): gr.Markdown(""" ## 🚀 API Usage Guide ### 1. Python Example with Requests ```python import requests import base64 from PIL import Image from io import BytesIO import json # Load and encode image with open("input.jpg", "rb") as f: img_data = base64.b64encode(f.read()).decode() # API endpoint url = "https://YOUR-SPACE-URL/api/predict" payload = { "data": [f"data:image/jpeg;base64,{img_data}"] } # Make request response = requests.post(url, json=payload) result = response.json() # Get output image (PNG with transparency) output_image_data = result['data'][0] output_json = json.loads(result['data'][2]) # Decode and save img_bytes = base64.b64decode(output_image_data.split(',')[1]) img = Image.open(BytesIO(img_bytes)) img.save('output_no_bg.png') print(json.dumps(output_json, indent=2)) ``` ### 2. Using Gradio Client ```python from gradio_client import Client from PIL import Image client = Client("YOUR-SPACE-URL") # Process image result = client.predict( input_image="path/to/image.jpg", api_name="/predict" ) # result contains: [output_image, mask, json_response] output_path, mask_path, json_data = result # Load and use output = Image.open(output_path) output.save("no_background.png") ``` ### 3. Response Format ```json { "success": true, "input_size": "1200x1600", "output_size": "1200x1600", "output_format": "PNG with alpha channel", "model": "mohantesting/remove_background", "device": "cuda", "processing_time": "~1-3 seconds" } ``` ### 4. Batch Processing Script ```python import os from pathlib import Path from gradio_client import Client from PIL import Image client = Client("YOUR-SPACE-URL") input_dir = 'input_images' output_dir = 'output_images' os.makedirs(output_dir, exist_ok=True) for img_file in Path(input_dir).glob('*.jpg'): print(f"Processing: {img_file.name}") result = client.predict( input_image=str(img_file), api_name="/predict" ) output_path = result[0] img = Image.open(output_path) save_path = Path(output_dir) / f"{img_file.stem}_no_bg.png" img.save(save_path) print(f"✓ Saved: {save_path}") ``` ### 5. Output Format Details **Image Format:** - Format: PNG with full alpha transparency - Resolution: Same as input (up to 2048x2048) - Background: Completely transparent (alpha = 0) - Foreground: Fully preserved with smooth edges **Alpha Mask:** - Grayscale image showing segmentation confidence - White (255) = foreground - Black (0) = background - Gray values = edge transitions ### 6. Best Practices ✅ **DO:** - Use high-resolution images (1000px+ recommended) - Ensure good contrast between subject and background - Use well-lit, sharp images - Save output as PNG to preserve transparency - Test with sample images first ❌ **DON'T:** - Don't use extremely large images (>4K) - they'll be auto-resized - Don't expect perfect results on very complex backgrounds - Don't save as JPEG (loses transparency!) - Don't use blurry or low-quality input images ### 7. Common Use Cases **E-Commerce Product Photos:** ```python # Remove background for clean product shots result = remove_background('product.jpg') result.save('product_transparent.png') # Upload to Shopify, Amazon, etc. ``` **Portrait Photography:** ```python # Create professional headshots result = remove_background('portrait.jpg') # Composite on professional backgrounds ``` **Social Media Content:** ```python # Create stickers, cutouts, graphics result = remove_background('subject.jpg') # Use in Instagram, TikTok, YouTube thumbnails ``` **Graphic Design:** ```python # Create design elements result = remove_background('object.jpg') # Import into Photoshop, Illustrator, Canva ``` ### 8. Performance Metrics - **Processing Time**: 1-3 seconds per image (GPU) / 5-10 seconds (CPU) - **Max Resolution**: 2048x2048 (auto-resized if larger) - **Model Size**: ~840MB - **GPU Memory**: ~2GB recommended - **Accuracy**: High-quality segmentation with clean edges ### 9. Error Handling ```python try: result = client.predict(input_image="image.jpg") output_data = json.loads(result[2]) if output_data['success']: print("Success!") else: print(f"Error: {output_data['error']}") except Exception as e: print(f"Request failed: {e}") ``` ### 10. Rate Limits & Quotas - No built-in rate limits (depends on hosting) - For HuggingFace Spaces: Check your space tier - For self-hosted: Limited by GPU/CPU resources - Recommended: Process images sequentially for stability --- **Model:** mohantesting/remove_background (BiRefNet) **Framework:** PyTorch + Transformers + Gradio **License:** Check model repository for licensing details """) gr.Markdown(""" ---

🚀 Ready to integrate background removal into your app?

Use the API documentation above to get started!

""") # Launch if __name__ == "__main__": demo.launch( server_name="0.0.0.0", server_port=7860, share=False )