Janeka commited on
Commit
2f4bc9a
·
verified ·
1 Parent(s): 626b6d0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -64
app.py CHANGED
@@ -3,103 +3,79 @@ import numpy as np
3
  from PIL import Image
4
  import gradio as gr
5
  import os
6
- from io import BytesIO
7
- import requests
8
 
9
- def refine_edges(image, edge_smoothness=3, blur_radius=2, feather_amount=1):
10
  """
11
- Refines edges of a transparent PNG image with configurable parameters.
12
  """
13
  img = image.convert("RGBA")
14
  np_img = np.array(img)
15
- alpha = np_img[:, :, 3]
16
 
17
- # Scale parameters
18
  blur_kernel = blur_radius * 2 + 1
19
  smooth_iterations = edge_smoothness
20
- feather_size = feather_amount
21
 
22
- # Edge smoothing
 
 
 
 
 
 
23
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
24
  for _ in range(smooth_iterations):
25
- alpha = cv2.morphologyEx(alpha, cv2.MORPH_OPEN, kernel)
26
- alpha = cv2.morphologyEx(alpha, cv2.MORPH_CLOSE, kernel)
27
 
28
- # Gaussian blur
29
- alpha = cv2.GaussianBlur(alpha, (blur_kernel, blur_kernel), 0)
30
 
31
- # Feather edges
32
  if feather_amount > 0:
33
- _, mask = cv2.threshold(alpha, 10, 255, cv2.THRESH_BINARY)
34
- edges = cv2.Canny(mask, 100, 200)
35
- edges = cv2.dilate(edges, np.ones((feather_size, feather_size), np.uint8), iterations=1)
36
- alpha_blurred = cv2.GaussianBlur(alpha, (blur_kernel, blur_kernel), 0)
37
- alpha = np.where(edges > 0, alpha_blurred, alpha)
38
 
39
- alpha = np.clip(alpha, 0, 255).astype(np.uint8)
40
- np_img[:, :, 3] = alpha
41
 
42
- return Image.fromarray(np_img)
43
-
44
- def download_example_images():
45
- """Download example images from alternative sources"""
46
- example_images = {
47
- "hair.png": "https://i.imgur.com/JQJQJQJ.png", # Replace with actual URL
48
- "furry_animal.png": "https://i.imgur.com/ANIMAL.png",
49
- "glasses.png": "https://i.imgur.com/GLASSES.png"
50
- }
51
 
52
- os.makedirs("examples", exist_ok=True)
 
53
 
54
- for filename, url in example_images.items():
55
- try:
56
- response = requests.get(url)
57
- if response.status_code == 200:
58
- with open(f"examples/{filename}", "wb") as f:
59
- f.write(response.content)
60
- except Exception as e:
61
- print(f"Couldn't download {filename}: {str(e)}")
62
- # Provide fallback blank image
63
- blank = Image.new("RGBA", (256, 256), (0, 0, 0, 0))
64
- blank.save(f"examples/{filename}")
65
 
66
- # Download examples at startup
67
- download_example_images()
68
-
69
- # Create Gradio interface
70
- with gr.Blocks(title="✨ Edge Refiner") as demo:
71
  gr.Markdown("""
72
- # ✨ Edge Refiner - Clean Up Your Background-Removed Images!
73
- Refine the edges of transparent PNG images for cleaner results.
74
  """)
75
 
76
  with gr.Row():
77
  with gr.Column():
78
  input_image = gr.Image(type="pil", label="Input Image")
79
- edge_smoothness = gr.Slider(1, 5, value=3, step=1, label="Edge Smoothness")
80
- blur_radius = gr.Slider(1, 5, value=2, step=1, label="Blur Radius")
81
- feather_amount = gr.Slider(0, 5, value=1, step=1, label="Feather Amount")
 
82
  submit_btn = gr.Button("Refine Edges", variant="primary")
83
 
84
  with gr.Column():
85
  output_image = gr.Image(type="pil", label="Refined Image")
86
 
87
- # Use local example files
88
- example_files = [f for f in os.listdir("examples") if f.endswith(".png")]
89
- if example_files:
90
- gr.Examples(
91
- examples=[[f"examples/{f}", 3, 2, 1] for f in example_files],
92
- inputs=[input_image, edge_smoothness, blur_radius, feather_amount],
93
- outputs=output_image,
94
- fn=refine_edges,
95
- cache_examples=False
96
- )
97
-
98
  submit_btn.click(
99
  fn=refine_edges,
100
- inputs=[input_image, edge_smoothness, blur_radius, feather_amount],
101
  outputs=output_image
102
  )
103
 
104
  if __name__ == "__main__":
105
- demo.launch(server_name="0.0.0.0", server_port=7860)
 
3
  from PIL import Image
4
  import gradio as gr
5
  import os
 
 
6
 
7
+ def refine_edges(image, edge_smoothness=3, blur_radius=2, feather_amount=1, threshold=0.1):
8
  """
9
+ Enhanced edge refinement that specifically targets leftover background pixels
10
  """
11
  img = image.convert("RGBA")
12
  np_img = np.array(img)
13
+ r, g, b, a = cv2.split(np_img)
14
 
15
+ # Convert parameters
16
  blur_kernel = blur_radius * 2 + 1
17
  smooth_iterations = edge_smoothness
 
18
 
19
+ # 1. Create a strict alpha mask (remove semi-transparent pixels)
20
+ _, strict_alpha = cv2.threshold(a, 254, 255, cv2.THRESH_BINARY)
21
+
22
+ # 2. Find the "edge zone" (area between strict alpha and original alpha)
23
+ edge_zone = cv2.bitwise_xor(a, strict_alpha)
24
+
25
+ # 3. Process only the edge zone
26
  kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
27
  for _ in range(smooth_iterations):
28
+ edge_zone = cv2.morphologyEx(edge_zone, cv2.MORPH_OPEN, kernel)
29
+ edge_zone = cv2.morphologyEx(edge_zone, cv2.MORPH_CLOSE, kernel)
30
 
31
+ # 4. Apply blur only to edge zone
32
+ blurred_edge = cv2.GaussianBlur(edge_zone, (blur_kernel, blur_kernel), 0)
33
 
34
+ # 5. Feathering with thresholding to remove leftover bg
35
  if feather_amount > 0:
36
+ edge_mask = (blurred_edge > threshold * 255).astype(np.uint8) * 255
37
+ edge_mask = cv2.erode(edge_mask, np.ones((feather_amount, feather_amount), np.uint8))
38
+ final_edge = cv2.GaussianBlur(edge_mask, (blur_kernel, blur_kernel), 0)
39
+ else:
40
+ final_edge = blurred_edge
41
 
42
+ # 6. Combine with strict alpha
43
+ new_alpha = cv2.bitwise_or(strict_alpha, final_edge)
44
 
45
+ # 7. Remove color information from transparent areas
46
+ r = r * (new_alpha > 0)
47
+ g = g * (new_alpha > 0)
48
+ b = b * (new_alpha > 0)
 
 
 
 
 
49
 
50
+ # Recombine channels
51
+ result = cv2.merge([r, g, b, new_alpha])
52
 
53
+ return Image.fromarray(result)
 
 
 
 
 
 
 
 
 
 
54
 
55
+ # Gradio interface
56
+ with gr.Blocks(title="✨ Advanced Edge Refiner") as demo:
 
 
 
57
  gr.Markdown("""
58
+ # ✨ Advanced Edge Refiner
59
+ Removes leftover background artifacts around hair and fine edges
60
  """)
61
 
62
  with gr.Row():
63
  with gr.Column():
64
  input_image = gr.Image(type="pil", label="Input Image")
65
+ edge_smoothness = gr.Slider(1, 5, value=3, label="Edge Smoothness")
66
+ blur_radius = gr.Slider(1, 5, value=2, label="Blur Strength")
67
+ feather_amount = gr.Slider(0, 5, value=1, label="Feather Amount")
68
+ threshold = gr.Slider(0, 100, value=10, label="Edge Threshold (%)")
69
  submit_btn = gr.Button("Refine Edges", variant="primary")
70
 
71
  with gr.Column():
72
  output_image = gr.Image(type="pil", label="Refined Image")
73
 
 
 
 
 
 
 
 
 
 
 
 
74
  submit_btn.click(
75
  fn=refine_edges,
76
+ inputs=[input_image, edge_smoothness, blur_radius, feather_amount, threshold],
77
  outputs=output_image
78
  )
79
 
80
  if __name__ == "__main__":
81
+ demo.launch()