dc086989 commited on
Commit
511782b
·
verified ·
1 Parent(s): 8d6efe4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -59
app.py CHANGED
@@ -1,69 +1,29 @@
1
  import gradio as gr
2
- from gradio_imageslider import ImageSlider
3
- from loadimg import load_img
4
- import spaces
5
- from transformers import AutoModelForImageSegmentation
6
- import torch
7
- from torchvision import transforms
8
 
9
- birefnet = AutoModelForImageSegmentation.from_pretrained(
10
- "ZhengPeng7/BiRefNet", trust_remote_code=True
11
- )
12
- birefnet.to("cpu")
13
-
14
- transform_image = transforms.Compose(
15
- [
16
- transforms.Resize((1024, 1024)),
17
- transforms.ToTensor(),
18
- transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
19
- ]
20
- )
21
 
22
- def fn(image):
23
- im = load_img(image, output_type="pil")
24
- im = im.convert("RGB")
25
- origin = im.copy()
26
- processed_image = process(im)
27
- return (processed_image, origin)
28
 
29
- def process(image):
30
- image_size = image.size
31
- input_images = transform_image(image).unsqueeze(0).to("cpu") # GPU -> CPU
32
- # Prediction
33
- with torch.no_grad():
34
- preds = birefnet(input_images)[-1].sigmoid().cpu()
35
- pred = preds[0].squeeze()
36
- pred_pil = transforms.ToPILImage()(pred)
37
- mask = pred_pil.resize(image_size)
38
- image.putalpha(mask)
39
- return image
40
 
41
- def process_file(f):
42
- name_path = f.rsplit(".", 1)[0] + ".png"
43
- im = load_img(f, output_type="pil")
44
- im = im.convert("RGB")
45
- transparent = process(im)
46
- transparent.save(name_path)
47
- return name_path
48
 
49
- slider1 = ImageSlider(label="Processed Image", type="pil")
50
- slider2 = ImageSlider(label="Processed Image from URL", type="pil")
51
- image_upload = gr.Image(label="Upload an image")
52
- image_file_upload = gr.Image(label="Upload an image", type="filepath")
53
- url_input = gr.Textbox(label="Paste an image URL")
54
- output_file = gr.File(label="Output PNG File")
55
 
56
- # Example images
57
- chameleon = load_img("butterfly.jpg", output_type="pil")
58
- url_example = "https://hips.hearstapps.com/hmg-prod/images/gettyimages-1229892983-square.jpg"
59
-
60
- tab1 = gr.Interface(fn, inputs=image_upload, outputs=slider1, examples=[chameleon], api_name="image")
61
- tab2 = gr.Interface(fn, inputs=url_input, outputs=slider2, examples=[url_example], api_name="text")
62
- tab3 = gr.Interface(process_file, inputs=image_file_upload, outputs=output_file, examples=["butterfly.jpg"], api_name="png")
63
-
64
- demo = gr.TabbedInterface(
65
- [tab1, tab2, tab3], ["Image Upload", "URL Input", "File Output"], title="Background Removal Tool"
66
  )
67
 
 
68
  if __name__ == "__main__":
69
- demo.launch(show_error=True, share=True)
 
1
  import gradio as gr
2
+ from rembg import remove
3
+ from PIL import Image
4
+ import numpy as np
 
 
 
5
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ def remove_background(input_image):
8
+ # Конвертируем изображение из Gradio в PIL
9
+ image = Image.fromarray(input_image.astype('uint8'), 'RGB')
 
 
 
10
 
11
+ # Удаляем фон с помощью U²-Net
12
+ output_image = remove(image)
 
 
 
 
 
 
 
 
 
13
 
14
+ # Конвертируем результат обратно в numpy array для Gradio
15
+ return np.array(output_image)
 
 
 
 
 
16
 
 
 
 
 
 
 
17
 
18
+ # Создаем интерфейс с примерами изображений для теста
19
+ demo = gr.Interface(
20
+ fn=remove_background,
21
+ inputs=gr.Image(label="Загрузите фото"),
22
+ outputs=gr.Image(label="Результат без фона"),
23
+ title="Background Remover",
24
+ description="Загрузите фото получите PNG без фона!"
 
 
 
25
  )
26
 
27
+
28
  if __name__ == "__main__":
29
+ demo.launch(share=False)