Abso1ute666 commited on
Commit
9239cd6
·
verified ·
1 Parent(s): 0d6590f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -0
app.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import SegformerImageProcessor, AutoModelForSemanticSegmentation
2
+ from PIL import Image, ImageFilter
3
+ import torch.nn as nn
4
+ import os
5
+ import gradio as gr
6
+
7
+ processor = SegformerImageProcessor.from_pretrained("mattmdjaga/segformer_b2_clothes")
8
+ model = AutoModelForSemanticSegmentation.from_pretrained("mattmdjaga/segformer_b2_clothes")
9
+
10
+ title = "Background remover 👀"
11
+ description = " Image segmentation model which removes the background and optionally adds a white border."
12
+ article = 'Inference done on "mattmdjaga/segformer_b2_clothes" model'
13
+
14
+
15
+ folder_path = "Images"
16
+ example_list = []
17
+ if os.path.exists(folder_path) and os.path.isdir(folder_path):
18
+ file_paths = [os.path.join(folder_path, file_name) for file_name in os.listdir(folder_path)]
19
+ for file_path in file_paths:
20
+ example_list.append(['Large',file_path])
21
+
22
+ def predict(border_size, image):
23
+ sizes = {'Large': 5, 'Medium': 3, 'Small': 1, 'None': 0}
24
+ image = image.convert('RGB')
25
+ inputs = processor(images=image, return_tensors="pt")
26
+
27
+ outputs = model(**inputs)
28
+ logits = outputs.logits.cpu()
29
+
30
+ upsampled_logits = nn.functional.interpolate(
31
+ logits,
32
+ size=image.size[::-1],
33
+ mode="bilinear",
34
+ align_corners=False,
35
+ )
36
+
37
+ pred_seg = upsampled_logits.argmax(dim=1)[0]
38
+
39
+ non_background_mask = pred_seg != 0
40
+
41
+ # Convert tensor mask to PIL Image with an alpha channel
42
+ non_background_pil_mask = Image.fromarray(non_background_mask.numpy().astype('uint8') * 255, 'L')
43
+
44
+ # Create a composite image using the non-background mask
45
+ composite_image = Image.new('RGBA', image.size, color=(0, 0, 0, 0))
46
+ composite_image.paste(image.convert('RGBA'), mask=non_background_pil_mask)
47
+
48
+ if sizes[border_size] != 0:
49
+ stroke_radius = sizes[border_size]
50
+ img = composite_image # RGBA image
51
+ stroke_image = Image.new("RGBA", img.size, (255, 255, 255, 255))
52
+ img_alpha = img.getchannel(3).point(lambda x: 255 if x>0 else 0)
53
+ stroke_alpha = img_alpha.filter(ImageFilter.MaxFilter(stroke_radius))
54
+ stroke_alpha = stroke_alpha.filter(ImageFilter.SMOOTH)
55
+ stroke_image.putalpha(stroke_alpha)
56
+ output = Image.alpha_composite(stroke_image, img)
57
+ return output
58
+ else:
59
+ return composite_image
60
+
61
+ iface = gr.Interface(fn=predict,
62
+ inputs=[gr.Dropdown(['None','Small', 'Medium', 'Large'], label='Select Border Size'),
63
+ gr.Image(type='pil', label='Select Image.')],
64
+ outputs=gr.Image(type='pil', label='Output with background removed (sorta?)'),
65
+ title=title,
66
+ description=description,
67
+ article=article,
68
+ examples=example_list)
69
+ iface.launch()