mohamed12ahmed commited on
Commit
161f282
·
verified ·
1 Parent(s): e9d7349

Update data/MBD/infer.py

Browse files
Files changed (1) hide show
  1. data/MBD/infer.py +43 -55
data/MBD/infer.py CHANGED
@@ -5,21 +5,21 @@ import torch.nn.functional as F
5
  import glob
6
  import cv2
7
  from tqdm import tqdm
8
-
9
  import time
10
  import os
11
  from model.deep_lab_model.deeplab import *
12
  from MBD import mask_base_dewarper
13
  import time
14
-
15
  from utils import cvimg2torch,torch2cvimg
16
 
 
 
 
17
 
18
 
19
- def net1_net2_infer(model,img_paths,args):
20
-
21
  ### validate on the real datasets
22
- seg_model=model
23
  seg_model.eval()
24
  for img_path in tqdm(img_paths):
25
  if os.path.exists(img_path.replace('_origin','_capture')):
@@ -28,16 +28,18 @@ def net1_net2_infer(model,img_paths,args):
28
  ### segmentation mask predict
29
  img_org = cv2.imread(img_path)
30
  h_org,w_org = img_org.shape[:2]
31
- img = cv2.resize(img_org,(448, 448))
32
  img = cv2.GaussianBlur(img,(15,15),0,0)
33
  img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
34
  img = cvimg2torch(img)
35
-
36
  with torch.no_grad():
37
- pred = seg_model(img.cuda())
 
38
  mask_pred = pred[:,0,:,:].unsqueeze(1)
39
  mask_pred = F.interpolate(mask_pred,(h_org,w_org))
40
- mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
 
41
  mask_pred = (mask_pred*255).astype(np.uint8)
42
  kernel = np.ones((3,3))
43
  mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
@@ -59,40 +61,46 @@ def net1_net2_infer(model,img_paths,args):
59
  # cv2.waitKey(0)
60
  cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
61
  cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
62
-
63
  grid0 = cv2.resize(grid[:,:,0],(128,128))
64
  grid1 = cv2.resize(grid[:,:,1],(128,128))
65
  grid = np.stack((grid0,grid1),axis=-1)
66
  np.save(img_path.replace('_origin','_grid1'),grid)
67
 
68
-
69
  def net1_net2_infer_single_im(img,model_path):
70
  seg_model = DeepLab(num_classes=1,
71
  backbone='resnet',
72
  output_stride=16,
73
  sync_bn=None,
74
  freeze_bn=False)
75
- seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
76
- seg_model.cuda()
77
- checkpoint = torch.load(model_path)
 
 
 
 
 
 
 
78
  seg_model.load_state_dict(checkpoint['model_state'])
 
79
  ### validate on the real datasets
80
  seg_model.eval()
 
81
  ### segmentation mask predict
82
  img_org = img
83
  h_org,w_org = img_org.shape[:2]
84
- img = cv2.resize(img_org,(448, 448))
85
  img = cv2.GaussianBlur(img,(15,15),0,0)
86
  img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
87
  img = cvimg2torch(img)
88
-
89
  with torch.no_grad():
90
- # from torchtoolbox.tools import summary
91
- # print(summary(seg_model,torch.rand((1, 3, 448, 448)).cuda())) 59.4M 135.6G
92
-
93
- pred = seg_model(img.cuda())
94
  mask_pred = pred[:,0,:,:].unsqueeze(1)
95
  mask_pred = F.interpolate(mask_pred,(h_org,w_org))
 
96
  mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
97
  mask_pred = (mask_pred*255).astype(np.uint8)
98
  kernel = np.ones((3,3))
@@ -100,52 +108,32 @@ def net1_net2_infer_single_im(img,model_path):
100
  mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
101
  mask_pred[mask_pred>100] = 255
102
  mask_pred[mask_pred<100] = 0
103
- ### tps transform base on the mask
104
- # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
105
- # try:
106
- # dewarp, grid = mask_base_dewarper(img_org,mask_pred)
107
- # except:
108
- # print('fail')
109
- # grid = np.meshgrid(np.arange(w_org),np.arange(h_org))/np.array([w_org,h_org]).reshape(2,1,1)
110
- # grid = torch.from_numpy((grid-0.5)*2).float().unsqueeze(0).permute(0,2,3,1)
111
- # dewarp = torch2cvimg(F.grid_sample(cvimg2torch(img_org),grid))[0]
112
- # grid = grid[0].numpy()
113
- # cv2.imshow('in',cv2.resize(img_org,(512,512)))
114
- # cv2.imshow('out',cv2.resize(dewarp,(512,512)))
115
- # cv2.waitKey(0)
116
- # cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
117
- # cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
118
-
119
- # grid0 = cv2.resize(grid[:,:,0],(128,128))
120
- # grid1 = cv2.resize(grid[:,:,1],(128,128))
121
- # grid = np.stack((grid0,grid1),axis=-1)
122
- # np.save(img_path.replace('_origin','_grid1'),grid)
123
  return mask_pred
124
 
125
-
126
-
127
  if __name__ == '__main__':
128
  parser = argparse.ArgumentParser(description='Hyperparams')
129
  parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
130
- parser.add_argument('--img_rows', nargs='?', type=int, default=448,
131
- help='Height of the input image')
132
- parser.add_argument('--img_cols', nargs='?', type=int, default=448,
133
- help='Width of the input image')
134
- parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl',
135
- help='Path to previous saved model to restart from')
136
  args = parser.parse_args()
137
-
138
  seg_model = DeepLab(num_classes=1,
139
  backbone='resnet',
140
  output_stride=16,
141
  sync_bn=None,
142
  freeze_bn=False)
143
- seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
144
- seg_model.cuda()
145
- checkpoint = torch.load(args.seg_model_path)
 
 
 
 
 
 
146
  seg_model.load_state_dict(checkpoint['model_state'])
147
-
148
  im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
149
-
150
  net1_net2_infer(seg_model,im_paths,args)
151
-
 
5
  import glob
6
  import cv2
7
  from tqdm import tqdm
 
8
  import time
9
  import os
10
  from model.deep_lab_model.deeplab import *
11
  from MBD import mask_base_dewarper
12
  import time
 
13
  from utils import cvimg2torch,torch2cvimg
14
 
15
+ # 1. تحديد الجهاز ليكون CPU بشكل إجباري
16
+ device = torch.device('cpu')
17
+ print(f"PyTorch running on device: {device}")
18
 
19
 
20
+ def net1_net2_infer(model, img_paths, args):
 
21
  ### validate on the real datasets
22
+ seg_model = model
23
  seg_model.eval()
24
  for img_path in tqdm(img_paths):
25
  if os.path.exists(img_path.replace('_origin','_capture')):
 
28
  ### segmentation mask predict
29
  img_org = cv2.imread(img_path)
30
  h_org,w_org = img_org.shape[:2]
31
+ img = cv2.resize(img_org,(448, 448))
32
  img = cv2.GaussianBlur(img,(15,15),0,0)
33
  img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
34
  img = cvimg2torch(img)
35
+
36
  with torch.no_grad():
37
+ # التعديل رقم 1: نقل المدخلات إلى الجهاز المحدد (CPU)
38
+ pred = seg_model(img.to(device))
39
  mask_pred = pred[:,0,:,:].unsqueeze(1)
40
  mask_pred = F.interpolate(mask_pred,(h_org,w_org))
41
+ # نقل الناتج إلى CPU قبل تحويله إلى NumPy
42
+ mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
43
  mask_pred = (mask_pred*255).astype(np.uint8)
44
  kernel = np.ones((3,3))
45
  mask_pred = cv2.dilate(mask_pred,kernel,iterations=3)
 
61
  # cv2.waitKey(0)
62
  cv2.imwrite(img_path.replace('_origin','_capture'),dewarp)
63
  cv2.imwrite(img_path.replace('_origin','_mask_new'),mask_pred)
 
64
  grid0 = cv2.resize(grid[:,:,0],(128,128))
65
  grid1 = cv2.resize(grid[:,:,1],(128,128))
66
  grid = np.stack((grid0,grid1),axis=-1)
67
  np.save(img_path.replace('_origin','_grid1'),grid)
68
 
 
69
  def net1_net2_infer_single_im(img,model_path):
70
  seg_model = DeepLab(num_classes=1,
71
  backbone='resnet',
72
  output_stride=16,
73
  sync_bn=None,
74
  freeze_bn=False)
75
+
76
+ # التعديل رقم 2: إزالة DataParallel لأنها مصممة للـ GPU
77
+ # واستبدالها بتحميل النموذج مباشرة
78
+ # seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
79
+
80
+ # التعديل رقم 3: نقل النموذج إلى الجهاز المحدد (CPU) بدلاً من .cuda()
81
+ seg_model.to(device)
82
+
83
+ # تحميل النموذج باستخدام map_location للتأكد من تحميله على CPU
84
+ checkpoint = torch.load(model_path, map_location=device)
85
  seg_model.load_state_dict(checkpoint['model_state'])
86
+
87
  ### validate on the real datasets
88
  seg_model.eval()
89
+
90
  ### segmentation mask predict
91
  img_org = img
92
  h_org,w_org = img_org.shape[:2]
93
+ img = cv2.resize(img_org,(448, 448))
94
  img = cv2.GaussianBlur(img,(15,15),0,0)
95
  img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
96
  img = cvimg2torch(img)
97
+
98
  with torch.no_grad():
99
+ # التعديل رقم 4: نقل المدخلات إلى الجهاز المحدد (CPU)
100
+ pred = seg_model(img.to(device))
 
 
101
  mask_pred = pred[:,0,:,:].unsqueeze(1)
102
  mask_pred = F.interpolate(mask_pred,(h_org,w_org))
103
+ # نقل الناتج إلى CPU قبل تحويله إلى NumPy
104
  mask_pred = mask_pred.squeeze(0).squeeze(0).cpu().numpy()
105
  mask_pred = (mask_pred*255).astype(np.uint8)
106
  kernel = np.ones((3,3))
 
108
  mask_pred = cv2.erode(mask_pred,kernel,iterations=3)
109
  mask_pred[mask_pred>100] = 255
110
  mask_pred[mask_pred<100] = 0
111
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  return mask_pred
113
 
 
 
114
  if __name__ == '__main__':
115
  parser = argparse.ArgumentParser(description='Hyperparams')
116
  parser.add_argument('--img_folder', nargs='?', type=str, default='./all_data',help='Data path to load data')
117
+ parser.add_argument('--img_rows', nargs='?', type=int, default=448, help='Height of the input image')
118
+ parser.add_argument('--img_cols', nargs='?', type=int, default=448, help='Width of the input image')
119
+ parser.add_argument('--seg_model_path', nargs='?', type=str, default='checkpoints/mbd.pkl', help='Path to previous saved model to restart from')
 
 
 
120
  args = parser.parse_args()
121
+
122
  seg_model = DeepLab(num_classes=1,
123
  backbone='resnet',
124
  output_stride=16,
125
  sync_bn=None,
126
  freeze_bn=False)
127
+
128
+ # التعديل رقم 5: إزالة DataParallel
129
+ # seg_model = torch.nn.DataParallel(seg_model, device_ids=range(torch.cuda.device_count()))
130
+
131
+ # التعديل رقم 6: نقل النموذج إلى الجهاز المحدد (CPU) بدلاً من .cuda()
132
+ seg_model.to(device)
133
+
134
+ # تحميل النموذج باستخدام map_location للتأكد من تحميله على CPU
135
+ checkpoint = torch.load(args.seg_model_path, map_location=device)
136
  seg_model.load_state_dict(checkpoint['model_state'])
137
+
138
  im_paths = glob.glob(os.path.join(args.img_folder,'*_origin.*'))
 
139
  net1_net2_infer(seg_model,im_paths,args)