Kyryll Kochkin commited on
Commit
efc975c
·
1 Parent(s): 4e203da

AI tries to fix tests v4

Browse files
Files changed (1) hide show
  1. app/core/engine.py +70 -5
app/core/engine.py CHANGED
@@ -52,6 +52,53 @@ def _filter_supported_kwargs(func, kwargs: dict) -> dict:
52
  return {key: value for key, value in kwargs.items() if key in supported_kwargs}
53
 
54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  def _install_tie_weights_compat_patch(
56
  model,
57
  *,
@@ -291,11 +338,29 @@ class _ModelHandle:
291
 
292
  modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_patched_load_pretrained_func)
293
  try:
294
- model = AutoModelForCausalLM.from_pretrained(
295
- spec.hf_repo,
296
- device_map=device_map,
297
- **model_kwargs,
298
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
299
  finally:
300
  modeling_utils.PreTrainedModel._load_pretrained_model = _restore_loader_attr
301
  logger.info("Model ready in %.2fs", time.perf_counter() - t1)
 
52
  return {key: value for key, value in kwargs.items() if key in supported_kwargs}
53
 
54
 
55
+ def _is_tie_weights_unexpected_kwarg_error(exc: Exception) -> bool:
56
+ message = str(exc)
57
+ return "tie_weights" in message and "unexpected keyword argument" in message
58
+
59
+
60
+ def _iter_subclasses(cls: type) -> Generator[type, None, None]:
61
+ for sub in cls.__subclasses__():
62
+ yield sub
63
+ yield from _iter_subclasses(sub)
64
+
65
+
66
+ def _install_global_tie_weights_compat_patch() -> Callable[[], None]:
67
+ """Patch loaded PreTrainedModel subclasses to absorb unsupported kwargs."""
68
+ from transformers import modeling_utils
69
+
70
+ base_cls = modeling_utils.PreTrainedModel
71
+ classes = [base_cls, *_iter_subclasses(base_cls)]
72
+ restore_stack: list[tuple[type, object]] = []
73
+
74
+ for cls in classes:
75
+ if "tie_weights" not in cls.__dict__:
76
+ continue
77
+ original_attr = cls.__dict__["tie_weights"]
78
+ original_callable = _unwrap_bound_callable(getattr(cls, "tie_weights"))
79
+
80
+ def _make_compat(callable_impl):
81
+ def _compat_tie_weights(self, *args, **kwargs):
82
+ filtered_kwargs = _filter_supported_kwargs(callable_impl, kwargs)
83
+ try:
84
+ return callable_impl(self, *args, **filtered_kwargs)
85
+ except TypeError as error:
86
+ if kwargs and "unexpected keyword argument" in str(error):
87
+ return callable_impl(self, *args)
88
+ raise
89
+
90
+ return _compat_tie_weights
91
+
92
+ setattr(cls, "tie_weights", _make_compat(original_callable))
93
+ restore_stack.append((cls, original_attr))
94
+
95
+ def _restore() -> None:
96
+ for cls, original_attr in reversed(restore_stack):
97
+ setattr(cls, "tie_weights", original_attr)
98
+
99
+ return _restore
100
+
101
+
102
  def _install_tie_weights_compat_patch(
103
  model,
104
  *,
 
338
 
339
  modeling_utils.PreTrainedModel._load_pretrained_model = classmethod(_patched_load_pretrained_func)
340
  try:
341
+ try:
342
+ model = AutoModelForCausalLM.from_pretrained(
343
+ spec.hf_repo,
344
+ device_map=device_map,
345
+ **model_kwargs,
346
+ )
347
+ except TypeError as exc:
348
+ if not _is_tie_weights_unexpected_kwarg_error(exc):
349
+ raise
350
+ logger.warning(
351
+ "Retrying model load for %s after tie_weights kwarg mismatch: %s",
352
+ spec.hf_repo,
353
+ exc,
354
+ )
355
+ restore_global_tie_weights = _install_global_tie_weights_compat_patch()
356
+ try:
357
+ model = AutoModelForCausalLM.from_pretrained(
358
+ spec.hf_repo,
359
+ device_map=device_map,
360
+ **model_kwargs,
361
+ )
362
+ finally:
363
+ restore_global_tie_weights()
364
  finally:
365
  modeling_utils.PreTrainedModel._load_pretrained_model = _restore_loader_attr
366
  logger.info("Model ready in %.2fs", time.perf_counter() - t1)