| | import sys |
| |
|
| | sys.dont_write_bytecode = True |
| |
|
| | import os |
| | import re |
| | import glob |
| | import time |
| | import torch |
| | import argparse |
| | import subprocess |
| | import torch.multiprocessing as mp |
| |
|
| | from core.config import Config |
| | from core import Trainer |
| |
|
| | def main(rank, config): |
| | trainer = Trainer(rank, config) |
| | trainer.train_loop() |
| |
|
| | if __name__ == "__main__": |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--config', type=str, default=None, help='Name of config file') |
| | parser.add_argument('--seed', type=int, default=-1, help='Seed') |
| | parser.add_argument('--device', type=int, default=-1, help='Device') |
| | args = parser.parse_args() |
| |
|
| | if args.config: |
| | args.config = args.config + '.yaml' if not args.config.endswith('.yaml') else args.config |
| | config_files = glob.glob(f'./config/**/{args.config}', recursive=True) |
| | assert len(config_files) == 1, "Config files conflict" |
| | config_path = config_files[0] |
| | config = Config(config_path).get_config_dict() |
| | else: |
| | config = Config("./config/InfLoRA.yaml").get_config_dict() |
| |
|
| | if config['device_ids'] == 'auto': |
| | least_utilized_device = 0 |
| | lowest_utilization = float('inf') |
| |
|
| | try: |
| | result = subprocess.run( |
| | ['nvidia-smi', '--query-gpu=index,memory.used,memory.total,utilization.gpu', '--format=csv,noheader,nounits'], |
| | stdout=subprocess.PIPE, |
| | stderr=subprocess.PIPE, |
| | text=True |
| | ) |
| | if result.returncode != 0: |
| | raise RuntimeError(f"nvidia-smi error: {result.stderr}") |
| |
|
| | gpu_info = result.stdout.strip().split('\n') |
| | gpu_utilization = [] |
| |
|
| | for gpu in gpu_info: |
| | match = re.match(r'(\d+),\s*(\d+),\s*(\d+),\s*(\d+)', gpu) |
| | if match: |
| | device_id, mem_used, mem_total, gpu_util = map(int, match.groups()) |
| | |
| | utilization_score = gpu_util + (mem_used / mem_total) * 100 |
| | gpu_utilization.append((device_id, utilization_score)) |
| |
|
| | |
| | gpu_utilization.sort(key=lambda x: x[1]) |
| | config["device_ids"] = [str(gpu[0]) for gpu in gpu_utilization[:config["n_gpu"]]] |
| |
|
| | except Exception as e: |
| | config["device_ids"] = range(config["n_gpu"]) |
| | print(f"Error while querying GPUs: {e}, using default device {config['device_ids']}") |
| |
|
| | if args.seed > -1: |
| | print(f'Seed : {config["seed"]} -> {args.seed}') |
| | config['seed'] = args.seed |
| |
|
| | if args.device > -1: |
| | config['device_ids'] = args.device |
| |
|
| | if not isinstance(config['device_ids'], list): |
| | config['device_ids'] = [config['device_ids']] |
| |
|
| | print(f'Selected GPUs: {config["device_ids"]}') |
| |
|
| | if config["n_gpu"] > 1: |
| | mp.spawn(main, nprocs=config["n_gpu"], args=(config,)) |
| | pass |
| | os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"] |
| | else: |
| | main(0, config) |
| |
|