File size: 2,666 Bytes
2f65125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
"""Runtime helpers: dependency install, seeding, device selection, memory."""

from __future__ import annotations

import gc
import random
import subprocess
import sys

import torch

from .constants import *
from .io_utils import log


def install_deps() -> None:
    packages = ["torch", "tokenizers", "datasets", "tqdm", "psutil", "numpy"]
    cmd = [sys.executable, "-m", "pip", "install", "--upgrade", *packages]
    log("Installing/updating Python packages...")
    subprocess.check_call(cmd)
    log("Dependencies are ready.")


def require_package(import_name: str, pip_name: str | None = None):
    try:
        return __import__(import_name)
    except ImportError as exc:
        pkg = pip_name or import_name
        raise SystemExit(
            f"Missing package '{pkg}'. Run:\n  python3 main.py install"
        ) from exc


def set_seed(seed: int) -> None:
    random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)


def device_for_training(force_cpu: bool = False) -> torch.device:
    if force_cpu:
        return torch.device("cpu")
    if torch.backends.mps.is_available():
        return torch.device("mps")
    if torch.cuda.is_available():
        return torch.device("cuda")
    return torch.device("cpu")


def cleanup_device(device: torch.device) -> None:
    gc.collect()
    if device.type == "mps":
        try:
            torch.mps.empty_cache()
            torch.mps.synchronize()
        except Exception:
            pass
    elif device.type == "cuda":
        torch.cuda.empty_cache()


def memory_fraction() -> tuple[float, float, float]:
    try:
        import psutil
    except ImportError:
        return 0.0, 32.0, 0.0
    rss = psutil.Process().memory_info().rss / 1024 ** 3
    total = psutil.virtual_memory().total / 1024 ** 3
    return rss, total, rss / max(total, 1e-6)


def recommended_runtime_settings(total_gb: float | None = None) -> dict:
    if total_gb is None:
        _, total_gb, _ = memory_fraction()
    if total_gb <= 18:
        return {
            "max_len": 256,
            "batch_size": 1,
            "grad_accum": 16,
            "memory_stop_fraction": 0.72,
            "note": "16 GB safety mode",
        }
    if total_gb <= 36:
        return {
            "max_len": 512,
            "batch_size": 2,
            "grad_accum": 8,
            "memory_stop_fraction": 0.78,
            "note": "32 GB recommended: about 7 GB RAM headroom",
        }
    return {
        "max_len": 512,
        "batch_size": 4,
        "grad_accum": 8,
        "memory_stop_fraction": 0.82,
        "note": "64 GB+ throughput mode",
    }