agentic-rl-main / opsd_utils /deepspeed_utils.py
Jack04810's picture
Add files using upload-large-folder tool
36d0b76 verified
Raw
History Blame Contribute Delete
5.09 kB
"""Helpers for DeepSpeed + Accelerate launch detection."""
from __future__ import annotations
import json
import os
import re
from pathlib import Path
from typing import Optional
def _project_root() -> Path:
return Path(__file__).resolve().parents[1]
def resolve_accelerate_config_path(config_name: Optional[str] = None) -> Optional[Path]:
candidates: list[str] = []
if config_name:
candidates.append(str(config_name).strip())
for env_key in ("ACCELERATE_CONFIG", "ACCELERATE_CONFIG_FILE"):
val = os.environ.get(env_key, "").strip()
if val:
candidates.append(val)
for raw in candidates:
if not raw:
continue
path = Path(raw)
if not path.is_file():
path = _project_root() / raw
if path.is_file():
return path
return None
def uses_deepspeed_json_file(config_name: Optional[str] = None) -> bool:
"""True when Accelerate loads DeepSpeed settings from an external JSON file."""
path = resolve_accelerate_config_path(config_name)
if path is None:
return False
return "deepspeed_config_file" in path.read_text(encoding="utf-8")
def _yaml_get_str(path: Path, key: str) -> Optional[str]:
pattern = re.compile(rf"^{re.escape(key)}\s*:\s*(.+?)\s*$", re.IGNORECASE)
for line in path.read_text(encoding="utf-8").splitlines():
m = pattern.match(line.strip())
if m:
return m.group(1).strip().strip("'\"")
return None
def is_deepspeed_accelerate_config(config_name: Optional[str] = None) -> bool:
path = resolve_accelerate_config_path(config_name)
if path is None:
return False
dist = (_yaml_get_str(path, "distributed_type") or "").upper()
return dist == "DEEPSPEED"
def deepspeed_zero_stage(config_name: Optional[str] = None) -> Optional[int]:
path = resolve_accelerate_config_path(config_name)
if path is None:
return None
text = path.read_text(encoding="utf-8")
m = re.search(r"zero_stage\s*:\s*(\d+)", text, re.IGNORECASE)
if m:
return int(m.group(1))
m = re.search(r"deepspeed_config_file\s*:\s*(\S+)", text, re.IGNORECASE)
if not m:
return None
json_path = Path(m.group(1).strip().strip("'\""))
if not json_path.is_file():
json_path = _project_root() / json_path
if not json_path.is_file():
return None
ds_json = json.loads(json_path.read_text(encoding="utf-8"))
stage = (ds_json.get("zero_optimization") or {}).get("stage")
return int(stage) if stage is not None else None
def should_colocate_teacher_with_student(device_map: Optional[str] = None) -> bool:
"""True when frozen teacher should sit on the same GPU as the trainable student."""
raw = (device_map or os.environ.get("DYME_TEACHER_DEVICE_MAP", "")).strip().lower()
if raw in ("same", "colocate", "local"):
return True
if os.environ.get("DYME_DEEPSPEED_COLOCATE", "").strip().lower() in ("1", "true", "yes", "on"):
return True
if is_deepspeed_accelerate_config() and raw in ("", "auto"):
return True
return False
def gradient_checkpointing_enable_kwargs(config_name: Optional[str] = None) -> Optional[dict]:
"""
Kwargs for ``model.gradient_checkpointing_enable``.
DeepSpeed ZeRO-1/2 + reentrant checkpointing runs backward twice per segment and
hits: "parameter ... has already been reduced".
"""
if not is_deepspeed_accelerate_config(config_name):
return None
override = os.environ.get("DYME_GRADIENT_CHECKPOINTING_USE_REENTRANT", "").strip().lower()
if override in ("1", "true", "yes", "on"):
return {"use_reentrant": True}
if override in ("0", "false", "no", "off"):
return {"use_reentrant": False}
return {"use_reentrant": False}
def deepspeed_requires_single_student_forward(config_name: Optional[str] = None) -> bool:
"""
DeepSpeed ZeRO-1/2 cannot reduce gradients when the student runs multiple
forwards in one backward (GRPO micro-chunks + OPSD loop).
"""
stage = deepspeed_zero_stage(config_name)
return stage is not None and stage <= 2
def should_disable_gradient_checkpointing(config_name: Optional[str] = None) -> bool:
"""Gradient checkpointing also triggers double reduction under ZeRO-1/2."""
return deepspeed_requires_single_student_forward(config_name)
def student_forward_chunk_size(
batch_size: int,
has_vision: bool,
config_name: Optional[str] = None,
) -> int:
"""
Micro-batch size for student forwards in ``_get_per_token_logps``.
Under ZeRO-1/2 we must use one forward per backward (full local batch by default).
Override with ``DYME_STUDENT_FORWARD_CHUNK`` only if you accept ZeRO-3+ or OOM risk.
"""
if not has_vision:
return batch_size
if not deepspeed_requires_single_student_forward(config_name):
return 1
override = os.environ.get("DYME_STUDENT_FORWARD_CHUNK", "").strip()
if override.isdigit():
return max(1, min(batch_size, int(override)))
return batch_size