Spaces:
Running
Running
Kyryll Kochkin commited on
Commit ·
efc975c
1
Parent(s): 4e203da
AI tries to fix tests v4
Browse files- app/core/engine.py +70 -5
app/core/engine.py
CHANGED
|
@@ -52,6 +52,53 @@ def _filter_supported_kwargs(func, kwargs: dict) -> dict:
|
|
| 52 |
return {key: value for key, value in kwargs.items() if key in supported_kwargs}
|
| 53 |
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
def _install_tie_weights_compat_patch(
|
| 56 |
model,
|
| 57 |
*,
|
|
@@ -291,11 +338,29 @@ class _ModelHandle:
|
|
| 291 |
|
| 292 |
modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_patched_load_pretrained_func)
|
| 293 |
try:
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 299 |
finally:
|
| 300 |
modeling_utils.PreTrainedModel._load_pretrained_model = _restore_loader_attr
|
| 301 |
logger.info("Model ready in %.2fs", time.perf_counter() - t1)
|
|
|
|
| 52 |
return {key: value for key, value in kwargs.items() if key in supported_kwargs}
|
| 53 |
|
| 54 |
|
| 55 |
+
def _is_tie_weights_unexpected_kwarg_error(exc: Exception) -> bool:
|
| 56 |
+
message = str(exc)
|
| 57 |
+
return "tie_weights" in message and "unexpected keyword argument" in message
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _iter_subclasses(cls: type) -> Generator[type, None, None]:
|
| 61 |
+
for sub in cls.__subclasses__():
|
| 62 |
+
yield sub
|
| 63 |
+
yield from _iter_subclasses(sub)
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _install_global_tie_weights_compat_patch() -> Callable[[], None]:
|
| 67 |
+
"""Patch loaded PreTrainedModel subclasses to absorb unsupported kwargs."""
|
| 68 |
+
from transformers import modeling_utils
|
| 69 |
+
|
| 70 |
+
base_cls = modeling_utils.PreTrainedModel
|
| 71 |
+
classes = [base_cls, *_iter_subclasses(base_cls)]
|
| 72 |
+
restore_stack: list[tuple[type, object]] = []
|
| 73 |
+
|
| 74 |
+
for cls in classes:
|
| 75 |
+
if "tie_weights" not in cls.__dict__:
|
| 76 |
+
continue
|
| 77 |
+
original_attr = cls.__dict__["tie_weights"]
|
| 78 |
+
original_callable = _unwrap_bound_callable(getattr(cls, "tie_weights"))
|
| 79 |
+
|
| 80 |
+
def _make_compat(callable_impl):
|
| 81 |
+
def _compat_tie_weights(self, *args, **kwargs):
|
| 82 |
+
filtered_kwargs = _filter_supported_kwargs(callable_impl, kwargs)
|
| 83 |
+
try:
|
| 84 |
+
return callable_impl(self, *args, **filtered_kwargs)
|
| 85 |
+
except TypeError as error:
|
| 86 |
+
if kwargs and "unexpected keyword argument" in str(error):
|
| 87 |
+
return callable_impl(self, *args)
|
| 88 |
+
raise
|
| 89 |
+
|
| 90 |
+
return _compat_tie_weights
|
| 91 |
+
|
| 92 |
+
setattr(cls, "tie_weights", _make_compat(original_callable))
|
| 93 |
+
restore_stack.append((cls, original_attr))
|
| 94 |
+
|
| 95 |
+
def _restore() -> None:
|
| 96 |
+
for cls, original_attr in reversed(restore_stack):
|
| 97 |
+
setattr(cls, "tie_weights", original_attr)
|
| 98 |
+
|
| 99 |
+
return _restore
|
| 100 |
+
|
| 101 |
+
|
| 102 |
def _install_tie_weights_compat_patch(
|
| 103 |
model,
|
| 104 |
*,
|
|
|
|
| 338 |
|
| 339 |
modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_patched_load_pretrained_func)
|
| 340 |
try:
|
| 341 |
+
try:
|
| 342 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 343 |
+
spec.hf_repo,
|
| 344 |
+
device_map=device_map,
|
| 345 |
+
**model_kwargs,
|
| 346 |
+
)
|
| 347 |
+
except TypeError as exc:
|
| 348 |
+
if not _is_tie_weights_unexpected_kwarg_error(exc):
|
| 349 |
+
raise
|
| 350 |
+
logger.warning(
|
| 351 |
+
"Retrying model load for %s after tie_weights kwarg mismatch: %s",
|
| 352 |
+
spec.hf_repo,
|
| 353 |
+
exc,
|
| 354 |
+
)
|
| 355 |
+
restore_global_tie_weights = _install_global_tie_weights_compat_patch()
|
| 356 |
+
try:
|
| 357 |
+
model = AutoModelForCausalLM.from_pretrained(
|
| 358 |
+
spec.hf_repo,
|
| 359 |
+
device_map=device_map,
|
| 360 |
+
**model_kwargs,
|
| 361 |
+
)
|
| 362 |
+
finally:
|
| 363 |
+
restore_global_tie_weights()
|
| 364 |
finally:
|
| 365 |
modeling_utils.PreTrainedModel._load_pretrained_model = _restore_loader_attr
|
| 366 |
logger.info("Model ready in %.2fs", time.perf_counter() - t1)
|