Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageSegmentation | |
| import json | |
| import logging | |
| import os | |
| from pathlib import Path | |
| import sys | |
| from datetime import datetime | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # ========== KONFIGURASI ========== | |
| MODEL_ID = "mohantesting/remove_background" | |
| MODEL_PATH = "./models/remove_background" | |
| MAX_IMAGE_SIZE = 2048 | |
| PROCESSING_SIZE = (1024, 1024) | |
| # ================================== | |
| model = None | |
| device = None | |
| transform_image = None | |
| stats = { | |
| "total_processed": 0, | |
| "total_errors": 0, | |
| "start_time": datetime.now() | |
| } | |
| def check_model_exists(path): | |
| """Cek apakah model sudah ada""" | |
| if not os.path.exists(path): | |
| return False | |
| required_files = ["config.json"] | |
| for file in required_files: | |
| if not os.path.exists(os.path.join(path, file)): | |
| return False | |
| has_weights = False | |
| for root, dirs, files in os.walk(path): | |
| for file in files: | |
| if file.endswith((".bin", ".safetensors")): | |
| has_weights = True | |
| break | |
| if has_weights: | |
| break | |
| return has_weights | |
| def get_folder_size(folder_path): | |
| """Hitung total ukuran folder""" | |
| total_size = 0 | |
| for dirpath, dirnames, filenames in os.walk(folder_path): | |
| for filename in filenames: | |
| filepath = os.path.join(dirpath, filename) | |
| if os.path.isfile(filepath): | |
| total_size += os.path.getsize(filepath) | |
| return total_size | |
| def download_model(): | |
| """Download model jika belum ada""" | |
| logger.info("="*60) | |
| logger.info("CHECKING BACKGROUND REMOVAL MODEL...") | |
| logger.info("="*60) | |
| if check_model_exists(MODEL_PATH): | |
| logger.info("✓ Model sudah ada di local!") | |
| logger.info(f"✓ Location: {MODEL_PATH}") | |
| size_bytes = get_folder_size(MODEL_PATH) | |
| size_mb = size_bytes / (1024 * 1024) | |
| logger.info(f"✓ Size: {size_mb:.2f} MB") | |
| logger.info("✓ Skipping download...\n") | |
| return True | |
| logger.info("✗ Model tidak ditemukan. Mulai download...") | |
| logger.info(f"Model ID: {MODEL_ID}") | |
| logger.info(f"Save to: {MODEL_PATH}") | |
| logger.info("-" * 60) | |
| try: | |
| os.makedirs(MODEL_PATH, exist_ok=True) | |
| logger.info("Downloading background removal model...") | |
| # Download langsung tanpa save - kita akan load langsung dari HF | |
| # karena model ini menggunakan custom code (BiRefNet) | |
| logger.info("✓ Model akan di-load langsung dari HuggingFace\n") | |
| return True | |
| except Exception as e: | |
| logger.error(f"✗ Error: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| def load_model(): | |
| """Load model ke memory""" | |
| global model, device, transform_image | |
| logger.info("="*60) | |
| logger.info("LOADING MODEL INTO MEMORY...") | |
| logger.info("="*60) | |
| try: | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| logger.info(f"Device: {device}") | |
| # Coba load dari local dulu, kalau gagal load dari HuggingFace | |
| try: | |
| if check_model_exists(MODEL_PATH): | |
| logger.info("Attempting to load from local...") | |
| # Add local path to sys.path for custom modules | |
| if MODEL_PATH not in sys.path: | |
| sys.path.insert(0, MODEL_PATH) | |
| model = AutoModelForImageSegmentation.from_pretrained( | |
| MODEL_PATH, | |
| trust_remote_code=True, | |
| local_files_only=True | |
| ) | |
| logger.info("✓ Loaded from local cache") | |
| else: | |
| raise FileNotFoundError("Local model not found") | |
| except Exception as e: | |
| logger.info(f"Local load failed: {str(e)}") | |
| logger.info("Loading from HuggingFace Hub...") | |
| model = AutoModelForImageSegmentation.from_pretrained( | |
| MODEL_ID, | |
| trust_remote_code=True | |
| ) | |
| logger.info("✓ Loaded from HuggingFace Hub") | |
| # Save untuk next time | |
| try: | |
| logger.info("Saving model to local cache...") | |
| model.save_pretrained(MODEL_PATH) | |
| logger.info(f"✓ Model saved to {MODEL_PATH}") | |
| except Exception as save_err: | |
| logger.warning(f"Could not save model: {save_err}") | |
| model.eval().to(device) | |
| # Setup transform | |
| transform_image = transforms.Compose([ | |
| transforms.Resize(PROCESSING_SIZE), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) | |
| ]) | |
| logger.info("="*60) | |
| logger.info("✓ MODEL READY!") | |
| logger.info(f" Model: {MODEL_ID}") | |
| logger.info(f" Device: {device}") | |
| logger.info(f" Processing Size: {PROCESSING_SIZE}") | |
| logger.info("="*60 + "\n") | |
| return True | |
| except Exception as e: | |
| logger.error(f"✗ Failed to load model: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| return False | |
| # ========== STARTUP SEQUENCE ========== | |
| logger.info("\n" + "="*60) | |
| logger.info(f" APPLICATION STARTUP - {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") | |
| logger.info("="*60 + "\n") | |
| if not download_model(): | |
| raise Exception("Failed to prepare model") | |
| if not load_model(): | |
| raise Exception("Failed to load model into memory") | |
| # ======================================== | |
| def remove_background(input_image): | |
| """Remove background dari image""" | |
| try: | |
| if model is None or transform_image is None: | |
| return None, None, json.dumps({ | |
| "success": False, | |
| "error": "Model belum siap" | |
| }, indent=2, ensure_ascii=False) | |
| if input_image is None: | |
| return None, None, json.dumps({ | |
| "success": False, | |
| "error": "Image tidak boleh kosong" | |
| }, indent=2, ensure_ascii=False) | |
| # Convert to PIL Image | |
| if not isinstance(input_image, Image.Image): | |
| input_image = Image.fromarray(input_image).convert("RGB") | |
| else: | |
| input_image = input_image.convert("RGB") | |
| original_size = input_image.size | |
| logger.info(f"Processing image... Size: {original_size[0]}x{original_size[1]}") | |
| # Check if image is too large | |
| max_dim = max(original_size) | |
| if max_dim > MAX_IMAGE_SIZE: | |
| scale = MAX_IMAGE_SIZE / max_dim | |
| new_size = (int(original_size[0] * scale), int(original_size[1] * scale)) | |
| logger.info(f"Resizing large image to {new_size[0]}x{new_size[1]}") | |
| input_image = input_image.resize(new_size, Image.Resampling.LANCZOS) | |
| # Transform image | |
| input_tensor = transform_image(input_image).unsqueeze(0).to(device) | |
| # Prediction | |
| with torch.no_grad(): | |
| preds = model(input_tensor)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| pred_pil = transforms.ToPILImage()(pred) | |
| mask = pred_pil.resize(input_image.size, Image.Resampling.LANCZOS) | |
| # Create output with alpha channel | |
| output_image = input_image.copy() | |
| output_image.putalpha(mask) | |
| # Update stats | |
| stats["total_processed"] += 1 | |
| logger.info(f"✓ Background removed. Output: {output_image.width}x{output_image.height}") | |
| # JSON result | |
| result = { | |
| "success": True, | |
| "input_size": f"{input_image.width}x{input_image.height}", | |
| "output_size": f"{output_image.width}x{output_image.height}", | |
| "output_format": "PNG with alpha channel", | |
| "model": MODEL_ID, | |
| "device": device, | |
| "processing_time": "~1-3 seconds" | |
| } | |
| return output_image, mask, json.dumps(result, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| stats["total_errors"] += 1 | |
| logger.error(f"Error removing background: {str(e)}", exc_info=True) | |
| return None, None, json.dumps({ | |
| "success": False, | |
| "error": str(e) | |
| }, indent=2, ensure_ascii=False) | |
| def get_model_info(): | |
| """Return model info sebagai JSON""" | |
| try: | |
| uptime = datetime.now() - stats["start_time"] | |
| info = { | |
| "model_name": "Background Removal Model (BiRefNet)", | |
| "model_id": MODEL_ID, | |
| "model_path": MODEL_PATH, | |
| "model_type": "Image Segmentation", | |
| "architecture": "BiRefNet (Bilateral Reference Network)", | |
| "device": device if device else "unknown", | |
| "model_loaded": model is not None, | |
| "processing_size": f"{PROCESSING_SIZE[0]}x{PROCESSING_SIZE[1]}", | |
| "max_input_size": f"{MAX_IMAGE_SIZE}x{MAX_IMAGE_SIZE}", | |
| "output_format": "PNG with transparency (alpha channel)", | |
| "statistics": { | |
| "total_processed": stats["total_processed"], | |
| "total_errors": stats["total_errors"], | |
| "uptime": str(uptime).split('.')[0], | |
| "success_rate": f"{((stats['total_processed'] - stats['total_errors']) / max(stats['total_processed'], 1) * 100):.1f}%" | |
| }, | |
| "capabilities": [ | |
| "Automatic background removal", | |
| "High-quality segmentation", | |
| "Preserve original image resolution", | |
| "Generate alpha mask", | |
| "Handles complex backgrounds" | |
| ], | |
| "use_cases": [ | |
| "Product photography", | |
| "Portrait editing", | |
| "E-commerce images", | |
| "Graphic design", | |
| "Social media content", | |
| "Profile pictures" | |
| ], | |
| "technical_details": { | |
| "framework": "PyTorch + Transformers", | |
| "trust_remote_code": True, | |
| "normalization": "ImageNet stats", | |
| "model_size": "~840MB" | |
| } | |
| } | |
| return json.dumps(info, indent=2, ensure_ascii=False) | |
| except Exception as e: | |
| return json.dumps({"error": str(e)}, indent=2, ensure_ascii=False) | |
| # Custom CSS | |
| custom_css = """ | |
| #output_json { | |
| font-family: 'Courier New', monospace; | |
| font-size: 14px; | |
| } | |
| .gradio-container { | |
| max-width: 1600px !important; | |
| } | |
| .tab-nav button { | |
| font-size: 16px; | |
| font-weight: 500; | |
| } | |
| """ | |
| # Gradio Interface | |
| with gr.Blocks(css=custom_css, theme=gr.themes.Soft()) as demo: | |
| gr.Markdown(""" | |
| # 🎨 Background Removal API | |
| ### AI-powered automatic background removal using BiRefNet | |
| Remove backgrounds from images with high-quality segmentation | |
| """) | |
| with gr.Tabs(): | |
| # Tab Background Removal | |
| with gr.Tab("✂️ Remove Background"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| input_image = gr.Image( | |
| label="📸 Input Image", | |
| type="pil", | |
| height=450 | |
| ) | |
| with gr.Row(): | |
| remove_btn = gr.Button( | |
| "✂️ Remove Background", | |
| variant="primary", | |
| size="lg", | |
| scale=2 | |
| ) | |
| clear_btn = gr.ClearButton( | |
| components=[input_image], | |
| value="🗑️ Clear", | |
| size="lg", | |
| scale=1 | |
| ) | |
| with gr.Column(scale=1): | |
| output_image = gr.Image( | |
| label="🖼️ Output (No Background)", | |
| type="pil", | |
| height=450 | |
| ) | |
| output_mask = gr.Image( | |
| label="🎭 Alpha Mask", | |
| type="pil", | |
| height=200 | |
| ) | |
| output_json = gr.Code( | |
| label="📄 JSON Response", | |
| language="json", | |
| lines=10, | |
| elem_id="output_json" | |
| ) | |
| gr.Markdown(""" | |
| ### 💡 Tips for Best Results | |
| - Use images with **clear subject-background separation** | |
| - **Good lighting** improves accuracy | |
| - **Higher resolution** = better edge quality | |
| - Images are automatically resized if too large (max 2048px) | |
| - Save as **PNG** to preserve transparency | |
| """) | |
| remove_btn.click( | |
| fn=remove_background, | |
| inputs=[input_image], | |
| outputs=[output_image, output_mask, output_json] | |
| ) | |
| # Tab Model Info | |
| with gr.Tab("ℹ️ Model Info"): | |
| model_info_output = gr.Code( | |
| label="Model Information & Statistics", | |
| language="json", | |
| lines=35 | |
| ) | |
| info_btn = gr.Button("🔍 Get Model Info & Stats", variant="secondary", size="lg") | |
| gr.Markdown(""" | |
| ### About BiRefNet | |
| **BiRefNet** (Bilateral Reference Network) is a state-of-the-art image segmentation model | |
| specifically designed for high-quality background removal. It uses bilateral reference | |
| mechanisms to achieve precise object segmentation with clean edges. | |
| **Key Features:** | |
| - Advanced bilateral architecture for precise segmentation | |
| - Handles complex backgrounds and fine details | |
| - Preserves hair, fur, and transparent objects | |
| - Production-ready quality | |
| """) | |
| info_btn.click( | |
| fn=get_model_info, | |
| inputs=[], | |
| outputs=model_info_output | |
| ) | |
| # Tab API Documentation | |
| with gr.Tab("📚 API Usage"): | |
| gr.Markdown(""" | |
| ## 🚀 API Usage Guide | |
| ### 1. Python Example with Requests | |
| ```python | |
| import requests | |
| import base64 | |
| from PIL import Image | |
| from io import BytesIO | |
| import json | |
| # Load and encode image | |
| with open("input.jpg", "rb") as f: | |
| img_data = base64.b64encode(f.read()).decode() | |
| # API endpoint | |
| url = "https://YOUR-SPACE-URL/api/predict" | |
| payload = { | |
| "data": [f"data:image/jpeg;base64,{img_data}"] | |
| } | |
| # Make request | |
| response = requests.post(url, json=payload) | |
| result = response.json() | |
| # Get output image (PNG with transparency) | |
| output_image_data = result['data'][0] | |
| output_json = json.loads(result['data'][2]) | |
| # Decode and save | |
| img_bytes = base64.b64decode(output_image_data.split(',')[1]) | |
| img = Image.open(BytesIO(img_bytes)) | |
| img.save('output_no_bg.png') | |
| print(json.dumps(output_json, indent=2)) | |
| ``` | |
| ### 2. Using Gradio Client | |
| ```python | |
| from gradio_client import Client | |
| from PIL import Image | |
| client = Client("YOUR-SPACE-URL") | |
| # Process image | |
| result = client.predict( | |
| input_image="path/to/image.jpg", | |
| api_name="/predict" | |
| ) | |
| # result contains: [output_image, mask, json_response] | |
| output_path, mask_path, json_data = result | |
| # Load and use | |
| output = Image.open(output_path) | |
| output.save("no_background.png") | |
| ``` | |
| ### 3. Response Format | |
| ```json | |
| { | |
| "success": true, | |
| "input_size": "1200x1600", | |
| "output_size": "1200x1600", | |
| "output_format": "PNG with alpha channel", | |
| "model": "mohantesting/remove_background", | |
| "device": "cuda", | |
| "processing_time": "~1-3 seconds" | |
| } | |
| ``` | |
| ### 4. Batch Processing Script | |
| ```python | |
| import os | |
| from pathlib import Path | |
| from gradio_client import Client | |
| from PIL import Image | |
| client = Client("YOUR-SPACE-URL") | |
| input_dir = 'input_images' | |
| output_dir = 'output_images' | |
| os.makedirs(output_dir, exist_ok=True) | |
| for img_file in Path(input_dir).glob('*.jpg'): | |
| print(f"Processing: {img_file.name}") | |
| result = client.predict( | |
| input_image=str(img_file), | |
| api_name="/predict" | |
| ) | |
| output_path = result[0] | |
| img = Image.open(output_path) | |
| save_path = Path(output_dir) / f"{img_file.stem}_no_bg.png" | |
| img.save(save_path) | |
| print(f"✓ Saved: {save_path}") | |
| ``` | |
| ### 5. Output Format Details | |
| **Image Format:** | |
| - Format: PNG with full alpha transparency | |
| - Resolution: Same as input (up to 2048x2048) | |
| - Background: Completely transparent (alpha = 0) | |
| - Foreground: Fully preserved with smooth edges | |
| **Alpha Mask:** | |
| - Grayscale image showing segmentation confidence | |
| - White (255) = foreground | |
| - Black (0) = background | |
| - Gray values = edge transitions | |
| ### 6. Best Practices | |
| ✅ **DO:** | |
| - Use high-resolution images (1000px+ recommended) | |
| - Ensure good contrast between subject and background | |
| - Use well-lit, sharp images | |
| - Save output as PNG to preserve transparency | |
| - Test with sample images first | |
| ❌ **DON'T:** | |
| - Don't use extremely large images (>4K) - they'll be auto-resized | |
| - Don't expect perfect results on very complex backgrounds | |
| - Don't save as JPEG (loses transparency!) | |
| - Don't use blurry or low-quality input images | |
| ### 7. Common Use Cases | |
| **E-Commerce Product Photos:** | |
| ```python | |
| # Remove background for clean product shots | |
| result = remove_background('product.jpg') | |
| result.save('product_transparent.png') | |
| # Upload to Shopify, Amazon, etc. | |
| ``` | |
| **Portrait Photography:** | |
| ```python | |
| # Create professional headshots | |
| result = remove_background('portrait.jpg') | |
| # Composite on professional backgrounds | |
| ``` | |
| **Social Media Content:** | |
| ```python | |
| # Create stickers, cutouts, graphics | |
| result = remove_background('subject.jpg') | |
| # Use in Instagram, TikTok, YouTube thumbnails | |
| ``` | |
| **Graphic Design:** | |
| ```python | |
| # Create design elements | |
| result = remove_background('object.jpg') | |
| # Import into Photoshop, Illustrator, Canva | |
| ``` | |
| ### 8. Performance Metrics | |
| - **Processing Time**: 1-3 seconds per image (GPU) / 5-10 seconds (CPU) | |
| - **Max Resolution**: 2048x2048 (auto-resized if larger) | |
| - **Model Size**: ~840MB | |
| - **GPU Memory**: ~2GB recommended | |
| - **Accuracy**: High-quality segmentation with clean edges | |
| ### 9. Error Handling | |
| ```python | |
| try: | |
| result = client.predict(input_image="image.jpg") | |
| output_data = json.loads(result[2]) | |
| if output_data['success']: | |
| print("Success!") | |
| else: | |
| print(f"Error: {output_data['error']}") | |
| except Exception as e: | |
| print(f"Request failed: {e}") | |
| ``` | |
| ### 10. Rate Limits & Quotas | |
| - No built-in rate limits (depends on hosting) | |
| - For HuggingFace Spaces: Check your space tier | |
| - For self-hosted: Limited by GPU/CPU resources | |
| - Recommended: Process images sequentially for stability | |
| --- | |
| **Model:** mohantesting/remove_background (BiRefNet) | |
| **Framework:** PyTorch + Transformers + Gradio | |
| **License:** Check model repository for licensing details | |
| """) | |
| gr.Markdown(""" | |
| --- | |
| <div style="text-align: center; padding: 20px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); border-radius: 10px; color: white;"> | |
| <h3 style="margin: 0; color: white;">🚀 Ready to integrate background removal into your app?</h3> | |
| <p style="margin: 10px 0 0 0; opacity: 0.9;">Use the API documentation above to get started!</p> | |
| </div> | |
| """) | |
| # Launch | |
| if __name__ == "__main__": | |
| demo.launch( | |
| server_name="0.0.0.0", | |
| server_port=7860, | |
| share=False | |
| ) |