Kyryll Kochkin commited on
Commit
d96feba
·
1 Parent(s): efc975c

AI eventually fixed tests and completions

Browse files
Files changed (2) hide show
  1. app/core/engine.py +60 -22
  2. tests/test_core_helpers.py +43 -0
app/core/engine.py CHANGED
@@ -57,6 +57,60 @@ def _is_tie_weights_unexpected_kwarg_error(exc: Exception) -> bool:
57
  return "tie_weights" in message and "unexpected keyword argument" in message
58
 
59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  def _iter_subclasses(cls: type) -> Generator[type, None, None]:
61
  for sub in cls.__subclasses__():
62
  yield sub
@@ -315,28 +369,12 @@ class _ModelHandle:
315
  device_pref,
316
  " (device_map=auto)" if device_map else "",
317
  )
318
- # Patch _load_pretrained_model to fix tie_weights incompatibility
319
- # with newer transformers that pass missing_keys keyword argument
320
  from transformers import modeling_utils
321
- _orig_load_pretrained = modeling_utils.PreTrainedModel._load_pretrained_model
322
- _orig_load_pretrained_func = _unwrap_bound_callable(_orig_load_pretrained)
323
- if hasattr(_orig_load_pretrained, "__func__"):
324
- _restore_loader_attr = classmethod(_orig_load_pretrained_func)
325
- else:
326
- _restore_loader_attr = _orig_load_pretrained_func
327
-
328
- def _patched_load_pretrained_func(cls, model, *args, **kwargs):
329
- # Patch tie_weights to absorb kwargs passed by newer transformers.
330
- restore_tie_weights = _install_tie_weights_compat_patch(
331
- model,
332
- extra_classes=(cls,),
333
- )
334
- try:
335
- return _orig_load_pretrained_func(cls, model, *args, **kwargs)
336
- finally:
337
- restore_tie_weights()
338
-
339
- modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_patched_load_pretrained_func)
340
  try:
341
  try:
342
  model = AutoModelForCausalLM.from_pretrained(
@@ -362,7 +400,7 @@ class _ModelHandle:
362
  finally:
363
  restore_global_tie_weights()
364
  finally:
365
- modeling_utils.PreTrainedModel._load_pretrained_model = _restore_loader_attr
366
  logger.info("Model ready in %.2fs", time.perf_counter() - t1)
367
  if device_map is None:
368
  model = model.to(device_pref)
 
57
  return "tie_weights" in message and "unexpected keyword argument" in message
58
 
59
 
60
+ def _install_loader_tie_weights_patch(pretrained_model_cls: type) -> Callable[[], None]:
61
+ """Wrap _load_pretrained_model while preserving its descriptor type."""
62
+ original_attr = pretrained_model_cls.__dict__.get("_load_pretrained_model")
63
+ if original_attr is None:
64
+ return lambda: None
65
+
66
+ if isinstance(original_attr, classmethod):
67
+ original_callable = original_attr.__func__
68
+
69
+ def _patched_loader(cls, *args, **kwargs):
70
+ model = args[0] if args else kwargs.get("model")
71
+ restore_tie_weights = _install_tie_weights_compat_patch(
72
+ model,
73
+ extra_classes=(cls,),
74
+ )
75
+ try:
76
+ return original_callable(cls, *args, **kwargs)
77
+ finally:
78
+ restore_tie_weights()
79
+
80
+ patched_attr = classmethod(_patched_loader)
81
+ elif isinstance(original_attr, staticmethod):
82
+ original_callable = original_attr.__func__
83
+
84
+ def _patched_loader(*args, **kwargs):
85
+ model = args[0] if args else kwargs.get("model")
86
+ restore_tie_weights = _install_tie_weights_compat_patch(model)
87
+ try:
88
+ return original_callable(*args, **kwargs)
89
+ finally:
90
+ restore_tie_weights()
91
+
92
+ patched_attr = staticmethod(_patched_loader)
93
+ else:
94
+ original_callable = original_attr
95
+
96
+ def _patched_loader(*args, **kwargs):
97
+ model = args[0] if args else kwargs.get("model")
98
+ restore_tie_weights = _install_tie_weights_compat_patch(model)
99
+ try:
100
+ return original_callable(*args, **kwargs)
101
+ finally:
102
+ restore_tie_weights()
103
+
104
+ patched_attr = _patched_loader
105
+
106
+ setattr(pretrained_model_cls, "_load_pretrained_model", patched_attr)
107
+
108
+ def _restore() -> None:
109
+ setattr(pretrained_model_cls, "_load_pretrained_model", original_attr)
110
+
111
+ return _restore
112
+
113
+
114
  def _iter_subclasses(cls: type) -> Generator[type, None, None]:
115
  for sub in cls.__subclasses__():
116
  yield sub
 
369
  device_pref,
370
  " (device_map=auto)" if device_map else "",
371
  )
372
+ # Patch _load_pretrained_model to inject tie_weights compatibility
373
+ # while preserving whatever descriptor type transformers currently uses.
374
  from transformers import modeling_utils
375
+ restore_loader_patch = _install_loader_tie_weights_patch(
376
+ modeling_utils.PreTrainedModel
377
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
378
  try:
379
  try:
380
  model = AutoModelForCausalLM.from_pretrained(
 
400
  finally:
401
  restore_global_tie_weights()
402
  finally:
403
+ restore_loader_patch()
404
  logger.info("Model ready in %.2fs", time.perf_counter() - t1)
405
  if device_map is None:
406
  model = model.to(device_pref)
tests/test_core_helpers.py CHANGED
@@ -161,3 +161,46 @@ def test_install_tie_weights_compat_patch_covers_class_dispatch() -> None:
161
  assert Base.tie_weights(instance, missing_keys=set(), recompute_mapping=False) == "ok"
162
  finally:
163
  restore()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  assert Base.tie_weights(instance, missing_keys=set(), recompute_mapping=False) == "ok"
162
  finally:
163
  restore()
164
+
165
+
166
+ def test_install_loader_tie_weights_patch_handles_plain_function_descriptor() -> None:
167
+ class DemoModel:
168
+ def tie_weights(self): # pragma: no cover - shape-only test
169
+ return "ok"
170
+
171
+ class DummyLoader:
172
+ def _load_pretrained_model(model, state, files, name): # noqa: N805
173
+ return model.tie_weights(missing_keys=set(), recompute_mapping=False), name
174
+
175
+ model = DemoModel()
176
+ with pytest.raises(TypeError):
177
+ model.tie_weights(missing_keys=set(), recompute_mapping=False)
178
+
179
+ restore_loader = engine._install_loader_tie_weights_patch(DummyLoader)
180
+ try:
181
+ result = DummyLoader._load_pretrained_model(model, None, None, "demo")
182
+ assert result == ("ok", "demo")
183
+ finally:
184
+ restore_loader()
185
+
186
+ with pytest.raises(TypeError):
187
+ model.tie_weights(missing_keys=set(), recompute_mapping=False)
188
+
189
+
190
+ def test_install_loader_tie_weights_patch_handles_classmethod_descriptor() -> None:
191
+ class DemoModel:
192
+ def tie_weights(self): # pragma: no cover - shape-only test
193
+ return "ok"
194
+
195
+ class DummyLoader:
196
+ @classmethod
197
+ def _load_pretrained_model(cls, model, state, files, name): # noqa: N805
198
+ return cls.__name__, model.tie_weights(recompute_mapping=False), name
199
+
200
+ model = DemoModel()
201
+ restore_loader = engine._install_loader_tie_weights_patch(DummyLoader)
202
+ try:
203
+ result = DummyLoader._load_pretrained_model(model, None, None, "demo")
204
+ assert result == ("DummyLoader", "ok", "demo")
205
+ finally:
206
+ restore_loader()