| import os |
| import time |
| import argparse |
| import subprocess |
| from typing import List, Tuple |
|
|
| def parse_gpus(s: str) -> List[int]: |
| return [int(x.strip()) for x in s.split(",") if x.strip() != ""] |
|
|
| def get_free_mem_gb_by_nvml(gpus: List[int]): |
| import pynvml |
| pynvml.nvmlInit() |
| info = {} |
| for gid in gpus: |
| h = pynvml.nvmlDeviceGetHandleByIndex(gid) |
| mem = pynvml.nvmlDeviceGetMemoryInfo(h) |
| info[gid] = mem.free / (1024**3) |
| return info |
|
|
| def pick_gpu(free_mem: dict, busy: set, min_free_gb: float): |
| cand = [(gid, gb) for gid, gb in free_mem.items() if gid not in busy and gb >= min_free_gb] |
| if not cand: |
| return None |
| cand.sort(key=lambda x: x[1], reverse=True) |
| return cand[0][0] |
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--worker_script", default="infer_worker.py") |
| ap.add_argument("--gpus", default="0", help="e.g. 0,1,2,3") |
| ap.add_argument("--max_jobs", type=int, default=4) |
| ap.add_argument("--min_free_gb", type=float, default=12.0) |
| ap.add_argument("--poll_sec", type=float, default=5.0) |
|
|
| ap.add_argument("--num_shards", type=int, required=True) |
| ap.add_argument("--ckpt_dir", required=True) |
| ap.add_argument("--data_jsonl", required=True) |
| ap.add_argument("--out_dir", required=True) |
|
|
| ap.add_argument("--base_model_name", default="openai") |
| ap.add_argument("--cache_dir", default=".") |
| ap.add_argument("--batch_size", type=int, default=64) |
| ap.add_argument("--max_length", type=int, default=512) |
|
|
| ap.add_argument("--eval_every_steps", type=int, default=100) |
| ap.add_argument("--flush_every_steps", type=int, default=1) |
| args = ap.parse_args() |
|
|
| gpus = parse_gpus(args.gpus) |
| os.makedirs(args.out_dir, exist_ok=True) |
|
|
| shard_ids = list(range(args.num_shards)) |
| running: List[Tuple[subprocess.Popen, int, int]] = [] |
| idx = 0 |
|
|
| while idx < len(shard_ids) or running: |
| |
| still = [] |
| for proc, gid, sid in running: |
| ret = proc.poll() |
| if ret is None: |
| still.append((proc, gid, sid)) |
| else: |
| print(f"[DONE] shard={sid} gpu={gid} ret={ret}") |
| running = still |
|
|
| busy = {gid for _, gid, _ in running} |
| free_mem = get_free_mem_gb_by_nvml(gpus) |
|
|
| while idx < len(shard_ids) and len(running) < args.max_jobs: |
| gid = pick_gpu(free_mem, busy, args.min_free_gb) |
| if gid is None: |
| break |
|
|
| sid = shard_ids[idx] |
| idx += 1 |
|
|
| env = os.environ.copy() |
| env["CUDA_VISIBLE_DEVICES"] = str(gid) |
| env.setdefault("OMP_NUM_THREADS", "4") |
|
|
| cmd = [ |
| "python", args.worker_script, |
| "--base_model_name", args.base_model_name, |
| "--cache_dir", args.cache_dir, |
| "--ckpt_dir", args.ckpt_dir, |
| "--data_jsonl", args.data_jsonl, |
| "--out_dir", args.out_dir, |
| "--batch_size", str(args.batch_size), |
| "--max_length", str(args.max_length), |
| "--num_shards", str(args.num_shards), |
| "--shard_id", str(sid), |
| "--eval_every_steps", str(args.eval_every_steps), |
| "--flush_every_steps", str(args.flush_every_steps), |
| ] |
|
|
| print(f"[START] shard={sid} -> GPU {gid} free={free_mem[gid]:.1f}GB") |
| proc = subprocess.Popen(cmd, env=env) |
| running.append((proc, gid, sid)) |
| busy.add(gid) |
|
|
| time.sleep(args.poll_sec) |
|
|
| print("[ALL DONE]") |
|
|
| if __name__ == "__main__": |
| main() |
|
|