|
|
--- |
|
|
--- |
|
|
license: apache-2.0 |
|
|
tags: |
|
|
- image-segmentation |
|
|
- image-matting |
|
|
- background-removal |
|
|
- computer-vision |
|
|
- custom-architecture |
|
|
library_name: transformers |
|
|
pipeline_tag: image-segmentation |
|
|
--- |
|
|
|
|
|
# MODNet – Fast Image Matting for Background Removal |
|
|
|
|
|
This repository provides a Hugging Face–compatible version of **MODNet (Mobile Object Detection Network)** for fast and high-quality foreground matting and background removal. |
|
|
|
|
|
The model is designed to produce **pixel-perfect alpha mattes**, handling fine details such as hair, fabric edges, and soft shadows, while remaining efficient enough for real-time CPU inference. |
|
|
|
|
|
--- |
|
|
|
|
|
## 🔹 Model Details |
|
|
|
|
|
- **Architecture**: MODNet (custom architecture) |
|
|
- **Task**: Image matting / background removal |
|
|
- **Framework**: PyTorch |
|
|
- **Backbone**: MobileNetV2 |
|
|
- **Input**: RGB image tensor `(B, 3, H, W)` |
|
|
- **Output**: `(semantic, detail, matte)` predictions |
|
|
|
|
|
## How to Load the Model |
|
|
|
|
|
```python |
|
|
from transformers import AutoModel |
|
|
|
|
|
model = AutoModel.from_pretrained( |
|
|
"boopathiraj/MODNet", |
|
|
trust_remote_code=True |
|
|
) |
|
|
|
|
|
model.eval() |
|
|
``` |
|
|
|
|
|
## Example Inference |
|
|
|
|
|
```python |
|
|
|
|
|
import torch |
|
|
import cv2 |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from torchvision import transforms |
|
|
|
|
|
# Preprocess |
|
|
def preprocess(image_path, ref_size=512): |
|
|
image = Image.open(image_path).convert("RGB") |
|
|
w, h = image.size |
|
|
|
|
|
# Resize while maintaining aspect ratio |
|
|
scale = ref_size / max(h, w) |
|
|
new_w, new_h = int(w * scale), int(h * scale) |
|
|
image_resized = image.resize((new_w, new_h), Image.BILINEAR) |
|
|
|
|
|
# Create a new blank image of ref_size x ref_size and paste the resized image |
|
|
# This will pad the image to ref_size x ref_size |
|
|
padded_image = Image.new("RGB", (ref_size, ref_size), (0, 0, 0)) # Black padding |
|
|
padded_image.paste(image_resized, ((ref_size - new_w) // 2, (ref_size - new_h) // 2)) |
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.5, 0.5, 0.5], |
|
|
std=[0.5, 0.5, 0.5]) |
|
|
]) |
|
|
|
|
|
tensor = transform(padded_image).unsqueeze(0) |
|
|
return tensor, (w, h) |
|
|
|
|
|
# Run inference |
|
|
inp, original_size = preprocess("/content/portrait.jpg") |
|
|
|
|
|
import torch |
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
|
|
model.to(device) |
|
|
inp = inp.to(device) # Assign the tensor to the device |
|
|
|
|
|
with torch.no_grad(): |
|
|
semantic, detail, matte = model(inp, True) |
|
|
|
|
|
matte = matte[0, 0].cpu().numpy() |
|
|
matte = cv2.resize(matte, original_size) |
|
|
|
|
|
# Save alpha matte |
|
|
matte = (matte * 255).astype(np.uint8) |
|
|
Image.fromarray(matte).save("alpha_matte.png") |
|
|
|
|
|
``` |
|
|
## Visualization |
|
|
|
|
|
```python |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
|
|
|
def combined_display(image, matte): |
|
|
# calculate display resolution |
|
|
w, h = image.width, image.height |
|
|
rw, rh = 800, int(h * 800 / (3 * w)) |
|
|
|
|
|
# obtain predicted foreground |
|
|
image = np.asarray(image) |
|
|
if len(image.shape) == 2: |
|
|
image = image[:, :, None] |
|
|
if image.shape[2] == 1: |
|
|
image = np.repeat(image, 3, axis=2) |
|
|
elif image.shape[2] == 4: |
|
|
image = image[:, :, 0:3] |
|
|
matte = np.repeat(np.asarray(matte)[:, :, None], 3, axis=2) / 255 |
|
|
foreground = image * matte + np.full(image.shape, 255) * (1 - matte) |
|
|
|
|
|
# combine image, foreground, and alpha into one line |
|
|
combined = np.concatenate((image, foreground, matte * 255), axis=1) |
|
|
combined = Image.fromarray(np.uint8(combined)).resize((rw, rh)) |
|
|
return combined |
|
|
|
|
|
# visualize all images |
|
|
image_names = os.listdir(input_folder) |
|
|
for image_name in image_names: |
|
|
matte_name = image_name.split('.')[0] + '.png' |
|
|
image = Image.open(os.path.join(input_folder, image_name)) |
|
|
matte = Image.open(os.path.join(output_folder, matte_name)) |
|
|
display(combined_display(image, matte)) |
|
|
print(image_name, '\n') |
|
|
|
|
|
``` |
|
|
|
|
|
|
|
|
--- |
|
|
|