import cv2 import numpy as np import torch import torch.nn.functional as F from PIL import Image import gradio as gr import os # U^2-Net model definition class U2NET(torch.nn.Module): def __init__(self, out_ch=1): super(U2NET, self).__init__() # Simplified U^2-Net architecture self.stage1 = torch.nn.Sequential( torch.nn.Conv2d(3, 64, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(64, 64, 3, padding=1), torch.nn.ReLU() ) self.stage2 = torch.nn.Sequential( torch.nn.MaxPool2d(2, 2), torch.nn.Conv2d(64, 128, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(128, 128, 3, padding=1), torch.nn.ReLU() ) self.stage3 = torch.nn.Sequential( torch.nn.MaxPool2d(2, 2), torch.nn.Conv2d(128, 256, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(256, 256, 3, padding=1), torch.nn.ReLU() ) self.stage4 = torch.nn.Sequential( torch.nn.MaxPool2d(2, 2), torch.nn.Conv2d(256, 512, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(512, 512, 3, padding=1), torch.nn.ReLU() ) self.stage5 = torch.nn.Sequential( torch.nn.MaxPool2d(2, 2), torch.nn.Conv2d(512, 512, 3, padding=1), torch.nn.ReLU(), torch.nn.Conv2d(512, 512, 3, padding=1), torch.nn.ReLU() ) self.up5 = torch.nn.ConvTranspose2d(512, 512, 2, stride=2) self.up4 = torch.nn.ConvTranspose2d(512, 256, 2, stride=2) self.up3 = torch.nn.ConvTranspose2d(256, 128, 2, stride=2) self.up2 = torch.nn.ConvTranspose2d(128, 64, 2, stride=2) self.conv_final = torch.nn.Conv2d(64, out_ch, 1) def forward(self, x): # Encoder x1 = self.stage1(x) x2 = self.stage2(x1) x3 = self.stage3(x2) x4 = self.stage4(x3) x5 = self.stage5(x4) # Decoder with skip connections u5 = self.up5(x5) u4 = self.up4(u5 + x4) u3 = self.up3(u4 + x3) u2 = self.up2(u3 + x2) return torch.sigmoid(self.conv_final(u2 + x1)) def load_model(): model = U2NET() # Load pre-trained weights (dummy initialization for demo) # In production, you would load actual trained weights here for m in model.modules(): if isinstance(m, torch.nn.Conv2d): torch.nn.init.kaiming_normal_(m.weight) return model.eval() model = load_model() def refine_edges(image, threshold=0.5): """Refine edges using U^2-Net""" # Preprocess img = np.array(image) if len(img.shape) == 2: img = np.stack([img]*3, axis=-1) elif img.shape[2] == 4: img = img[..., :3] img = cv2.resize(img, (320, 320)) tensor = torch.from_numpy(img).permute(2,0,1).float().unsqueeze(0) / 255.0 # Inference with torch.no_grad(): matte = model(tensor) # Post-process matte = F.interpolate(matte, image.size[::-1], mode='bilinear') matte = (matte.squeeze().numpy() * 255).astype(np.uint8) _, matte = cv2.threshold(matte, int(threshold*255), 255, cv2.THRESH_BINARY) # Create transparent result rgba = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2RGBA) rgba[..., 3] = matte return Image.fromarray(rgba), Image.fromarray(matte) # Gradio Interface with gr.Blocks() as demo: gr.Markdown("## ✂️ Professional Edge Refiner (U^2-Net)") with gr.Row(): with gr.Column(): input_img = gr.Image(type="pil", label="Input Image") threshold = gr.Slider(0, 100, 50, label="Edge Threshold") process_btn = gr.Button("Refine Edges", variant="primary") with gr.Column(): output_img = gr.Image(type="pil", label="Refined Image") matte_img = gr.Image(type="pil", label="Alpha Matte") process_btn.click( fn=refine_edges, inputs=[input_img, threshold], outputs=[output_img, matte_img] ) if __name__ == "__main__": demo.launch()