phucd commited on
Commit ·
92b5cbe
1
Parent(s): ca8fa7a
Update demo
Browse files
app.py
CHANGED
|
@@ -8,8 +8,6 @@ import numpy as np
|
|
| 8 |
import gradio as gr
|
| 9 |
from seg import U2NETP
|
| 10 |
|
| 11 |
-
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 12 |
-
|
| 13 |
# Image processing utilities
|
| 14 |
def load_image(path: str):
|
| 15 |
""" Loads an image from the specified path and converts it to RGB format. """
|
|
@@ -34,7 +32,7 @@ class U2NETP_DocSeg(nn.Module):
|
|
| 34 |
return mask
|
| 35 |
|
| 36 |
# Initialize the document segmentation model
|
| 37 |
-
docseg = U2NETP_DocSeg(num_classes=1)
|
| 38 |
# Load pretrained weights
|
| 39 |
docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth'
|
| 40 |
checkpoint = torch.load(docseg_weight_path)
|
|
@@ -44,7 +42,7 @@ docseg.eval()
|
|
| 44 |
# Get segmentation mask
|
| 45 |
def get_mask(image, confidence=0.5):
|
| 46 |
org_shape = image.shape[:2]
|
| 47 |
-
image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0)
|
| 48 |
image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear')
|
| 49 |
with torch.inference_mode(): # faster than no_grad
|
| 50 |
mask = docseg(image_tensor)
|
|
@@ -53,8 +51,8 @@ def get_mask(image, confidence=0.5):
|
|
| 53 |
return mask[0, 0] # keep tensor
|
| 54 |
|
| 55 |
def overlay_mask(image, mask):
|
| 56 |
-
image = torch.from_numpy(image).float()
|
| 57 |
-
red = torch.tensor([1.0, 0, 0]
|
| 58 |
mask = mask.unsqueeze(0) # (1, H, W)
|
| 59 |
mask = mask.unsqueeze(0) # (1, 1, H, W)
|
| 60 |
overlay = image.permute(2, 0, 1).unsqueeze(0)
|
|
@@ -75,4 +73,6 @@ with gr.Blocks() as demo:
|
|
| 75 |
input_image = gr.Image(label="Input Image", type="numpy")
|
| 76 |
output_image = gr.Image(label="Segmentation Overlay", type="numpy")
|
| 77 |
|
| 78 |
-
input_image.change(segment_image, inputs=input_image, outputs=output_image)
|
|
|
|
|
|
|
|
|
| 8 |
import gradio as gr
|
| 9 |
from seg import U2NETP
|
| 10 |
|
|
|
|
|
|
|
| 11 |
# Image processing utilities
|
| 12 |
def load_image(path: str):
|
| 13 |
""" Loads an image from the specified path and converts it to RGB format. """
|
|
|
|
| 32 |
return mask
|
| 33 |
|
| 34 |
# Initialize the document segmentation model
|
| 35 |
+
docseg = U2NETP_DocSeg(num_classes=1)
|
| 36 |
# Load pretrained weights
|
| 37 |
docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth'
|
| 38 |
checkpoint = torch.load(docseg_weight_path)
|
|
|
|
| 42 |
# Get segmentation mask
|
| 43 |
def get_mask(image, confidence=0.5):
|
| 44 |
org_shape = image.shape[:2]
|
| 45 |
+
image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0)
|
| 46 |
image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear')
|
| 47 |
with torch.inference_mode(): # faster than no_grad
|
| 48 |
mask = docseg(image_tensor)
|
|
|
|
| 51 |
return mask[0, 0] # keep tensor
|
| 52 |
|
| 53 |
def overlay_mask(image, mask):
|
| 54 |
+
image = torch.from_numpy(image).float()
|
| 55 |
+
red = torch.tensor([1.0, 0, 0]).view(1, 3, 1, 1)
|
| 56 |
mask = mask.unsqueeze(0) # (1, H, W)
|
| 57 |
mask = mask.unsqueeze(0) # (1, 1, H, W)
|
| 58 |
overlay = image.permute(2, 0, 1).unsqueeze(0)
|
|
|
|
| 73 |
input_image = gr.Image(label="Input Image", type="numpy")
|
| 74 |
output_image = gr.Image(label="Segmentation Overlay", type="numpy")
|
| 75 |
|
| 76 |
+
input_image.change(segment_image, inputs=input_image, outputs=output_image)
|
| 77 |
+
|
| 78 |
+
demo.launch()
|