AkashKumarave commited on
Commit
02e3e02
·
verified ·
1 Parent(s): ebeb9ee

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -49
app.py CHANGED
@@ -1,34 +1,34 @@
1
  import os
2
  import cv2
3
  import gradio as gr
4
- from PIL import Image
5
  import numpy as np
6
  import torch
7
- from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
 
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
  # Initialize device
14
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
15
 
16
- # Clone repository if not exists
17
  if not os.path.exists("DIS"):
18
- os.system("git clone https://github.com/xuebinqin/DIS")
19
  os.system("mv DIS/IS-Net/* .")
20
 
21
  # Import model components
22
- from data_loader_cache import normalize, im_reader, im_preprocess
23
  from models import ISNetDIS
24
 
25
  # Setup model directory
26
  os.makedirs("saved_models", exist_ok=True)
27
  if os.path.exists("isnet.pth"):
28
- os.system("mv isnet.pth saved_models/")
29
 
30
  class GOSNormalize:
31
- def __init__(self, mean=[0.5,0.5,0.5], std=[1.0,1.0,1.0]):
32
  self.mean = mean
33
  self.std = std
34
 
@@ -54,32 +54,31 @@ def build_model(hypar, device):
54
 
55
  net.to(device)
56
 
57
- if hypar["restore_model"]:
58
- net.load_state_dict(torch.load(
59
- os.path.join(hypar["model_path"], hypar["restore_model"]),
60
- map_location=device
61
- ))
62
  net.eval()
63
  return net
64
 
65
  def predict(net, inputs_val, shapes_val, hypar, device):
66
- net.eval()
67
- inputs_val = inputs_val.type(torch.FloatTensor if hypar["model_digit"] == "full" else torch.HalfTensor)
68
- inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
69
- ds_val = net(inputs_val_v)[0]
70
- pred_val = ds_val[0][0,:,:,:]
71
-
72
- pred_val = torch.squeeze(F.interpolate(
73
- torch.unsqueeze(pred_val, 0),
74
- size=(shapes_val[0][0], shapes_val[0][1]),
75
- mode='bilinear'
76
- ))
77
-
78
- pred_val = (pred_val - pred_val.min()) / (pred_val.max() - pred_val.min())
79
-
80
- if device == 'cuda':
81
- torch.cuda.empty_cache()
82
- return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
83
 
84
  # Model configuration
85
  hypar = {
@@ -99,11 +98,12 @@ net = build_model(hypar, device)
99
 
100
  def process_image(image):
101
  try:
102
- if isinstance(image, str):
103
- image_path = image
104
- else:
105
- image_path = image.name
106
-
 
107
  image_tensor, orig_size = load_image(image_path, hypar)
108
  mask = predict(net, image_tensor, orig_size, hypar, device)
109
 
@@ -118,27 +118,43 @@ def process_image(image):
118
  raise gr.Error(f"Error processing image: {str(e)}")
119
 
120
  # Interface setup
121
- title = "Image Segmentation Demo"
122
- description = "Upload an image to extract its foreground"
 
 
 
123
 
 
124
  examples = []
125
- if os.path.exists("robot.png"):
126
- examples.append(["robot.png"])
127
- if os.path.exists("ship.png"):
128
- examples.append(["ship.png"])
129
 
130
- with gr.Blocks() as app:
 
131
  gr.Markdown(f"## {title}")
132
  gr.Markdown(description)
133
 
134
  with gr.Row():
135
  with gr.Column():
136
- input_image = gr.Image(type="filepath", label="Input Image")
137
- submit_btn = gr.Button("Process")
 
 
 
 
138
 
139
  with gr.Column():
140
- output_rgba = gr.Image(label="Transparent Background", type="pil")
141
- output_mask = gr.Image(label="Segmentation Mask", type="pil")
 
 
 
 
 
 
 
 
142
 
143
  if examples:
144
  gr.Examples(
@@ -146,18 +162,22 @@ with gr.Blocks() as app:
146
  inputs=input_image,
147
  outputs=[output_rgba, output_mask],
148
  fn=process_image,
149
- cache_examples=True
 
150
  )
151
 
152
  submit_btn.click(
153
  fn=process_image,
154
  inputs=input_image,
155
- outputs=[output_rgba, output_mask]
 
156
  )
157
 
 
158
  if __name__ == "__main__":
159
  app.launch(
160
  server_name="0.0.0.0",
161
  server_port=7860,
162
- show_error=True
 
163
  )
 
1
  import os
2
  import cv2
3
  import gradio as gr
 
4
  import numpy as np
5
  import torch
6
+ import torch.nn as nn
7
  from torchvision import transforms
8
  import torch.nn.functional as F
9
+ from PIL import Image
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
  # Initialize device
14
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
15
 
16
+ # Clone repository and setup model
17
  if not os.path.exists("DIS"):
18
+ os.system("git clone --depth 1 https://github.com/xuebinqin/DIS")
19
  os.system("mv DIS/IS-Net/* .")
20
 
21
  # Import model components
22
+ from data_loader_cache import normalize, im_reader, im_preprocess
23
  from models import ISNetDIS
24
 
25
  # Setup model directory
26
  os.makedirs("saved_models", exist_ok=True)
27
  if os.path.exists("isnet.pth"):
28
+ os.rename("isnet.pth", "saved_models/isnet.pth")
29
 
30
  class GOSNormalize:
31
+ def __init__(self, mean=[0.5, 0.5, 0.5], std=[1.0, 1.0, 1.0]):
32
  self.mean = mean
33
  self.std = std
34
 
 
54
 
55
  net.to(device)
56
 
57
+ model_path = os.path.join(hypar["model_path"], hypar["restore_model"])
58
+ if os.path.exists(model_path):
59
+ state_dict = torch.load(model_path, map_location=device)
60
+ net.load_state_dict(state_dict)
61
+
62
  net.eval()
63
  return net
64
 
65
  def predict(net, inputs_val, shapes_val, hypar, device):
66
+ with torch.no_grad():
67
+ inputs_val = inputs_val.type(torch.float16 if hypar["model_digit"] == "half" else torch.float32)
68
+ inputs_val = inputs_val.to(device)
69
+ ds_val = net(inputs_val)[0]
70
+ pred_val = ds_val[0][0,:,:,:]
71
+
72
+ pred_val = F.interpolate(
73
+ pred_val.unsqueeze(0).unsqueeze(0),
74
+ size=(shapes_val[0][0], shapes_val[0][1]),
75
+ mode='bilinear',
76
+ align_corners=False
77
+ ).squeeze()
78
+
79
+ pred_val = (pred_val - pred_val.min()) / (pred_val.max() - pred_val.min() + 1e-8)
80
+
81
+ return (pred_val.cpu().numpy() * 255).astype(np.uint8)
 
82
 
83
  # Model configuration
84
  hypar = {
 
98
 
99
  def process_image(image):
100
  try:
101
+ image_path = image if isinstance(image, str) else image.name
102
+
103
+ # Verify image exists
104
+ if not os.path.exists(image_path):
105
+ raise FileNotFoundError(f"Image file not found: {image_path}")
106
+
107
  image_tensor, orig_size = load_image(image_path, hypar)
108
  mask = predict(net, image_tensor, orig_size, hypar, device)
109
 
 
118
  raise gr.Error(f"Error processing image: {str(e)}")
119
 
120
  # Interface setup
121
+ title = "DIS Image Segmentation"
122
+ description = """
123
+ Highly Accurate Dichotomous Image Segmentation
124
+ <br>GitHub: [xuebinqin/DIS](https://github.com/xuebinqin/DIS)
125
+ """
126
 
127
+ # Prepare examples
128
  examples = []
129
+ for example_file in ["robot.png", "ship.png"]:
130
+ if os.path.exists(example_file):
131
+ examples.append([example_file])
 
132
 
133
+ # Create Gradio interface
134
+ with gr.Blocks(title=title) as app:
135
  gr.Markdown(f"## {title}")
136
  gr.Markdown(description)
137
 
138
  with gr.Row():
139
  with gr.Column():
140
+ input_image = gr.Image(
141
+ type="filepath",
142
+ label="Input Image",
143
+ height=400
144
+ )
145
+ submit_btn = gr.Button("Process", variant="primary")
146
 
147
  with gr.Column():
148
+ output_rgba = gr.Image(
149
+ label="Transparent Background",
150
+ type="pil",
151
+ height=400
152
+ )
153
+ output_mask = gr.Image(
154
+ label="Segmentation Mask",
155
+ type="pil",
156
+ height=400
157
+ )
158
 
159
  if examples:
160
  gr.Examples(
 
162
  inputs=input_image,
163
  outputs=[output_rgba, output_mask],
164
  fn=process_image,
165
+ cache_examples=True,
166
+ label="Example Images"
167
  )
168
 
169
  submit_btn.click(
170
  fn=process_image,
171
  inputs=input_image,
172
+ outputs=[output_rgba, output_mask],
173
+ api_name="predict"
174
  )
175
 
176
+ # Launch application
177
  if __name__ == "__main__":
178
  app.launch(
179
  server_name="0.0.0.0",
180
  server_port=7860,
181
+ show_error=True,
182
+ share=False
183
  )