Darknsu commited on
Commit
70bb740
·
verified ·
1 Parent(s): 72ada7f

Update dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +123 -154
dataset.py CHANGED
@@ -97,7 +97,7 @@ class VideoDataSet(data.Dataset):
97
  self.feature_rgb_file = {}
98
  self.feature_flow_file = {}
99
  for file in self.video_list:
100
- feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz')
101
  if not os.path.exists(feature_path):
102
  raise ValueError(f"Feature file {feature_path} not found")
103
  feature_All[file] = np.load(feature_path)['feats']
@@ -110,7 +110,7 @@ class VideoDataSet(data.Dataset):
110
  self.feature_rgb_file = {}
111
  self.feature_flow_file = {}
112
  for file in self.video_list:
113
- feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz')
114
  if not os.path.exists(feature_path):
115
  raise ValueError(f"Feature file {feature_path} not found")
116
  feature_All[file] = np.load(feature_path)
@@ -123,7 +123,7 @@ class VideoDataSet(data.Dataset):
123
  self.feature_rgb_file = {}
124
  self.feature_flow_file = {}
125
  for file in self.video_list:
126
- feature_path = os.path.join(opt["video_feature_all_train"], file + '.pt')
127
  if not os.path.exists(feature_path):
128
  raise ValueError(f"Feature file {feature_path} not found")
129
  feature_All[file] = torch.load(feature_path)
@@ -164,7 +164,7 @@ class VideoDataSet(data.Dataset):
164
  self.feature_rgb_file = {}
165
  self.feature_flow_file = {}
166
  for file in self.video_list:
167
- feature_path = os.path.join(opt['video_feature_all_test'], file + '.npz')
168
  if not os.path.exists(feature_path):
169
  raise ValueError(f"Feature file {feature_path} not found")
170
  feature_All[file] = np.load(feature_path)['feats']
@@ -177,7 +177,7 @@ class VideoDataSet(data.Dataset):
177
  self.feature_rgb_file = {}
178
  self.feature_flow_file = {}
179
  for file in self.video_list:
180
- feature_path = os.path.join(opt['video_feature_all_test'], file + '.npz')
181
  if not os.path.exists(feature_path):
182
  raise ValueError(f"Feature file {feature_path} not found")
183
  feature_All[file] = np.load(feature_path)
@@ -190,7 +190,7 @@ class VideoDataSet(data.Dataset):
190
  self.feature_rgb_file = {}
191
  self.feature_flow_file = {}
192
  for file in self.video_list:
193
- feature_path = os.path.join(opt['video_feature_all_test'], file + '.pt')
194
  if not os.path.exists(feature_path):
195
  raise ValueError(f"Feature file {feature_path} not found")
196
  feature_All[file] = torch.load(feature_path)
@@ -213,27 +213,15 @@ class VideoDataSet(data.Dataset):
213
  elif opt['data_format'] == "npz":
214
  feature_file = {}
215
  for file in self.video_list:
216
- feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz')
217
- if os.path.exists(feature_path):
218
- feature_file[file] = np.load(feature_path)['feats']
219
- else:
220
- print(f"Warning: Feature file {feature_path} not found for length calculation")
221
  elif opt['data_format'] == "npz_i3d":
222
  feature_file = {}
223
  for file in self.video_list:
224
- feature_path = os.path.join(opt["video_feature_all_train"], file + '.npz')
225
- if os.path.exists(feature_path):
226
- feature_file[file] = np.load(feature_path)
227
- else:
228
- print(f"Warning: Feature file {feature_path} not found for length calculation")
229
  elif opt['data_format'] == "pt":
230
  feature_file = {}
231
  for file in self.video_list:
232
- feature_path = os.path.join(opt["video_feature_all_train"], file + '.pt')
233
- if os.path.exists(feature_path):
234
- feature_file[file] = torch.load(feature_path)
235
- else:
236
- print(f"Warning: Feature file {feature_path} not found for length calculation")
237
  else:
238
  if opt['data_format'] == "h5":
239
  feature_file = h5py.File(opt["video_feature_rgb_test"], 'r')
@@ -242,54 +230,35 @@ class VideoDataSet(data.Dataset):
242
  elif opt['data_format'] == "npz":
243
  feature_file = {}
244
  for file in self.video_list:
245
- feature_path = os.path.join(opt['video_feature_all_test'], file + '.npz')
246
- if os.path.exists(feature_path):
247
- feature_file[file] = np.load(feature_path)['feats']
248
- else:
249
- print(f"Warning: Feature file {feature_path} not found for length calculation")
250
  elif opt['data_format'] == "npz_i3d":
251
  feature_file = {}
252
  for file in self.video_list:
253
- feature_path = os.path.join(opt['video_feature_all_test'], file + '.npz')
254
- if os.path.exists(feature_path):
255
- feature_file[file] = np.load(feature_path)
256
- else:
257
- print(f"Warning: Feature file {feature_path} not found for length calculation")
258
  elif opt['data_format'] == "pt":
259
  feature_file = {}
260
  for file in self.video_list:
261
- feature_path = os.path.join(opt['video_feature_all_test'], file + '.pt')
262
- if os.path.exists(feature_path):
263
- feature_file[file] = torch.load(feature_path)
264
- else:
265
- print(f"Warning: Feature file {feature_path} not found for length calculation")
266
 
267
  keys = self.video_list
268
  if opt['data_format'] == "h5":
269
  for vidx in range(len(keys)):
270
- if keys[vidx] in feature_file:
271
- self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
272
  elif opt['data_format'] == "pickle":
273
  for vidx in range(len(keys)):
274
- if keys[vidx] in feature_file:
275
- self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
276
  elif opt['data_format'] == "npz":
277
  for vidx in range(len(keys)):
278
- if keys[vidx] in feature_file:
279
- self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
280
  elif opt['data_format'] == "npz_i3d":
281
  for vidx in range(len(keys)):
282
- if keys[vidx] in feature_file:
283
- self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
284
  elif opt['data_format'] == "pt":
285
  for vidx in range(len(keys)):
286
- if keys[vidx] in feature_file:
287
- self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
288
-
289
- if self.video_len: # Only save if we have any lengths
290
- outfile = open(self.video_len_path, "w")
291
- json.dump(self.video_len, outfile, indent=2)
292
- outfile.close()
293
 
294
  def _getDatasetDict(self):
295
  anno_database = load_json(self.video_anno_path)
@@ -368,29 +337,29 @@ class VideoDataSet(data.Dataset):
368
  video_name = self.video_list[index]
369
  duration = self.match_score[video_name].shape[0]
370
  for i in range(1, duration + 1):
371
- st = i - self._segment_size
372
  ed = i
373
- self._inputs_all.append([video_name, st, ed, data_idx])
374
  data_idx += 1
375
 
376
- self._inputs = self._inputs_all.copy()
377
- print(f"{self._subset} subset seg numbers: {len(self._inputs)}")
378
 
379
  def _makePropLabelUnit(self, i):
380
- video_name = self._inputs_all[i][0]
381
- st = self._inputs_all[i][1]
382
- ed = self._inputs_all[i][2]
383
  cls_anc = []
384
  reg_anc = []
385
 
386
- for j in range(0, len(self._anchors)):
387
- v1 = np.zeros(self._num_of_class)
388
  v1[-1] = 1
389
  v2 = np.zeros(2)
390
  v2[-1] = -1e3
391
- y_box = [ed - 1, self._anchors[j]]
392
 
393
- subset_label = self._get_train_subset_label(video_name, ed - self._anchors[j], ed)
394
  idx_list = []
395
  for ii in range(0, subset_label.shape[0]):
396
  for jj in range(0, subset_label.shape[1]):
@@ -399,23 +368,23 @@ class VideoDataSet(data.Dataset):
399
  idx_list.append(idx - 1)
400
 
401
  for idx in idx_list:
402
- target_box_idx = self._gt_action_list[video_name][idx]
403
- cls = int(target_box_idx[2])
404
- iou = calc_iou(y_box_idx, target_box)
405
- if iou >= self._pos_threshold or (j == len(self._anchors) - 1 and box_include_idx(y_box, target_box)) or (j == 0 and box_include_idx(target_box, y_box)):
406
  v1[cls] = 1
407
  v1[-1] = 0
408
- v2[0] = 1.0 * (target_box[0] - y_box[0]) / self._anchors[j]
409
  v2[1] = np.log(1.0 * max(1, target_box[1]) / y_box[1])
410
 
411
  cls_anc.append(v1)
412
  reg_anc.append(v2)
413
 
414
- v0 = np.zeros(self._num_of_class)
415
  v0[-1] = 1
416
  segment_size = ed - st
417
- y_box = [ed - 1, self._anchors[-1]]
418
- subset_label = self._get_subset_label(video_name, ed - self._anchors[-1], ed)
419
  idx_list = []
420
  for ii in range(0, subset_label.shape[0]):
421
  for jj in range(0, subset_label.shape[1]):
@@ -424,141 +393,141 @@ class VideoDataSet(data.Dataset):
424
  idx_list.append(idx - 1)
425
 
426
  for idx in idx_list:
427
- target_box = self._gt_action[video_name][idx]
428
  cls = int(target_box[2])
429
  iou = calc_iou(y_box, target_box)
430
  if iou >= 0:
431
  v0[cls] = 1
432
  v0[-1] = 0
433
 
434
- cls_anc = np.stack(cls._anc, idx=0)
435
- reg_anc = np.stack(reg._anc, idx=0)
436
  cls_snip = np.array(v0)
437
  return cls_anc, reg_anc, cls_snip
438
 
439
  def _loadPropLabel(self, filename):
440
  if os.path.exists(filename):
441
  prop_label_file = h5py.File(filename, 'r')
442
- self._cls_label = np.array(prop_label_file['cls_label'][:])
443
- self._reg_label = np.array(prop_label_file['reg_label'][:])
444
- self._snip_label = np.array(prop_label_file['snip_label'][:])
445
  prop_label_file.close()
446
- self._action_frame_count = np.sum(self._cls_label.reshape((-1, self._cls_label.shape[-1])), idx=0)
447
- self._action_frame_count = torch.Tensor(self._action_frame_count)
448
  return
449
 
450
  pool = Pool(os.cpu_count() // 2)
451
- labels = pool.map(self._makePropLabelUnit, range(0, len(self._inputs_all)))
452
  pool.close()
453
- pool pool.join()
454
 
455
  cls_label = []
456
  reg_label = []
457
  snip_label = []
458
  for i in range(0, len(labels)):
459
- cls_label[i].append(labels[i][0])
460
  reg_label.append(labels[i][1])
461
  snip_label.append(labels[i][2])
462
- self._cls_label = np.stack(labels_cls, idx=0)
463
- self._reg_label = np.stack(labels_reg, idx=0)
464
- self._snip_label = np.stack(labels_snip, idx=0)
465
 
466
  outfile = h5py.File(filename, 'w')
467
- dset_cls = outfile._create_dataset('/cls_label', self._cls_label.shape, shape=self._cls._label_shape, chunks=True, type=np.float32)
468
- dset_cls[_._ :] = self._cls._label[_._ :]
469
- dset_reg_label = outfile._create_dataset('/label_reg', self._reg._label.shape, shape=self._reg._label.shape, chunks=True, type=np.float32)
470
- dset_reg[_._ :] = self._reg._reg_label[_._ :]
471
- dset_snip_label = outfile._create_dataset('/snip_label', self._snip._label.shape, shape=self._snip._label.shape, chunks=True, type=np.float32)
472
- dset_snip[_._ :] = self._snip._snip_label[_._ :]
473
- outfile._close()
474
 
475
  return
476
 
477
- def _getitem_item(self, idx):
478
- video_name, st, ed, d_idx_data = self._inputs[idx]
479
  if st >= 0:
480
- feature_data = self._get_base_data(video_name, st, ed)
481
  else:
482
- feature_data = self._get_base_data(video_name, idx=0, st, ed)
483
- pad_func = torch.nn.ConstantPad2d(st, (0, 0, -st, 0), idx=0)
484
- data_feature = pad_func(data_feature)
485
 
486
- cls_label_data = torch.Tensor(self._cls_label[d_idx_data])
487
- reg_label_data = torch.Tensor(self._reg_label[d_idx_data])
488
- snip_label_data = torch.Tensor(self._snip_label[d_idx_data])
489
 
490
- return data_feature, cls_label_data, reg_label_data, snip_label_data
491
 
492
  def _get_base_data(self, video_name, st, ed):
493
- feature_rgb_data = self._feature_rgb_file[video_name]
494
- feature_rgb_data = feature_rgb_data[st:ed, :]
495
 
496
- if self._feature_flow_file is not None:
497
- feature_flow_data = self._feature_flow_file[video_name]
498
- feature_flow_data = feature_flow_data[st:ed, :]
499
- data_feature = np.append(feature_data_rgb, feature_flow_data, idx=1)
500
  else:
501
- data_feature = feature_rgb_data
502
- data_feature = torch.from_numpy(np.array(data_feature))
503
 
504
- return data_feature
505
 
506
- def _get_train_label_with_class(self, video_name, st, idx_ed):
507
- duration_data = len(self._match_score_data[video_name])
508
- st_padding_data = pad_0
509
- ed_padding_data = pad_0
510
  if st < 0:
511
- st_padding_data = -st
512
- st = pad_0
513
- if idx_ed > duration_data:
514
- ed_padding_data = idx_ed - duration_data
515
- idx_ed = duration_data
516
 
517
- match_score_data = torch.Tensor(self._match_score_data[video_name][st:idx_ed])
518
- if st_padding_data > pad_0:
519
- pad_func_2d = torch.nn.ConstantPad(data_2d, (pad_0, pad_0, st_padding_data, pad_0), idx=0)
520
- data_match_score = pad_func_2d(data_match_score)
521
- if ed_padding_data > pad_0:
522
- pad_func_2d = torch.nn(data_ConstantPad2d, (pad_0, pad_0, pad_0, ed_padding_data), idx=pad_0)
523
- pad_func_2d = pad(data_func_2d(data_match_score))
524
- return data_match_score
525
 
526
- def _len__(self):
527
- return len(self._inputs)
528
 
529
- def _reset_sample(self):
530
- self._inputs = self._inputs_all.copy()
531
 
532
- def _select_sample(self, idx):
533
- inputs_data = [self._inputs_all[i] for i in idx]
534
- self._inputs = inputs_data.copy()
535
  return
536
 
537
  class SuppressDataSet(data.Dataset):
538
  def __init__(self, opt, subset="train"):
539
- self._subset = subset
540
- self._mode = opt["mode"]
541
- self._data_file = h5py.File(opt["suppress_label_file"].format(self._subset + "_" + opt['setup']), 'r')
542
- self._video_list = list(self._data_file.keys())
543
- self._inputs = []
544
- for idx in range(0, len(self._video_list)):
545
- video_name = self._video_list[idx]
546
- duration_data = self._data_file[video_name + '/input_seq'].shape[0]
547
- for i in range(0, duration_data):
548
- self._inputs.append([video_name, i])
549
 
550
- print(f"{self._subset} subset seg numbers: {len(self._inputs)}")
551
 
552
- def _getitem__(self, idx):
553
- video_name, idx = self._inputs[idx]
554
 
555
- input_seq_data = self._data_file[video_name + '/input_seq'][idx]
556
- label_data = self._data_file[video_name + '/label_data'][idx]
557
 
558
- input_seq_data = torch.from_numpy(input_seq_data)
559
- label_data = torch.from_numpy(label_data)
560
 
561
- return input_seq_data, label_data
562
 
563
- def _len__(self):
564
- return len(self._inputs)
 
97
  self.feature_rgb_file = {}
98
  self.feature_flow_file = {}
99
  for file in self.video_list:
100
+ feature_path = opt["video_feature_all_train"] + file + '.npz'
101
  if not os.path.exists(feature_path):
102
  raise ValueError(f"Feature file {feature_path} not found")
103
  feature_All[file] = np.load(feature_path)['feats']
 
110
  self.feature_rgb_file = {}
111
  self.feature_flow_file = {}
112
  for file in self.video_list:
113
+ feature_path = opt["video_feature_all_train"] + file + '.npz'
114
  if not os.path.exists(feature_path):
115
  raise ValueError(f"Feature file {feature_path} not found")
116
  feature_All[file] = np.load(feature_path)
 
123
  self.feature_rgb_file = {}
124
  self.feature_flow_file = {}
125
  for file in self.video_list:
126
+ feature_path = opt["video_feature_all_train"] + file + '.pt'
127
  if not os.path.exists(feature_path):
128
  raise ValueError(f"Feature file {feature_path} not found")
129
  feature_All[file] = torch.load(feature_path)
 
164
  self.feature_rgb_file = {}
165
  self.feature_flow_file = {}
166
  for file in self.video_list:
167
+ feature_path = os.path.join(opt['video_feature_all_test'], video_name + '.npz')
168
  if not os.path.exists(feature_path):
169
  raise ValueError(f"Feature file {feature_path} not found")
170
  feature_All[file] = np.load(feature_path)['feats']
 
177
  self.feature_rgb_file = {}
178
  self.feature_flow_file = {}
179
  for file in self.video_list:
180
+ feature_path = os.path.join(opt['video_feature_all_test'], video_name + '.npz')
181
  if not os.path.exists(feature_path):
182
  raise ValueError(f"Feature file {feature_path} not found")
183
  feature_All[file] = np.load(feature_path)
 
190
  self.feature_rgb_file = {}
191
  self.feature_flow_file = {}
192
  for file in self.video_list:
193
+ feature_path = opt["video_feature_all_test"] + file + '.pt'
194
  if not os.path.exists(feature_path):
195
  raise ValueError(f"Feature file {feature_path} not found")
196
  feature_All[file] = torch.load(feature_path)
 
213
  elif opt['data_format'] == "npz":
214
  feature_file = {}
215
  for file in self.video_list:
216
+ feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')['feats']
 
 
 
 
217
  elif opt['data_format'] == "npz_i3d":
218
  feature_file = {}
219
  for file in self.video_list:
220
+ feature_file[file] = np.load(opt["video_feature_all_train"] + file + '.npz')
 
 
 
 
221
  elif opt['data_format'] == "pt":
222
  feature_file = {}
223
  for file in self.video_list:
224
+ feature_file[file] = torch.load(opt["video_feature_all_train"] + file + '.pt')
 
 
 
 
225
  else:
226
  if opt['data_format'] == "h5":
227
  feature_file = h5py.File(opt["video_feature_rgb_test"], 'r')
 
230
  elif opt['data_format'] == "npz":
231
  feature_file = {}
232
  for file in self.video_list:
233
+ feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')['feats']
 
 
 
 
234
  elif opt['data_format'] == "npz_i3d":
235
  feature_file = {}
236
  for file in self.video_list:
237
+ feature_file[file] = np.load(opt["video_feature_all_test"] + file + '.npz')
 
 
 
 
238
  elif opt['data_format'] == "pt":
239
  feature_file = {}
240
  for file in self.video_list:
241
+ feature_file[file] = torch.load(opt["video_feature_all_test"] + file + '.pt')
 
 
 
 
242
 
243
  keys = self.video_list
244
  if opt['data_format'] == "h5":
245
  for vidx in range(len(keys)):
246
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
 
247
  elif opt['data_format'] == "pickle":
248
  for vidx in range(len(keys)):
249
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
 
250
  elif opt['data_format'] == "npz":
251
  for vidx in range(len(keys)):
252
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
 
253
  elif opt['data_format'] == "npz_i3d":
254
  for vidx in range(len(keys)):
255
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]]['rgb'])
 
256
  elif opt['data_format'] == "pt":
257
  for vidx in range(len(keys)):
258
+ self.video_len[keys[vidx]] = len(feature_file[keys[vidx]])
259
+ outfile = open(self.video_len_path, "w")
260
+ json.dump(self.video_len, outfile, indent=2)
261
+ outfile.close()
 
 
 
262
 
263
  def _getDatasetDict(self):
264
  anno_database = load_json(self.video_anno_path)
 
337
  video_name = self.video_list[index]
338
  duration = self.match_score[video_name].shape[0]
339
  for i in range(1, duration + 1):
340
+ st = i - self.segment_size
341
  ed = i
342
+ self.inputs_all.append([video_name, st, ed, data_idx])
343
  data_idx += 1
344
 
345
+ self.inputs = self.inputs_all.copy()
346
+ print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
347
 
348
  def _makePropLabelUnit(self, i):
349
+ video_name = self.inputs_all[i][0]
350
+ st = self.inputs_all[i][1]
351
+ ed = self.inputs_all[i][2]
352
  cls_anc = []
353
  reg_anc = []
354
 
355
+ for j in range(0, len(self.anchors)):
356
+ v1 = np.zeros(self.num_of_class)
357
  v1[-1] = 1
358
  v2 = np.zeros(2)
359
  v2[-1] = -1e3
360
+ y_box = [ed - 1, self.anchors[j]]
361
 
362
+ subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[j], ed)
363
  idx_list = []
364
  for ii in range(0, subset_label.shape[0]):
365
  for jj in range(0, subset_label.shape[1]):
 
368
  idx_list.append(idx - 1)
369
 
370
  for idx in idx_list:
371
+ target_box = self.gt_action[video_name][idx]
372
+ cls = int(target_box[2])
373
+ iou = calc_iou(y_box, target_box)
374
+ if iou >= self.pos_threshold or (j == len(self.anchors) - 1 and box_include(y_box, target_box)) or (j == 0 and box_include(target_box, y_box)):
375
  v1[cls] = 1
376
  v1[-1] = 0
377
+ v2[0] = 1.0 * (target_box[0] - y_box[0]) / self.anchors[j]
378
  v2[1] = np.log(1.0 * max(1, target_box[1]) / y_box[1])
379
 
380
  cls_anc.append(v1)
381
  reg_anc.append(v2)
382
 
383
+ v0 = np.zeros(self.num_of_class)
384
  v0[-1] = 1
385
  segment_size = ed - st
386
+ y_box = [ed - 1, self.anchors[-1]]
387
+ subset_label = self._get_train_label_with_class(video_name, ed - self.anchors[-1], ed)
388
  idx_list = []
389
  for ii in range(0, subset_label.shape[0]):
390
  for jj in range(0, subset_label.shape[1]):
 
393
  idx_list.append(idx - 1)
394
 
395
  for idx in idx_list:
396
+ target_box = self.gt_action[video_name][idx]
397
  cls = int(target_box[2])
398
  iou = calc_iou(y_box, target_box)
399
  if iou >= 0:
400
  v0[cls] = 1
401
  v0[-1] = 0
402
 
403
+ cls_anc = np.stack(cls_anc, axis=0)
404
+ reg_anc = np.stack(reg_anc, axis=0)
405
  cls_snip = np.array(v0)
406
  return cls_anc, reg_anc, cls_snip
407
 
408
  def _loadPropLabel(self, filename):
409
  if os.path.exists(filename):
410
  prop_label_file = h5py.File(filename, 'r')
411
+ self.cls_label = np.array(prop_label_file['cls_label'][:])
412
+ self.reg_label = np.array(prop_label_file['reg_label'][:])
413
+ self.snip_label = np.array(prop_label_file['snip_label'][:])
414
  prop_label_file.close()
415
+ self.action_frame_count = np.sum(self.cls_label.reshape((-1, self.cls_label.shape[-1])), axis=0)
416
+ self.action_frame_count = torch.Tensor(self.action_frame_count)
417
  return
418
 
419
  pool = Pool(os.cpu_count() // 2)
420
+ labels = pool.map(self._makePropLabelUnit, range(0, len(self.inputs_all)))
421
  pool.close()
422
+ pool.join()
423
 
424
  cls_label = []
425
  reg_label = []
426
  snip_label = []
427
  for i in range(0, len(labels)):
428
+ cls_label.append(labels[i][0])
429
  reg_label.append(labels[i][1])
430
  snip_label.append(labels[i][2])
431
+ self.cls_label = np.stack(cls_label, axis=0)
432
+ self.reg_label = np.stack(reg_label, axis=0)
433
+ self.snip_label = np.stack(snip_label, axis=0)
434
 
435
  outfile = h5py.File(filename, 'w')
436
+ dset_cls = outfile.create_dataset('/cls_label', self.cls_label.shape, maxshape=self.cls_label.shape, chunks=True, dtype=np.float32)
437
+ dset_cls[:, :] = self.cls_label[:, :]
438
+ dset_reg = outfile.create_dataset('/reg_label', self.reg_label.shape, maxshape=self.reg_label.shape, chunks=True, dtype=np.float32)
439
+ dset_reg[:, :] = self.reg_label[:, :]
440
+ dset_snip = outfile.create_dataset('/snip_label', self.snip_label.shape, maxshape=self.snip_label.shape, chunks=True, dtype=np.float32)
441
+ dset_snip[:, :] = self.snip_label[:, :]
442
+ outfile.close()
443
 
444
  return
445
 
446
+ def __getitem__(self, index):
447
+ video_name, st, ed, data_idx = self.inputs[index]
448
  if st >= 0:
449
+ feature = self._get_base_data(video_name, st, ed)
450
  else:
451
+ feature = self._get_base_data(video_name, 0, ed)
452
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, -st, 0), 0)
453
+ feature = padfunc2d(feature)
454
 
455
+ cls_label = torch.Tensor(self.cls_label[data_idx])
456
+ reg_label = torch.Tensor(self.reg_label[data_idx])
457
+ snip_label = torch.Tensor(self.snip_label[data_idx])
458
 
459
+ return feature, cls_label, reg_label, snip_label
460
 
461
  def _get_base_data(self, video_name, st, ed):
462
+ feature_rgb = self.feature_rgb_file[video_name]
463
+ feature_rgb = feature_rgb[st:ed, :]
464
 
465
+ if self.feature_flow_file is not None:
466
+ feature_flow = self.feature_flow_file[video_name]
467
+ feature_flow = feature_flow[st:ed, :]
468
+ feature = np.append(feature_rgb, feature_flow, axis=1)
469
  else:
470
+ feature = feature_rgb
471
+ feature = torch.from_numpy(np.array(feature))
472
 
473
+ return feature
474
 
475
+ def _get_train_label_with_class(self, video_name, st, ed):
476
+ duration = len(self.match_score[video_name])
477
+ st_padding = 0
478
+ ed_padding = 0
479
  if st < 0:
480
+ st_padding = -st
481
+ st = 0
482
+ if ed > duration:
483
+ ed_padding = ed - duration
484
+ ed = duration
485
 
486
+ match_score = torch.Tensor(self.match_score[video_name][st:ed])
487
+ if st_padding > 0:
488
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, st_padding, 0), 0)
489
+ match_score = padfunc2d(match_score)
490
+ if ed_padding > 0:
491
+ padfunc2d = torch.nn.ConstantPad2d((0, 0, 0, ed_padding), 0)
492
+ match_score = padfunc2d(match_score)
493
+ return match_score
494
 
495
+ def __len__(self):
496
+ return len(self.inputs)
497
 
498
+ def reset_sample(self):
499
+ self.inputs = self.inputs_all.copy()
500
 
501
+ def select_sample(self, idx):
502
+ inputs = [self.inputs_all[i] for i in idx]
503
+ self.inputs = inputs.copy()
504
  return
505
 
506
  class SuppressDataSet(data.Dataset):
507
  def __init__(self, opt, subset="train"):
508
+ self.subset = subset
509
+ self.mode = opt["mode"]
510
+ self.data_file = h5py.File(opt["suppress_label_file"].format(self.subset + "_" + opt['setup']), 'r')
511
+ self.video_list = list(self.data_file.keys())
512
+ self.inputs = []
513
+ for index in range(0, len(self.video_list)):
514
+ video_name = self.video_list[index]
515
+ duration = self.data_file[video_name + '/input'].shape[0]
516
+ for i in range(0, duration):
517
+ self.inputs.append([video_name, i])
518
 
519
+ print(f"{self.subset} subset seg numbers: {len(self.inputs)}")
520
 
521
+ def __getitem__(self, index):
522
+ video_name, idx = self.inputs[index]
523
 
524
+ input_seq = self.data_file[video_name + '/input'][idx]
525
+ label = self.data_file[video_name + '/label'][idx]
526
 
527
+ input_seq = torch.from_numpy(input_seq)
528
+ label = torch.from_numpy(label)
529
 
530
+ return input_seq, label
531
 
532
+ def __len__(self):
533
+ return len(self.inputs)