File size: 3,094 Bytes
a45dc04
f41e3e9
331d778
 
 
 
a45dc04
 
 
e571436
 
 
 
 
e67f455
e571436
e67f455
e571436
 
 
 
 
 
 
e67f455
a45dc04
e571436
 
 
 
 
 
 
 
 
 
 
 
 
a45dc04
 
 
36588be
ba2ffd5
a45dc04
 
9ca621b
6128b5a
9ca621b
a45dc04
6128b5a
9ca621b
a45dc04
 
 
 
 
e67f455
a45dc04
 
9ca621b
a45dc04
9ca621b
a45dc04
 
 
 
 
9ca621b
a45dc04
9ca621b
 
 
 
 
e67f455
 
 
 
 
 
a45dc04
e67f455
36588be
855a559
9ca621b
855a559
36588be
9ca621b
855a559
9ca621b
855a559
f41e3e9
36588be
a45dc04
9ca621b
a45dc04
 
 
9ca621b
 
e67f455
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
from transformers import pipeline
import torch
import numpy as np
from PIL import Image
import io

def remove_background(input_image):
    try:
        # Convert input to PIL Image if it's not already
        if not isinstance(input_image, Image.Image):
            input_image = Image.fromarray(input_image)
            
        # Initialize the pipeline
        segmentor = pipeline(
            task="image-segmentation",
            model="briaai/RMBG-1.4",
            trust_remote_code=True
        )
        
        # Process the image and get mask
        result = segmentor(
            input_image, 
            return_mask=True
        )
        
        # Create output image with transparent background
        output_image = Image.new('RGBA', input_image.size, (0, 0, 0, 0))
        
        # Convert input to RGBA if it's not already
        if input_image.mode != 'RGBA':
            input_image = input_image.convert('RGBA')
            
        # Apply mask to create transparent background
        mask = result['mask'] if isinstance(result, dict) else result
        output_image.paste(input_image, mask=mask)
        
        return output_image
        
    except Exception as e:
        raise gr.Error(f"Error processing image: {str(e)}")

# Create Gradio interface
with gr.Blocks() as demo:
    gr.HTML(
        """
        <div style="text-align: center; max-width: 800px; margin: 0 auto; padding: 20px;">
            <h1 style="font-size: 2.5rem; margin-bottom: 1rem;">
                AI Background Remover
            </h1>
            <p style="color: #666; font-size: 1.1rem;">
                Remove backgrounds instantly using RMBG V1.4 model
            </p>
        </div>
        """
    )
    
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Upload Image",
                type="pil",
                sources=["upload", "clipboard"]
            )
            
        with gr.Column():
            output_image = gr.Image(
                label="Result",
                type="pil"
            )
    
    with gr.Row():
        clear_btn = gr.Button("Clear", variant="secondary")
        process_btn = gr.Button("Remove Background", variant="primary")

    # Status message
    status_msg = gr.Textbox(
        label="Status",
        placeholder="Ready to process your image...",
        interactive=False
    )

    # Event handlers
    def process_and_update(image):
        if image is None:
            return None, "Please upload an image first"
        try:
            result = remove_background(image)
            return result, "✨ Background removed successfully!"
        except Exception as e:
            return None, f"❌ Error: {str(e)}"

    process_btn.click(
        fn=process_and_update,
        inputs=[input_image],
        outputs=[output_image, status_msg],
    )
    
    clear_btn.click(
        fn=lambda: (None, None, "Ready to process your image..."),
        outputs=[input_image, output_image, status_msg],
    )

# Launch the app
demo.launch()