| import cv2 |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import numpy as np |
| import gradio as gr |
| from seg import U2NETP |
|
|
| |
| def load_image(path: str): |
| """ Loads an image from the specified path and converts it to RGB format. """ |
| img = cv2.imread(path) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| return img / 255.0 |
|
|
| def save_image(image: np.ndarray, path: str): |
| """ Saves an image to the specified path. """ |
| img = (image * 255).astype(np.uint8) |
| img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
| cv2.imwrite(path, img) |
|
|
| |
| class U2NETP_DocSeg(nn.Module): |
| def __init__(self, num_classes=1): |
| super(U2NETP_DocSeg, self).__init__() |
| self.u2netp = U2NETP(out_ch=num_classes) |
|
|
| def forward(self, x): |
| mask, *_ = self.u2netp(x) |
| return mask |
|
|
| |
| docseg = U2NETP_DocSeg(num_classes=1) |
| |
| docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth' |
| checkpoint = torch.load(docseg_weight_path, map_location=torch.device('cpu')) |
| docseg.load_state_dict(checkpoint[f"model_state_dict"]) |
| docseg.eval() |
|
|
| |
| def get_mask(image, confidence=0.5): |
| org_shape = image.shape[:2] |
| image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0) |
| image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear') |
| with torch.inference_mode(): |
| mask = docseg(image_tensor) |
| mask = (mask > confidence).float() |
| mask = F.interpolate(mask, size=org_shape, mode='bilinear') |
| return mask[0, 0] |
|
|
| def overlay_mask(image, mask): |
| image = torch.from_numpy(image).float() |
| red = torch.tensor([1.0, 0, 0]).view(1, 3, 1, 1) |
| mask = mask.unsqueeze(0) |
| mask = mask.unsqueeze(0) |
| overlay = image.permute(2, 0, 1).unsqueeze(0) |
| overlay = torch.where(mask > 0, red, overlay) |
| blended = 0.7 * image.permute(2, 0, 1).unsqueeze(0) + 0.3 * overlay |
| return blended[0].permute(1, 2, 0).cpu().numpy() |
|
|
| def segment_image(image): |
| """ Gradio function to segment input image and return overlay. """ |
| image = image.astype(np.float32) / 255.0 |
| mask = get_mask(image, confidence=0.5) |
| overlayed_image = overlay_mask(image, mask) |
| yield overlayed_image |
|
|
| with gr.Blocks() as demo: |
| gr.Markdown("## Real-time Document Segmentation") |
| with gr.Row(): |
| input_image = gr.Image(label="Input Image", type="numpy") |
| output_image = gr.Image(label="Segmentation Overlay", type="numpy") |
| examples = gr.Examples( |
| examples=[ |
| "./examples/sample.jpg", |
| "./examples/manga.png", |
| "./examples/invoice.png" |
| ], |
| inputs=input_image |
| ) |
| input_image.change(segment_image, inputs=input_image, outputs=output_image) |
|
|
| demo.launch() |