Janeka commited on
Commit
f1139b5
·
verified ·
1 Parent(s): 93cbd2a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -57
app.py CHANGED
@@ -2,79 +2,68 @@ import cv2
2
  import numpy as np
3
  from PIL import Image
4
  import gradio as gr
5
- import os
6
 
7
- def refine_edges(image, edge_smoothness=3, blur_radius=2, feather_amount=1, threshold=0.1):
8
  """
9
- Enhanced edge refinement that specifically targets leftover background pixels
 
 
 
10
  """
11
- img = image.convert("RGBA")
12
- np_img = np.array(img)
13
- r, g, b, a = cv2.split(np_img)
 
14
 
15
- # Convert parameters
16
- blur_kernel = blur_radius * 2 + 1
17
- smooth_iterations = edge_smoothness
 
 
18
 
19
- # 1. Create a strict alpha mask (remove semi-transparent pixels)
20
- _, strict_alpha = cv2.threshold(a, 254, 255, cv2.THRESH_BINARY)
 
21
 
22
- # 2. Find the "edge zone" (area between strict alpha and original alpha)
23
- edge_zone = cv2.bitwise_xor(a, strict_alpha)
24
-
25
- # 3. Process only the edge zone
26
- kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
27
- for _ in range(smooth_iterations):
28
- edge_zone = cv2.morphologyEx(edge_zone, cv2.MORPH_OPEN, kernel)
29
- edge_zone = cv2.morphologyEx(edge_zone, cv2.MORPH_CLOSE, kernel)
30
-
31
- # 4. Apply blur only to edge zone
32
- blurred_edge = cv2.GaussianBlur(edge_zone, (blur_kernel, blur_kernel), 0)
33
-
34
- # 5. Feathering with thresholding to remove leftover bg
35
- if feather_amount > 0:
36
- edge_mask = (blurred_edge > threshold * 255).astype(np.uint8) * 255
37
- edge_mask = cv2.erode(edge_mask, np.ones((feather_amount, feather_amount), np.uint8))
38
- final_edge = cv2.GaussianBlur(edge_mask, (blur_kernel, blur_kernel), 0)
39
- else:
40
- final_edge = blurred_edge
41
-
42
- # 6. Combine with strict alpha
43
- new_alpha = cv2.bitwise_or(strict_alpha, final_edge)
44
 
45
- # 7. Remove color information from transparent areas
46
- r = r * (new_alpha > 0)
47
- g = g * (new_alpha > 0)
48
- b = b * (new_alpha > 0)
49
 
50
- # Recombine channels
51
- result = cv2.merge([r, g, b, new_alpha])
 
52
 
 
 
53
  return Image.fromarray(result)
54
 
55
- # Gradio interface
56
- with gr.Blocks(title="✨ Advanced Edge Refiner") as demo:
57
- gr.Markdown("""
58
- # ✨ Advanced Edge Refiner
59
- Removes leftover background artifacts around hair and fine edges
60
- """)
61
-
62
  with gr.Row():
63
  with gr.Column():
64
- input_image = gr.Image(type="pil", label="Input Image")
65
- edge_smoothness = gr.Slider(1, 5, value=3, label="Edge Smoothness")
66
- blur_radius = gr.Slider(1, 5, value=2, label="Blur Strength")
67
- feather_amount = gr.Slider(0, 5, value=1, label="Feather Amount")
68
- threshold = gr.Slider(0, 100, value=10, label="Edge Threshold (%)")
69
- submit_btn = gr.Button("Refine Edges", variant="primary")
70
-
71
  with gr.Column():
72
- output_image = gr.Image(type="pil", label="Refined Image")
73
 
74
- submit_btn.click(
75
  fn=refine_edges,
76
- inputs=[input_image, edge_smoothness, blur_radius, feather_amount, threshold],
77
- outputs=output_image
78
  )
79
 
80
  if __name__ == "__main__":
 
2
  import numpy as np
3
  from PIL import Image
4
  import gradio as gr
 
5
 
6
+ def refine_edges(img, edge_aggressiveness=3, bg_removal_strength=50):
7
  """
8
+ Advanced edge refinement using:
9
+ 1. Alpha matte estimation
10
+ 2. Guided filtering
11
+ 3. Color decontamination
12
  """
13
+ # Convert to numpy array
14
+ np_img = np.array(img.convert("RGBA"))
15
+ rgb = np_img[..., :3].astype(np.float32) / 255.0
16
+ alpha = np_img[..., 3].astype(np.float32) / 255.0
17
 
18
+ # 1. Create trimap from alpha
19
+ trimap = np.zeros_like(alpha)
20
+ trimap[alpha > 0.95] = 1 # Definite foreground
21
+ trimap[alpha < 0.05] = 0 # Definite background
22
+ trimap[(alpha >= 0.05) & (alpha <= 0.95)] = 0.5 # Unknown area
23
 
24
+ # 2. Estimate foreground/background colors
25
+ fg = rgb * (trimap == 1)[..., None]
26
+ bg = rgb * (trimap == 0)[..., None]
27
 
28
+ # 3. Guided filter for alpha refinement
29
+ radius = edge_aggressiveness * 5
30
+ eps = 0.01
31
+ refined_alpha = cv2.ximgproc.guidedFilter(
32
+ guide=rgb,
33
+ src=alpha,
34
+ radius=radius,
35
+ eps=eps
36
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
37
 
38
+ # 4. Color decontamination
39
+ bg_removal = bg_removal_strength / 100.0
40
+ new_rgb = (rgb - bg_removal * bg) / (1 - bg_removal * (1 - refined_alpha[..., None]))
41
+ new_rgb = np.clip(new_rgb, 0, 1)
42
 
43
+ # 5. Final alpha thresholding
44
+ final_alpha = np.clip(refined_alpha * 255, 0, 255).astype(np.uint8)
45
+ new_rgb = (new_rgb * 255).astype(np.uint8)
46
 
47
+ # Combine channels
48
+ result = np.concatenate([new_rgb, final_alpha[..., None]], axis=-1)
49
  return Image.fromarray(result)
50
 
51
+ # Gradio Interface
52
+ with gr.Blocks() as demo:
53
+ gr.Markdown("## ✂️ Advanced Edge Refiner")
 
 
 
 
54
  with gr.Row():
55
  with gr.Column():
56
+ img_input = gr.Image(type="pil", label="Input PNG")
57
+ edge_slider = gr.Slider(1, 10, value=3, label="Edge Precision")
58
+ bg_slider = gr.Slider(1, 100, value=50, label="BG Removal Strength")
59
+ process_btn = gr.Button("Refine Edges")
 
 
 
60
  with gr.Column():
61
+ img_output = gr.Image(type="pil", label="Refined Result")
62
 
63
+ process_btn.click(
64
  fn=refine_edges,
65
+ inputs=[img_input, edge_slider, bg_slider],
66
+ outputs=img_output
67
  )
68
 
69
  if __name__ == "__main__":