minor changes
Browse files- train_dist.py +3 -5
train_dist.py
CHANGED
|
@@ -255,9 +255,9 @@ def train_worker(rank, world_size, args):
|
|
| 255 |
dampening=0, weight_decay=1e-4, nesterov=True)
|
| 256 |
criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
|
| 257 |
|
| 258 |
-
if args.checkpoint:
|
| 259 |
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
| 260 |
-
checkpoint = torch.load(args.checkpoint, map_location=map_location)
|
| 261 |
model.module.load_state_dict(checkpoint['model_state_dict'])
|
| 262 |
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 263 |
|
|
@@ -359,8 +359,6 @@ def train_worker(rank, world_size, args):
|
|
| 359 |
|
| 360 |
miniplaces_test = MiniPlaces(data_root, split='test', transform=val_transform)
|
| 361 |
test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
|
| 362 |
-
checkpoint = torch.load(args.checkpoint, map_location=device)
|
| 363 |
-
model.module.load_state_dict(checkpoint['model_state_dict'])
|
| 364 |
|
| 365 |
preds = test(model, test_loader, device)
|
| 366 |
if rank == 0: # Only write predictions on rank 0
|
|
@@ -435,7 +433,7 @@ if __name__ == "__main__":
|
|
| 435 |
parser = argparse.ArgumentParser()
|
| 436 |
parser.add_argument('--test', action='store_true')
|
| 437 |
parser.add_argument('--checkpoint')
|
| 438 |
-
parser.add_argument('--epochs', type=int, default=
|
| 439 |
parser.add_argument('--batch_size', type=int, default=64)
|
| 440 |
parser.add_argument('--port', type=int, default=4224)
|
| 441 |
args = parser.parse_args()
|
|
|
|
| 255 |
dampening=0, weight_decay=1e-4, nesterov=True)
|
| 256 |
criterion = torch.nn.CrossEntropyLoss(reduction='mean', label_smoothing=0.1)
|
| 257 |
|
| 258 |
+
if args.checkpoint or args.test:
|
| 259 |
map_location = {'cuda:%d' % 0: 'cuda:%d' % rank}
|
| 260 |
+
checkpoint = torch.load((args.checkpoint if args.checkpoint else 'model.ckpt'), map_location=map_location)
|
| 261 |
model.module.load_state_dict(checkpoint['model_state_dict'])
|
| 262 |
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| 263 |
|
|
|
|
| 359 |
|
| 360 |
miniplaces_test = MiniPlaces(data_root, split='test', transform=val_transform)
|
| 361 |
test_loader = DataLoader(miniplaces_test, batch_size=args.batch_size, num_workers=2, shuffle=False)
|
|
|
|
|
|
|
| 362 |
|
| 363 |
preds = test(model, test_loader, device)
|
| 364 |
if rank == 0: # Only write predictions on rank 0
|
|
|
|
| 433 |
parser = argparse.ArgumentParser()
|
| 434 |
parser.add_argument('--test', action='store_true')
|
| 435 |
parser.add_argument('--checkpoint')
|
| 436 |
+
parser.add_argument('--epochs', type=int, default=200)
|
| 437 |
parser.add_argument('--batch_size', type=int, default=64)
|
| 438 |
parser.add_argument('--port', type=int, default=4224)
|
| 439 |
args = parser.parse_args()
|