Janeka commited on
Commit
2dcab43
·
verified ·
1 Parent(s): 3cd9d3e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +68 -33
app.py CHANGED
@@ -1,41 +1,76 @@
1
- import gradio as gr
2
- import numpy as np
3
  import cv2
 
 
4
  from PIL import Image
 
 
5
 
6
- def smooth_edges(image: Image.Image):
7
- # Convert to numpy array
8
- img_np = np.array(image.convert("RGB"))
9
-
10
- # Create a grayscale version to detect edges
11
- gray = cv2.cvtColor(img_np, cv2.COLOR_RGB2GRAY)
12
-
13
- # Detect edges using Canny
14
- edges = cv2.Canny(gray, threshold1=50, threshold2=150)
15
-
16
- # Dilate edges a bit to make the mask thicker
17
- kernel = np.ones((3, 3), np.uint8)
18
- dilated = cv2.dilate(edges, kernel, iterations=1)
19
-
20
- # Create a mask with blur
21
- mask = cv2.GaussianBlur(dilated, (11, 11), 0)
22
-
23
- # Convert mask to 3-channel
24
- mask_3c = cv2.cvtColor(mask, cv2.COLOR_GRAY2RGB) / 255.0
25
 
26
- # Apply soft blending of edges
27
- softened = (img_np * (1 - mask_3c) + cv2.GaussianBlur(img_np, (15, 15), 0) * mask_3c).astype(np.uint8)
 
 
 
 
 
 
28
 
29
- return Image.fromarray(softened)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
- # Interface
32
- iface = gr.Interface(
33
- fn=smooth_edges,
34
- inputs=gr.Image(type="pil", label="Upload Image (even non-transparent!)"),
35
- outputs=gr.Image(type="pil", label="Softened Edges"),
36
- title="Edge Smoother v2",
37
- description="Smooth the edges of your subject — works for normal and background-removed images."
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  if __name__ == "__main__":
41
- iface.launch()
 
 
 
1
  import cv2
2
+ import numpy as np
3
+ import torch
4
  from PIL import Image
5
+ import gradio as gr
6
+ from huggingface_hub import hf_hub_download
7
 
8
+ # Load MODNet (PyTorch version)
9
+ MODNET_REPO = "ZHTX/modnet"
10
+ MODNET_FILE = "modnet_photographic_portrait_matting.ckpt"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
+ try:
13
+ model_path = hf_hub_download(repo_id=MODNET_REPO, filename=MODNET_FILE)
14
+ modnet = torch.hub.load('ZHTX/modnet', 'modnet', pretrained=False)
15
+ modnet.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
16
+ modnet.eval()
17
+ except Exception as e:
18
+ print(f"Error loading MODNet: {e}")
19
+ modnet = None
20
 
21
+ def refine_with_modnet(input_image, bg_color="#FFFFFF", threshold=0.1):
22
+ """Refine alpha matte using MODNet"""
23
+ if modnet is None:
24
+ raise gr.Error("MODNet model failed to load")
25
+
26
+ # Convert input
27
+ img = np.array(input_image.convert("RGB"))
28
+ img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_AREA)
29
+ img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() / 255.0
30
+
31
+ # Inference
32
+ with torch.no_grad():
33
+ _, _, matte = modnet(img, True)
34
+
35
+ # Process output
36
+ matte = matte.squeeze().cpu().numpy()
37
+ matte = (matte * 255).astype(np.uint8)
38
+ matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY)[1]
39
+
40
+ # Composite with background
41
+ bg_color = bg_color.lstrip('#')
42
+ bg_rgb = tuple(int(bg_color[i:i+2], 16) for i in (0, 2, 4))
43
+ bg = Image.new("RGB", input_image.size, bg_rgb)
44
+
45
+ # Apply refined matte
46
+ refined = Image.fromarray(matte).resize(input_image.size)
47
+ result = Image.composite(input_image, bg, refined)
48
+
49
+ return refined, result
50
 
51
+ # Gradio Interface
52
+ with gr.Blocks(title="🔍 MODNet Edge Refiner") as demo:
53
+ gr.Markdown("""
54
+ ## 🔍 MODNet Professional Edge Refinement
55
+ Uses AI to perfectly refine hair/fur edges from trimmed images
56
+ """)
57
+
58
+ with gr.Row():
59
+ with gr.Column():
60
+ input_img = gr.Image(type="pil", label="Trimmed Input")
61
+ bg_color = gr.ColorPicker("#FFFFFF", label="Background Color")
62
+ threshold = gr.Slider(0, 100, 10, label="Edge Threshold")
63
+ process_btn = gr.Button("Refine Edges", variant="primary")
64
+
65
+ with gr.Column():
66
+ matte_output = gr.Image(label="Refined Alpha Matte", type="pil")
67
+ final_output = gr.Image(label="Composited Result", type="pil")
68
+
69
+ process_btn.click(
70
+ fn=refine_with_modnet,
71
+ inputs=[input_img, bg_color, threshold],
72
+ outputs=[matte_output, final_output]
73
+ )
74
 
75
  if __name__ == "__main__":
76
+ demo.launch()