|
|
import gradio as gr |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from transformers import pipeline |
|
|
import cv2 |
|
|
|
|
|
|
|
|
MODELS = [ |
|
|
{"name": "BRIA", "repo": "BRIA-AI/bria-rmbg", "weight": 1.0}, |
|
|
{"name": "INSPyReNet", "repo": "mattmdjaga/INSPyReNet", "weight": 0.9}, |
|
|
{"name": "U2Net", "repo": "silks-road/u2net", "weight": 0.8}, |
|
|
{"name": "U2Net-Human", "repo": "mattmdjaga/u2net-human-seg", "weight": 0.7}, |
|
|
{"name": "ISNet-General", "repo": "xuebinqin/ISNet-general-use", "weight": 0.6}, |
|
|
{"name": "ISNet-Anime", "repo": "skytnt/anime-seg", "weight": 0.5} |
|
|
] |
|
|
|
|
|
def load_model(model_repo): |
|
|
return pipeline("image-segmentation", model_repo) |
|
|
|
|
|
def process_image(input_image): |
|
|
|
|
|
if isinstance(input_image, np.ndarray): |
|
|
input_image = Image.fromarray(input_image) |
|
|
|
|
|
masks = [] |
|
|
weights = [] |
|
|
|
|
|
for model in MODELS: |
|
|
try: |
|
|
pipe = load_model(model["repo"]) |
|
|
result = pipe(np.array(input_image)) |
|
|
mask = result[0]['mask'] if isinstance(result, list) else result['mask'] |
|
|
masks.append(mask) |
|
|
weights.append(model["weight"]) |
|
|
print(f"{model['name']} completed successfully") |
|
|
except Exception as e: |
|
|
print(f"{model['name']} failed: {str(e)}") |
|
|
continue |
|
|
|
|
|
if not masks: |
|
|
return None |
|
|
|
|
|
|
|
|
combined = np.zeros_like(masks[0], dtype=np.float32) |
|
|
for mask, weight in zip(masks, weights): |
|
|
combined += mask.astype(np.float32) * weight |
|
|
final_mask = (combined / sum(weights)).astype(np.uint8) |
|
|
|
|
|
|
|
|
result = input_image.copy() |
|
|
result.putalpha(Image.fromarray(final_mask)) |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
demo = gr.Interface( |
|
|
fn=process_image, |
|
|
inputs=gr.Image(label="Input Image"), |
|
|
outputs=gr.Image(label="Result (PNG with Transparency)"), |
|
|
title="🎨 Advanced Background Remover", |
|
|
description="Combines 6 AI models for perfect background removal", |
|
|
examples=[ |
|
|
["example1.jpg"], |
|
|
["example2.jpg"], |
|
|
["example3.png"] |
|
|
] |
|
|
) |
|
|
|
|
|
demo.launch(server_name="0.0.0.0", server_port=7860) |