File size: 2,256 Bytes
c2c25ca
0b41393
 
494a7a7
c2c25ca
0b41393
c2c25ca
0b41393
 
 
 
 
 
 
 
 
494a7a7
c2c25ca
0b41393
c2c25ca
 
 
 
 
494a7a7
 
0b41393
494a7a7
c2c25ca
 
 
 
 
 
 
 
 
 
494a7a7
 
c2c25ca
494a7a7
 
 
0b41393
494a7a7
c2c25ca
0b41393
c2c25ca
 
 
0b41393
c2c25ca
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0b41393
c2c25ca
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import gradio as gr
import numpy as np
from PIL import Image
from transformers import pipeline
import cv2

# Model sequence with weights
MODELS = [
    {"name": "BRIA", "repo": "BRIA-AI/bria-rmbg", "weight": 1.0},
    {"name": "INSPyReNet", "repo": "mattmdjaga/INSPyReNet", "weight": 0.9},
    {"name": "U2Net", "repo": "silks-road/u2net", "weight": 0.8},
    {"name": "U2Net-Human", "repo": "mattmdjaga/u2net-human-seg", "weight": 0.7},
    {"name": "ISNet-General", "repo": "xuebinqin/ISNet-general-use", "weight": 0.6},
    {"name": "ISNet-Anime", "repo": "skytnt/anime-seg", "weight": 0.5}
]

def load_model(model_repo):
    return pipeline("image-segmentation", model_repo)

def process_image(input_image):
    # Convert Gradio input to PIL Image
    if isinstance(input_image, np.ndarray):
        input_image = Image.fromarray(input_image)
    
    masks = []
    weights = []
    
    for model in MODELS:
        try:
            pipe = load_model(model["repo"])
            result = pipe(np.array(input_image))
            mask = result[0]['mask'] if isinstance(result, list) else result['mask']
            masks.append(mask)
            weights.append(model["weight"])
            print(f"{model['name']} completed successfully")  # Debug print
        except Exception as e:
            print(f"{model['name']} failed: {str(e)}")  # Debug print
            continue
    
    if not masks:
        return None
    
    # Weighted average of masks
    combined = np.zeros_like(masks[0], dtype=np.float32)
    for mask, weight in zip(masks, weights):
        combined += mask.astype(np.float32) * weight
    final_mask = (combined / sum(weights)).astype(np.uint8)
    
    # Create transparent background
    result = input_image.copy()
    result.putalpha(Image.fromarray(final_mask))
    
    return result

# Gradio interface
demo = gr.Interface(
    fn=process_image,
    inputs=gr.Image(label="Input Image"),
    outputs=gr.Image(label="Result (PNG with Transparency)"),
    title="🎨 Advanced Background Remover",
    description="Combines 6 AI models for perfect background removal",
    examples=[
        ["example1.jpg"],
        ["example2.jpg"],
        ["example3.png"]
    ]
)

demo.launch(server_name="0.0.0.0", server_port=7860)