IvanBanny commited on
Commit
db39442
·
1 Parent(s): 9c4195d

minor changes

Browse files
Files changed (1) hide show
  1. train_dist.py +3 -5
train_dist.py CHANGED
@@ -255,9 +255,9 @@ def train_worker(rank, world_size, args):
255
  dampening=0, weight_decay=1e-4, nesterov=True)
256
  criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
257
 
258
- if args.checkpoint:
259
  map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
260
- checkpoint = torch.load(args.checkpoint, map_location=map_location)
261
  model.module.load_state_dict(checkpoint['model_state_dict'])
262
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
263
 
@@ -359,8 +359,6 @@ def train_worker(rank, world_size, args):
359
 
360
  miniplaces_test = MiniPlaces(data_root, split='test', transform=val_transform)
361
  test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
362
- checkpoint = torch.load(args.checkpoint, map_location=device)
363
- model.module.load_state_dict(checkpoint['model_state_dict'])
364
 
365
  preds = test(model, test_loader, device)
366
  if rank == 0: # Only write predictions on rank 0
@@ -435,7 +433,7 @@ if __name__ == "__main__":
435
  parser = argparse.ArgumentParser()
436
  parser.add_argument('--test', action='store_true')
437
  parser.add_argument('--checkpoint')
438
- parser.add_argument('--epochs', type=int, default=100)
439
  parser.add_argument('--batch_size', type=int, default=64)
440
  parser.add_argument('--port', type=int, default=4224)
441
  args = parser.parse_args()
 
255
  dampening=0, weight_decay=1e-4, nesterov=True)
256
  criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
257
 
258
+ if args.checkpoint or args.test:
259
  map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
260
+ checkpoint = torch.load((args.checkpoint if args.checkpoint else 'model.ckpt'), map_location=map_location)
261
  model.module.load_state_dict(checkpoint['model_state_dict'])
262
  optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
263
 
 
359
 
360
  miniplaces_test = MiniPlaces(data_root, split='test', transform=val_transform)
361
  test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
 
 
362
 
363
  preds = test(model, test_loader, device)
364
  if rank == 0: # Only write predictions on rank 0
 
433
  parser = argparse.ArgumentParser()
434
  parser.add_argument('--test', action='store_true')
435
  parser.add_argument('--checkpoint')
436
+ parser.add_argument('--epochs', type=int, default=200)
437
  parser.add_argument('--batch_size', type=int, default=64)
438
  parser.add_argument('--port', type=int, default=4224)
439
  args = parser.parse_args()