File size: 3,147 Bytes
5fee096
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
import sys

sys.dont_write_bytecode = True

import os
import re
import glob
import time
import torch
import argparse
import subprocess
import torch.multiprocessing as mp

from core.config import Config
from core import Trainer

def main(rank, config):
    trainer = Trainer(rank, config)
    trainer.train_loop()

if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument('--config', type=str, default=None, help='Name of config file')
    parser.add_argument('--seed', type=int, default=-1, help='Seed')
    parser.add_argument('--device', type=int, default=-1, help='Device')
    args = parser.parse_args()

    if args.config:
        args.config = args.config + '.yaml' if not args.config.endswith('.yaml') else args.config
        config_files = glob.glob(f'./config/**/{args.config}', recursive=True)
        assert len(config_files) == 1, "Config files conflict"
        config_path = config_files[0]
        config = Config(config_path).get_config_dict()
    else:
        config = Config("./config/InfLoRA.yaml").get_config_dict()    

    if config['device_ids'] == 'auto':
        least_utilized_device = 0
        lowest_utilization = float('inf')

        try:
            result = subprocess.run(
                ['nvidia-smi', '--query-gpu=index,memory.used,memory.total,utilization.gpu', '--format=csv,noheader,nounits'],
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE,
                text=True
            )
            if result.returncode != 0:
                raise RuntimeError(f"nvidia-smi error: {result.stderr}")

            gpu_info = result.stdout.strip().split('\n')
            gpu_utilization = []

            for gpu in gpu_info:
                match = re.match(r'(\d+),\s*(\d+),\s*(\d+),\s*(\d+)', gpu)
                if match:
                    device_id, mem_used, mem_total, gpu_util = map(int, match.groups())
                    # Combine memory usage and GPU utilization to determine the utilization score
                    utilization_score = gpu_util + (mem_used / mem_total) * 100
                    gpu_utilization.append((device_id, utilization_score))

            # Sort GPUs by utilization score (ascending) and select the least utilized GPUs
            gpu_utilization.sort(key=lambda x: x[1])
            config["device_ids"] = [str(gpu[0]) for gpu in gpu_utilization[:config["n_gpu"]]]

        except Exception as e:
            config["device_ids"] = range(config["n_gpu"])
            print(f"Error while querying GPUs: {e}, using default device {config['device_ids']}")

    if args.seed > -1:
        print(f'Seed : {config["seed"]} -> {args.seed}')
        config['seed'] = args.seed

    if args.device > -1:
        config['device_ids'] = args.device

    if not isinstance(config['device_ids'], list):
        config['device_ids'] = [config['device_ids']]

    print(f'Selected GPUs: {config["device_ids"]}')

    if config["n_gpu"] > 1:
        mp.spawn(main, nprocs=config["n_gpu"], args=(config,))
        pass
        os.environ["CUDA_VISIBLE_DEVICES"] = config["device_ids"]
    else:
        main(0, config)