File size: 2,751 Bytes
a86e315
 
036811a
 
a86e315
036811a
 
a86e315
036811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86e315
036811a
 
 
 
 
 
 
 
 
 
 
 
 
a86e315
036811a
 
a86e315
036811a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a86e315
036811a
a86e315
 
 
036811a
a86e315
 
 
 
 
 
 
5716a35
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import io
import torch
import numpy as np
from PIL import Image
import os
import sys

# Add current directory to path for model files
sys.path.append("/app")

# Import model components
from briarmbg import BriaRMBG
from utilities import preprocess_image, postprocess_image

class BackgroundRemover:
    def __init__(self):
        self.model = None
        self.device = None
        self.load_model()
    
    def load_model(self):
        """Load the RMBG-1.4 model"""
        try:
            print("πŸ”„ Loading background removal model...")
            self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            self.model = BriaRMBG.from_pretrained("/app")
            self.model.to(self.device)
            self.model.eval()
            print("βœ… Model loaded successfully!")
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            self.model = None
    
    def remove_background(self, image):
        """Remove background from image"""
        if self.model is None:
            raise Exception("Model not loaded")
        
        try:
            # Convert to RGB if needed
            input_image = image.convert("RGB")
            
            # Preprocess
            model_input_size = [1024, 1024]
            orig_im = np.array(input_image)
            orig_im_size = orig_im.shape[0:2]
            processed_image = preprocess_image(orig_im, model_input_size).to(self.device)
            
            # Inference
            with torch.no_grad():
                result = self.model(processed_image)
            
            # Postprocess
            result_image = postprocess_image(result[0][0], orig_im_size)
            
            # Create transparent image
            pil_mask = Image.fromarray(result_image)
            no_bg_image = input_image.copy()
            no_bg_image.putalpha(pil_mask)
            
            return no_bg_image
            
        except Exception as e:
            raise Exception(f"Background removal failed: {str(e)}")

# Initialize the remover
remover = BackgroundRemover()

def process_image(image):
    """Gradio interface function"""
    try:
        result = remover.remove_background(image)
        return result
    except Exception as e:
        raise gr.Error(str(e))

# Create Gradio interface
demo = gr.Interface(
    fn=process_image,
    inputs=gr.Image(type="pil", label="πŸ“· Upload Image"),
    outputs=gr.Image(type="pil", label="🎨 Background Removed"),
    title="🎨 Professional Background Remover",
    description="Upload any image (JPG, PNG, etc) to remove background automatically with AI"
)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)