File size: 3,397 Bytes
febae40
9734b71
febae40
9734b71
 
 
 
 
 
 
febae40
9734b71
 
 
febae40
 
 
9734b71
 
 
 
febae40
 
 
9734b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
febae40
 
9734b71
 
 
 
 
 
febae40
 
 
9734b71
 
 
 
 
 
febae40
 
9734b71
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
febae40
9734b71
 
 
 
febae40
 
 
9734b71
febae40
 
 
 
9734b71
 
febae40
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
"""
Apply local compatibility patches to installed TimesFM source files.

Patches:
1. Strip unknown kwargs that huggingface_hub >= 0.30 may forward into
   TimesFM.__init__.
2. Make checkpoint loading work when the model was created with meta tensors
   by using assign=True (or a to_empty fallback).

This script is idempotent and safe to run multiple times.
"""

from __future__ import annotations

import sys
from pathlib import Path

INIT_MARKER = "# AI Forecast patch: strip unsupported hub kwargs"
INIT_TARGET = "    # Create an instance of the model wrapper class.\n"
INIT_PATCH = """\
    # AI Forecast patch: strip unsupported hub kwargs from huggingface_hub.
    _KNOWN_INIT_KWARGS = {"torch_compile", "config"}
    model_kwargs = {k: v for k, v in model_kwargs.items() if k in _KNOWN_INIT_KWARGS}

"""

META_MARKER = "    # AI Forecast patch: support loading checkpoints into meta tensors.\n"
META_TARGET = """\
  def load_checkpoint(self, path: str, **kwargs):
    \"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\"
    tensors = load_file(path)
    self.load_state_dict(tensors, strict=True)
    self.to(self.device)
"""
META_PATCH = """\
  def load_checkpoint(self, path: str, **kwargs):
    \"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\"
    tensors = load_file(path)
    # AI Forecast patch: support loading checkpoints into meta tensors.
    has_meta_parameters = any(
      getattr(parameter, "is_meta", False) for parameter in self.parameters()
    )
    try:
      if has_meta_parameters:
        self.load_state_dict(tensors, strict=True, assign=True)
      else:
        self.load_state_dict(tensors, strict=True)
    except TypeError:
      if has_meta_parameters:
        self.to_empty(device=self.device)
      self.load_state_dict(tensors, strict=True)
    self.to(self.device)
"""


def find_timesfm_torch_files() -> list[Path]:
    candidates: list[Path] = []
    for sys_path_entry in sys.path:
        path = Path(sys_path_entry) / "timesfm" / "timesfm_2p5" / "timesfm_2p5_torch.py"
        if path.exists():
            candidates.append(path)
    return candidates


def apply_patch_once(text: str, *, marker: str, target: str, replacement: str) -> tuple[str, bool, str]:
    if marker in text:
        return text, False, "already patched"
    if target not in text:
        return text, False, "target not found"
    return text.replace(target, replacement, 1), True, "patched"


def patch_file(path: Path) -> str:
    original = path.read_text(encoding="utf-8")
    updated = original

    updated, _, init_status = apply_patch_once(
        updated,
        marker=INIT_MARKER,
        target=INIT_TARGET,
        replacement=INIT_PATCH + INIT_TARGET,
    )
    updated, _, meta_status = apply_patch_once(
        updated,
        marker=META_MARKER,
        target=META_TARGET,
        replacement=META_PATCH,
    )

    if updated != original:
        path.write_text(updated, encoding="utf-8")
        return f"patched OK: {path} | init={init_status} | meta={meta_status}"
    return f"no changes: {path} | init={init_status} | meta={meta_status}"


if __name__ == "__main__":
    files = find_timesfm_torch_files()
    if not files:
        print("[patch_timesfm] No timesfm installation found in sys.path.")
        sys.exit(1)

    for file_path in files:
        print(f"[patch_timesfm] {patch_file(file_path)}")

    sys.exit(0)