mohamed12ahmed commited on
Commit
63d3c38
·
verified ·
1 Parent(s): c180b03

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +53 -80
inference.py CHANGED
@@ -3,16 +3,17 @@ import cv2
3
  import utils
4
  import argparse
5
  import numpy as np
6
-
7
- import torch
8
-
9
  from utils import convert_state_dict
10
  from models import restormer_arch
11
  from data.preprocess.crop_merge_image import stride_integral
12
-
13
  os.sys.path.append('./data/MBD/')
14
  from data.MBD.infer import net1_net2_infer_single_im
15
 
 
 
 
16
 
17
  def dewarp_prompt(img):
18
  mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
@@ -50,10 +51,10 @@ def deshadow_prompt(img):
50
  return bg_imgs
51
 
52
  def deblur_prompt(img):
53
- x = cv2.Sobel(img,cv2.CV_16S,1,0)
54
- y = cv2.Sobel(img,cv2.CV_16S,0,1)
55
- absX = cv2.convertScaleAbs(x) # 转回uint8
56
- absY = cv2.convertScaleAbs(y)
57
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
58
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
59
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
@@ -82,11 +83,10 @@ def binarization_promptv2(img):
82
  thresh = thresh.astype(np.uint8)
83
  result[result>155]=255
84
  result[result<=155]=0
85
-
86
- x = cv2.Sobel(img,cv2.CV_16S,1,0)
87
- y = cv2.Sobel(img,cv2.CV_16S,0,1)
88
- absX = cv2.convertScaleAbs(x) # 转回uint8
89
- absY = cv2.convertScaleAbs(y)
90
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
91
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
92
  return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
@@ -95,19 +95,15 @@ def dewarping(model,im_path):
95
  INPUT_SIZE=256
96
  im_org = cv2.imread(im_path)
97
  im_masked, prompt_org = dewarp_prompt(im_org.copy())
98
-
99
  h,w = im_masked.shape[:2]
100
  im_masked = im_masked.copy()
101
  im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
102
  im_masked = im_masked / 255.0
103
  im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
104
  im_masked = im_masked.float().to(DEVICE)
105
-
106
  prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
107
  prompt = prompt.float().to(DEVICE)
108
-
109
  in_im = torch.cat((im_masked,prompt),dim=1)
110
-
111
  # inference
112
  base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
113
  model = model.float()
@@ -117,14 +113,12 @@ def dewarping(model,im_path):
117
  pred = pred+base_coord
118
  ## smooth
119
  for i in range(15):
120
- pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
121
  pred = cv2.resize(pred,(w,h))*(w,h)
122
  pred = pred.astype(np.float32)
123
  out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
124
-
125
  prompt_org = (prompt_org*255).astype(np.uint8)
126
  prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
127
-
128
  return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
129
 
130
  def appearance(model,im_path):
@@ -134,26 +128,24 @@ def appearance(model,im_path):
134
  h,w = im_org.shape[:2]
135
  prompt = appearance_prompt(im_org)
136
  in_im = np.concatenate((im_org,prompt),-1)
137
-
138
- # constrain the max resolution
139
  if max(w,h) < MAX_SIZE:
140
  in_im,padding_h,padding_w = stride_integral(in_im,8)
141
  else:
142
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
143
-
144
- # normalize
145
  in_im = in_im / 255.0
146
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
147
-
148
  # inference
149
- in_im = in_im.half().to(DEVICE)
150
- model = model.half()
 
 
151
  with torch.no_grad():
152
  pred = model(in_im)
153
  pred = torch.clamp(pred,0,1)
154
  pred = pred[0].permute(1,2,0).cpu().numpy()
155
  pred = (pred*255).astype(np.uint8)
156
-
157
  if max(w,h) < MAX_SIZE:
158
  out_im = pred[padding_h:,padding_w:]
159
  else:
@@ -162,9 +154,7 @@ def appearance(model,im_path):
162
  shadow_map = cv2.resize(shadow_map,(w,h))
163
  shadow_map[shadow_map==0]=0.00001
164
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
165
-
166
- return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
167
-
168
 
169
  def deshadowing(model,im_path):
170
  MAX_SIZE=1600
@@ -173,26 +163,24 @@ def deshadowing(model,im_path):
173
  h,w = im_org.shape[:2]
174
  prompt = deshadow_prompt(im_org)
175
  in_im = np.concatenate((im_org,prompt),-1)
176
-
177
- # constrain the max resolution
178
  if max(w,h) < MAX_SIZE:
179
  in_im,padding_h,padding_w = stride_integral(in_im,8)
180
  else:
181
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
182
-
183
- # normalize
184
  in_im = in_im / 255.0
185
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
186
-
187
  # inference
188
- in_im = in_im.half().to(DEVICE)
189
- model = model.half()
 
 
190
  with torch.no_grad():
191
  pred = model(in_im)
192
  pred = torch.clamp(pred,0,1)
193
  pred = pred[0].permute(1,2,0).cpu().numpy()
194
  pred = (pred*255).astype(np.uint8)
195
-
196
  if max(w,h) < MAX_SIZE:
197
  out_im = pred[padding_h:,padding_w:]
198
  else:
@@ -201,10 +189,8 @@ def deshadowing(model,im_path):
201
  shadow_map = cv2.resize(shadow_map,(w,h))
202
  shadow_map[shadow_map==0]=0.00001
203
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
204
-
205
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
206
 
207
-
208
  def deblurring(model,im_path):
209
  # setup image
210
  im_org = cv2.imread(im_path)
@@ -213,34 +199,34 @@ def deblurring(model,im_path):
213
  in_im = np.concatenate((in_im,prompt),-1)
214
  in_im = in_im / 255.0
215
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
216
- in_im = in_im.half().to(DEVICE)
 
217
  # inference
218
  model.to(DEVICE)
219
  model.eval()
220
- model = model.half()
 
221
  with torch.no_grad():
222
  pred = model(in_im)
223
  pred = torch.clamp(pred,0,1)
224
  pred = pred[0].permute(1,2,0).cpu().numpy()
225
  pred = (pred*255).astype(np.uint8)
226
  out_im = pred[padding_h:,padding_w:]
227
-
228
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
229
 
230
-
231
-
232
  def binarization(model,im_path):
233
  im_org = cv2.imread(im_path)
234
  im,padding_h,padding_w = stride_integral(im_org,8)
235
  prompt = binarization_promptv2(im)
236
  h,w = im.shape[:2]
237
  in_im = np.concatenate((im,prompt),-1)
238
-
239
  in_im = in_im / 255.0
240
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
241
  in_im = in_im.to(DEVICE)
242
- model = model.half()
243
- in_im = in_im.half()
 
 
244
  with torch.no_grad():
245
  pred = model(in_im)
246
  pred = pred[:,:2,:,:]
@@ -249,24 +235,15 @@ def binarization(model,im_path):
249
  pred = (pred*255).astype(np.uint8)
250
  pred = cv2.resize(pred,(w,h))
251
  out_im = pred[padding_h:,padding_w:]
252
-
253
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
254
 
255
-
256
-
257
-
258
-
259
  def get_args():
260
  parser = argparse.ArgumentParser(description='Params')
261
  parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
262
- parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/',
263
- help='Path of input document image')
264
- parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/',
265
- help='Folder of the output images')
266
- parser.add_argument('--task', nargs='?', type=str, default='dewarping',
267
- help='task that need to be executed')
268
- parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0,
269
- help='Width of the input image')
270
  args = parser.parse_args()
271
  possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end']
272
  assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks)
@@ -274,25 +251,26 @@ def get_args():
274
 
275
  def model_init(args):
276
  # prepare model
277
- model = restormer_arch.Restormer(
278
- inp_channels=6,
279
- out_channels=3,
280
- dim = 48,
281
- num_blocks = [2,3,3,4],
282
- num_refinement_blocks = 4,
283
  heads = [1,2,4,8],
284
  ffn_expansion_factor = 2.66,
285
  bias = False,
286
  LayerNorm_type = 'WithBias',
287
- dual_pixel_task = True
288
- )
289
-
290
  if DEVICE.type == 'cpu':
291
  state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
292
  else:
293
- state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
 
 
294
  model.load_state_dict(state)
295
-
296
  model.eval()
297
  model = model.to(DEVICE)
298
  return model
@@ -316,20 +294,15 @@ def inference_one_im(model,im_path,task):
316
  prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg')
317
  # os.remove('restorted/step1.jpg')
318
  # os.remove('restorted/step2.jpg')
319
-
320
  return prompt1,prompt2,prompt3,restorted
321
 
322
-
323
-
324
  if __name__ == '__main__':
325
- ## model init
326
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
327
  args = get_args()
328
  model = model_init(args)
329
-
330
  ## inference
331
  prompt1,prompt2,prompt3,restorted = inference_one_im(model,args.im_path,args.task)
332
-
333
  ## results saving
334
  im_name = os.path.split(args.im_path)[-1]
335
  im_format = '.'+im_name.split('.')[-1]
 
3
  import utils
4
  import argparse
5
  import numpy as np
6
+ import torch # تم تصحيح: torcth -> torch
7
+ from PIL import Image
 
8
  from utils import convert_state_dict
9
  from models import restormer_arch
10
  from data.preprocess.crop_merge_image import stride_integral
 
11
  os.sys.path.append('./data/MBD/')
12
  from data.MBD.infer import net1_net2_infer_single_im
13
 
14
+ # تحديد الجهاز ليكون CPU بشكل إجباري
15
+ # يتم استخدامه في دالة model_init وكذلك في الدالة الرئيسية
16
+ DEVICE = torch.device('cpu')
17
 
18
  def dewarp_prompt(img):
19
  mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
 
51
  return bg_imgs
52
 
53
  def deblur_prompt(img):
54
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
55
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
56
+ absX = cv2.convertScaleAbs(x) # 转回uint8
57
+ absY = cv2.convertScaleAbs(y)
58
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
59
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
60
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
 
83
  thresh = thresh.astype(np.uint8)
84
  result[result>155]=255
85
  result[result<=155]=0
86
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
87
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
88
+ absX = cv2.convertScaleAbs(x) # 转回uint8
89
+ absY = cv2.convertScaleAbs(y)
 
90
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
91
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
92
  return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
 
95
  INPUT_SIZE=256
96
  im_org = cv2.imread(im_path)
97
  im_masked, prompt_org = dewarp_prompt(im_org.copy())
 
98
  h,w = im_masked.shape[:2]
99
  im_masked = im_masked.copy()
100
  im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
101
  im_masked = im_masked / 255.0
102
  im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
103
  im_masked = im_masked.float().to(DEVICE)
 
104
  prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
105
  prompt = prompt.float().to(DEVICE)
 
106
  in_im = torch.cat((im_masked,prompt),dim=1)
 
107
  # inference
108
  base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
109
  model = model.float()
 
113
  pred = pred+base_coord
114
  ## smooth
115
  for i in range(15):
116
+ pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
117
  pred = cv2.resize(pred,(w,h))*(w,h)
118
  pred = pred.astype(np.float32)
119
  out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
 
120
  prompt_org = (prompt_org*255).astype(np.uint8)
121
  prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
 
122
  return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
123
 
124
  def appearance(model,im_path):
 
128
  h,w = im_org.shape[:2]
129
  prompt = appearance_prompt(im_org)
130
  in_im = np.concatenate((im_org,prompt),-1)
131
+ # constrain the max resolution
 
132
  if max(w,h) < MAX_SIZE:
133
  in_im,padding_h,padding_w = stride_integral(in_im,8)
134
  else:
135
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
136
+ # normalize
 
137
  in_im = in_im / 255.0
138
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
 
139
  # inference
140
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
141
+ in_im = in_im.float().to(DEVICE)
142
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
143
+ model = model.float()
144
  with torch.no_grad():
145
  pred = model(in_im)
146
  pred = torch.clamp(pred,0,1)
147
  pred = pred[0].permute(1,2,0).cpu().numpy()
148
  pred = (pred*255).astype(np.uint8)
 
149
  if max(w,h) < MAX_SIZE:
150
  out_im = pred[padding_h:,padding_w:]
151
  else:
 
154
  shadow_map = cv2.resize(shadow_map,(w,h))
155
  shadow_map[shadow_map==0]=0.00001
156
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
157
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
 
 
158
 
159
  def deshadowing(model,im_path):
160
  MAX_SIZE=1600
 
163
  h,w = im_org.shape[:2]
164
  prompt = deshadow_prompt(im_org)
165
  in_im = np.concatenate((im_org,prompt),-1)
166
+ # constrain the max resolution
 
167
  if max(w,h) < MAX_SIZE:
168
  in_im,padding_h,padding_w = stride_integral(in_im,8)
169
  else:
170
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
171
+ # normalize
 
172
  in_im = in_im / 255.0
173
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
 
174
  # inference
175
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
176
+ in_im = in_im.float().to(DEVICE)
177
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
178
+ model = model.float()
179
  with torch.no_grad():
180
  pred = model(in_im)
181
  pred = torch.clamp(pred,0,1)
182
  pred = pred[0].permute(1,2,0).cpu().numpy()
183
  pred = (pred*255).astype(np.uint8)
 
184
  if max(w,h) < MAX_SIZE:
185
  out_im = pred[padding_h:,padding_w:]
186
  else:
 
189
  shadow_map = cv2.resize(shadow_map,(w,h))
190
  shadow_map[shadow_map==0]=0.00001
191
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
 
192
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
193
 
 
194
  def deblurring(model,im_path):
195
  # setup image
196
  im_org = cv2.imread(im_path)
 
199
  in_im = np.concatenate((in_im,prompt),-1)
200
  in_im = in_im / 255.0
201
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
202
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
203
+ in_im = in_im.float().to(DEVICE)
204
  # inference
205
  model.to(DEVICE)
206
  model.eval()
207
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
208
+ model = model.float()
209
  with torch.no_grad():
210
  pred = model(in_im)
211
  pred = torch.clamp(pred,0,1)
212
  pred = pred[0].permute(1,2,0).cpu().numpy()
213
  pred = (pred*255).astype(np.uint8)
214
  out_im = pred[padding_h:,padding_w:]
 
215
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
216
 
 
 
217
  def binarization(model,im_path):
218
  im_org = cv2.imread(im_path)
219
  im,padding_h,padding_w = stride_integral(im_org,8)
220
  prompt = binarization_promptv2(im)
221
  h,w = im.shape[:2]
222
  in_im = np.concatenate((im,prompt),-1)
 
223
  in_im = in_im / 255.0
224
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
225
  in_im = in_im.to(DEVICE)
226
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
227
+ model = model.float()
228
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
229
+ in_im = in_im.float()
230
  with torch.no_grad():
231
  pred = model(in_im)
232
  pred = pred[:,:2,:,:]
 
235
  pred = (pred*255).astype(np.uint8)
236
  pred = cv2.resize(pred,(w,h))
237
  out_im = pred[padding_h:,padding_w:]
 
238
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
239
 
 
 
 
 
240
  def get_args():
241
  parser = argparse.ArgumentParser(description='Params')
242
  parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
243
+ parser.add_argument('--im_path', nargs='?', type=str, default='./distorted/', help='Path of input document image')
244
+ parser.add_argument('--out_folder', nargs='?', type=str, default='./restorted/', help='Folder of the output images')
245
+ parser.add_argument('--task', nargs='?', type=str, default='dewarping', help='task that need to be executed')
246
+ parser.add_argument('--save_dtsprompt', nargs='?', type=int, default=0, help='Width of the input image')
 
 
 
 
247
  args = parser.parse_args()
248
  possible_tasks = ['dewarping','deshadowing','appearance','deblurring','binarization','end2end']
249
  assert args.task in possible_tasks, 'Unsupported task, task must be one of '+', '.join(possible_tasks)
 
251
 
252
  def model_init(args):
253
  # prepare model
254
+ model = restormer_arch.Restormer(
255
+ inp_channels=6,
256
+ out_channels=3,
257
+ dim = 48,
258
+ num_blocks = [2,3,3,4],
259
+ num_refinement_blocks = 4,
260
  heads = [1,2,4,8],
261
  ffn_expansion_factor = 2.66,
262
  bias = False,
263
  LayerNorm_type = 'WithBias',
264
+ dual_pixel_task = True
265
+ )
266
+ # نظرًا لأن DEVICE تم تعيينه لـ 'cpu' بشكل إجباري في __main__، سيتم تنفيذ هذا الشرط دائمًا
267
  if DEVICE.type == 'cpu':
268
  state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
269
  else:
270
+ # هذا الجزء لن يتم تنفيذه
271
+ state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
272
+
273
  model.load_state_dict(state)
 
274
  model.eval()
275
  model = model.to(DEVICE)
276
  return model
 
294
  prompt1,prompt2,prompt3,restorted = appearance(model,'restorted/step2.jpg')
295
  # os.remove('restorted/step1.jpg')
296
  # os.remove('restorted/step2.jpg')
 
297
  return prompt1,prompt2,prompt3,restorted
298
 
 
 
299
  if __name__ == '__main__':
300
+ # *** تم التعديل: تم تعيين DEVICE لـ 'cpu' بشكل إجباري ***
301
+ # هذا يضمن أن البرنامج سيعمل على CPU حتى لو كانت CUDA متاحة
302
  args = get_args()
303
  model = model_init(args)
 
304
  ## inference
305
  prompt1,prompt2,prompt3,restorted = inference_one_im(model,args.im_path,args.task)
 
306
  ## results saving
307
  im_name = os.path.split(args.im_path)[-1]
308
  im_format = '.'+im_name.split('.')[-1]