"""Helpers for DeepSpeed + Accelerate launch detection.""" from __future__ import annotations import json import os import re from pathlib import Path from typing import Optional def _project_root() -> Path: return Path(__file__).resolve().parents[1] def resolve_accelerate_config_path(config_name: Optional[str] = None) -> Optional[Path]: candidates: list[str] = [] if config_name: candidates.append(str(config_name).strip()) for env_key in ("ACCELERATE_CONFIG", "ACCELERATE_CONFIG_FILE"): val = os.environ.get(env_key, "").strip() if val: candidates.append(val) for raw in candidates: if not raw: continue path = Path(raw) if not path.is_file(): path = _project_root() / raw if path.is_file(): return path return None def uses_deepspeed_json_file(config_name: Optional[str] = None) -> bool: """True when Accelerate loads DeepSpeed settings from an external JSON file.""" path = resolve_accelerate_config_path(config_name) if path is None: return False return "deepspeed_config_file" in path.read_text(encoding="utf-8") def _yaml_get_str(path: Path, key: str) -> Optional[str]: pattern = re.compile(rf"^{re.escape(key)}\s*:\s*(.+?)\s*$", re.IGNORECASE) for line in path.read_text(encoding="utf-8").splitlines(): m = pattern.match(line.strip()) if m: return m.group(1).strip().strip("'\"") return None def is_deepspeed_accelerate_config(config_name: Optional[str] = None) -> bool: path = resolve_accelerate_config_path(config_name) if path is None: return False dist = (_yaml_get_str(path, "distributed_type") or "").upper() return dist == "DEEPSPEED" def deepspeed_zero_stage(config_name: Optional[str] = None) -> Optional[int]: path = resolve_accelerate_config_path(config_name) if path is None: return None text = path.read_text(encoding="utf-8") m = re.search(r"zero_stage\s*:\s*(\d+)", text, re.IGNORECASE) if m: return int(m.group(1)) m = re.search(r"deepspeed_config_file\s*:\s*(\S+)", text, re.IGNORECASE) if not m: return None json_path = Path(m.group(1).strip().strip("'\"")) if not json_path.is_file(): json_path = _project_root() / json_path if not json_path.is_file(): return None ds_json = json.loads(json_path.read_text(encoding="utf-8")) stage = (ds_json.get("zero_optimization") or {}).get("stage") return int(stage) if stage is not None else None def should_colocate_teacher_with_student(device_map: Optional[str] = None) -> bool: """True when frozen teacher should sit on the same GPU as the trainable student.""" raw = (device_map or os.environ.get("DYME_TEACHER_DEVICE_MAP", "")).strip().lower() if raw in ("same", "colocate", "local"): return True if os.environ.get("DYME_DEEPSPEED_COLOCATE", "").strip().lower() in ("1", "true", "yes", "on"): return True if is_deepspeed_accelerate_config() and raw in ("", "auto"): return True return False def gradient_checkpointing_enable_kwargs(config_name: Optional[str] = None) -> Optional[dict]: """ Kwargs for ``model.gradient_checkpointing_enable``. DeepSpeed ZeRO-1/2 + reentrant checkpointing runs backward twice per segment and hits: "parameter ... has already been reduced". """ if not is_deepspeed_accelerate_config(config_name): return None override = os.environ.get("DYME_GRADIENT_CHECKPOINTING_USE_REENTRANT", "").strip().lower() if override in ("1", "true", "yes", "on"): return {"use_reentrant": True} if override in ("0", "false", "no", "off"): return {"use_reentrant": False} return {"use_reentrant": False} def deepspeed_requires_single_student_forward(config_name: Optional[str] = None) -> bool: """ DeepSpeed ZeRO-1/2 cannot reduce gradients when the student runs multiple forwards in one backward (GRPO micro-chunks + OPSD loop). """ stage = deepspeed_zero_stage(config_name) return stage is not None and stage <= 2 def should_disable_gradient_checkpointing(config_name: Optional[str] = None) -> bool: """Gradient checkpointing also triggers double reduction under ZeRO-1/2.""" return deepspeed_requires_single_student_forward(config_name) def student_forward_chunk_size( batch_size: int, has_vision: bool, config_name: Optional[str] = None, ) -> int: """ Micro-batch size for student forwards in ``_get_per_token_logps``. Under ZeRO-1/2 we must use one forward per backward (full local batch by default). Override with ``DYME_STUDENT_FORWARD_CHUNK`` only if you accept ZeRO-3+ or OOM risk. """ if not has_vision: return batch_size if not deepspeed_requires_single_student_forward(config_name): return 1 override = os.environ.get("DYME_STUDENT_FORWARD_CHUNK", "").strip() if override.isdigit(): return max(1, min(batch_size, int(override))) return batch_size