File size: 2,296 Bytes
ac2243f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 | class ThreadSafeTokenizerWrapper:
def __init__(self, tokenizer, lock):
self._tokenizer = tokenizer
self._lock = lock
self._thread_safe_methods = {
"__call__",
"encode",
"decode",
"tokenize",
"encode_plus",
"batch_encode_plus",
"batch_decode",
}
def __getattr__(self, name):
attr = getattr(self._tokenizer, name)
if name in self._thread_safe_methods and callable(attr):
def wrapped_method(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped_method
return attr
def __call__(self, *args, **kwargs):
with self._lock:
return self._tokenizer(*args, **kwargs)
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._tokenizer, name, value)
def __dir__(self):
return dir(self._tokenizer)
class ThreadSafeVAEWrapper:
def __init__(self, vae, lock):
self._vae = vae
self._lock = lock
def __getattr__(self, name):
attr = getattr(self._vae, name)
if name in {"decode", "encode", "forward"} and callable(attr):
def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped
return attr
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._vae, name, value)
class ThreadSafeImageProcessorWrapper:
def __init__(self, proc, lock):
self._proc = proc
self._lock = lock
def __getattr__(self, name):
attr = getattr(self._proc, name)
if name in {"postprocess", "preprocess"} and callable(attr):
def wrapped(*args, **kwargs):
with self._lock:
return attr(*args, **kwargs)
return wrapped
return attr
def __setattr__(self, name, value):
if name.startswith("_"):
super().__setattr__(name, value)
else:
setattr(self._proc, name, value)
|