EDEN / eden /runtime.py
Rybib's picture
Upload EDEN model and code
2f65125 verified
Raw
History Blame Contribute Delete
2.67 kB
"""Runtime helpers: dependency install, seeding, device selection, memory."""
from __future__ import annotations
import gc
import random
import subprocess
import sys
import torch
from .constants import *
from .io_utils import log
def install_deps() -> None:
packages = ["torch", "tokenizers", "datasets", "tqdm", "psutil", "numpy"]
cmd = [sys.executable, "-m", "pip", "install", "--upgrade", *packages]
log("Installing/updating Python packages...")
subprocess.check_call(cmd)
log("Dependencies are ready.")
def require_package(import_name: str, pip_name: str | None = None):
try:
return __import__(import_name)
except ImportError as exc:
pkg = pip_name or import_name
raise SystemExit(
f"Missing package '{pkg}'. Run:\n python3 main.py install"
) from exc
def set_seed(seed: int) -> None:
random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def device_for_training(force_cpu: bool = False) -> torch.device:
if force_cpu:
return torch.device("cpu")
if torch.backends.mps.is_available():
return torch.device("mps")
if torch.cuda.is_available():
return torch.device("cuda")
return torch.device("cpu")
def cleanup_device(device: torch.device) -> None:
gc.collect()
if device.type == "mps":
try:
torch.mps.empty_cache()
torch.mps.synchronize()
except Exception:
pass
elif device.type == "cuda":
torch.cuda.empty_cache()
def memory_fraction() -> tuple[float, float, float]:
try:
import psutil
except ImportError:
return 0.0, 32.0, 0.0
rss = psutil.Process().memory_info().rss / 1024 ** 3
total = psutil.virtual_memory().total / 1024 ** 3
return rss, total, rss / max(total, 1e-6)
def recommended_runtime_settings(total_gb: float | None = None) -> dict:
if total_gb is None:
_, total_gb, _ = memory_fraction()
if total_gb <= 18:
return {
"max_len": 256,
"batch_size": 1,
"grad_accum": 16,
"memory_stop_fraction": 0.72,
"note": "16 GB safety mode",
}
if total_gb <= 36:
return {
"max_len": 512,
"batch_size": 2,
"grad_accum": 8,
"memory_stop_fraction": 0.78,
"note": "32 GB recommended: about 7 GB RAM headroom",
}
return {
"max_len": 512,
"batch_size": 4,
"grad_accum": 8,
"memory_stop_fraction": 0.82,
"note": "64 GB+ throughput mode",
}