ToletiSri commited on
Commit
586be7d
·
1 Parent(s): c43984a

Update utils.py

Browse files

Add normalisation before passing image to model during test

Files changed (1) hide show
  1. utils.py +20 -7
utils.py CHANGED
@@ -524,13 +524,25 @@ def plot_couple_examples(model, loader, thresh, iou_thresh, anchors):
524
  plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
525
 
526
  def plot_single_image(model, image, thresh, iou_thresh, anchors):
527
- print('calling to plot single image')
528
  model.eval()
529
- with torch.no_grad():
530
- x = image
531
- print('Train begin')
 
 
 
 
 
 
 
 
 
 
 
 
 
532
  out = model(x)
533
- print('Train end')
534
  bboxes = [[] for _ in range(x.shape[0])]
535
  for i in range(3):
536
  batch_size, A, S, _, _ = out[i].shape
@@ -542,12 +554,13 @@ def plot_single_image(model, image, thresh, iou_thresh, anchors):
542
  bboxes[idx] += box
543
 
544
  model.train()
545
- print('batch size = ',batch_size)
546
  for i in range(batch_size):
 
547
  nms_boxes = non_max_suppression(
548
  bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
549
  )
550
- print('calling to plot bounding boxes image')
551
  fig = plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
552
  return fig
553
 
 
524
  plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
525
 
526
  def plot_single_image(model, image, thresh, iou_thresh, anchors):
 
527
  model.eval()
528
+ with torch.no_grad():
529
+
530
+ import albumentations as A
531
+ from albumentations.pytorch import ToTensorV2
532
+ import config
533
+ #print('----------------')
534
+ #print('type of input image before transform = ',type(image))
535
+ test_transforms = A.Compose(
536
+ [
537
+ A.Normalize(mean=[0, 0, 0], std=[1, 1, 1], max_pixel_value=255,),
538
+ ToTensorV2(),
539
+ ])
540
+
541
+ img = test_transforms(image=image)
542
+ print('type of input image after transform = ',type(img))
543
+ x = img['image'].unsqueeze(0)
544
  out = model(x)
545
+ print('Train end, out size = ',out[0].size())
546
  bboxes = [[] for _ in range(x.shape[0])]
547
  for i in range(3):
548
  batch_size, A, S, _, _ = out[i].shape
 
554
  bboxes[idx] += box
555
 
556
  model.train()
557
+
558
  for i in range(batch_size):
559
+ print('total bboxes input to nms = ', len(bboxes[i][0]))
560
  nms_boxes = non_max_suppression(
561
  bboxes[i], iou_threshold=iou_thresh, threshold=thresh, box_format="midpoint",
562
  )
563
+ print('calling to plot bounding boxes image, nms_boxes = ',nms_boxes)
564
  fig = plot_image(x[i].permute(1,2,0).detach().cpu(), nms_boxes)
565
  return fig
566