""" Apply local compatibility patches to installed TimesFM source files. Patches: 1. Strip unknown kwargs that huggingface_hub >= 0.30 may forward into TimesFM.__init__. 2. Make checkpoint loading work when the model was created with meta tensors by using assign=True (or a to_empty fallback). This script is idempotent and safe to run multiple times. """ from __future__ import annotations import sys from pathlib import Path INIT_MARKER = "# AI Forecast patch: strip unsupported hub kwargs" INIT_TARGET = " # Create an instance of the model wrapper class.\n" INIT_PATCH = """\ # AI Forecast patch: strip unsupported hub kwargs from huggingface_hub. _KNOWN_INIT_KWARGS = {"torch_compile", "config"} model_kwargs = {k: v for k, v in model_kwargs.items() if k in _KNOWN_INIT_KWARGS} """ META_MARKER = " # AI Forecast patch: support loading checkpoints into meta tensors.\n" META_TARGET = """\ def load_checkpoint(self, path: str, **kwargs): \"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\" tensors = load_file(path) self.load_state_dict(tensors, strict=True) self.to(self.device) """ META_PATCH = """\ def load_checkpoint(self, path: str, **kwargs): \"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\" tensors = load_file(path) # AI Forecast patch: support loading checkpoints into meta tensors. has_meta_parameters = any( getattr(parameter, "is_meta", False) for parameter in self.parameters() ) try: if has_meta_parameters: self.load_state_dict(tensors, strict=True, assign=True) else: self.load_state_dict(tensors, strict=True) except TypeError: if has_meta_parameters: self.to_empty(device=self.device) self.load_state_dict(tensors, strict=True) self.to(self.device) """ def find_timesfm_torch_files() -> list[Path]: candidates: list[Path] = [] for sys_path_entry in sys.path: path = Path(sys_path_entry) / "timesfm" / "timesfm_2p5" / "timesfm_2p5_torch.py" if path.exists(): candidates.append(path) return candidates def apply_patch_once(text: str, *, marker: str, target: str, replacement: str) -> tuple[str, bool, str]: if marker in text: return text, False, "already patched" if target not in text: return text, False, "target not found" return text.replace(target, replacement, 1), True, "patched" def patch_file(path: Path) -> str: original = path.read_text(encoding="utf-8") updated = original updated, _, init_status = apply_patch_once( updated, marker=INIT_MARKER, target=INIT_TARGET, replacement=INIT_PATCH + INIT_TARGET, ) updated, _, meta_status = apply_patch_once( updated, marker=META_MARKER, target=META_TARGET, replacement=META_PATCH, ) if updated != original: path.write_text(updated, encoding="utf-8") return f"patched OK: {path} | init={init_status} | meta={meta_status}" return f"no changes: {path} | init={init_status} | meta={meta_status}" if __name__ == "__main__": files = find_timesfm_torch_files() if not files: print("[patch_timesfm] No timesfm installation found in sys.path.") sys.exit(1) for file_path in files: print(f"[patch_timesfm] {patch_file(file_path)}") sys.exit(0)