sudo-saidso commited on
Commit
a28c507
·
verified ·
1 Parent(s): 2f5709b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -54
app.py CHANGED
@@ -7,60 +7,34 @@ from torchvision import transforms
7
  from typing import Union, Tuple
8
  from PIL import Image
9
 
10
- torch.set_float32_matmul_precision(["high", "highest"][0])
 
 
11
 
 
12
  birefnet = AutoModelForImageSegmentation.from_pretrained(
13
  "ZhengPeng7/BiRefNet", trust_remote_code=True
14
  )
15
- birefnet.to("cuda")
16
 
17
- transform_image = transforms.Compose(
18
- [
19
- transforms.Resize((1024, 1024)),
20
- transforms.ToTensor(),
21
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
22
- ]
23
- )
24
 
 
25
  def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
26
- """
27
- Remove the background from an image and return both the transparent version and the original.
28
-
29
- This function performs background removal using a BiRefNet segmentation model. It is intended for use
30
- with image input (either uploaded or from a URL). The function returns a transparent PNG version of the image
31
- with the background removed, along with the original RGB version for comparison.
32
-
33
- Args:
34
- image (PIL.Image or str): The input image, either as a PIL object or a filepath/URL string.
35
-
36
- Returns:
37
- tuple:
38
- - processed_image (PIL.Image): The input image with the background removed and transparency applied.
39
- - origin (PIL.Image): The original RGB image, unchanged.
40
- """
41
- im = load_img(image, output_type="pil")
42
- im = im.convert("RGB")
43
  origin = im.copy()
44
  processed_image = process(im)
45
  return (processed_image, origin)
46
 
47
  @spaces.GPU
48
  def process(image: Image.Image) -> Image.Image:
49
- """
50
- Apply BiRefNet-based image segmentation to remove the background.
51
-
52
- This function preprocesses the input image, runs it through a BiRefNet segmentation model to obtain a mask,
53
- and applies the mask as an alpha (transparency) channel to the original image.
54
-
55
- Args:
56
- image (PIL.Image): The input RGB image.
57
-
58
- Returns:
59
- PIL.Image: The image with the background removed, using the segmentation mask as transparency.
60
- """
61
  image_size = image.size
62
- input_images = transform_image(image).unsqueeze(0).to("cpu")
63
- # Prediction
64
  with torch.no_grad():
65
  preds = birefnet(input_images)[-1].sigmoid().cpu()
66
  pred = preds[0].squeeze()
@@ -70,22 +44,13 @@ def process(image: Image.Image) -> Image.Image:
70
  return image
71
 
72
  def process_file(f: str) -> str:
73
- """
74
- Load an image file from disk, remove the background, and save the output as a transparent PNG.
75
-
76
- Args:
77
- f (str): Filepath of the image to process.
78
-
79
- Returns:
80
- str: Path to the saved PNG image with background removed.
81
- """
82
  name_path = f.rsplit(".", 1)[0] + ".png"
83
- im = load_img(f, output_type="pil")
84
- im = im.convert("RGB")
85
  transparent = process(im)
86
  transparent.save(name_path)
87
  return name_path
88
 
 
89
  slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
90
  slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
91
  image_upload = gr.Image(label="Upload an image")
@@ -93,7 +58,6 @@ image_file_upload = gr.Image(label="Upload an image", type="filepath")
93
  url_input = gr.Textbox(label="Paste an image URL")
94
  output_file = gr.File(label="Output PNG File")
95
 
96
- # Example images
97
  chameleon = load_img("butterfly.jpg", output_type="pil")
98
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
99
 
@@ -102,8 +66,10 @@ tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example
102
  tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
103
 
104
  demo = gr.TabbedInterface(
105
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
 
 
106
  )
107
 
108
  if __name__ == "__main__":
109
- demo.launch(show_error=True, mcp_server=True)
 
7
  from typing import Union, Tuple
8
  from PIL import Image
9
 
10
+ # Automatically use GPU if available
11
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ torch.set_float32_matmul_precision("high")
13
 
14
+ # Load model
15
  birefnet = AutoModelForImageSegmentation.from_pretrained(
16
  "ZhengPeng7/BiRefNet", trust_remote_code=True
17
  )
18
+ birefnet.to(device) # Use dynamic device
19
 
20
+ # Preprocessing
21
+ transform_image = transforms.Compose([
22
+ transforms.Resize((1024, 1024)),
23
+ transforms.ToTensor(),
24
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
25
+ ])
 
26
 
27
+ # Background removal function
28
  def fn(image: Union[Image.Image, str]) -> Tuple[Image.Image, Image.Image]:
29
+ im = load_img(image, output_type="pil").convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  origin = im.copy()
31
  processed_image = process(im)
32
  return (processed_image, origin)
33
 
34
  @spaces.GPU
35
  def process(image: Image.Image) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
 
 
36
  image_size = image.size
37
+ input_images = transform_image(image).unsqueeze(0).to(device) # Use dynamic device
 
38
  with torch.no_grad():
39
  preds = birefnet(input_images)[-1].sigmoid().cpu()
40
  pred = preds[0].squeeze()
 
44
  return image
45
 
46
  def process_file(f: str) -> str:
 
 
 
 
 
 
 
 
 
47
  name_path = f.rsplit(".", 1)[0] + ".png"
48
+ im = load_img(f, output_type="pil").convert("RGB")
 
49
  transparent = process(im)
50
  transparent.save(name_path)
51
  return name_path
52
 
53
+ # Gradio UI
54
  slider1 = gr.ImageSlider(label="Processed Image", type="pil", format="png")
55
  slider2 = gr.ImageSlider(label="Processed Image from URL", type="pil", format="png")
56
  image_upload = gr.Image(label="Upload an image")
 
58
  url_input = gr.Textbox(label="Paste an image URL")
59
  output_file = gr.File(label="Output PNG File")
60
 
 
61
  chameleon = load_img("butterfly.jpg", output_type="pil")
62
  url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
63
 
 
66
  tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
67
 
68
  demo = gr.TabbedInterface(
69
+ [tab1, tab2, tab3],
70
+ ["Image Upload", "URL Input", "File Output"],
71
+ title="Background Removal Tool"
72
  )
73
 
74
  if __name__ == "__main__":
75
+ demo.launch(show_error=True, mcp_server=True)