MOF-deprecated / script /infer_multi_worker.py
StarLiu714's picture
Upload folder using huggingface_hub
24c3614 verified
Raw
History Blame Contribute Delete
3.73 kB
import os
import time
import argparse
import subprocess
from typing import List, Tuple
def parse_gpus(s: str) -> List[int]:
return [int(x.strip()) for x in s.split(",") if x.strip() != ""]
def get_free_mem_gb_by_nvml(gpus: List[int]):
import pynvml
pynvml.nvmlInit()
info = {}
for gid in gpus:
h = pynvml.nvmlDeviceGetHandleByIndex(gid)
mem = pynvml.nvmlDeviceGetMemoryInfo(h)
info[gid] = mem.free / (1024**3)
return info
def pick_gpu(free_mem: dict, busy: set, min_free_gb: float):
cand = [(gid, gb) for gid, gb in free_mem.items() if gid not in busy and gb >= min_free_gb]
if not cand:
return None
cand.sort(key=lambda x: x[1], reverse=True)
return cand[0][0]
def main():
ap = argparse.ArgumentParser()
ap.add_argument("--worker_script", default="infer_worker.py")
ap.add_argument("--gpus", default="0", help="e.g. 0,1,2,3")
ap.add_argument("--max_jobs", type=int, default=4)
ap.add_argument("--min_free_gb", type=float, default=12.0)
ap.add_argument("--poll_sec", type=float, default=5.0)
ap.add_argument("--num_shards", type=int, required=True)
ap.add_argument("--ckpt_dir", required=True)
ap.add_argument("--data_jsonl", required=True)
ap.add_argument("--out_dir", required=True)
ap.add_argument("--base_model_name", default="openai")
ap.add_argument("--cache_dir", default=".")
ap.add_argument("--batch_size", type=int, default=64)
ap.add_argument("--max_length", type=int, default=512)
ap.add_argument("--eval_every_steps", type=int, default=100)
ap.add_argument("--flush_every_steps", type=int, default=1)
args = ap.parse_args()
gpus = parse_gpus(args.gpus)
os.makedirs(args.out_dir, exist_ok=True)
shard_ids = list(range(args.num_shards))
running: List[Tuple[subprocess.Popen, int, int]] = [] # (proc, gpu_id, shard_id)
idx = 0
while idx < len(shard_ids) or running:
# Clean up finished processes
still = []
for proc, gid, sid in running:
ret = proc.poll()
if ret is None:
still.append((proc, gid, sid))
else:
print(f"[DONE] shard={sid} gpu={gid} ret={ret}")
running = still
busy = {gid for _, gid, _ in running}
free_mem = get_free_mem_gb_by_nvml(gpus)
while idx < len(shard_ids) and len(running) < args.max_jobs:
gid = pick_gpu(free_mem, busy, args.min_free_gb)
if gid is None:
break
sid = shard_ids[idx]
idx += 1
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(gid) # Let the subprocess see only 1 GPU
env.setdefault("OMP_NUM_THREADS", "4")
cmd = [
"python", args.worker_script,
"--base_model_name", args.base_model_name,
"--cache_dir", args.cache_dir,
"--ckpt_dir", args.ckpt_dir,
"--data_jsonl", args.data_jsonl,
"--out_dir", args.out_dir,
"--batch_size", str(args.batch_size),
"--max_length", str(args.max_length),
"--num_shards", str(args.num_shards),
"--shard_id", str(sid),
"--eval_every_steps", str(args.eval_every_steps),
"--flush_every_steps", str(args.flush_every_steps),
]
print(f"[START] shard={sid} -> GPU {gid} free={free_mem[gid]:.1f}GB")
proc = subprocess.Popen(cmd, env=env)
running.append((proc, gid, sid))
busy.add(gid)
time.sleep(args.poll_sec)
print("[ALL DONE]")
if __name__ == "__main__":
main()