| # Copyright 2020-2026 The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import gc | |
| from functools import wraps | |
| import pytest | |
| import torch | |
| # ============================================================================ | |
| # Model Revision Override | |
| # ============================================================================ | |
| # To test a tiny model PR before merging to main: | |
| # 1. Add the full model_id and PR revision to this dict | |
| # 2. Commit and push to trigger CI | |
| # 3. Once CI is green, merge the tiny model PR on HF Hub | |
| # 4. Remove the entry from this dict and commit | |
| # | |
| # Example: | |
| # MODEL_REVISIONS = { | |
| # "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5": "refs/pr/3", | |
| # "trl-internal-testing/tiny-LlavaForConditionalGeneration": "refs/pr/5", | |
| # } | |
| # ============================================================================ | |
| MODEL_REVISIONS = { | |
| # Add model_id: revision mappings here to test PRs | |
| } | |
| def apply_model_revisions(monkeypatch): | |
| """Auto-inject revision parameter for models defined in MODEL_REVISIONS.""" | |
| if not MODEL_REVISIONS: | |
| return | |
| from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin | |
| def create_classmethod_wrapper(original_classmethod): | |
| # Extract the underlying function from the classmethod | |
| original_func = original_classmethod.__func__ | |
| def wrapper(cls, pretrained_model_name_or_path, *args, **kwargs): | |
| # Direct lookup: only inject if model_id is in the override dict | |
| if pretrained_model_name_or_path in MODEL_REVISIONS: | |
| if "revision" not in kwargs: | |
| kwargs["revision"] = MODEL_REVISIONS[pretrained_model_name_or_path] | |
| return original_func(cls, pretrained_model_name_or_path, *args, **kwargs) | |
| # Re-wrap as classmethod | |
| return classmethod(wrapper) | |
| # Patch all transformers Auto* classes | |
| for cls in [ | |
| PreTrainedModel, | |
| PreTrainedTokenizerBase, | |
| ProcessorMixin, | |
| ]: | |
| monkeypatch.setattr(cls, "from_pretrained", create_classmethod_wrapper(cls.from_pretrained)) | |
| def cleanup_gpu(): | |
| """ | |
| Automatically cleanup GPU memory after each test. | |
| This fixture helps prevent CUDA out of memory errors when running tests in parallel with pytest-xdist by ensuring | |
| models and tensors are properly garbage collected and GPU memory caches are cleared between tests. | |
| """ | |
| yield | |
| # Cleanup after test | |
| gc.collect() | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |