Janeka commited on
Commit
4e6d0f7
·
verified ·
1 Parent(s): 769e2f0

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +99 -0
app.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import requests
4
+ from io import BytesIO
5
+ import cv2
6
+
7
+ # Load models from Hugging Face Hub
8
+ from huggingface_hub import from_pretrained_keras, hf_hub_download
9
+ import tensorflow as tf
10
+
11
+ # Load U²-Net (we'll use a lightweight version suitable for CPU)
12
+ def load_u2net():
13
+ model_path = hf_hub_download(repo_id="skytnt/anime-remove-background", filename="u2netp.onnx")
14
+ net = cv2.dnn.readNetFromONNX(model_path)
15
+ return net
16
+
17
+ # Load BRIA model (using a CPU-compatible version)
18
+ def load_bria():
19
+ model = from_pretrained_keras("briaai/RMBG-1.4", compile=False)
20
+ return model
21
+
22
+ # Preprocess image for U²-Net
23
+ def preprocess_u2net(image):
24
+ image = image.resize((320, 320))
25
+ image = np.array(image)
26
+ image = image / 255.0
27
+ image = image.transpose(2, 0, 1)
28
+ image = np.expand_dims(image, axis=0).astype('float32')
29
+ return image
30
+
31
+ # Preprocess image for BRIA
32
+ def preprocess_bria(image):
33
+ image = image.resize((1024, 1024))
34
+ image = np.array(image)
35
+ image = image / 255.0
36
+ image = image.astype('float32')
37
+ return np.expand_dims(image, axis=0)
38
+
39
+ # Postprocess mask
40
+ def postprocess_mask(mask):
41
+ mask = mask.squeeze()
42
+ mask = (mask * 255).astype('uint8')
43
+ mask = Image.fromarray(mask).resize((original_width, original_height))
44
+ return mask
45
+
46
+ # Compare masks and select the better one
47
+ def select_better_mask(mask1, mask2):
48
+ # Simple heuristic: select the mask with more defined edges
49
+ # You can implement more sophisticated comparison if needed
50
+ edge1 = cv2.Canny(np.array(mask1), 100, 200)
51
+ edge2 = cv2.Canny(np.array(mask2), 100, 200)
52
+ return mask1 if np.sum(edge1) > np.sum(edge2) else mask2
53
+
54
+ # Load models (we'll do this once when the Space starts)
55
+ u2net = load_u2net()
56
+ bria = load_bria()
57
+
58
+ def remove_background(image):
59
+ global original_width, original_height
60
+ original_width, original_height = image.size
61
+
62
+ # Process with U²-Net
63
+ u2net_input = preprocess_u2net(image)
64
+ u2net.setInput(u2net_input)
65
+ u2net_mask = u2net.forward()
66
+ u2net_mask = postprocess_mask(u2net_mask[0][0])
67
+
68
+ # Process with BRIA
69
+ bria_input = preprocess_bria(image)
70
+ bria_mask = bria.predict(bria_input)
71
+ bria_mask = postprocess_mask(bria_mask[0][:, :, 0])
72
+
73
+ # Select better mask
74
+ final_mask = select_better_mask(u2net_mask, bria_mask)
75
+
76
+ # Apply mask to original image
77
+ image = image.convert("RGBA")
78
+ final_mask = final_mask.convert("L")
79
+ image.putalpha(final_mask)
80
+
81
+ return image
82
+
83
+ # Gradio interface
84
+ import gradio as gr
85
+
86
+ def process_image(input_image):
87
+ image = Image.fromarray(input_image)
88
+ result = remove_background(image)
89
+ return result
90
+
91
+ iface = gr.Interface(
92
+ fn=process_image,
93
+ inputs=gr.Image(),
94
+ outputs=gr.Image(type="pil"),
95
+ title="Background Removal Pipeline (BRIA + U²-Net)",
96
+ description="Combines BRIA and U²-Net models for better background removal (CPU-only version)"
97
+ )
98
+
99
+ iface.launch()