pixelsdesign commited on
Commit
808ad1a
·
verified ·
1 Parent(s): f89d019

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -104
app.py CHANGED
@@ -1,109 +1,63 @@
1
- import gradio as gr
2
- from loadimg import load_img
3
- import spaces
4
- from transformers import AutoModelForImageSegmentation
5
  import torch
6
- from torchvision import transforms
7
- from typing import Union, Tuple
8
  from PIL import Image
9
-
10
- torch.set_float32_matmul_precision(["high", "highest"][0])
11
-
12
- birefnet = AutoModelForImageSegmentation.from_pretrained(
13
- "ZhengPeng7/BiRefNet", trust_remote_code=True
14
- )
15
- birefnet.to("cpu")
16
-
17
- transform_image = transforms.Compose(
18
- [
19
- transforms.Resize((1024, 1024)),
20
- transforms.ToTensor(),
21
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
22
- ]
23
- )
24
-
25
- def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
26
- """
27
- Remove the background from an image and return both the transparent version and the original.
28
-
29
- This function performs background removal using a BiRefNet segmentation model. It is intended for use
30
- with image input (either uploaded or from a URL). The function returns a transparent PNG version of the image
31
- with the background removed, along with the original RGB version for comparison.
32
-
33
- Args:
34
- image (PIL.Image or str): The input image, either as a PIL object or a filepath/URL string.
35
-
36
- Returns:
37
- tuple:
38
- - processed_image (PIL.Image): The input image with the background removed and transparency applied.
39
- - origin (PIL.Image): The original RGB image, unchanged.
40
- """
41
- im = load_img(image, output_type="pil")
42
- im = im.convert("RGB")
43
- origin = im.copy()
44
- processed_image = process(im)
45
- return (origin, processed_image)
46
-
47
- @spaces.GPU
48
- def process(image: Image.Image) -> Image.Image:
49
- """
50
- Apply BiRefNet-based image segmentation to remove the background.
51
-
52
- This function preprocesses the input image, runs it through a BiRefNet segmentation model to obtain a mask,
53
- and applies the mask as an alpha (transparency) channel to the original image.
54
-
55
- Args:
56
- image (PIL.Image): The input RGB image.
57
-
58
- Returns:
59
- PIL.Image: The image with the background removed, using the segmentation mask as transparency.
60
- """
61
- image_size = image.size
62
- input_images = transform_image(image).unsqueeze(0).to("cuda")
63
- # Prediction
64
- with torch.no_grad():
65
- preds = birefnet(input_images)[-1].sigmoid().cpu()
66
- pred = preds[0].squeeze()
67
- pred_pil = transforms.ToPILImage()(pred)
68
- mask = pred_pil.resize(image_size)
69
- image.putalpha(mask)
70
- return image
71
-
72
- def process_file(f: str) -> str:
73
- """
74
- Load an image file from disk, remove the background, and save the output as a transparent PNG.
75
-
76
- Args:
77
- f (str): Filepath of the image to process.
78
-
79
- Returns:
80
- str: Path to the saved PNG image with background removed.
81
- """
82
- name_path = f.rsplit(".", 1)[0] + ".png"
83
- im = load_img(f, output_type="pil")
84
- im = im.convert("RGB")
85
- transparent = process(im)
86
- transparent.save(name_path)
87
- return name_path
88
-
89
- slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
90
- slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
91
- image_upload = gr.Image(label="Upload an image")
92
- image_file_upload = gr.Image(label="Upload an image", type="filepath")
93
- url_input = gr.Textbox(label="Paste an image URL")
94
- output_file = gr.File(label="Output PNG File")
95
-
96
- # Example images
97
- chameleon = load_img("butterfly.jpg", output_type="pil")
98
- url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
99
-
100
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
101
- tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
102
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
103
-
104
- demo = gr.TabbedInterface(
105
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
106
  )
107
 
108
  if __name__ == "__main__":
109
- demo.launch(show_error=True, mcp_server=True)
 
 
 
 
 
1
  import torch
 
 
2
  from PIL import Image
3
+ import numpy as np
4
+ import base64
5
+ import io
6
+ import gradio as gr
7
+ from your_model_imports import BiRefNet # replace with your actual model import
8
+
9
+ # Force CPU
10
+ device = torch.device("cpu")
11
+
12
+ # Load model
13
+ birefnet = BiRefNet() # or your model class
14
+ birefnet.to(device)
15
+ birefnet.eval() # set evaluation mode
16
+
17
+ # Helper to convert base64 to PIL
18
+ def b64_to_pil(b64_image):
19
+ header, data = b64_image.split(",", 1)
20
+ img_bytes = base64.b64decode(data)
21
+ return Image.open(io.BytesIO(img_bytes)).convert("RGBA")
22
+
23
+ # Helper to convert PIL to base64
24
+ def pil_to_b64(pil_img):
25
+ buffered = io.BytesIO()
26
+ pil_img.save(buffered, format="PNG")
27
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
28
+ return f"data:image/png;base64,{img_str}"
29
+
30
+ # Background removal function
31
+ def remove_bg(image_b64):
32
+ try:
33
+ # Convert to PIL
34
+ img = b64_to_pil(image_b64)
35
+
36
+ # Convert PIL to tensor
37
+ img_tensor = torch.from_numpy(np.array(img)).permute(2,0,1).unsqueeze(0).float() / 255.0
38
+ img_tensor = img_tensor.to(device)
39
+
40
+ # Run model
41
+ with torch.no_grad():
42
+ output_tensor = birefnet(img_tensor)
43
+
44
+ # Convert output tensor to PIL
45
+ output_np = (output_tensor.squeeze().permute(1,2,0).numpy() * 255).astype(np.uint8)
46
+ output_pil = Image.fromarray(output_np)
47
+
48
+ # Convert to base64
49
+ return pil_to_b64(output_pil)
50
+ except Exception as e:
51
+ return f"ERROR: {str(e)}"
52
+
53
+ # Gradio interface
54
+ iface = gr.Interface(
55
+ fn=remove_bg,
56
+ inputs=gr.Image(type="pil", label="Input Image"),
57
+ outputs=gr.Image(type="auto", label="Background Removed"),
58
+ title="Background Remover Pixels",
59
+ description="Removes background using CPU-only model."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  )
61
 
62
  if __name__ == "__main__":
63
+ iface.launch()