File size: 1,905 Bytes
dbc69f3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path

from .utils import repository_root


@dataclass(frozen=True)
class StorageLayout:
    root: Path
    datasets: Path
    models: Path
    hf: Path
    artifacts: Path
    runs: Path
    sweeps: Path
    evals: Path
    logs: Path
    wandb: Path


def storage_layout(cache_dir: str | Path = "cache") -> StorageLayout:
    root = _resolve_storage_root(cache_dir)
    artifacts = root / "artifacts"
    logs = root / "logs"
    return StorageLayout(
        root=root,
        datasets=root / "datasets",
        models=root / "models",
        hf=root / "hf",
        artifacts=artifacts,
        runs=artifacts / "runs",
        sweeps=artifacts / "sweeps",
        evals=artifacts / "eval",
        logs=logs,
        wandb=logs / "wandb",
    )


def ensure_storage_layout(cache_dir: str | Path = "cache") -> StorageLayout:
    layout = storage_layout(cache_dir)
    for path in (
        layout.root,
        layout.datasets,
        layout.models,
        layout.hf,
        layout.artifacts,
        layout.runs,
        layout.sweeps,
        layout.evals,
        layout.logs,
        layout.wandb,
    ):
        path.mkdir(parents=True, exist_ok=True)
    return layout


def resolve_storage_path(path: str | Path, cache_dir: str | Path = "cache") -> Path:
    candidate = Path(path)
    if candidate.is_absolute():
        return candidate.resolve()

    repo = repository_root()
    cache_root = _resolve_storage_root(cache_dir)

    if candidate.parts and candidate.parts[0] == cache_root.name:
        return (repo / candidate).resolve()

    return (cache_root / candidate).resolve()


def _resolve_storage_root(cache_dir: str | Path) -> Path:
    candidate = Path(cache_dir)
    if candidate.is_absolute():
        return candidate.resolve()
    return (repository_root() / candidate).resolve()