import os import time import argparse import subprocess from typing import List, Tuple def parse_gpus(s: str) -> List[int]: return [int(x.strip()) for x in s.split(",") if x.strip() != ""] def get_free_mem_gb_by_nvml(gpus: List[int]): import pynvml pynvml.nvmlInit() info = {} for gid in gpus: h = pynvml.nvmlDeviceGetHandleByIndex(gid) mem = pynvml.nvmlDeviceGetMemoryInfo(h) info[gid] = mem.free / (1024**3) return info def pick_gpu(free_mem: dict, busy: set, min_free_gb: float): cand = [(gid, gb) for gid, gb in free_mem.items() if gid not in busy and gb >= min_free_gb] if not cand: return None cand.sort(key=lambda x: x[1], reverse=True) return cand[0][0] def main(): ap = argparse.ArgumentParser() ap.add_argument("--worker_script", default="infer_worker.py") ap.add_argument("--gpus", default="0", help="e.g. 0,1,2,3") ap.add_argument("--max_jobs", type=int, default=4) ap.add_argument("--min_free_gb", type=float, default=12.0) ap.add_argument("--poll_sec", type=float, default=5.0) ap.add_argument("--num_shards", type=int, required=True) ap.add_argument("--ckpt_dir", required=True) ap.add_argument("--data_jsonl", required=True) ap.add_argument("--out_dir", required=True) ap.add_argument("--base_model_name", default="openai") ap.add_argument("--cache_dir", default=".") ap.add_argument("--batch_size", type=int, default=64) ap.add_argument("--max_length", type=int, default=512) ap.add_argument("--eval_every_steps", type=int, default=100) ap.add_argument("--flush_every_steps", type=int, default=1) args = ap.parse_args() gpus = parse_gpus(args.gpus) os.makedirs(args.out_dir, exist_ok=True) shard_ids = list(range(args.num_shards)) running: List[Tuple[subprocess.Popen, int, int]] = [] # (proc, gpu_id, shard_id) idx = 0 while idx < len(shard_ids) or running: # Clean up finished processes still = [] for proc, gid, sid in running: ret = proc.poll() if ret is None: still.append((proc, gid, sid)) else: print(f"[DONE] shard={sid} gpu={gid} ret={ret}") running = still busy = {gid for _, gid, _ in running} free_mem = get_free_mem_gb_by_nvml(gpus) while idx < len(shard_ids) and len(running) < args.max_jobs: gid = pick_gpu(free_mem, busy, args.min_free_gb) if gid is None: break sid = shard_ids[idx] idx += 1 env = os.environ.copy() env["CUDA_VISIBLE_DEVICES"] = str(gid) # Let the subprocess see only 1 GPU env.setdefault("OMP_NUM_THREADS", "4") cmd = [ "python", args.worker_script, "--base_model_name", args.base_model_name, "--cache_dir", args.cache_dir, "--ckpt_dir", args.ckpt_dir, "--data_jsonl", args.data_jsonl, "--out_dir", args.out_dir, "--batch_size", str(args.batch_size), "--max_length", str(args.max_length), "--num_shards", str(args.num_shards), "--shard_id", str(sid), "--eval_every_steps", str(args.eval_every_steps), "--flush_every_steps", str(args.flush_every_steps), ] print(f"[START] shard={sid} -> GPU {gid} free={free_mem[gid]:.1f}GB") proc = subprocess.Popen(cmd, env=env) running.append((proc, gid, sid)) busy.add(gid) time.sleep(args.poll_sec) print("[ALL DONE]") if __name__ == "__main__": main()