DocSeg / app.py
phucd
Minor changes
09330e4
import cv2
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import gradio as gr
from seg import U2NETP
# Image processing utilities
def load_image(path: str):
""" Loads an image from the specified path and converts it to RGB format. """
img = cv2.imread(path)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
return img / 255.0
def save_image(image: np.ndarray, path: str):
""" Saves an image to the specified path. """
img = (image * 255).astype(np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
cv2.imwrite(path, img)
# Document Segmentation Model
class U2NETP_DocSeg(nn.Module):
def __init__(self, num_classes=1):
super(U2NETP_DocSeg, self).__init__()
self.u2netp = U2NETP(out_ch=num_classes)
def forward(self, x):
mask, *_ = self.u2netp(x)
return mask
# Initialize the document segmentation model
docseg = U2NETP_DocSeg(num_classes=1)
# Load pretrained weights
docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth'
checkpoint = torch.load(docseg_weight_path, map_location=torch.device('cpu'))
docseg.load_state_dict(checkpoint[f"model_state_dict"])
docseg.eval()
# Get segmentation mask
def get_mask(image, confidence=0.5):
org_shape = image.shape[:2]
image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0)
image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear')
with torch.inference_mode(): # faster than no_grad
mask = docseg(image_tensor)
mask = (mask > confidence).float()
mask = F.interpolate(mask, size=org_shape, mode='bilinear')
return mask[0, 0] # keep tensor
def overlay_mask(image, mask):
image = torch.from_numpy(image).float()
red = torch.tensor([1.0, 0, 0]).view(1, 3, 1, 1)
mask = mask.unsqueeze(0) # (1, H, W)
mask = mask.unsqueeze(0) # (1, 1, H, W)
overlay = image.permute(2, 0, 1).unsqueeze(0)
overlay = torch.where(mask > 0, red, overlay)
blended = 0.7 * image.permute(2, 0, 1).unsqueeze(0) + 0.3 * overlay
return blended[0].permute(1, 2, 0).cpu().numpy()
def segment_image(image):
""" Gradio function to segment input image and return overlay. """
image = image.astype(np.float32) / 255.0 # Normalize to [0, 1]
mask = get_mask(image, confidence=0.5)
overlayed_image = overlay_mask(image, mask)
yield overlayed_image
with gr.Blocks() as demo:
gr.Markdown("## Real-time Document Segmentation")
with gr.Row():
input_image = gr.Image(label="Input Image", type="numpy")
output_image = gr.Image(label="Segmentation Overlay", type="numpy")
examples = gr.Examples(
examples=[
"./examples/sample.jpg",
"./examples/manga.png",
"./examples/invoice.png"
],
inputs=input_image
)
input_image.change(segment_image, inputs=input_image, outputs=output_image)
demo.launch()