trl-mcsd / tests /conftest.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
from functools import wraps
import pytest
import torch
# ============================================================================
# Model Revision Override
# ============================================================================
# To test a tiny model PR before merging to main:
# 1. Add the full model_id and PR revision to this dict
# 2. Commit and push to trigger CI
# 3. Once CI is green, merge the tiny model PR on HF Hub
# 4. Remove the entry from this dict and commit
#
# Example:
# MODEL_REVISIONS = {
# "trl-internal-testing/tiny-Qwen2ForCausalLM-2.5": "refs/pr/3",
# "trl-internal-testing/tiny-LlavaForConditionalGeneration": "refs/pr/5",
# }
# ============================================================================
MODEL_REVISIONS = {
# Add model_id: revision mappings here to test PRs
}
@pytest.fixture(autouse=True)
def apply_model_revisions(monkeypatch):
"""Auto-inject revision parameter for models defined in MODEL_REVISIONS."""
if not MODEL_REVISIONS:
return
from transformers import PreTrainedModel, PreTrainedTokenizerBase, ProcessorMixin
def create_classmethod_wrapper(original_classmethod):
# Extract the underlying function from the classmethod
original_func = original_classmethod.__func__
@wraps(original_func)
def wrapper(cls, pretrained_model_name_or_path, *args, **kwargs):
# Direct lookup: only inject if model_id is in the override dict
if pretrained_model_name_or_path in MODEL_REVISIONS:
if "revision" not in kwargs:
kwargs["revision"] = MODEL_REVISIONS[pretrained_model_name_or_path]
return original_func(cls, pretrained_model_name_or_path, *args, **kwargs)
# Re-wrap as classmethod
return classmethod(wrapper)
# Patch all transformers Auto* classes
for cls in [
PreTrainedModel,
PreTrainedTokenizerBase,
ProcessorMixin,
]:
monkeypatch.setattr(cls, "from_pretrained", create_classmethod_wrapper(cls.from_pretrained))
@pytest.fixture(autouse=True)
def cleanup_gpu():
"""
Automatically cleanup GPU memory after each test.
This fixture helps prevent CUDA out of memory errors when running tests in parallel with pytest-xdist by ensuring
models and tensors are properly garbage collected and GPU memory caches are cleared between tests.
"""
yield
# Cleanup after test
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.synchronize()