AkashKumarave commited on
Commit
a54d85c
·
verified ·
1 Parent(s): 14ebe9c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -106
app.py CHANGED
@@ -1,51 +1,41 @@
 
1
  import cv2
2
  import gradio as gr
3
- import os
4
  from PIL import Image
5
  import numpy as np
6
  import torch
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
10
- import gdown
11
- import matplotlib.pyplot as plt
12
  import warnings
13
  warnings.filterwarnings("ignore")
14
 
15
- # Clean up any previous runs
16
- if os.path.exists("DIS"):
17
- os.system("rm -rf DIS")
18
 
19
- # Clone and setup the model
20
- os.system("git clone https://github.com/xuebinqin/DIS")
21
- os.system("mv DIS/IS-Net/* .")
 
22
 
23
- # Project imports
24
  from data_loader_cache import normalize, im_reader, im_preprocess
25
- from models import *
26
 
27
- # Device configuration
28
- device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
 
29
 
30
- # Setup model directory and weights
31
- if not os.path.exists("saved_models"):
32
- os.makedirs("saved_models", exist_ok=True)
33
- if os.path.exists("isnet.pth"):
34
- os.system("mv isnet.pth saved_models/")
35
-
36
- class GOSNormalize(object):
37
- '''
38
- Normalize the Image using torch.transforms
39
- '''
40
- def __init__(self, mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]):
41
  self.mean = mean
42
  self.std = std
43
 
44
  def __call__(self, image):
45
- image = normalize(image, self.mean, self.std)
46
- return image
47
 
48
- transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5],[1.0,1.0,1.0])])
49
 
50
  def load_image(im_path, hypar):
51
  im = im_reader(im_path)
@@ -55,55 +45,45 @@ def load_image(im_path, hypar):
55
  return transform(im).unsqueeze(0), shape.unsqueeze(0)
56
 
57
  def build_model(hypar, device):
58
- net = hypar["model"] # GOSNETINC(3,1)
59
-
60
- # Convert to half precision
61
  if hypar["model_digit"] == "half":
62
  net.half()
63
  for layer in net.modules():
64
  if isinstance(layer, nn.BatchNorm2d):
65
  layer.float()
66
-
67
  net.to(device)
68
-
69
- if hypar["restore_model"] != "":
70
  net.load_state_dict(torch.load(
71
- hypar["model_path"]+"/"+hypar["restore_model"],
72
  map_location=device
73
  ))
74
- net.eval()
75
  return net
76
 
77
  def predict(net, inputs_val, shapes_val, hypar, device):
78
  net.eval()
79
-
80
- if hypar["model_digit"] == "full":
81
- inputs_val = inputs_val.type(torch.FloatTensor)
82
- else:
83
- inputs_val = inputs_val.type(torch.HalfTensor)
84
-
85
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
86
- ds_val = net(inputs_val_v)[0] # list of 6 results
87
- pred_val = ds_val[0][0,:,:,:] # B x 1 x H x W
88
-
89
- # Recover the prediction spatial size to the original image size
90
- pred_val = torch.squeeze(F.upsample(
91
  torch.unsqueeze(pred_val, 0),
92
- (shapes_val[0][0], shapes_val[0][1]),
93
  mode='bilinear'
94
  ))
95
-
96
- ma = torch.max(pred_val)
97
- mi = torch.min(pred_val)
98
- pred_val = (pred_val-mi)/(ma-mi) # max = 1
99
-
100
- if device == 'cuda':
101
  torch.cuda.empty_cache()
102
- return (pred_val.detach().cpu().numpy()*255).astype(np.uint8)
103
 
104
- # Set parameters
105
  hypar = {
106
- "model_path": "./saved_models",
107
  "restore_model": "isnet.pth",
108
  "interm_sup": False,
109
  "model_digit": "full",
@@ -114,60 +94,70 @@ hypar = {
114
  "model": ISNetDIS()
115
  }
116
 
117
- # Build model
118
  net = build_model(hypar, device)
119
 
120
- def inference(image):
121
  try:
122
- image_path = image.name if hasattr(image, 'name') else image
123
-
124
- image_tensor, orig_size = load_image(image_path, hypar)
 
 
 
125
  mask = predict(net, image_tensor, orig_size, hypar, device)
126
 
127
- pil_mask = Image.fromarray(mask).convert('L')
128
- im_rgb = Image.open(image_path).convert("RGB")
129
 
130
- im_rgba = im_rgb.copy()
131
- im_rgba.putalpha(pil_mask)
132
-
133
- return [im_rgba, pil_mask]
134
  except Exception as e:
135
- print(f"Error during inference: {str(e)}")
136
- raise e
137
-
138
- title = "Highly Accurate Dichotomous Image Segmentation"
139
- description = """
140
- This is an unofficial demo for DIS, a model that can remove the background from a given image.
141
- To use it, simply upload your image, or click one of the examples to load them.
142
- <br>GitHub: https://github.com/xuebinqin/DIS
143
- <br>Telegram bot: https://t.me/restoration_photo_bot
144
- [![](https://img.shields.io/twitter/follow/DoEvent?label=@DoEvent&style=social)](https://twitter.com/DoEvent)
145
- """
146
- article = "<div><center><img src='https://visitor-badge.glitch.me/badge?page_id=max_skobeev_dis_cmp_public' alt='visitor badge'></center></div>"
147
-
148
- # Create interface
149
- interface = gr.Interface(
150
- fn=inference,
151
- inputs=gr.Image(type="filepath"),
152
- outputs=[
153
- gr.Image(type="pil", label="Image with Transparency"),
154
- gr.Image(type="pil", label="Mask Only")
155
- ],
156
- examples=[
157
- ["robot.png"],
158
- ["ship.png"]
159
- ],
160
- title=title,
161
- description=description,
162
- article=article,
163
- allow_flagging="never"
164
- )
165
-
166
- # Launch with corrected parameters
167
- interface.launch(
168
- server_name="0.0.0.0",
169
- server_port=7860,
170
- share=False,
171
- debug=True,
172
- show_error=True
173
- )
 
 
 
 
 
 
 
 
1
+ import os
2
  import cv2
3
  import gradio as gr
 
4
  from PIL import Image
5
  import numpy as np
6
  import torch
7
  from torch.autograd import Variable
8
  from torchvision import transforms
9
  import torch.nn.functional as F
 
 
10
  import warnings
11
  warnings.filterwarnings("ignore")
12
 
13
+ # Initialize device
14
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
15
 
16
+ # Clone repository if not exists
17
+ if not os.path.exists("DIS"):
18
+ os.system("git clone https://github.com/xuebinqin/DIS")
19
+ os.system("mv DIS/IS-Net/* .")
20
 
21
+ # Import model components
22
  from data_loader_cache import normalize, im_reader, im_preprocess
23
+ from models import ISNetDIS
24
 
25
+ # Setup model directory
26
+ os.makedirs("saved_models", exist_ok=True)
27
+ if os.path.exists("isnet.pth"):
28
+ os.system("mv isnet.pth saved_models/")
29
 
30
+ class GOSNormalize:
31
+ def __init__(self, mean=[0.5,0.5,0.5], std=[1.0,1.0,1.0]):
 
 
 
 
 
 
 
 
 
32
  self.mean = mean
33
  self.std = std
34
 
35
  def __call__(self, image):
36
+ return normalize(image, self.mean, self.std)
 
37
 
38
+ transform = transforms.Compose([GOSNormalize()])
39
 
40
  def load_image(im_path, hypar):
41
  im = im_reader(im_path)
 
45
  return transform(im).unsqueeze(0), shape.unsqueeze(0)
46
 
47
  def build_model(hypar, device):
48
+ net = hypar["model"]
 
 
49
  if hypar["model_digit"] == "half":
50
  net.half()
51
  for layer in net.modules():
52
  if isinstance(layer, nn.BatchNorm2d):
53
  layer.float()
54
+
55
  net.to(device)
56
+
57
+ if hypar["restore_model"]:
58
  net.load_state_dict(torch.load(
59
+ os.path.join(hypar["model_path"], hypar["restore_model"]),
60
  map_location=device
61
  ))
62
+ net.eval()
63
  return net
64
 
65
  def predict(net, inputs_val, shapes_val, hypar, device):
66
  net.eval()
67
+ inputs_val = inputs_val.type(torch.FloatTensor if hypar["model_digit"] == "full" else torch.HalfTensor)
 
 
 
 
 
68
  inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
69
+ ds_val = net(inputs_val_v)[0]
70
+ pred_val = ds_val[0][0,:,:,:]
71
+
72
+ pred_val = torch.squeeze(F.interpolate(
 
73
  torch.unsqueeze(pred_val, 0),
74
+ size=(shapes_val[0][0], shapes_val[0][1]),
75
  mode='bilinear'
76
  ))
77
+
78
+ pred_val = (pred_val - pred_val.min()) / (pred_val.max() - pred_val.min())
79
+
80
+ if device == 'cuda':
 
 
81
  torch.cuda.empty_cache()
82
+ return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
83
 
84
+ # Model configuration
85
  hypar = {
86
+ "model_path": "saved_models",
87
  "restore_model": "isnet.pth",
88
  "interm_sup": False,
89
  "model_digit": "full",
 
94
  "model": ISNetDIS()
95
  }
96
 
97
+ # Initialize model
98
  net = build_model(hypar, device)
99
 
100
+ def process_image(image):
101
  try:
102
+ if isinstance(image, str):
103
+ image_path = image
104
+ else:
105
+ image_path = image.name
106
+
107
+ image_tensor, orig_size = load_image(image_path, hypar)
108
  mask = predict(net, image_tensor, orig_size, hypar, device)
109
 
110
+ mask_img = Image.fromarray(mask).convert('L')
111
+ rgb_img = Image.open(image_path).convert("RGB")
112
 
113
+ rgba_img = rgb_img.copy()
114
+ rgba_img.putalpha(mask_img)
115
+
116
+ return rgba_img, mask_img
117
  except Exception as e:
118
+ raise gr.Error(f"Error processing image: {str(e)}")
119
+
120
+ # Interface setup
121
+ title = "Image Segmentation Demo"
122
+ description = "Upload an image to extract its foreground"
123
+
124
+ examples = []
125
+ if os.path.exists("robot.png"):
126
+ examples.append(["robot.png"])
127
+ if os.path.exists("ship.png"):
128
+ examples.append(["ship.png"])
129
+
130
+ with gr.Blocks() as app:
131
+ gr.Markdown(f"## {title}")
132
+ gr.Markdown(description)
133
+
134
+ with gr.Row():
135
+ with gr.Column():
136
+ input_image = gr.Image(type="filepath", label="Input Image")
137
+ submit_btn = gr.Button("Process")
138
+
139
+ with gr.Column():
140
+ output_rgba = gr.Image(label="Transparent Background", type="pil")
141
+ output_mask = gr.Image(label="Segmentation Mask", type="pil")
142
+
143
+ if examples:
144
+ gr.Examples(
145
+ examples=examples,
146
+ inputs=input_image,
147
+ outputs=[output_rgba, output_mask],
148
+ fn=process_image,
149
+ cache_examples=True
150
+ )
151
+
152
+ submit_btn.click(
153
+ fn=process_image,
154
+ inputs=input_image,
155
+ outputs=[output_rgba, output_mask]
156
+ )
157
+
158
+ if __name__ == "__main__":
159
+ app.launch(
160
+ server_name="0.0.0.0",
161
+ server_port=7860,
162
+ show_error=True
163
+ )