| """ |
| patch_annotators.py |
| =================== |
| annotator ๅ
ใฎไบๆๆงๅ้กใใฉใณใฟใคใ ใงใใใใใใขใธใฅใผใซใ |
| app.py ใฎๅ้ ญใง import ใใใ ใใงๆๅนใซใชใใพใใ |
| |
| ๅฏพๅฟใใๅ้ก: |
| 1. annotator/zoe/__init__.py: torch.load ใซ strict=False ใๅฟ
่ฆ |
| (timm ใฎใใผใธใงใณใขใใใง relative_position_index ใญใผใๆถใใใใ) |
| """ |
|
|
| import importlib |
| import sys |
| from unittest.mock import patch |
|
|
|
|
| |
|
|
| def _patch_zoe(): |
| """ |
| annotator.zoe.__init__ ใฎ ZoeDetector.__init__ ใ |
| model.load_state_dict(...) ใ strict=True (ใใใฉใซใ) ใงๅผใถใฎใ |
| strict=False ใซๅทฎใๆฟใใใ |
| """ |
| try: |
| import torch |
| _original_load_state_dict = torch.nn.Module.load_state_dict |
|
|
| def _patched_load_state_dict(self, state_dict, strict=True, **kwargs): |
| |
| cls_name = type(self).__name__ |
| if cls_name == "ZoeDepth": |
| strict = False |
| return _original_load_state_dict(self, state_dict, strict=strict, **kwargs) |
|
|
| torch.nn.Module.load_state_dict = _patched_load_state_dict |
| print("[patch_annotators] Patched: ZoeDepth.load_state_dict -> strict=False") |
| except Exception as e: |
| print(f"[patch_annotators] Warning: Could not patch ZoeDepth: {e}") |
|
|
|
|
| _patch_zoe() |
|
|