WeVi commited on
Commit
86b6923
·
verified ·
1 Parent(s): 3724257

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -46
app.py CHANGED
@@ -1,59 +1,68 @@
1
  import gradio as gr
2
- from PIL import Image
3
  import torch
4
  import numpy as np
 
 
 
5
  import os
6
- from torchvision import transforms
7
- from u2net import U2NET # Load model class
8
 
9
- # Load the model (download once, reuse)
 
10
  model_path = "u2net.pth"
 
11
  if not os.path.exists(model_path):
12
- import requests
13
- url = "https://huggingface.co/akhaliq/U-2-Net/resolve/main/u2net.pth"
14
- with open(model_path, "wb") as f:
15
- f.write(requests.get(url).content)
16
-
17
- # Load model to CPU
18
- net = U2NET(3,1)
19
- net.load_state_dict(torch.load(model_path, map_location='cpu'))
20
  net.eval()
21
 
22
  # Preprocessing
23
- transform = transforms.Compose([
24
- transforms.Resize((320,320)),
25
- transforms.ToTensor(),
26
- transforms.Normalize([0.485, 0.456, 0.406],
27
- [0.229, 0.224, 0.225])
28
- ])
29
-
30
- # Post-process output mask
31
- def normalize_prediction(pred):
32
- ma = torch.max(pred)
33
- mi = torch.min(pred)
34
- return (pred - mi) / (ma - mi)
35
-
36
- # Main background remover function
37
- def remove_bg(input_image):
38
- image = input_image.convert("RGB")
39
- orig_size = image.size
40
- img_tensor = transform(image).unsqueeze(0)
41
 
 
 
 
 
42
  with torch.no_grad():
43
- d1, *_ = net(img_tensor)
44
- mask = normalize_prediction(d1[0][0])
45
- mask = mask.squeeze().cpu().numpy()
46
- mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(orig_size)
47
-
48
- # Add alpha channel using mask
49
- image.putalpha(mask)
50
- return image
51
-
52
- # Gradio app
53
- gr.Interface(
54
- fn=remove_bg,
 
 
 
55
  inputs=gr.Image(type="pil", label="Upload Image"),
56
- outputs=gr.Image(type="pil", label="Transparent PNG"),
57
- title="🪄 Remove.bg Clone - AI Background Remover",
58
- description="Upload any image to remove the background using U-2-Net (Hugging Face version)."
59
- ).launch()
 
 
 
1
  import gradio as gr
 
2
  import torch
3
  import numpy as np
4
+ from PIL import Image
5
+ import torchvision.transforms as transforms
6
+ from u2net import U2NET
7
  import os
8
+ import urllib.request
 
9
 
10
+ # Download model if not present
11
+ model_url = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth"
12
  model_path = "u2net.pth"
13
+
14
  if not os.path.exists(model_path):
15
+ print("Downloading model...")
16
+ urllib.request.urlretrieve(model_url, model_path)
17
+
18
+ # Load model
19
+ print("Loading model...")
20
+ net = U2NET(3, 1)
21
+ net.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=False))
 
22
  net.eval()
23
 
24
  # Preprocessing
25
+ def preprocess(img):
26
+ transform = transforms.Compose([
27
+ transforms.Resize((320, 320)),
28
+ transforms.ToTensor(),
29
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
30
+ std=[0.229, 0.224, 0.225])
31
+ ])
32
+ return transform(img).unsqueeze(0)
33
+
34
+ # Postprocess output mask
35
+ def postprocess_mask(d):
36
+ pred = d[0][0]
37
+ pred = (pred - pred.min()) / (pred.max() - pred.min())
38
+ pred = pred.detach().cpu().numpy()
39
+ mask = (pred > 0.5).astype(np.uint8) * 255
40
+ return Image.fromarray(mask)
 
 
41
 
42
+ # Main function
43
+ def remove_background(input_image):
44
+ image = input_image.convert("RGB")
45
+ input_tensor = preprocess(image)
46
  with torch.no_grad():
47
+ d1, *_ = net(input_tensor)
48
+ mask = postprocess_mask(d1)
49
+
50
+ image = image.resize(mask.size)
51
+ image_np = np.array(image)
52
+ mask_np = np.array(mask) / 255
53
+ mask_np = np.expand_dims(mask_np, axis=2)
54
+
55
+ result = image_np * mask_np + (1 - mask_np) * 255
56
+ result = Image.fromarray(result.astype(np.uint8))
57
+ return result
58
+
59
+ # Gradio UI
60
+ demo = gr.Interface(
61
+ fn=remove_background,
62
  inputs=gr.Image(type="pil", label="Upload Image"),
63
+ outputs=gr.Image(type="pil", label="Image without Background"),
64
+ title="🧠 AI Background Remover (U²-Net)",
65
+ description="Removes background from images using U²-Net."
66
+ )
67
+
68
+ demo.launch()