phucd commited on
Commit
92b5cbe
·
1 Parent(s): ca8fa7a

Update demo

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -8,8 +8,6 @@ import numpy as np
8
  import gradio as gr
9
  from seg import U2NETP
10
 
11
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
12
-
13
  # Image processing utilities
14
  def load_image(path: str):
15
  """ Loads an image from the specified path and converts it to RGB format. """
@@ -34,7 +32,7 @@ class U2NETP_DocSeg(nn.Module):
34
  return mask
35
 
36
  # Initialize the document segmentation model
37
- docseg = U2NETP_DocSeg(num_classes=1).to(DEVICE)
38
  # Load pretrained weights
39
  docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth'
40
  checkpoint = torch.load(docseg_weight_path)
@@ -44,7 +42,7 @@ docseg.eval()
44
  # Get segmentation mask
45
  def get_mask(image, confidence=0.5):
46
  org_shape = image.shape[:2]
47
- image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0).to(DEVICE)
48
  image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear')
49
  with torch.inference_mode(): # faster than no_grad
50
  mask = docseg(image_tensor)
@@ -53,8 +51,8 @@ def get_mask(image, confidence=0.5):
53
  return mask[0, 0] # keep tensor
54
 
55
  def overlay_mask(image, mask):
56
- image = torch.from_numpy(image).float().to(DEVICE)
57
- red = torch.tensor([1.0, 0, 0], device=DEVICE).view(1, 3, 1, 1)
58
  mask = mask.unsqueeze(0) # (1, H, W)
59
  mask = mask.unsqueeze(0) # (1, 1, H, W)
60
  overlay = image.permute(2, 0, 1).unsqueeze(0)
@@ -75,4 +73,6 @@ with gr.Blocks() as demo:
75
  input_image = gr.Image(label="Input Image", type="numpy")
76
  output_image = gr.Image(label="Segmentation Overlay", type="numpy")
77
 
78
- input_image.change(segment_image, inputs=input_image, outputs=output_image)
 
 
 
8
  import gradio as gr
9
  from seg import U2NETP
10
 
 
 
11
  # Image processing utilities
12
  def load_image(path: str):
13
  """ Loads an image from the specified path and converts it to RGB format. """
 
32
  return mask
33
 
34
  # Initialize the document segmentation model
35
+ docseg = U2NETP_DocSeg(num_classes=1)
36
  # Load pretrained weights
37
  docseg_weight_path = './weights/u2netp_docseg_epoch_225_date_2026-01-02.pth'
38
  checkpoint = torch.load(docseg_weight_path)
 
42
  # Get segmentation mask
43
  def get_mask(image, confidence=0.5):
44
  org_shape = image.shape[:2]
45
+ image_tensor = torch.from_numpy(image).float().permute(2, 0, 1).unsqueeze(0)
46
  image_tensor = F.interpolate(image_tensor, size=(288, 288), mode='bilinear')
47
  with torch.inference_mode(): # faster than no_grad
48
  mask = docseg(image_tensor)
 
51
  return mask[0, 0] # keep tensor
52
 
53
  def overlay_mask(image, mask):
54
+ image = torch.from_numpy(image).float()
55
+ red = torch.tensor([1.0, 0, 0]).view(1, 3, 1, 1)
56
  mask = mask.unsqueeze(0) # (1, H, W)
57
  mask = mask.unsqueeze(0) # (1, 1, H, W)
58
  overlay = image.permute(2, 0, 1).unsqueeze(0)
 
73
  input_image = gr.Image(label="Input Image", type="numpy")
74
  output_image = gr.Image(label="Segmentation Overlay", type="numpy")
75
 
76
+ input_image.change(segment_image, inputs=input_image, outputs=output_image)
77
+
78
+ demo.launch()