Janeka commited on
Commit
ec18915
·
verified ·
1 Parent(s): fe059e7

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +121 -0
app.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import cv2
4
+ from PIL import Image
5
+ import requests
6
+ import torch
7
+ from torchvision.transforms import ToTensor, ToPILImage
8
+
9
+ # Initialize models (will be loaded on first use)
10
+ models = {
11
+ "BRIA": None,
12
+ "INSPyReNet": None,
13
+ "U2Net": None,
14
+ "U2NetHumanSeg": None,
15
+ "ISNetGeneral": None,
16
+ "ISNetAnime": None
17
+ }
18
+
19
+ # Model URLs and loading functions
20
+ def load_model(model_name):
21
+ if model_name == "BRIA":
22
+ from transformers import AutoModelForImageSegmentation
23
+ return AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
24
+ elif model_name == "INSPyReNet":
25
+ from IS2.models.inspyrenet import INSPyReNet
26
+ model = INSPyReNet()
27
+ model.load_state_dict(torch.hub.load_state_dict_from_url("https://github.com/helloyufei/INSPyReNet/releases/download/v1.0.0/inspyrenet.pth"))
28
+ return model
29
+ elif model_name == "U2Net":
30
+ import u2net
31
+ return u2net.load_model(model_name="u2net")
32
+ elif model_name == "U2NetHumanSeg":
33
+ import u2net
34
+ return u2net.load_model(model_name="u2net_human_seg")
35
+ elif model_name == "ISNetGeneral":
36
+ from isnet import ISNet
37
+ model = ISNet()
38
+ model.load_state_dict(torch.hub.load_state_dict_from_url("https://github.com/xuebinqin/DIS/raw/main/IS-Net/isnet-general-use.pth"))
39
+ return model
40
+ elif model_name == "ISNetAnime":
41
+ from isnet import ISNet
42
+ model = ISNet()
43
+ model.load_state_dict(torch.hub.load_state_dict_from_url("https://github.com/xuebinqin/DIS/raw/main/IS-Net/isnet-anime.pth"))
44
+ return model
45
+
46
+ def apply_model(image, model_name):
47
+ if models[model_name] is None:
48
+ models[model_name] = load_model(model_name)
49
+
50
+ model = models[model_name]
51
+ model.eval()
52
+
53
+ # Preprocess image based on model requirements
54
+ if model_name == "BRIA":
55
+ from transformers import AutoImageProcessor
56
+ processor = AutoImageProcessor.from_pretrained("briaai/RMBG-1.4")
57
+ inputs = processor(images=image, return_tensors="pt")
58
+ with torch.no_grad():
59
+ outputs = model(**inputs)
60
+ mask = outputs.logits.squeeze().cpu().numpy()
61
+ elif model_name in ["U2Net", "U2NetHumanSeg"]:
62
+ input_img = np.array(image)
63
+ input_img = cv2.resize(input_img, (320, 320))
64
+ input_img = ToTensor()(input_img).unsqueeze(0)
65
+ with torch.no_grad():
66
+ mask = model(input_img).squeeze().cpu().numpy()
67
+ else: # INSPyReNet, ISNet models
68
+ input_img = np.array(image)
69
+ input_img = cv2.resize(input_img, (1024, 1024))
70
+ input_img = ToTensor()(input_img).unsqueeze(0)
71
+ with torch.no_grad():
72
+ mask = model(input_img).squeeze().cpu().numpy()
73
+
74
+ # Post-process mask
75
+ mask = (mask - mask.min()) / (mask.max() - mask.min())
76
+ mask = (mask * 255).astype(np.uint8)
77
+ mask = cv2.resize(mask, (image.width, image.height))
78
+ return Image.fromarray(mask)
79
+
80
+ def combine_masks(masks):
81
+ # Combine masks using weighted average
82
+ combined = np.zeros_like(masks[0], dtype=np.float32)
83
+ weights = [0.3, 0.2, 0.15, 0.15, 0.1, 0.1] # Adjust weights as needed
84
+
85
+ for mask, weight in zip(masks, weights):
86
+ combined += mask.astype(np.float32) * weight
87
+
88
+ combined = np.clip(combined, 0, 255).astype(np.uint8)
89
+ return Image.fromarray(combined)
90
+
91
+ def remove_background(image):
92
+ # Convert to PIL Image if needed
93
+ if isinstance(image, np.ndarray):
94
+ image = Image.fromarray(image)
95
+
96
+ # Apply all models in sequence
97
+ masks = []
98
+ for model_name in ["BRIA", "INSPyReNet", "U2Net", "U2NetHumanSeg", "ISNetGeneral", "ISNetAnime"]:
99
+ mask = apply_model(image, model_name)
100
+ masks.append(np.array(mask))
101
+
102
+ # Combine masks
103
+ final_mask = combine_masks(masks)
104
+
105
+ # Apply mask to original image
106
+ result = image.copy()
107
+ result.putalpha(final_mask)
108
+
109
+ return result
110
+
111
+ # Gradio interface
112
+ interface = gr.Interface(
113
+ fn=remove_background,
114
+ inputs=gr.Image(label="Input Image"),
115
+ outputs=gr.Image(label="Background Removed", type="pil"),
116
+ title="Multi-Model Background Removal",
117
+ description="Combines BRIA, INSPyReNet, U²-Net, U²-Net Human Seg, ISNet-General-Use, and ISNet-Anime for high-quality background removal"
118
+ )
119
+
120
+ if __name__ == "__main__":
121
+ interface.launch()