Janeka commited on
Commit
0b41393
·
verified ·
1 Parent(s): f8034a4

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -0
app.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import torch
5
+ from transformers import pipeline
6
+ from functools import lru_cache
7
+ import cv2
8
+ import logging
9
+
10
+ # Configure logging
11
+ logging.basicConfig(level=logging.INFO)
12
+ logger = logging.getLogger(__name__)
13
+
14
+ # Cache models to avoid reloading on every request
15
+ @lru_cache(maxsize=1)
16
+ def load_model(model_name):
17
+ try:
18
+ return pipeline("image-segmentation", model_name)
19
+ except Exception as e:
20
+ logger.error(f"Failed to load {model_name}: {e}")
21
+ return None
22
+
23
+ # Model sequence configuration
24
+ MODELS = [
25
+ {"name": "BRIA", "repo": "BRIA-AI/bria-rmbg", "weight": 1.0},
26
+ {"name": "INSPyReNet", "repo": "mattmdjaga/INSPyReNet", "weight": 0.9},
27
+ {"name": "U2Net", "repo": "silks-road/u2net", "weight": 0.8},
28
+ {"name": "U2Net-Human", "repo": "mattmdjaga/u2net-human-seg", "weight": 0.7},
29
+ {"name": "ISNet-General", "repo": "xuebinqin/ISNet-general-use", "weight": 0.6},
30
+ {"name": "ISNet-Anime", "repo": "skytnt/anime-seg", "weight": 0.5}
31
+ ]
32
+
33
+ def process_single_model(image, model):
34
+ """Process image with a single model"""
35
+ try:
36
+ pipe = load_model(model["repo"])
37
+ if pipe is None:
38
+ return None
39
+
40
+ # Convert image to numpy array if needed
41
+ if isinstance(image, Image.Image):
42
+ image_np = np.array(image)
43
+ else:
44
+ image_np = image
45
+
46
+ result = pipe(image_np)
47
+ return result['mask'] if isinstance(result, dict) else result[0]['mask']
48
+ except Exception as e:
49
+ logger.warning(f"{model['name']} failed: {e}")
50
+ return None
51
+
52
+ def combine_masks(masks, weights):
53
+ """Combine masks with weighted averaging"""
54
+ valid_masks = [m for m in masks if m is not None]
55
+ if not valid_masks:
56
+ return None
57
+
58
+ total_weight = sum(w for w, m in zip(weights, masks) if m is not None)
59
+ combined = np.zeros_like(valid_masks[0], dtype=np.float32)
60
+
61
+ for mask, weight in zip(masks, weights):
62
+ if mask is not None:
63
+ combined += (mask.astype(np.float32) * weight
64
+
65
+ return (combined / total_weight).astype(np.uint8)
66
+
67
+ def remove_background(image):
68
+ """Main processing pipeline"""
69
+ try:
70
+ # Convert input to PIL Image
71
+ if isinstance(image, np.ndarray):
72
+ image = Image.fromarray(image)
73
+
74
+ # Process through all models
75
+ masks = []
76
+ for model in MODELS:
77
+ mask = process_single_model(image, model)
78
+ masks.append(mask)
79
+
80
+ # Combine results
81
+ weights = [m["weight"] for m in MODELS]
82
+ final_mask = combine_masks(masks, weights)
83
+
84
+ if final_mask is None:
85
+ raise ValueError("All models failed")
86
+
87
+ # Apply mask
88
+ background = Image.new('RGB', image.size, (0, 0, 0))
89
+ final_image = Image.composite(image, background, Image.fromarray(final_mask))
90
+
91
+ return final_image
92
+ except Exception as e:
93
+ logger.error(f"Processing failed: {e}")
94
+ return None
95
+
96
+ # Gradio interface with API endpoint
97
+ with gr.Blocks() as app:
98
+ gr.Markdown("## 🖼️ Advanced Background Remover")
99
+ with gr.Row():
100
+ with gr.Column():
101
+ input_image = gr.Image(label="Upload Image")
102
+ submit_btn = gr.Button("Remove Background")
103
+ with gr.Column():
104
+ output_image = gr.Image(label="Result")
105
+
106
+ submit_btn.click(
107
+ fn=remove_background,
108
+ inputs=input_image,
109
+ outputs=output_image
110
+ )
111
+
112
+ # API endpoint for mobile apps
113
+ app.api_app = gr.routes.App.create_app(
114
+ fn=remove_background,
115
+ inputs=gr.Image(),
116
+ outputs=gr.Image()
117
+ )
118
+
119
+ if __name__ == "__main__":
120
+ app.launch(server_name="0.0.0.0", server_port=7860)