WeVi commited on
Commit
7b78269
·
verified ·
1 Parent(s): 7f994d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -62
app.py CHANGED
@@ -1,68 +1,28 @@
1
  import gradio as gr
2
- import torch
3
- import numpy as np
4
  from PIL import Image
5
- import torchvision.transforms as transforms
6
- from u2net import U2NET
7
- import os
8
- import urllib.request
9
 
10
- # Download model if not present
11
- model_url = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth"
12
- model_path = "u2net.pth"
13
-
14
- if not os.path.exists(model_path):
15
- print("Downloading model...")
16
- urllib.request.urlretrieve(model_url, model_path)
17
-
18
- # Load model
19
- print("Loading model...")
20
- net = U2NET(3, 1)
21
- net.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
22
- net.eval()
23
-
24
- # Preprocessing
25
- def preprocess(img):
26
- transform = transforms.Compose([
27
- transforms.Resize((320, 320)),
28
- transforms.ToTensor(),
29
- transforms.Normalize(mean=[0.485, 0.456, 0.406],
30
- std=[0.229, 0.224, 0.225])
31
- ])
32
- return transform(img).unsqueeze(0)
33
-
34
- # Postprocess output mask
35
- def postprocess_mask(d):
36
- pred = d[0][0]
37
- pred = (pred - pred.min()) / (pred.max() - pred.min())
38
- pred = pred.detach().cpu().numpy()
39
- mask = (pred > 0.5).astype(np.uint8) * 255
40
- return Image.fromarray(mask)
41
-
42
- # Main function
43
  def remove_background(input_image):
44
- image = input_image.convert("RGB")
45
- input_tensor = preprocess(image)
46
- with torch.no_grad():
47
- d1, *_ = net(input_tensor)
48
- mask = postprocess_mask(d1)
49
-
50
- image = image.resize(mask.size)
51
- image_np = np.array(image)
52
- mask_np = np.array(mask) / 255
53
- mask_np = np.expand_dims(mask_np, axis=2)
54
-
55
- result = image_np * mask_np + (1 - mask_np) * 255
56
- result = Image.fromarray(result.astype(np.uint8))
57
- return result
58
-
59
- # Gradio UI
60
- demo = gr.Interface(
61
- fn=remove_background,
62
- inputs=gr.Image(type="pil", label="Upload Image"),
63
- outputs=gr.Image(type="pil", label="Image without Background"),
64
- title="🧠 AI Background Remover (U²-Net)",
65
- description="Removes background from images using U²-Net."
66
- )
67
 
68
  demo.launch()
 
1
  import gradio as gr
 
 
2
  from PIL import Image
3
+ from rembg import remove
4
+ import io
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  def remove_background(input_image):
7
+ # Ensure image is RGBA for transparency
8
+ input_image = input_image.convert("RGBA")
9
+
10
+ # Remove background
11
+ result = remove(input_image)
12
+
13
+ return Image.open(io.BytesIO(result))
14
+
15
+ with gr.Blocks(title="AI Background Remover (rembg)") as demo:
16
+ gr.Markdown("🧠 **AI Background Remover** - Powered by `rembg` for transparent results")
17
+
18
+ with gr.Row():
19
+ with gr.Column():
20
+ input_image = gr.Image(label="Upload Image", type="pil")
21
+ submit = gr.Button("Remove Background")
22
+
23
+ with gr.Column():
24
+ output_image = gr.Image(label="Transparent PNG Output")
25
+
26
+ submit.click(fn=remove_background, inputs=input_image, outputs=output_image)
 
 
 
27
 
28
  demo.launch()