bluspater's picture
Update model/modnet.py
b6584d4 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import cv2
from PIL import Image
from torchvision import transforms
# Backbone: U2NET-like architecture (simplified for inference only)
class BasicConvBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(BasicConvBlock, self).__init__()
self.block = nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
nn.Conv2d(out_channels, out_channels, 3, 1, 1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
class SimpleMODNetBackbone(nn.Module):
def __init__(self):
super(SimpleMODNetBackbone, self).__init__()
self.stage1 = BasicConvBlock(3, 64)
self.pool1 = nn.MaxPool2d(2, 2)
self.stage2 = BasicConvBlock(64, 128)
self.pool2 = nn.MaxPool2d(2, 2)
self.stage3 = BasicConvBlock(128, 256)
def forward(self, x):
x = self.stage1(x)
x = self.pool1(x)
x = self.stage2(x)
x = self.pool2(x)
x = self.stage3(x)
return x
class MODNet(nn.Module):
def __init__(self):
super(MODNet, self).__init__()
self.backbone = SimpleMODNetBackbone()
self.seg_head = nn.Sequential(
nn.Conv2d(256, 64, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(64, 1, kernel_size=1),
nn.Sigmoid()
)
def forward(self, x):
features = self.backbone(x)
pred_matte = self.seg_head(features)
return pred_matte
def preprocess_image(image: Image.Image, device: torch.device) -> torch.Tensor:
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
img_tensor = transform(image.convert("RGB")).unsqueeze(0).to(device)
return img_tensor
def remove_background_modnet(image: Image.Image) -> Image.Image:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
modnet = MODNet()
modnet.to(device)
# Skip loading weights (simple version)
modnet.eval()
img_tensor = preprocess_image(image, device)
with torch.no_grad():
pred_matte = modnet(img_tensor)
matte = pred_matte[0][0].cpu().numpy()
matte = cv2.resize(matte, image.size)
matte = np.uint8(matte * 255)
image = image.convert("RGBA")
image_np = np.array(image)
image_np[:, :, 3] = matte
return Image.fromarray(image_np)