Janeka commited on
Commit
1dca306
·
verified ·
1 Parent(s): 899ac53

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -37
app.py CHANGED
@@ -1,74 +1,103 @@
1
  import cv2
2
  import numpy as np
3
  import torch
 
4
  from PIL import Image
5
  import gradio as gr
6
- from huggingface_hub import hf_hub_download
7
 
8
- # Load MODNet (PyTorch version)
9
- MODNET_REPO = "ZHTX/modnet"
10
- MODNET_FILE = "modnet_photographic_portrait_matting.ckpt"
11
 
12
- try:
13
- model_path = hf_hub_download(repo_id=MODNET_REPO, filename=MODNET_FILE)
14
- modnet = torch.hub.load('ZHTX/modnet', 'modnet', pretrained=False)
15
- modnet.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  modnet.eval()
17
- except Exception as e:
18
- print(f"Error loading MODNet: {e}")
19
  modnet = None
20
 
21
- def refine_with_modnet(input_image, bg_color="#FFFFFF", threshold=0.1):
22
- """Refine alpha matte using MODNet"""
23
  if modnet is None:
24
- raise gr.Error("MODNet model failed to load")
25
 
26
- # Convert input
27
- img = np.array(input_image.convert("RGB"))
28
- img = cv2.resize(img, (512, 512), interpolation=cv2.INTER_AREA)
29
- img = torch.from_numpy(img).permute(2,0,1).unsqueeze(0).float() / 255.0
 
 
 
 
 
30
 
31
  # Inference
32
  with torch.no_grad():
33
- _, _, matte = modnet(img, True)
34
 
35
- # Process output
 
36
  matte = matte.squeeze().cpu().numpy()
37
  matte = (matte * 255).astype(np.uint8)
38
- matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY)[1]
39
 
40
  # Composite with background
41
  bg_color = bg_color.lstrip('#')
42
  bg_rgb = tuple(int(bg_color[i:i+2], 16) for i in (0, 2, 4))
43
- bg = Image.new("RGB", input_image.size, bg_rgb)
44
 
45
- # Apply refined matte
46
- refined = Image.fromarray(matte).resize(input_image.size)
47
- result = Image.composite(input_image, bg, refined)
48
 
49
- return refined, result
50
 
51
  # Gradio Interface
52
- with gr.Blocks(title="🔍 MODNet Edge Refiner") as demo:
53
- gr.Markdown("""
54
- ## 🔍 MODNet Professional Edge Refinement
55
- Uses AI to perfectly refine hair/fur edges from trimmed images
56
- """)
57
-
58
  with gr.Row():
59
  with gr.Column():
60
- input_img = gr.Image(type="pil", label="Trimmed Input")
61
- bg_color = gr.ColorPicker("#FFFFFF", label="Background Color")
62
- threshold = gr.Slider(0, 100, 10, label="Edge Threshold")
63
  process_btn = gr.Button("Refine Edges", variant="primary")
64
-
65
  with gr.Column():
66
  matte_output = gr.Image(label="Refined Alpha Matte", type="pil")
67
  final_output = gr.Image(label="Composited Result", type="pil")
68
 
69
  process_btn.click(
70
- fn=refine_with_modnet,
71
- inputs=[input_img, bg_color, threshold],
72
  outputs=[matte_output, final_output]
73
  )
74
 
 
1
  import cv2
2
  import numpy as np
3
  import torch
4
+ import torch.nn.functional as F
5
  from PIL import Image
6
  import gradio as gr
7
+ from torchvision.transforms import ToTensor, ToPILImage
8
 
9
+ # Load MODNet (local weights)
10
+ MODEL_URL = "https://drive.google.com/uc?export=download&id=1mcr7ALciuAsHCpLnrtG_eop5-EYhbCmz"
11
+ MODEL_PATH = "modnet.pth"
12
 
13
+ def download_model():
14
+ import requests
15
+ import os
16
+ if not os.path.exists(MODEL_PATH):
17
+ print("Downloading MODNet weights...")
18
+ try:
19
+ response = requests.get(MODEL_URL, stream=True)
20
+ with open(MODEL_PATH, 'wb') as f:
21
+ for chunk in response.iter_content(chunk_size=1024):
22
+ if chunk:
23
+ f.write(chunk)
24
+ print("Download complete!")
25
+ except Exception as e:
26
+ print(f"Download failed: {e}")
27
+ return False
28
+ return True
29
+
30
+ class MODNet(torch.nn.Module):
31
+ def __init__(self):
32
+ super().__init__()
33
+ self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
34
+ self.head = torch.nn.Sequential(
35
+ torch.nn.Conv2d(1280, 1, kernel_size=3, padding=1),
36
+ torch.nn.Sigmoid()
37
+ )
38
+
39
+ def forward(self, x):
40
+ features = self.backbone.features(x)
41
+ return self.head(features)
42
+
43
+ # Initialize model
44
+ if download_model():
45
+ modnet = MODNet()
46
+ modnet.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))
47
  modnet.eval()
48
+ else:
 
49
  modnet = None
50
 
51
+ def refine_edges(img, bg_color="#FFFFFF"):
52
+ """Refine edges using local MODNet"""
53
  if modnet is None:
54
+ raise gr.Error("Model failed to load. Please check logs.")
55
 
56
+ # Preprocess
57
+ img = img.convert("RGB")
58
+ img_tensor = ToTensor()(img).unsqueeze(0)
59
+
60
+ # Resize to nearest multiple of 32
61
+ h, w = img_tensor.shape[2], img_tensor.shape[3]
62
+ new_h = h - h % 32
63
+ new_w = w - w % 32
64
+ img_tensor = F.interpolate(img_tensor, (new_h, new_w), mode='area')
65
 
66
  # Inference
67
  with torch.no_grad():
68
+ matte = modnet(img_tensor)
69
 
70
+ # Post-process
71
+ matte = F.interpolate(matte, (h, w), mode='bilinear')
72
  matte = matte.squeeze().cpu().numpy()
73
  matte = (matte * 255).astype(np.uint8)
 
74
 
75
  # Composite with background
76
  bg_color = bg_color.lstrip('#')
77
  bg_rgb = tuple(int(bg_color[i:i+2], 16) for i in (0, 2, 4))
78
+ bg = Image.new("RGB", img.size, bg_rgb)
79
 
80
+ # Create mask
81
+ mask = Image.fromarray(matte).convert("L")
82
+ result = Image.composite(img, bg, mask)
83
 
84
+ return mask, result
85
 
86
  # Gradio Interface
87
+ with gr.Blocks() as demo:
88
+ gr.Markdown("## ✨ Professional Edge Refiner")
 
 
 
 
89
  with gr.Row():
90
  with gr.Column():
91
+ input_img = gr.Image(type="pil", label="Input Image")
92
+ bg_color = gr.ColorPicker("#FFFFFF", label="Preview Background")
 
93
  process_btn = gr.Button("Refine Edges", variant="primary")
 
94
  with gr.Column():
95
  matte_output = gr.Image(label="Refined Alpha Matte", type="pil")
96
  final_output = gr.Image(label="Composited Result", type="pil")
97
 
98
  process_btn.click(
99
+ fn=refine_edges,
100
+ inputs=[input_img, bg_color],
101
  outputs=[matte_output, final_output]
102
  )
103