File size: 5,898 Bytes
79e6483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
#!/usr/bin/env python3
from __future__ import annotations

import argparse
import importlib
import shutil
import subprocess
import sys
import tempfile
from pathlib import Path

from packaging.utils import canonicalize_name


DEFAULT_WHEEL_SUBDIR = "vendor/wheels/kaggle_py312_manylinux_x86_64"
WHEEL_MANIFEST_NAME = "_wheel_manifest.txt"
OFFLINE_PACKAGES = [
    "annotated-types",
    "blosc2",
    "certifi",
    "charset-normalizer",
    "dill",
    "filelock",
    "fire",
    "idna",
    "joblib",
    "loguru",
    "msgpack",
    "ndindex",
    "numexpr",
    "packaging",
    "py-cpuinfo",
    "pydantic",
    "pydantic-core",
    "pydantic-settings",
    "pyyaml",
    "pyqlib",
    "python-dotenv",
    "python-redis-lock",
    "redis",
    "requests",
    "ruamel.yaml",
    "setuptools-scm",
    "tables",
    "termcolor",
    "tqdm",
    "typing-extensions",
    "typing-inspection",
    "urllib3",
]


def _can_import_runtime() -> tuple[bool, str]:
    try:
        importlib.import_module("tables")
        import qlib  # noqa: F401
        from qlib.backtest import backtest  # noqa: F401
        from qlib.backtest.executor import SimulatorExecutor  # noqa: F401
        from qlib.contrib.strategy.signal_strategy import TopkDropoutStrategy  # noqa: F401
        return True, ""
    except Exception as exc:  # pragma: no cover - best effort diagnostic
        return False, f"{type(exc).__name__}: {exc}"


def _resolve_repo_root() -> Path:
    return Path(__file__).resolve().parents[2]


def _resolve_wheel_dir(repo_root: Path, override: str | None) -> Path:
    if override:
        return Path(override).expanduser().resolve()
    return (repo_root / DEFAULT_WHEEL_SUBDIR).resolve()


def _read_wheel_manifest(wheel_dir: Path) -> list[str]:
    manifest_path = wheel_dir / WHEEL_MANIFEST_NAME
    if not manifest_path.exists():
        return []
    return [
        line.strip()
        for line in manifest_path.read_text().splitlines()
        if line.strip() and not line.strip().startswith("#")
    ]


def _prepare_wheelhouse(wheel_dir: Path) -> Path:
    manifest = _read_wheel_manifest(wheel_dir)
    direct_wheels = sorted(wheel_dir.glob("*.whl"))
    if direct_wheels and not manifest:
        return wheel_dir

    source_files = {
        path.name: path
        for path in sorted(wheel_dir.iterdir())
        if path.is_file() and path.name != WHEEL_MANIFEST_NAME
    }
    temp_root = Path(tempfile.mkdtemp(prefix="aae_kaggle_wheels_"))

    if manifest:
        restored = 0
        for expected_name in manifest:
            target = temp_root / expected_name
            source = source_files.get(expected_name)
            if source is None:
                matches = [path for name, path in source_files.items() if expected_name.startswith(name)]
                if len(matches) == 1:
                    source = matches[0]
            if source is None:
                continue
            shutil.copy2(source, target)
            if source.name != expected_name:
                restored += 1
        if restored:
            print(f"Restored {restored} truncated wheel filename(s) into {temp_root}", flush=True)

    for path in direct_wheels:
        target = temp_root / path.name
        if not target.exists():
            shutil.copy2(path, target)

    return temp_root


def _install_from_wheels(wheel_dir: Path, force_reinstall: bool) -> None:
    prepared_wheel_dir = _prepare_wheelhouse(wheel_dir)
    wheel_map: dict[str, Path] = {}
    for wheel_path in sorted(prepared_wheel_dir.glob("*.whl")):
        package_name = canonicalize_name(wheel_path.name.split("-", 1)[0])
        wheel_map[package_name] = wheel_path

    missing = [pkg for pkg in OFFLINE_PACKAGES if canonicalize_name(pkg) not in wheel_map]
    if missing:
        raise FileNotFoundError(
            "Missing offline wheels for packages: "
            + ", ".join(missing)
            + f". wheel_dir={wheel_dir} prepared_wheel_dir={prepared_wheel_dir}"
        )

    cmd = [
        sys.executable,
        "-m",
        "pip",
        "install",
        "--no-index",
        "--no-deps",
    ]
    if force_reinstall:
        cmd.append("--force-reinstall")
    cmd.extend(str(wheel_map[canonicalize_name(pkg)]) for pkg in OFFLINE_PACKAGES)
    print("Running offline install:\n ", " ".join(cmd), flush=True)
    subprocess.run(cmd, check=True)


def main() -> None:
    parser = argparse.ArgumentParser(description="Install Kaggle backtest runtime from wheels bundled in the dataset.")
    parser.add_argument("--wheel-dir", default=None, help="Override local wheelhouse path.")
    parser.add_argument("--force-reinstall", action="store_true", help="Force reinstall even if imports already work.")
    parser.add_argument("--check-only", action="store_true", help="Only verify imports; do not install.")
    args = parser.parse_args()

    repo_root = _resolve_repo_root()
    wheel_dir = _resolve_wheel_dir(repo_root, args.wheel_dir)

    if not wheel_dir.exists():
        raise FileNotFoundError(f"Offline wheel directory not found: {wheel_dir}")

    ok, detail = _can_import_runtime()
    if ok and not args.force_reinstall:
        print("Offline runtime already available; skipping install.", flush=True)
    elif args.check_only:
        raise RuntimeError(f"Offline runtime check failed: {detail}")
    else:
        print(f"Runtime import check failed before install: {detail}", flush=True)
        _install_from_wheels(wheel_dir, force_reinstall=args.force_reinstall)

    ok, detail = _can_import_runtime()
    if not ok:
        raise RuntimeError(f"Offline runtime verification failed after install: {detail}")

    import qlib
    import tables

    print("qlib import OK:", qlib.__file__, flush=True)
    print("tables import OK:", tables.__file__, flush=True)
    print("wheel_dir =", wheel_dir, flush=True)


if __name__ == "__main__":
    main()