Spaces:
Running
Running
| """ | |
| Apply local compatibility patches to installed TimesFM source files. | |
| Patches: | |
| 1. Strip unknown kwargs that huggingface_hub >= 0.30 may forward into | |
| TimesFM.__init__. | |
| 2. Make checkpoint loading work when the model was created with meta tensors | |
| by using assign=True (or a to_empty fallback). | |
| This script is idempotent and safe to run multiple times. | |
| """ | |
| from __future__ import annotations | |
| import sys | |
| from pathlib import Path | |
| INIT_MARKER = "# AI Forecast patch: strip unsupported hub kwargs" | |
| INIT_TARGET = " # Create an instance of the model wrapper class.\n" | |
| INIT_PATCH = """\ | |
| # AI Forecast patch: strip unsupported hub kwargs from huggingface_hub. | |
| _KNOWN_INIT_KWARGS = {"torch_compile", "config"} | |
| model_kwargs = {k: v for k, v in model_kwargs.items() if k in _KNOWN_INIT_KWARGS} | |
| """ | |
| META_MARKER = " # AI Forecast patch: support loading checkpoints into meta tensors.\n" | |
| META_TARGET = """\ | |
| def load_checkpoint(self, path: str, **kwargs): | |
| \"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\" | |
| tensors = load_file(path) | |
| self.load_state_dict(tensors, strict=True) | |
| self.to(self.device) | |
| """ | |
| META_PATCH = """\ | |
| def load_checkpoint(self, path: str, **kwargs): | |
| \"\"\"Loads a PyTorch TimesFM model from a checkpoint.\"\"\" | |
| tensors = load_file(path) | |
| # AI Forecast patch: support loading checkpoints into meta tensors. | |
| has_meta_parameters = any( | |
| getattr(parameter, "is_meta", False) for parameter in self.parameters() | |
| ) | |
| try: | |
| if has_meta_parameters: | |
| self.load_state_dict(tensors, strict=True, assign=True) | |
| else: | |
| self.load_state_dict(tensors, strict=True) | |
| except TypeError: | |
| if has_meta_parameters: | |
| self.to_empty(device=self.device) | |
| self.load_state_dict(tensors, strict=True) | |
| self.to(self.device) | |
| """ | |
| def find_timesfm_torch_files() -> list[Path]: | |
| candidates: list[Path] = [] | |
| for sys_path_entry in sys.path: | |
| path = Path(sys_path_entry) / "timesfm" / "timesfm_2p5" / "timesfm_2p5_torch.py" | |
| if path.exists(): | |
| candidates.append(path) | |
| return candidates | |
| def apply_patch_once(text: str, *, marker: str, target: str, replacement: str) -> tuple[str, bool, str]: | |
| if marker in text: | |
| return text, False, "already patched" | |
| if target not in text: | |
| return text, False, "target not found" | |
| return text.replace(target, replacement, 1), True, "patched" | |
| def patch_file(path: Path) -> str: | |
| original = path.read_text(encoding="utf-8") | |
| updated = original | |
| updated, _, init_status = apply_patch_once( | |
| updated, | |
| marker=INIT_MARKER, | |
| target=INIT_TARGET, | |
| replacement=INIT_PATCH + INIT_TARGET, | |
| ) | |
| updated, _, meta_status = apply_patch_once( | |
| updated, | |
| marker=META_MARKER, | |
| target=META_TARGET, | |
| replacement=META_PATCH, | |
| ) | |
| if updated != original: | |
| path.write_text(updated, encoding="utf-8") | |
| return f"patched OK: {path} | init={init_status} | meta={meta_status}" | |
| return f"no changes: {path} | init={init_status} | meta={meta_status}" | |
| if __name__ == "__main__": | |
| files = find_timesfm_torch_files() | |
| if not files: | |
| print("[patch_timesfm] No timesfm installation found in sys.path.") | |
| sys.exit(1) | |
| for file_path in files: | |
| print(f"[patch_timesfm] {patch_file(file_path)}") | |
| sys.exit(0) | |