SuperAI_Forecast / scripts /patch_timesfm.py
Thang6822
Update branding to SuperAI Forecast
9734b71
"""
Apply local compatibility patches to installed TimesFM source files.
Patches:
1. Strip unknown kwargs that huggingface_hub >= 0.30 may forward into
TimesFM.__init__.
2. Make checkpoint loading work when the model was created with meta tensors
by using assign=True (or a to_empty fallback).
This script is idempotent and safe to run multiple times.
"""
from __future__ import annotations
import sys
from pathlib import Path
INIT_MARKER = "# AI Forecast patch: strip unsupported hub kwargs"
INIT_TARGET = " # Create an instance of the model wrapper class.\n"
INIT_PATCH = """\
# AI Forecast patch: strip unsupported hub kwargs from huggingface_hub.
_KNOWN_INIT_KWARGS = {"torch_compile", "config"}
model_kwargs = {k: v for k, v in model_kwargs.items() if k in _KNOWN_INIT_KWARGS}
"""
META_MARKER = " # AI Forecast patch: support loading checkpoints into meta tensors.\n"
META_TARGET = """\
def load_checkpoint(self, path: str, **kwargs):
\"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\"
tensors = load_file(path)
self.load_state_dict(tensors, strict=True)
self.to(self.device)
"""
META_PATCH = """\
def load_checkpoint(self, path: str, **kwargs):
\"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\"
tensors = load_file(path)
# AI Forecast patch: support loading checkpoints into meta tensors.
has_meta_parameters = any(
getattr(parameter, "is_meta", False) for parameter in self.parameters()
)
try:
if has_meta_parameters:
self.load_state_dict(tensors, strict=True, assign=True)
else:
self.load_state_dict(tensors, strict=True)
except TypeError:
if has_meta_parameters:
self.to_empty(device=self.device)
self.load_state_dict(tensors, strict=True)
self.to(self.device)
"""
def find_timesfm_torch_files() -> list[Path]:
candidates: list[Path] = []
for sys_path_entry in sys.path:
path = Path(sys_path_entry) / "timesfm" / "timesfm_2p5" / "timesfm_2p5_torch.py"
if path.exists():
candidates.append(path)
return candidates
def apply_patch_once(text: str, *, marker: str, target: str, replacement: str) -> tuple[str, bool, str]:
if marker in text:
return text, False, "already patched"
if target not in text:
return text, False, "target not found"
return text.replace(target, replacement, 1), True, "patched"
def patch_file(path: Path) -> str:
original = path.read_text(encoding="utf-8")
updated = original
updated, _, init_status = apply_patch_once(
updated,
marker=INIT_MARKER,
target=INIT_TARGET,
replacement=INIT_PATCH + INIT_TARGET,
)
updated, _, meta_status = apply_patch_once(
updated,
marker=META_MARKER,
target=META_TARGET,
replacement=META_PATCH,
)
if updated != original:
path.write_text(updated, encoding="utf-8")
return f"patched OK: {path} | init={init_status} | meta={meta_status}"
return f"no changes: {path} | init={init_status} | meta={meta_status}"
if __name__ == "__main__":
files = find_timesfm_torch_files()
if not files:
print("[patch_timesfm] No timesfm installation found in sys.path.")
sys.exit(1)
for file_path in files:
print(f"[patch_timesfm] {patch_file(file_path)}")
sys.exit(0)