File size: 1,535 Bytes
dfa370d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""
patch_annotators.py
===================
annotator ๅ†…ใฎไบ’ๆ›ๆ€งๅ•้กŒใ‚’ใƒฉใƒณใ‚ฟใ‚คใƒ ใงใƒ‘ใƒƒใƒใ™ใ‚‹ใƒขใ‚ธใƒฅใƒผใƒซใ€‚
app.py ใฎๅ†’้ ญใง import ใ™ใ‚‹ใ ใ‘ใงๆœ‰ๅŠนใซใชใ‚Šใพใ™ใ€‚

ๅฏพๅฟœใ™ใ‚‹ๅ•้กŒ:
  1. annotator/zoe/__init__.py: torch.load ใซ strict=False ใŒๅฟ…่ฆ
     (timm ใฎใƒใƒผใ‚ธใƒงใƒณใ‚ขใƒƒใƒ—ใง relative_position_index ใ‚ญใƒผใŒๆถˆใˆใŸใŸใ‚)
"""

import importlib
import sys
from unittest.mock import patch


# โ”€โ”€ Zoe: torch.load ใ‚’ strict=False ใงใƒฉใƒƒใƒ— โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€

def _patch_zoe():
    """
    annotator.zoe.__init__ ใฎ ZoeDetector.__init__ ใŒ
    model.load_state_dict(...) ใ‚’ strict=True (ใƒ‡ใƒ•ใ‚ฉใƒซใƒˆ) ใงๅ‘ผใถใฎใ‚’
    strict=False ใซๅทฎใ—ๆ›ฟใˆใ‚‹ใ€‚
    """
    try:
        import torch
        _original_load_state_dict = torch.nn.Module.load_state_dict

        def _patched_load_state_dict(self, state_dict, strict=True, **kwargs):
            # ZoeDepth ใƒขใƒ‡ใƒซใฎใƒญใƒผใƒ‰ๆ™‚ใ ใ‘ strict=False ใซ็ทฉใ‚ใ‚‹
            cls_name = type(self).__name__
            if cls_name == "ZoeDepth":
                strict = False
            return _original_load_state_dict(self, state_dict, strict=strict, **kwargs)

        torch.nn.Module.load_state_dict = _patched_load_state_dict
        print("[patch_annotators] Patched: ZoeDepth.load_state_dict -> strict=False")
    except Exception as e:
        print(f"[patch_annotators] Warning: Could not patch ZoeDepth: {e}")


_patch_zoe()