bluspater commited on
Commit
db1e9a5
·
verified ·
1 Parent(s): 202082b

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +17 -0
  2. requirements.txt +6 -0
  3. u2net_utils.py +63 -0
app.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image
3
+ from u2net_utils import remove_background
4
+
5
+ def process(image):
6
+ result = remove_background(image)
7
+ return result
8
+
9
+ demo = gr.Interface(
10
+ fn=process,
11
+ inputs=gr.Image(type="pil"),
12
+ outputs=gr.Image(type="pil"),
13
+ title="Remove Background",
14
+ description="Upload an image and remove the background using U²-Net."
15
+ )
16
+
17
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ Pillow
4
+ gradio
5
+ numpy
6
+ requests
u2net_utils.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as transforms
3
+ from PIL import Image
4
+ import numpy as np
5
+ import os
6
+ import requests
7
+ from model.u2net import U2NET
8
+
9
+ MODEL_DIR = "saved_models/u2net"
10
+ MODEL_PATH = os.path.join(MODEL_DIR, "u2net.pth")
11
+ MODEL_URL = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth"
12
+
13
+ def download_model():
14
+ if not os.path.exists(MODEL_PATH):
15
+ os.makedirs(MODEL_DIR, exist_ok=True)
16
+ print("Downloading model...")
17
+ r = requests.get(MODEL_URL, stream=True)
18
+ with open(MODEL_PATH, "wb") as f:
19
+ for chunk in r.iter_content(chunk_size=8192):
20
+ f.write(chunk)
21
+ print("Model downloaded.")
22
+
23
+ download_model()
24
+
25
+ def load_model():
26
+ net = U2NET(3, 1)
27
+ net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu"))
28
+ net.eval()
29
+ return net
30
+
31
+ model = load_model()
32
+
33
+ def preprocess(image):
34
+ transform = transforms.Compose([
35
+ transforms.Resize((320, 320)),
36
+ transforms.ToTensor(),
37
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
38
+ std=[0.229, 0.224, 0.225])
39
+ ])
40
+ return transform(image).unsqueeze(0)
41
+
42
+ def postprocess(mask, original_size):
43
+ mask = mask.squeeze().cpu().data.numpy()
44
+ mask = (mask - mask.min()) / (mask.max() - mask.min())
45
+ mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR)
46
+ return mask
47
+
48
+ def remove_background(image):
49
+ input_tensor = preprocess(image)
50
+ with torch.no_grad():
51
+ d1, *_ = model(input_tensor)
52
+ mask = postprocess(d1, image.size)
53
+
54
+ image = image.convert("RGBA")
55
+ datas = image.getdata()
56
+ masks = mask.getdata()
57
+
58
+ new_data = []
59
+ for item, m in zip(datas, masks):
60
+ new_data.append((item[0], item[1], item[2], m))
61
+
62
+ image.putdata(new_data)
63
+ return image