Spaces:
Runtime error
Runtime error
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import cv2 | |
| from PIL import Image | |
| from torchvision import transforms | |
| # Backbone: U2NET-like architecture (simplified for inference only) | |
| class BasicConvBlock(nn.Module): | |
| def __init__(self, in_channels, out_channels): | |
| super(BasicConvBlock, self).__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(in_channels, out_channels, 3, 1, 1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_channels, out_channels, 3, 1, 1), | |
| nn.BatchNorm2d(out_channels), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| class SimpleMODNetBackbone(nn.Module): | |
| def __init__(self): | |
| super(SimpleMODNetBackbone, self).__init__() | |
| self.stage1 = BasicConvBlock(3, 64) | |
| self.pool1 = nn.MaxPool2d(2, 2) | |
| self.stage2 = BasicConvBlock(64, 128) | |
| self.pool2 = nn.MaxPool2d(2, 2) | |
| self.stage3 = BasicConvBlock(128, 256) | |
| def forward(self, x): | |
| x = self.stage1(x) | |
| x = self.pool1(x) | |
| x = self.stage2(x) | |
| x = self.pool2(x) | |
| x = self.stage3(x) | |
| return x | |
| class MODNet(nn.Module): | |
| def __init__(self): | |
| super(MODNet, self).__init__() | |
| self.backbone = SimpleMODNetBackbone() | |
| self.seg_head = nn.Sequential( | |
| nn.Conv2d(256, 64, kernel_size=3, padding=1), | |
| nn.ReLU(), | |
| nn.Conv2d(64, 1, kernel_size=1), | |
| nn.Sigmoid() | |
| ) | |
| def forward(self, x): | |
| features = self.backbone(x) | |
| pred_matte = self.seg_head(features) | |
| return pred_matte | |
| def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor: | |
| transform = transforms.Compose([ | |
| transforms.Resize((512, 512)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| img_tensor = transform(image.convert("RGB")).unsqueeze(0).to(device) | |
| return img_tensor | |
| def remove_background_modnet(image: Image.Image) -> Image.Image: | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| modnet = MODNet() | |
| modnet.to(device) | |
| # Skip loading weights (simple version) | |
| modnet.eval() | |
| img_tensor = preprocess_image(image, device) | |
| with torch.no_grad(): | |
| pred_matte = modnet(img_tensor) | |
| matte = pred_matte[0][0].cpu().numpy() | |
| matte = cv2.resize(matte, image.size) | |
| matte = np.uint8(matte * 255) | |
| image = image.convert("RGBA") | |
| image_np = np.array(image) | |
| image_np[:, :, 3] = matte | |
| return Image.fromarray(image_np) | |