Spaces:
Running
Running
Kyryll Kochkin commited on
Commit ·
d96feba
1
Parent(s): efc975c
AI eventually fixed tests and completions
Browse files- app/core/engine.py +60 -22
- tests/test_core_helpers.py +43 -0
app/core/engine.py
CHANGED
|
@@ -57,6 +57,60 @@ def _is_tie_weights_unexpected_kwarg_error(exc: Exception) -> bool:
|
|
| 57 |
return "tie_weights" in message and "unexpected keyword argument" in message
|
| 58 |
|
| 59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
def _iter_subclasses(cls: type) -> Generator[type, None, None]:
|
| 61 |
for sub in cls.__subclasses__():
|
| 62 |
yield sub
|
|
@@ -315,28 +369,12 @@ class _ModelHandle:
|
|
| 315 |
device_pref,
|
| 316 |
" (device_map=auto)" if device_map else "",
|
| 317 |
)
|
| 318 |
-
# Patch _load_pretrained_model to
|
| 319 |
-
#
|
| 320 |
from transformers import modeling_utils
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
_restore_loader_attr = classmethod(_orig_load_pretrained_func)
|
| 325 |
-
else:
|
| 326 |
-
_restore_loader_attr = _orig_load_pretrained_func
|
| 327 |
-
|
| 328 |
-
def _patched_load_pretrained_func(cls, model, *args, **kwargs):
|
| 329 |
-
# Patch tie_weights to absorb kwargs passed by newer transformers.
|
| 330 |
-
restore_tie_weights = _install_tie_weights_compat_patch(
|
| 331 |
-
model,
|
| 332 |
-
extra_classes=(cls,),
|
| 333 |
-
)
|
| 334 |
-
try:
|
| 335 |
-
return _orig_load_pretrained_func(cls, model, *args, **kwargs)
|
| 336 |
-
finally:
|
| 337 |
-
restore_tie_weights()
|
| 338 |
-
|
| 339 |
-
modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_patched_load_pretrained_func)
|
| 340 |
try:
|
| 341 |
try:
|
| 342 |
model = AutoModelForCausalLM.from_pretrained(
|
|
@@ -362,7 +400,7 @@ class _ModelHandle:
|
|
| 362 |
finally:
|
| 363 |
restore_global_tie_weights()
|
| 364 |
finally:
|
| 365 |
-
|
| 366 |
logger.info("Model ready in %.2fs", time.perf_counter() - t1)
|
| 367 |
if device_map is None:
|
| 368 |
model = model.to(device_pref)
|
|
|
|
| 57 |
return "tie_weights" in message and "unexpected keyword argument" in message
|
| 58 |
|
| 59 |
|
| 60 |
+
def _install_loader_tie_weights_patch(pretrained_model_cls: type) -> Callable[[], None]:
|
| 61 |
+
"""Wrap _load_pretrained_model while preserving its descriptor type."""
|
| 62 |
+
original_attr = pretrained_model_cls.__dict__.get("_load_pretrained_model")
|
| 63 |
+
if original_attr is None:
|
| 64 |
+
return lambda: None
|
| 65 |
+
|
| 66 |
+
if isinstance(original_attr, classmethod):
|
| 67 |
+
original_callable = original_attr.__func__
|
| 68 |
+
|
| 69 |
+
def _patched_loader(cls, *args, **kwargs):
|
| 70 |
+
model = args[0] if args else kwargs.get("model")
|
| 71 |
+
restore_tie_weights = _install_tie_weights_compat_patch(
|
| 72 |
+
model,
|
| 73 |
+
extra_classes=(cls,),
|
| 74 |
+
)
|
| 75 |
+
try:
|
| 76 |
+
return original_callable(cls, *args, **kwargs)
|
| 77 |
+
finally:
|
| 78 |
+
restore_tie_weights()
|
| 79 |
+
|
| 80 |
+
patched_attr = classmethod(_patched_loader)
|
| 81 |
+
elif isinstance(original_attr, staticmethod):
|
| 82 |
+
original_callable = original_attr.__func__
|
| 83 |
+
|
| 84 |
+
def _patched_loader(*args, **kwargs):
|
| 85 |
+
model = args[0] if args else kwargs.get("model")
|
| 86 |
+
restore_tie_weights = _install_tie_weights_compat_patch(model)
|
| 87 |
+
try:
|
| 88 |
+
return original_callable(*args, **kwargs)
|
| 89 |
+
finally:
|
| 90 |
+
restore_tie_weights()
|
| 91 |
+
|
| 92 |
+
patched_attr = staticmethod(_patched_loader)
|
| 93 |
+
else:
|
| 94 |
+
original_callable = original_attr
|
| 95 |
+
|
| 96 |
+
def _patched_loader(*args, **kwargs):
|
| 97 |
+
model = args[0] if args else kwargs.get("model")
|
| 98 |
+
restore_tie_weights = _install_tie_weights_compat_patch(model)
|
| 99 |
+
try:
|
| 100 |
+
return original_callable(*args, **kwargs)
|
| 101 |
+
finally:
|
| 102 |
+
restore_tie_weights()
|
| 103 |
+
|
| 104 |
+
patched_attr = _patched_loader
|
| 105 |
+
|
| 106 |
+
setattr(pretrained_model_cls, "_load_pretrained_model", patched_attr)
|
| 107 |
+
|
| 108 |
+
def _restore() -> None:
|
| 109 |
+
setattr(pretrained_model_cls, "_load_pretrained_model", original_attr)
|
| 110 |
+
|
| 111 |
+
return _restore
|
| 112 |
+
|
| 113 |
+
|
| 114 |
def _iter_subclasses(cls: type) -> Generator[type, None, None]:
|
| 115 |
for sub in cls.__subclasses__():
|
| 116 |
yield sub
|
|
|
|
| 369 |
device_pref,
|
| 370 |
" (device_map=auto)" if device_map else "",
|
| 371 |
)
|
| 372 |
+
# Patch _load_pretrained_model to inject tie_weights compatibility
|
| 373 |
+
# while preserving whatever descriptor type transformers currently uses.
|
| 374 |
from transformers import modeling_utils
|
| 375 |
+
restore_loader_patch = _install_loader_tie_weights_patch(
|
| 376 |
+
modeling_utils.PreTrainedModel
|
| 377 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
try:
|
| 379 |
try:
|
| 380 |
model = AutoModelForCausalLM.from_pretrained(
|
|
|
|
| 400 |
finally:
|
| 401 |
restore_global_tie_weights()
|
| 402 |
finally:
|
| 403 |
+
restore_loader_patch()
|
| 404 |
logger.info("Model ready in %.2fs", time.perf_counter() - t1)
|
| 405 |
if device_map is None:
|
| 406 |
model = model.to(device_pref)
|
tests/test_core_helpers.py
CHANGED
|
@@ -161,3 +161,46 @@ def test_install_tie_weights_compat_patch_covers_class_dispatch() -> None:
|
|
| 161 |
assert Base.tie_weights(instance, missing_keys=set(), recompute_mapping=False) == "ok"
|
| 162 |
finally:
|
| 163 |
restore()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
assert Base.tie_weights(instance, missing_keys=set(), recompute_mapping=False) == "ok"
|
| 162 |
finally:
|
| 163 |
restore()
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def test_install_loader_tie_weights_patch_handles_plain_function_descriptor() -> None:
|
| 167 |
+
class DemoModel:
|
| 168 |
+
def tie_weights(self): # pragma: no cover - shape-only test
|
| 169 |
+
return "ok"
|
| 170 |
+
|
| 171 |
+
class DummyLoader:
|
| 172 |
+
def _load_pretrained_model(model, state, files, name): # noqa: N805
|
| 173 |
+
return model.tie_weights(missing_keys=set(), recompute_mapping=False), name
|
| 174 |
+
|
| 175 |
+
model = DemoModel()
|
| 176 |
+
with pytest.raises(TypeError):
|
| 177 |
+
model.tie_weights(missing_keys=set(), recompute_mapping=False)
|
| 178 |
+
|
| 179 |
+
restore_loader = engine._install_loader_tie_weights_patch(DummyLoader)
|
| 180 |
+
try:
|
| 181 |
+
result = DummyLoader._load_pretrained_model(model, None, None, "demo")
|
| 182 |
+
assert result == ("ok", "demo")
|
| 183 |
+
finally:
|
| 184 |
+
restore_loader()
|
| 185 |
+
|
| 186 |
+
with pytest.raises(TypeError):
|
| 187 |
+
model.tie_weights(missing_keys=set(), recompute_mapping=False)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
def test_install_loader_tie_weights_patch_handles_classmethod_descriptor() -> None:
|
| 191 |
+
class DemoModel:
|
| 192 |
+
def tie_weights(self): # pragma: no cover - shape-only test
|
| 193 |
+
return "ok"
|
| 194 |
+
|
| 195 |
+
class DummyLoader:
|
| 196 |
+
@classmethod
|
| 197 |
+
def _load_pretrained_model(cls, model, state, files, name): # noqa: N805
|
| 198 |
+
return cls.__name__, model.tie_weights(recompute_mapping=False), name
|
| 199 |
+
|
| 200 |
+
model = DemoModel()
|
| 201 |
+
restore_loader = engine._install_loader_tie_weights_patch(DummyLoader)
|
| 202 |
+
try:
|
| 203 |
+
result = DummyLoader._load_pretrained_model(model, None, None, "demo")
|
| 204 |
+
assert result == ("DummyLoader", "ok", "demo")
|
| 205 |
+
finally:
|
| 206 |
+
restore_loader()
|