Janeka commited on
Commit
b67f566
·
verified ·
1 Parent(s): 8ca9a30

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -70
app.py CHANGED
@@ -4,101 +4,123 @@ import torch
4
  import torch.nn.functional as F
5
  from PIL import Image
6
  import gradio as gr
7
- from torchvision.transforms import ToTensor, ToPILImage
8
 
9
- # Load MODNet (local weights)
10
- MODEL_URL = "https://drive.google.com/uc?export=download&id=1mcr7ALciuAsHCpLnrtG_eop5-EYhbCmz"
11
- MODEL_PATH = "modnet.pth"
12
-
13
- def download_model():
14
- import requests
15
- import os
16
- if not os.path.exists(MODEL_PATH):
17
- print("Downloading MODNet weights...")
18
- try:
19
- response = requests.get(MODEL_URL, stream=True)
20
- with open(MODEL_PATH, 'wb') as f:
21
- for chunk in response.iter_content(chunk_size=1024):
22
- if chunk:
23
- f.write(chunk)
24
- print("Download complete!")
25
- except Exception as e:
26
- print(f"Download failed: {e}")
27
- return False
28
- return True
29
-
30
- class MODNet(torch.nn.Module):
31
- def __init__(self):
32
- super().__init__()
33
- self.backbone = torch.hub.load('pytorch/vision:v0.10.0', 'mobilenet_v2', pretrained=True)
34
- self.head = torch.nn.Sequential(
35
- torch.nn.Conv2d(1280, 1, kernel_size=3, padding=1),
36
- torch.nn.Sigmoid()
37
  )
38
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  def forward(self, x):
40
- features = self.backbone.features(x)
41
- return self.head(features)
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- # Initialize model
44
- if download_model():
45
- modnet = MODNet()
46
- modnet.load_state_dict(torch.load(MODEL_PATH, map_location='cpu'))
47
- modnet.eval()
48
- else:
49
- modnet = None
 
50
 
51
- def refine_edges(img, bg_color="#FFFFFF"):
52
- """Refine edges using local MODNet"""
53
- if modnet is None:
54
- raise gr.Error("Model failed to load. Please check logs.")
55
-
56
  # Preprocess
57
- img = img.convert("RGB")
58
- img_tensor = ToTensor()(img).unsqueeze(0)
 
 
 
59
 
60
- # Resize to nearest multiple of 32
61
- h, w = img_tensor.shape[2], img_tensor.shape[3]
62
- new_h = h - h % 32
63
- new_w = w - w % 32
64
- img_tensor = F.interpolate(img_tensor, (new_h, new_w), mode='area')
65
 
66
  # Inference
67
  with torch.no_grad():
68
- matte = modnet(img_tensor)
69
 
70
  # Post-process
71
- matte = F.interpolate(matte, (h, w), mode='bilinear')
72
- matte = matte.squeeze().cpu().numpy()
73
- matte = (matte * 255).astype(np.uint8)
74
-
75
- # Composite with background
76
- bg_color = bg_color.lstrip('#')
77
- bg_rgb = tuple(int(bg_color[i:i+2], 16) for i in (0, 2, 4))
78
- bg = Image.new("RGB", img.size, bg_rgb)
79
 
80
- # Create mask
81
- mask = Image.fromarray(matte).convert("L")
82
- result = Image.composite(img, bg, mask)
83
 
84
- return mask, result
85
 
86
  # Gradio Interface
87
  with gr.Blocks() as demo:
88
- gr.Markdown("## Professional Edge Refiner")
89
  with gr.Row():
90
  with gr.Column():
91
  input_img = gr.Image(type="pil", label="Input Image")
92
- bg_color = gr.ColorPicker("#FFFFFF", label="Preview Background")
93
  process_btn = gr.Button("Refine Edges", variant="primary")
94
  with gr.Column():
95
- matte_output = gr.Image(label="Refined Alpha Matte", type="pil")
96
- final_output = gr.Image(label="Composited Result", type="pil")
97
 
98
  process_btn.click(
99
  fn=refine_edges,
100
- inputs=[input_img, bg_color],
101
- outputs=[matte_output, final_output]
102
  )
103
 
104
  if __name__ == "__main__":
 
4
  import torch.nn.functional as F
5
  from PIL import Image
6
  import gradio as gr
7
+ import os
8
 
9
+ # U^2-Net model definition
10
+ class U2NET(torch.nn.Module):
11
+ def __init__(self, out_ch=1):
12
+ super(U2NET, self).__init__()
13
+ # Simplified U^2-Net architecture
14
+ self.stage1 = torch.nn.Sequential(
15
+ torch.nn.Conv2d(3, 64, 3, padding=1),
16
+ torch.nn.ReLU(),
17
+ torch.nn.Conv2d(64, 64, 3, padding=1),
18
+ torch.nn.ReLU()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  )
20
+ self.stage2 = torch.nn.Sequential(
21
+ torch.nn.MaxPool2d(2, 2),
22
+ torch.nn.Conv2d(64, 128, 3, padding=1),
23
+ torch.nn.ReLU(),
24
+ torch.nn.Conv2d(128, 128, 3, padding=1),
25
+ torch.nn.ReLU()
26
+ )
27
+ self.stage3 = torch.nn.Sequential(
28
+ torch.nn.MaxPool2d(2, 2),
29
+ torch.nn.Conv2d(128, 256, 3, padding=1),
30
+ torch.nn.ReLU(),
31
+ torch.nn.Conv2d(256, 256, 3, padding=1),
32
+ torch.nn.ReLU()
33
+ )
34
+ self.stage4 = torch.nn.Sequential(
35
+ torch.nn.MaxPool2d(2, 2),
36
+ torch.nn.Conv2d(256, 512, 3, padding=1),
37
+ torch.nn.ReLU(),
38
+ torch.nn.Conv2d(512, 512, 3, padding=1),
39
+ torch.nn.ReLU()
40
+ )
41
+ self.stage5 = torch.nn.Sequential(
42
+ torch.nn.MaxPool2d(2, 2),
43
+ torch.nn.Conv2d(512, 512, 3, padding=1),
44
+ torch.nn.ReLU(),
45
+ torch.nn.Conv2d(512, 512, 3, padding=1),
46
+ torch.nn.ReLU()
47
+ )
48
+ self.up5 = torch.nn.ConvTranspose2d(512, 512, 2, stride=2)
49
+ self.up4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2)
50
+ self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2)
51
+ self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2)
52
+ self.conv_final = torch.nn.Conv2d(64, out_ch, 1)
53
+
54
  def forward(self, x):
55
+ # Encoder
56
+ x1 = self.stage1(x)
57
+ x2 = self.stage2(x1)
58
+ x3 = self.stage3(x2)
59
+ x4 = self.stage4(x3)
60
+ x5 = self.stage5(x4)
61
+
62
+ # Decoder with skip connections
63
+ u5 = self.up5(x5)
64
+ u4 = self.up4(u5 + x4)
65
+ u3 = self.up3(u4 + x3)
66
+ u2 = self.up2(u3 + x2)
67
+
68
+ return torch.sigmoid(self.conv_final(u2 + x1))
69
 
70
+ def load_model():
71
+ model = U2NET()
72
+ # Load pre-trained weights (dummy initialization for demo)
73
+ # In production, you would load actual trained weights here
74
+ for m in model.modules():
75
+ if isinstance(m, torch.nn.Conv2d):
76
+ torch.nn.init.kaiming_normal_(m.weight)
77
+ return model.eval()
78
 
79
+ model = load_model()
80
+
81
+ def refine_edges(image, threshold=0.5):
82
+ """Refine edges using U^2-Net"""
 
83
  # Preprocess
84
+ img = np.array(image)
85
+ if len(img.shape) == 2:
86
+ img = np.stack([img]*3, axis=-1)
87
+ elif img.shape[2] == 4:
88
+ img = img[..., :3]
89
 
90
+ img = cv2.resize(img, (320, 320))
91
+ tensor = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.0
 
 
 
92
 
93
  # Inference
94
  with torch.no_grad():
95
+ matte = model(tensor)
96
 
97
  # Post-process
98
+ matte = F.interpolate(matte, image.size[::-1], mode='bilinear')
99
+ matte = (matte.squeeze().numpy() * 255).astype(np.uint8)
100
+ _, matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY)
 
 
 
 
 
101
 
102
+ # Create transparent result
103
+ rgba = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2RGBA)
104
+ rgba[..., 3] = matte
105
 
106
+ return Image.fromarray(rgba), Image.fromarray(matte)
107
 
108
  # Gradio Interface
109
  with gr.Blocks() as demo:
110
+ gr.Markdown("## ✂️ Professional Edge Refiner (U^2-Net)")
111
  with gr.Row():
112
  with gr.Column():
113
  input_img = gr.Image(type="pil", label="Input Image")
114
+ threshold = gr.Slider(0, 100, 50, label="Edge Threshold")
115
  process_btn = gr.Button("Refine Edges", variant="primary")
116
  with gr.Column():
117
+ output_img = gr.Image(type="pil", label="Refined Image")
118
+ matte_img = gr.Image(type="pil", label="Alpha Matte")
119
 
120
  process_btn.click(
121
  fn=refine_edges,
122
+ inputs=[input_img, threshold],
123
+ outputs=[output_img, matte_img]
124
  )
125
 
126
  if __name__ == "__main__":