Update file trainer.cli.py
Browse files- trainer.cli.py +17 -1
trainer.cli.py
CHANGED
|
@@ -18,6 +18,9 @@ parser = ArgumentParser(
|
|
| 18 |
)
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
if __name__ == '__main__':
|
| 23 |
|
|
@@ -49,7 +52,20 @@ if __name__ == '__main__':
|
|
| 49 |
|
| 50 |
config.trainer.model = model
|
| 51 |
config.trainer.wandb = wandb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
|
| 54 |
trainer = Trainer(config.trainer)
|
| 55 |
-
trainer.train(batches)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
)
|
| 19 |
|
| 20 |
|
| 21 |
+
import torch
|
| 22 |
+
import torch.utils.bottleneck
|
| 23 |
+
|
| 24 |
|
| 25 |
if __name__ == '__main__':
|
| 26 |
|
|
|
|
| 52 |
|
| 53 |
config.trainer.model = model
|
| 54 |
config.trainer.wandb = wandb
|
| 55 |
+
|
| 56 |
+
# Create a bottleneck profiler
|
| 57 |
+
profiler = torch.utils.bottleneck.Profiler()
|
| 58 |
+
|
| 59 |
+
# Start profiling
|
| 60 |
+
profiler.start()
|
| 61 |
|
| 62 |
|
| 63 |
trainer = Trainer(config.trainer)
|
| 64 |
+
trainer.train(batches)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# Stop profiling
|
| 68 |
+
profiler.stop()
|
| 69 |
+
|
| 70 |
+
# Retrieve profiling data
|
| 71 |
+
profiler.report()
|