File size: 2,296 Bytes
ac2243f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
class ThreadSafeTokenizerWrapper:
    def __init__(self, tokenizer, lock):
        self._tokenizer = tokenizer
        self._lock = lock

        self._thread_safe_methods = {
            "__call__",
            "encode",
            "decode",
            "tokenize",
            "encode_plus",
            "batch_encode_plus",
            "batch_decode",
        }

    def __getattr__(self, name):
        attr = getattr(self._tokenizer, name)

        if name in self._thread_safe_methods and callable(attr):

            def wrapped_method(*args, **kwargs):
                with self._lock:
                    return attr(*args, **kwargs)

            return wrapped_method

        return attr

    def __call__(self, *args, **kwargs):
        with self._lock:
            return self._tokenizer(*args, **kwargs)

    def __setattr__(self, name, value):
        if name.startswith("_"):
            super().__setattr__(name, value)
        else:
            setattr(self._tokenizer, name, value)

    def __dir__(self):
        return dir(self._tokenizer)


class ThreadSafeVAEWrapper:
    def __init__(self, vae, lock):
        self._vae = vae
        self._lock = lock

    def __getattr__(self, name):
        attr = getattr(self._vae, name)
        if name in {"decode", "encode", "forward"} and callable(attr):

            def wrapped(*args, **kwargs):
                with self._lock:
                    return attr(*args, **kwargs)

            return wrapped
        return attr

    def __setattr__(self, name, value):
        if name.startswith("_"):
            super().__setattr__(name, value)
        else:
            setattr(self._vae, name, value)


class ThreadSafeImageProcessorWrapper:
    def __init__(self, proc, lock):
        self._proc = proc
        self._lock = lock

    def __getattr__(self, name):
        attr = getattr(self._proc, name)
        if name in {"postprocess", "preprocess"} and callable(attr):

            def wrapped(*args, **kwargs):
                with self._lock:
                    return attr(*args, **kwargs)

            return wrapped
        return attr

    def __setattr__(self, name, value):
        if name.startswith("_"):
            super().__setattr__(name, value)
        else:
            setattr(self._proc, name, value)