VIATEUR-AI commited on
Commit
eb15366
·
verified ·
1 Parent(s): 8490f3a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -130
app.py CHANGED
@@ -1,131 +1,37 @@
1
- import gradio as gr
2
- from loadimg import load_img
3
- from transformers import AutoModelForImageSegmentation
4
- import torch
5
- from torchvision import transforms
6
  from PIL import Image
7
- import tempfile
8
-
9
- torch.set_float32_matmul_precision(["high", "highest"][0])
10
-
11
- birefnet = AutoModelForImageSegmentation.from_pretrained(
12
- "ZhengPeng7/BiRefNet", trust_remote_code=True
13
- )
14
- birefnet.to("cpu")
15
- transform_image = transforms.Compose(
16
- [
17
- transforms.Resize((1024, 1024)),
18
- transforms.ToTensor(),
19
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
20
- ]
21
- )
22
-
23
- def fn(image):
24
- im = load_img(image, output_type="pil")
25
- im = im.convert("RGB")
26
- origin = im.copy()
27
- image = process(im)
28
- return image
29
-
30
- def parse_color(color):
31
- if color.startswith('#'):
32
- hex_color = color.lstrip('#')
33
- r = int(hex_color[0:2], 16)
34
- g = int(hex_color[2:4], 16)
35
- b = int(hex_color[4:6], 16)
36
- elif color.startswith('rgba'):
37
- rgba_values = color.replace('rgba(', '').replace(')', '')
38
- parts = [x.strip() for x in rgba_values.split(',')]
39
- r, g, b = int(float(parts[0])), int(float(parts[1])), int(float(parts[2]))
40
- elif color.startswith('rgb'):
41
- rgb_values = color.replace('rgb(', '').replace(')', '')
42
- r, g, b = [int(float(x.strip())) for x in rgb_values.split(',')]
43
- else:
44
- r, g, b = 255, 255, 255
45
- return (r, g, b, 255)
46
-
47
- def process(image):
48
- image_size = image.size
49
- input_images = transform_image(image).unsqueeze(0).to("cpu")
50
- with torch.no_grad():
51
- preds = birefnet(input_images)[-1].sigmoid().cpu()
52
- pred = preds[0].squeeze()
53
- pred_pil = transforms.ToPILImage()(pred)
54
- mask = pred_pil.resize(image_size)
55
- image.putalpha(mask)
56
- return image, mask
57
-
58
- def process_file(f, bg_color):
59
- im = load_img(f, output_type="pil")
60
- im = im.convert("RGB")
61
-
62
- transparent_img, mask = process(im)
63
-
64
- # With background color
65
- rgba_color = parse_color(bg_color)
66
- background = Image.new("RGBA", im.size, rgba_color)
67
- with_bg = Image.alpha_composite(background, transparent_img)
68
- with_bg_rgb = with_bg.convert("RGB")
69
-
70
- bg_png_path = tempfile.mktemp(suffix=".png")
71
- with_bg.save(bg_png_path, "PNG")
72
-
73
- bg_jpeg_path = tempfile.mktemp(suffix=".jpeg")
74
- with_bg_rgb.save(bg_jpeg_path, "JPEG")
75
-
76
- # Transparent (no background)
77
- trans_png_path = tempfile.mktemp(suffix=".png")
78
- transparent_img.save(trans_png_path, "PNG")
79
-
80
- return (with_bg_rgb, bg_png_path, bg_jpeg_path,
81
- transparent_img, trans_png_path)
82
-
83
- css = """
84
- .gradio-container h1 {
85
- margin-bottom: 24px;
86
- }
87
- .small-file, .small-file * {
88
- min-height: 0 !important;
89
- height: auto !important;
90
- }
91
- .small-file svg {
92
- display: none !important;
93
- }
94
- """
95
-
96
- with gr.Blocks(css=css, title="Background Remover") as background_remover_app:
97
- gr.Markdown("<h1 style='text-align: center;'>Background Remover</h1>")
98
-
99
- with gr.Row():
100
- with gr.Column(scale=2):
101
- gr.Markdown("### Input")
102
- input_image = gr.Image(label="Upload an image", type="filepath")
103
- color_picker = gr.ColorPicker(label="Background Color", value="#ffffff")
104
- submit_btn = gr.Button("Submit", variant="primary")
105
-
106
- with gr.Column(scale=2):
107
- with gr.Row():
108
- with gr.Column(scale=1):
109
- gr.Markdown("### Output With Background Color")
110
- bg_preview = gr.Image(label="Preview")
111
- bg_png = gr.File(label="Download PNG", elem_classes="small-file")
112
- bg_jpeg = gr.File(label="Download JPEG", elem_classes="small-file")
113
-
114
- with gr.Column(scale=1):
115
- gr.Markdown("### Output With Transparent Background")
116
- trans_preview = gr.Image(label="Preview")
117
- trans_png = gr.File(label="Download PNG", elem_classes="small-file")
118
-
119
- gr.Examples(
120
- examples=[["butterfly.jpg", "#ffffff"]],
121
- inputs=[input_image, color_picker]
122
- )
123
-
124
- submit_btn.click(
125
- fn=process_file,
126
- inputs=[input_image, color_picker],
127
- outputs=[bg_preview, bg_png, bg_jpeg, trans_preview, trans_png]
128
- )
129
-
130
- if __name__ == "__main__":
131
- background_remover_app.launch(share=True)
 
1
+ import requests
 
 
 
 
2
  from PIL import Image
3
+ from io import BytesIO
4
+ import torch
5
+ from transformers import DPTForDepthEstimation, DPTFeatureExtractor
6
+
7
+ # Example: Load an image
8
+ image_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/tasks/image_classification.png"
9
+ response = requests.get(image_url)
10
+ image = Image.open(BytesIO(response.content)).convert("RGB")
11
+
12
+ # Load Hugging Face DPT depth estimation model (can be used to segment background)
13
+ model_name = "Intel/dpt-large"
14
+ feature_extractor = DPTFeatureExtractor.from_pretrained(model_name)
15
+ model = DPTForDepthEstimation.from_pretrained(model_name)
16
+
17
+ # Preprocess image
18
+ inputs = feature_extractor(images=image, return_tensors="pt")
19
+ with torch.no_grad():
20
+ outputs = model(**inputs)
21
+ predicted_depth = outputs.predicted_depth
22
+
23
+ # Convert to numpy and normalize
24
+ depth = predicted_depth.squeeze().cpu().numpy()
25
+ depth_min, depth_max = depth.min(), depth.max()
26
+ normalized_depth = (depth - depth_min) / (depth_max - depth_min)
27
+
28
+ # Simple threshold to create mask (background = 1)
29
+ import numpy as np
30
+ mask = (normalized_depth < 0.6).astype(np.uint8) * 255
31
+
32
+ # Apply mask to image
33
+ image_np = np.array(image)
34
+ image_rgba = np.dstack([image_np, mask]) # add alpha channel
35
+ result = Image.fromarray(image_rgba)
36
+ result.save("output.png")
37
+ print("Background removed and saved as output.png")