# Copyright 2020-2026 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import gc from functools import wraps import pytest import torch # ============================================================================ # Model Revision Override # ============================================================================ # To test a tiny model PR before merging to main: # 1. Add the full model_id and PR revision to this dict # 2. Commit and push to trigger CI # 3. Once CI is green, merge the tiny model PR on HF Hub # 4. Remove the entry from this dict and commit # # Example: # MODEL_REVISIONS = { # "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5": "refs/pr/3", # "trl-internal-testing/tiny-LlavaForConditionalGeneration": "refs/pr/5", # } # ============================================================================ MODEL_REVISIONS = { # Add model_id: revision mappings here to test PRs } @pytest.fixture(autouse=True) def apply_model_revisions(monkeypatch): """Auto-inject revision parameter for models defined in MODEL_REVISIONS.""" if not MODEL_REVISIONS: return from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin def create_classmethod_wrapper(original_classmethod): # Extract the underlying function from the classmethod original_func = original_classmethod.__func__ @wraps(original_func) def wrapper(cls, pretrained_model_name_or_path, *args, **kwargs): # Direct lookup: only inject if model_id is in the override dict if pretrained_model_name_or_path in MODEL_REVISIONS: if "revision" not in kwargs: kwargs["revision"] = MODEL_REVISIONS[pretrained_model_name_or_path] return original_func(cls, pretrained_model_name_or_path, *args, **kwargs) # Re-wrap as classmethod return classmethod(wrapper) # Patch all transformers Auto* classes for cls in [ PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin, ]: monkeypatch.setattr(cls, "from_pretrained", create_classmethod_wrapper(cls.from_pretrained)) @pytest.fixture(autouse=True) def cleanup_gpu(): """ Automatically cleanup GPU memory after each test. This fixture helps prevent CUDA out of memory errors when running tests in parallel with pytest-xdist by ensuring models and tensors are properly garbage collected and GPU memory caches are cleared between tests. """ yield # Cleanup after test gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.synchronize()