Spaces:
Running on Zero
Running on Zero
| # coding: utf-8 | |
| __author__ = "Ilya Kiselev (kiselecheck): https://github.com/kiselecheck" | |
| __version__ = "1.0.1" | |
| import warnings | |
| import torch | |
| import torch.multiprocessing as mp | |
| from train import train_model | |
| from utils.settings import cleanup_ddp | |
| warnings.filterwarnings("ignore") | |
| def train_model_single(rank: int, world_size: int, args=None): | |
| """ | |
| Trains the model based on the provided arguments, including data preparation, optimizer setup, | |
| and loss calculation. The model is trained for multiple epochs with logging via wandb. | |
| Args: | |
| world_size: | |
| rank: | |
| args: Command-line arguments containing configuration paths, hyperparameters, and other settings. | |
| Returns: | |
| None | |
| """ | |
| train_model(args, rank, world_size) # Close DDP | |
| def train_model_ddp(args=None): | |
| world_size = torch.cuda.device_count() | |
| try: | |
| mp.spawn( | |
| train_model_single, args=(world_size, args), nprocs=world_size, join=True | |
| ) | |
| except Exception as e: | |
| cleanup_ddp() | |
| raise e | |
| if __name__ == "__main__": | |
| train_model_ddp() | |