mohamed12ahmed commited on
Commit
0407759
·
verified ·
1 Parent(s): 63d3c38

Update eval.py

Browse files
Files changed (1) hide show
  1. eval.py +68 -82
eval.py CHANGED
@@ -6,16 +6,15 @@ import argparse
6
  import numpy as np
7
  from tqdm import tqdm
8
  from skimage.metrics import structural_similarity,peak_signal_noise_ratio
9
-
10
  import torch
11
-
12
  from utils import convert_state_dict
13
  from models import restormer_arch
14
  from data.preprocess.crop_merge_image import stride_integral
15
-
16
  os.sys.path.append('./data/MBD/')
17
  from data.MBD.infer import net1_net2_infer_single_im
18
 
 
 
19
 
20
  def dewarp_prompt(img):
21
  mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
@@ -26,7 +25,7 @@ def dewarp_prompt(img):
26
 
27
  def deshadow_prompt(img):
28
  h,w = img.shape[:2]
29
- # img = cv2.resize(img,(128,128))
30
  img = cv2.resize(img,(1024,1024))
31
  rgb_planes = cv2.split(img)
32
  result_planes = []
@@ -53,10 +52,10 @@ def deshadow_prompt(img):
53
  return bg_imgs
54
 
55
  def deblur_prompt(img):
56
- x = cv2.Sobel(img,cv2.CV_16S,1,0)
57
- y = cv2.Sobel(img,cv2.CV_16S,0,1)
58
- absX = cv2.convertScaleAbs(x) # 转回uint8
59
- absY = cv2.convertScaleAbs(y)
60
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
61
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
62
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
@@ -85,11 +84,10 @@ def binarization_promptv2(img):
85
  thresh = thresh.astype(np.uint8)
86
  result[result>155]=255
87
  result[result<=155]=0
88
-
89
- x = cv2.Sobel(img,cv2.CV_16S,1,0)
90
- y = cv2.Sobel(img,cv2.CV_16S,0,1)
91
- absX = cv2.convertScaleAbs(x) # 转回uint8
92
- absY = cv2.convertScaleAbs(y)
93
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
94
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
95
  return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
@@ -98,36 +96,31 @@ def dewarping(model,im_path):
98
  INPUT_SIZE=256
99
  im_org = cv2.imread(im_path)
100
  im_masked, prompt_org = dewarp_prompt(im_org.copy())
101
-
102
  h,w = im_masked.shape[:2]
103
  im_masked = im_masked.copy()
104
  im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
105
  im_masked = im_masked / 255.0
106
  im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
107
  im_masked = im_masked.float().to(DEVICE)
108
-
109
  prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
110
  prompt = prompt.float().to(DEVICE)
111
-
112
  in_im = torch.cat((im_masked,prompt),dim=1)
113
-
114
  # inference
115
  base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
116
- model = model.float()
 
117
  with torch.no_grad():
118
  pred = model(in_im)
119
  pred = pred[0][:2].permute(1,2,0).cpu().numpy()
120
  pred = pred+base_coord
121
  ## smooth
122
  for i in range(15):
123
- pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
124
  pred = cv2.resize(pred,(w,h))*(w,h)
125
  pred = pred.astype(np.float32)
126
  out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
127
-
128
  prompt_org = (prompt_org*255).astype(np.uint8)
129
  prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
130
-
131
  return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
132
 
133
  def appearance(model,im_path):
@@ -137,26 +130,23 @@ def appearance(model,im_path):
137
  h,w = im_org.shape[:2]
138
  prompt = appearance_prompt(im_org)
139
  in_im = np.concatenate((im_org,prompt),-1)
140
-
141
- # constrain the max resolution
142
  if max(w,h) < MAX_SIZE:
143
  in_im,padding_h,padding_w = stride_integral(in_im,8)
144
  else:
145
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
146
-
147
- # normalize
148
  in_im = in_im / 255.0
149
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
150
-
151
  # inference
152
- in_im = in_im.half().to(DEVICE)
153
- model = model.half()
 
154
  with torch.no_grad():
155
  pred = model(in_im)
156
  pred = torch.clamp(pred,0,1)
157
  pred = pred[0].permute(1,2,0).cpu().numpy()
158
  pred = (pred*255).astype(np.uint8)
159
-
160
  if max(w,h) < MAX_SIZE:
161
  out_im = pred[padding_h:,padding_w:]
162
  else:
@@ -165,9 +155,7 @@ def appearance(model,im_path):
165
  shadow_map = cv2.resize(shadow_map,(w,h))
166
  shadow_map[shadow_map==0]=0.00001
167
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
168
-
169
- return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
170
-
171
 
172
  def deshadowing(model,im_path):
173
  MAX_SIZE=1600
@@ -176,26 +164,23 @@ def deshadowing(model,im_path):
176
  h,w = im_org.shape[:2]
177
  prompt = deshadow_prompt(im_org)
178
  in_im = np.concatenate((im_org,prompt),-1)
179
-
180
- # constrain the max resolution
181
  if max(w,h) < MAX_SIZE:
182
  in_im,padding_h,padding_w = stride_integral(in_im,8)
183
  else:
184
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
185
-
186
- # normalize
187
  in_im = in_im / 255.0
188
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
189
-
190
  # inference
191
- in_im = in_im.half().to(DEVICE)
192
- model = model.half()
 
193
  with torch.no_grad():
194
  pred = model(in_im)
195
  pred = torch.clamp(pred,0,1)
196
  pred = pred[0].permute(1,2,0).cpu().numpy()
197
  pred = (pred*255).astype(np.uint8)
198
-
199
  if max(w,h) < MAX_SIZE:
200
  out_im = pred[padding_h:,padding_w:]
201
  else:
@@ -204,10 +189,8 @@ def deshadowing(model,im_path):
204
  shadow_map = cv2.resize(shadow_map,(w,h))
205
  shadow_map[shadow_map==0]=0.00001
206
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
207
-
208
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
209
 
210
-
211
  def deblurring(model,im_path):
212
  # setup image
213
  im_org = cv2.imread(im_path)
@@ -216,34 +199,33 @@ def deblurring(model,im_path):
216
  in_im = np.concatenate((in_im,prompt),-1)
217
  in_im = in_im / 255.0
218
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
219
- in_im = in_im.half().to(DEVICE)
 
220
  # inference
221
  model.to(DEVICE)
222
  model.eval()
223
- model = model.half()
 
224
  with torch.no_grad():
225
  pred = model(in_im)
226
  pred = torch.clamp(pred,0,1)
227
  pred = pred[0].permute(1,2,0).cpu().numpy()
228
  pred = (pred*255).astype(np.uint8)
229
  out_im = pred[padding_h:,padding_w:]
230
-
231
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
232
 
233
-
234
-
235
  def binarization(model,im_path):
236
  im_org = cv2.imread(im_path)
237
  im,padding_h,padding_w = stride_integral(im_org,8)
238
  prompt = binarization_promptv2(im)
239
  h,w = im.shape[:2]
240
  in_im = np.concatenate((im,prompt),-1)
241
-
242
  in_im = in_im / 255.0
243
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
244
  in_im = in_im.to(DEVICE)
245
- model = model.half()
246
- in_im = in_im.half()
 
247
  with torch.no_grad():
248
  pred = model(in_im,'binarization')
249
  pred = pred[:,:2,:,:]
@@ -252,42 +234,36 @@ def binarization(model,im_path):
252
  pred = (pred*255).astype(np.uint8)
253
  pred = cv2.resize(pred,(w,h))
254
  out_im = pred[padding_h:,padding_w:]
255
-
256
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
257
 
258
-
259
-
260
-
261
-
262
  def get_args():
263
  parser = argparse.ArgumentParser(description='Params')
264
  parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
265
  parser.add_argument('--dataset', nargs='?', type=str, default='./distorted/',help='Path of input document image')
266
  args = parser.parse_args()
267
- assert args.dataset in all_datasets.keys(), 'Unregisted dataset, dataset must be one of '+', '.join(all_datasets)
 
 
268
  return args
269
 
270
  def model_init(args):
271
  # prepare model
272
- model = restormer_arch.Restormer(
273
- inp_channels=6,
274
- out_channels=3,
275
- dim = 48,
276
- num_blocks = [2,3,3,4],
277
- num_refinement_blocks = 4,
278
  heads = [1,2,4,8],
279
  ffn_expansion_factor = 2.66,
280
  bias = False,
281
- LayerNorm_type = 'WithBias',
282
- dual_pixel_task = True
283
- )
284
-
285
- if DEVICE.type == 'cpu':
286
- state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
287
- else:
288
- state = convert_state_dict(torch.load(args.model_path, map_location='cuda:0')['model_state'])
289
  model.load_state_dict(state)
290
-
291
  model.eval()
292
  model = model.to(DEVICE)
293
  return model
@@ -310,27 +286,27 @@ def inference_one_im(model,im_path,task):
310
  cv2.imwrite('./temp.jpg',restorted)
311
  prompt1,prompt2,prompt3,restorted = appearance(model,'./temp.jpg')
312
  os.remove('./temp.jpg')
313
-
314
  return prompt1,prompt2,prompt3,restorted
315
 
316
-
317
-
318
  if __name__ == '__main__':
319
  all_datasets = {'dir300':'dewarping','kligler':'deshadowing','jung':'deshadowing','osr':'deshadowing','docunet_docaligner':'appearance','realdae':'appearance','tdd':'deblurring','dibco18':'binarization'}
320
-
321
- ## model init
322
- DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
323
  args = get_args()
 
 
 
 
324
  model = model_init(args)
325
-
326
- ## inference
327
  print('Predicting')
328
  task = all_datasets[args.dataset]
329
  im_paths = glob.glob(os.path.join('./data/eval/',args.dataset,'*_in.*'))
330
  for im_path in tqdm(im_paths):
331
  _,_,_,restorted = inference_one_im(model,im_path,task)
332
  cv2.imwrite(im_path.replace('_in','_docres'),restorted)
333
-
334
  ## obtain metric
335
  print('Metric calculating')
336
  if task == 'dewarping':
@@ -341,22 +317,34 @@ if __name__ == '__main__':
341
  for im_path in tqdm(im_paths):
342
  pred = cv2.imread(im_path.replace('_in','_docres'))
343
  gt = cv2.imread(im_path.replace('_in','_gt'))
 
 
 
 
344
  ssim.append(structural_similarity(pred,gt,multichannel=True))
345
  psnr.append(peak_signal_noise_ratio(pred, gt))
346
  print(args.dataset)
347
  print('ssim:',np.mean(ssim))
348
  print('psnr:',np.mean(psnr))
 
349
  elif task=='binarization':
350
  fmeasures, pfmeasures,psnrs = [],[],[]
351
  for im_path in tqdm(im_paths):
352
  pred = cv2.imread(im_path.replace('_in','_docres'))
353
  gt = cv2.imread(im_path.replace('_in','_gt'))
 
 
 
 
 
354
  pred = cv2.cvtColor(pred,cv2.COLOR_BGR2GRAY)
355
  gt = cv2.cvtColor(gt,cv2.COLOR_BGR2GRAY)
 
356
  pred[pred>155]=255
357
  pred[pred<=155]=0
358
  gt[gt>155]=255
359
  gt[gt<=155]=0
 
360
  fmeasure, pfmeasure,psnr,_,_,_ = utils.bin_metric(pred,gt)
361
  fmeasures.append(fmeasure)
362
  pfmeasures.append(pfmeasure)
@@ -365,5 +353,3 @@ if __name__ == '__main__':
365
  print('fmeasure:',np.mean(fmeasures))
366
  print('pfmeasure:',np.mean(pfmeasures))
367
  print('psnr:',np.mean(psnrs))
368
-
369
-
 
6
  import numpy as np
7
  from tqdm import tqdm
8
  from skimage.metrics import structural_similarity,peak_signal_noise_ratio
 
9
  import torch
 
10
  from utils import convert_state_dict
11
  from models import restormer_arch
12
  from data.preprocess.crop_merge_image import stride_integral
 
13
  os.sys.path.append('./data/MBD/')
14
  from data.MBD.infer import net1_net2_infer_single_im
15
 
16
+ # *** تحديد الجهاز ليكون CPU بشكل إجباري لضمان التشغيل الموثوق ***
17
+ DEVICE = torch.device('cpu')
18
 
19
  def dewarp_prompt(img):
20
  mask = net1_net2_infer_single_im(img,'data/MBD/checkpoint/mbd.pkl')
 
25
 
26
  def deshadow_prompt(img):
27
  h,w = img.shape[:2]
28
+ # img = cv2.resize(img,(1024,1024))
29
  img = cv2.resize(img,(1024,1024))
30
  rgb_planes = cv2.split(img)
31
  result_planes = []
 
52
  return bg_imgs
53
 
54
  def deblur_prompt(img):
55
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
56
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
57
+ absX = cv2.convertScaleAbs(x) # 转回uint8
58
+ absY = cv2.convertScaleAbs(y)
59
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
60
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
61
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_GRAY2BGR)
 
84
  thresh = thresh.astype(np.uint8)
85
  result[result>155]=255
86
  result[result<=155]=0
87
+ x = cv2.Sobel(img,cv2.CV_16S,1,0)
88
+ y = cv2.Sobel(img,cv2.CV_16S,0,1)
89
+ absX = cv2.convertScaleAbs(x) # 转回uint8
90
+ absY = cv2.convertScaleAbs(y)
 
91
  high_frequency = cv2.addWeighted(absX,0.5,absY,0.5,0)
92
  high_frequency = cv2.cvtColor(high_frequency,cv2.COLOR_BGR2GRAY)
93
  return np.concatenate((np.expand_dims(thresh,-1),np.expand_dims(high_frequency,-1),np.expand_dims(result,-1)),-1)
 
96
  INPUT_SIZE=256
97
  im_org = cv2.imread(im_path)
98
  im_masked, prompt_org = dewarp_prompt(im_org.copy())
 
99
  h,w = im_masked.shape[:2]
100
  im_masked = im_masked.copy()
101
  im_masked = cv2.resize(im_masked,(INPUT_SIZE,INPUT_SIZE))
102
  im_masked = im_masked / 255.0
103
  im_masked = torch.from_numpy(im_masked.transpose(2,0,1)).unsqueeze(0)
104
  im_masked = im_masked.float().to(DEVICE)
 
105
  prompt = torch.from_numpy(prompt_org.transpose(2,0,1)).unsqueeze(0)
106
  prompt = prompt.float().to(DEVICE)
 
107
  in_im = torch.cat((im_masked,prompt),dim=1)
 
108
  # inference
109
  base_coord = utils.getBasecoord(INPUT_SIZE,INPUT_SIZE)/INPUT_SIZE
110
+ # *** تم التعديل: استخدام .float() لـ CPU ***
111
+ model = model.float()
112
  with torch.no_grad():
113
  pred = model(in_im)
114
  pred = pred[0][:2].permute(1,2,0).cpu().numpy()
115
  pred = pred+base_coord
116
  ## smooth
117
  for i in range(15):
118
+ pred = cv2.blur(pred,(3,3),borderType=cv2.BORDER_REPLICATE)
119
  pred = cv2.resize(pred,(w,h))*(w,h)
120
  pred = pred.astype(np.float32)
121
  out_im = cv2.remap(im_org,pred[:,:,0],pred[:,:,1],cv2.INTER_LINEAR)
 
122
  prompt_org = (prompt_org*255).astype(np.uint8)
123
  prompt_org = cv2.resize(prompt_org,im_org.shape[:2][::-1])
 
124
  return prompt_org[:,:,0],prompt_org[:,:,1],prompt_org[:,:,2],out_im
125
 
126
  def appearance(model,im_path):
 
130
  h,w = im_org.shape[:2]
131
  prompt = appearance_prompt(im_org)
132
  in_im = np.concatenate((im_org,prompt),-1)
133
+ # constrain the max resolution
 
134
  if max(w,h) < MAX_SIZE:
135
  in_im,padding_h,padding_w = stride_integral(in_im,8)
136
  else:
137
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
138
+ # normalize
 
139
  in_im = in_im / 255.0
140
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
 
141
  # inference
142
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
143
+ in_im = in_im.float().to(DEVICE)
144
+ model = model.float()
145
  with torch.no_grad():
146
  pred = model(in_im)
147
  pred = torch.clamp(pred,0,1)
148
  pred = pred[0].permute(1,2,0).cpu().numpy()
149
  pred = (pred*255).astype(np.uint8)
 
150
  if max(w,h) < MAX_SIZE:
151
  out_im = pred[padding_h:,padding_w:]
152
  else:
 
155
  shadow_map = cv2.resize(shadow_map,(w,h))
156
  shadow_map[shadow_map==0]=0.00001
157
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
158
+ return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
 
 
159
 
160
  def deshadowing(model,im_path):
161
  MAX_SIZE=1600
 
164
  h,w = im_org.shape[:2]
165
  prompt = deshadow_prompt(im_org)
166
  in_im = np.concatenate((im_org,prompt),-1)
167
+ # constrain the max resolution
 
168
  if max(w,h) < MAX_SIZE:
169
  in_im,padding_h,padding_w = stride_integral(in_im,8)
170
  else:
171
  in_im = cv2.resize(in_im,(MAX_SIZE,MAX_SIZE))
172
+ # normalize
 
173
  in_im = in_im / 255.0
174
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
 
175
  # inference
176
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
177
+ in_im = in_im.float().to(DEVICE)
178
+ model = model.float()
179
  with torch.no_grad():
180
  pred = model(in_im)
181
  pred = torch.clamp(pred,0,1)
182
  pred = pred[0].permute(1,2,0).cpu().numpy()
183
  pred = (pred*255).astype(np.uint8)
 
184
  if max(w,h) < MAX_SIZE:
185
  out_im = pred[padding_h:,padding_w:]
186
  else:
 
189
  shadow_map = cv2.resize(shadow_map,(w,h))
190
  shadow_map[shadow_map==0]=0.00001
191
  out_im = np.clip(im_org.astype(float)/shadow_map,0,255).astype(np.uint8)
 
192
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
193
 
 
194
  def deblurring(model,im_path):
195
  # setup image
196
  im_org = cv2.imread(im_path)
 
199
  in_im = np.concatenate((in_im,prompt),-1)
200
  in_im = in_im / 255.0
201
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
202
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
203
+ in_im = in_im.float().to(DEVICE)
204
  # inference
205
  model.to(DEVICE)
206
  model.eval()
207
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
208
+ model = model.float()
209
  with torch.no_grad():
210
  pred = model(in_im)
211
  pred = torch.clamp(pred,0,1)
212
  pred = pred[0].permute(1,2,0).cpu().numpy()
213
  pred = (pred*255).astype(np.uint8)
214
  out_im = pred[padding_h:,padding_w:]
 
215
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
216
 
 
 
217
  def binarization(model,im_path):
218
  im_org = cv2.imread(im_path)
219
  im,padding_h,padding_w = stride_integral(im_org,8)
220
  prompt = binarization_promptv2(im)
221
  h,w = im.shape[:2]
222
  in_im = np.concatenate((im,prompt),-1)
 
223
  in_im = in_im / 255.0
224
  in_im = torch.from_numpy(in_im.transpose(2,0,1)).unsqueeze(0)
225
  in_im = in_im.to(DEVICE)
226
+ # *** تم التعديل: استخدام .float() بدلاً من .half() لـ CPU ***
227
+ model = model.float()
228
+ in_im = in_im.float()
229
  with torch.no_grad():
230
  pred = model(in_im,'binarization')
231
  pred = pred[:,:2,:,:]
 
234
  pred = (pred*255).astype(np.uint8)
235
  pred = cv2.resize(pred,(w,h))
236
  out_im = pred[padding_h:,padding_w:]
 
237
  return prompt[:,:,0],prompt[:,:,1],prompt[:,:,2],out_im
238
 
 
 
 
 
239
  def get_args():
240
  parser = argparse.ArgumentParser(description='Params')
241
  parser.add_argument('--model_path', nargs='?', type=str, default='./checkpoints/docres.pkl',help='Path of the saved checkpoint')
242
  parser.add_argument('--dataset', nargs='?', type=str, default='./distorted/',help='Path of input document image')
243
  args = parser.parse_args()
244
+ # يتم تعريف all_datasets لاحقًا في __main__
245
+ # سنحذف assert مؤقتًا أو نعتمد على تعريفها لاحقًا
246
+ # assert args.dataset in all_datasets.keys(), 'Unregisted dataset, dataset must be one of '+', '.join(all_datasets)
247
  return args
248
 
249
  def model_init(args):
250
  # prepare model
251
+ model = restormer_arch.Restormer(
252
+ inp_channels=6,
253
+ out_channels=3,
254
+ dim = 48,
255
+ num_blocks = [2,3,3,4],
256
+ num_refinement_blocks = 4,
257
  heads = [1,2,4,8],
258
  ffn_expansion_factor = 2.66,
259
  bias = False,
260
+ LayerNorm_type = 'WithBias',
261
+ dual_pixel_task = True
262
+ )
263
+
264
+ # تحميل النموذج وتعيينه لـ CPU بشكل إجباري
265
+ state = convert_state_dict(torch.load(args.model_path, map_location='cpu')['model_state'])
 
 
266
  model.load_state_dict(state)
 
267
  model.eval()
268
  model = model.to(DEVICE)
269
  return model
 
286
  cv2.imwrite('./temp.jpg',restorted)
287
  prompt1,prompt2,prompt3,restorted = appearance(model,'./temp.jpg')
288
  os.remove('./temp.jpg')
 
289
  return prompt1,prompt2,prompt3,restorted
290
 
 
 
291
  if __name__ == '__main__':
292
  all_datasets = {'dir300':'dewarping','kligler':'deshadowing','jung':'deshadowing','osr':'deshadowing','docunet_docaligner':'appearance','realdae':'appearance','tdd':'deblurring','dibco18':'binarization'}
293
+
294
+ # تم تعيين DEVICE بالفعل لـ 'cpu' في بداية الملف. نستخدمه هنا.
 
295
  args = get_args()
296
+
297
+ # التأكد من أن مجموعة البيانات المدخلة موجودة
298
+ assert args.dataset in all_datasets.keys(), 'Unregisted dataset, dataset must be one of '+', '.join(all_datasets)
299
+
300
  model = model_init(args)
301
+
302
+ ## inference
303
  print('Predicting')
304
  task = all_datasets[args.dataset]
305
  im_paths = glob.glob(os.path.join('./data/eval/',args.dataset,'*_in.*'))
306
  for im_path in tqdm(im_paths):
307
  _,_,_,restorted = inference_one_im(model,im_path,task)
308
  cv2.imwrite(im_path.replace('_in','_docres'),restorted)
309
+
310
  ## obtain metric
311
  print('Metric calculating')
312
  if task == 'dewarping':
 
317
  for im_path in tqdm(im_paths):
318
  pred = cv2.imread(im_path.replace('_in','_docres'))
319
  gt = cv2.imread(im_path.replace('_in','_gt'))
320
+ # لضمان التوافق في الأشكال قبل حساب المقاييس
321
+ if pred.shape != gt.shape:
322
+ gt = cv2.resize(gt, (pred.shape[1], pred.shape[0]))
323
+
324
  ssim.append(structural_similarity(pred,gt,multichannel=True))
325
  psnr.append(peak_signal_noise_ratio(pred, gt))
326
  print(args.dataset)
327
  print('ssim:',np.mean(ssim))
328
  print('psnr:',np.mean(psnr))
329
+
330
  elif task=='binarization':
331
  fmeasures, pfmeasures,psnrs = [],[],[]
332
  for im_path in tqdm(im_paths):
333
  pred = cv2.imread(im_path.replace('_in','_docres'))
334
  gt = cv2.imread(im_path.replace('_in','_gt'))
335
+
336
+ # لضمان التوافق في الأشكال قبل حساب المقاييس
337
+ if pred.shape != gt.shape:
338
+ gt = cv2.resize(gt, (pred.shape[1], pred.shape[0]))
339
+
340
  pred = cv2.cvtColor(pred,cv2.COLOR_BGR2GRAY)
341
  gt = cv2.cvtColor(gt,cv2.COLOR_BGR2GRAY)
342
+
343
  pred[pred>155]=255
344
  pred[pred<=155]=0
345
  gt[gt>155]=255
346
  gt[gt<=155]=0
347
+
348
  fmeasure, pfmeasure,psnr,_,_,_ = utils.bin_metric(pred,gt)
349
  fmeasures.append(fmeasure)
350
  pfmeasures.append(pfmeasure)
 
353
  print('fmeasure:',np.mean(fmeasures))
354
  print('pfmeasure:',np.mean(pfmeasures))
355
  print('psnr:',np.mean(psnrs))