| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """ |
| Test for attn_implementation override configuration in FSDP workers. |
| |
| This test verifies that the fix for honoring attn_implementation override config |
| works correctly in the ActorRolloutRefWorker._build_model_optimizer method. |
| """ |
|
|
| from unittest.mock import Mock, patch |
|
|
| import pytest |
| import torch |
| from omegaconf import OmegaConf |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| |
| try: |
| from verl.workers.config import FSDPEngineConfig |
| from verl.workers.fsdp_workers import ( |
| ActorRolloutRefWorker, |
| CriticWorker, |
| ) |
|
|
| VERL_AVAILABLE = True |
| except ImportError: |
| VERL_AVAILABLE = False |
|
|
|
|
| @pytest.mark.skipif(not VERL_AVAILABLE, reason="VERL components not available") |
| class TestFSDPAttnImplementation: |
| """Test cases for attn_implementation override in FSDP workers.""" |
|
|
| def test_attn_implementation_extraction_logic(self): |
| """Test the core logic for extracting attn_implementation from override config.""" |
|
|
| |
| override_config = {} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "flash_attention_2" |
|
|
| |
| override_config = {"attn_implementation": "eager"} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "eager" |
|
|
| |
| override_config = {"attn_implementation": "sdpa"} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "sdpa" |
|
|
| |
| override_config = {"other_setting": "value", "dropout": 0.1} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "flash_attention_2" |
|
|
| @patch("transformers.AutoConfig.from_pretrained") |
| @patch("transformers.AutoModelForCausalLM.from_pretrained") |
| def test_attn_implementation_passed_to_autoconfig(self, mock_model_from_pretrained, mock_config_from_pretrained): |
| """Test that attn_implementation is correctly passed to AutoConfig.from_pretrained.""" |
|
|
| |
| mock_config = Mock() |
| mock_config.tie_word_embeddings = False |
| mock_config.architectures = ["LlamaForCausalLM"] |
| mock_config_from_pretrained.return_value = mock_config |
|
|
| |
| mock_model = Mock() |
| mock_model_from_pretrained.return_value = mock_model |
|
|
| |
| test_cases = [ |
| ({}, "flash_attention_2"), |
| ({"attn_implementation": "eager"}, "eager"), |
| ({"attn_implementation": "sdpa"}, "sdpa"), |
| ] |
|
|
| for override_config, expected_attn_impl in test_cases: |
| |
| mock_config_from_pretrained.reset_mock() |
| mock_model_from_pretrained.reset_mock() |
|
|
| |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| AutoConfig.from_pretrained("test_path", trust_remote_code=False, attn_implementation=attn_implementation) |
|
|
| |
| mock_config_from_pretrained.assert_called_once_with( |
| "test_path", trust_remote_code=False, attn_implementation=expected_attn_impl |
| ) |
|
|
| @patch("transformers.AutoConfig.from_pretrained") |
| @patch("transformers.AutoModelForCausalLM.from_pretrained") |
| def test_attn_implementation_passed_to_model(self, mock_model_from_pretrained, mock_config_from_pretrained): |
| """Test that attn_implementation is correctly passed to model.from_pretrained.""" |
|
|
| |
| mock_config = Mock() |
| mock_config.tie_word_embeddings = False |
| mock_config.architectures = ["LlamaForCausalLM"] |
| mock_config_from_pretrained.return_value = mock_config |
|
|
| |
| mock_model = Mock() |
| mock_model_from_pretrained.return_value = mock_model |
|
|
| |
| override_config = {"attn_implementation": "eager"} |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| AutoModelForCausalLM.from_pretrained( |
| pretrained_model_name_or_path="test_path", |
| torch_dtype=torch.bfloat16, |
| config=mock_config, |
| trust_remote_code=False, |
| attn_implementation=attn_implementation, |
| ) |
|
|
| |
| mock_model_from_pretrained.assert_called_once_with( |
| pretrained_model_name_or_path="test_path", |
| torch_dtype=torch.bfloat16, |
| config=mock_config, |
| trust_remote_code=False, |
| attn_implementation="eager", |
| ) |
|
|
| def test_override_config_integration(self): |
| """Test that override_config from Hydra configuration works correctly.""" |
|
|
| |
| config_dict = { |
| "model": {"path": "/test/path", "override_config": {"attn_implementation": "eager", "dropout": 0.1}} |
| } |
|
|
| |
| omegaconf = OmegaConf.create(config_dict) |
|
|
| |
| override_model_config = OmegaConf.to_container(OmegaConf.create(omegaconf.model.get("override_config", {}))) |
|
|
| |
| attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "eager" |
|
|
| |
| assert override_model_config.get("dropout") == 0.1 |
|
|
| def test_hydra_plus_prefix_config(self): |
| """Test that Hydra +prefix configurations work correctly.""" |
|
|
| |
| |
|
|
| |
| config_dict = { |
| "actor_rollout_ref": { |
| "model": { |
| "path": "/test/path", |
| "override_config": { |
| "attn_implementation": "eager" |
| }, |
| } |
| } |
| } |
|
|
| omegaconf = OmegaConf.create(config_dict) |
|
|
| |
| override_model_config = OmegaConf.to_container( |
| OmegaConf.create(omegaconf.actor_rollout_ref.model.get("override_config", {})) |
| ) |
|
|
| |
| attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "eager" |
|
|
| def test_backward_compatibility(self): |
| """Test that the fix maintains backward compatibility.""" |
|
|
| |
| config_without_override = {} |
| attn_implementation = config_without_override.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "flash_attention_2" |
|
|
| |
| config_with_empty_override = {"override_config": {}} |
| override_config = config_with_empty_override.get("override_config", {}) |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "flash_attention_2" |
|
|
| |
| config_with_other_overrides = {"override_config": {"dropout": 0.1, "hidden_size": 1024}} |
| override_config = config_with_other_overrides.get("override_config", {}) |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "flash_attention_2" |
|
|
| def test_critic_attn_implementation_extraction_logic(self): |
| """Test the core logic for extracting attn_implementation from override config for CriticWorker.""" |
|
|
| |
| override_config = {} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "flash_attention_2" |
|
|
| |
| override_config = {"attn_implementation": "eager"} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "eager" |
|
|
| |
| override_config = {"attn_implementation": "sdpa"} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "sdpa" |
|
|
| |
| override_config = {"other_setting": "value", "dropout": 0.1} |
| attn_impl = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_impl == "flash_attention_2" |
|
|
| @patch("transformers.AutoConfig.from_pretrained") |
| def test_critic_attn_implementation_passed_to_autoconfig(self, mock_config_from_pretrained): |
| """Test that attn_implementation is correctly passed to AutoConfig.from_pretrained in CriticWorker.""" |
|
|
| |
| mock_config = Mock() |
| mock_config.tie_word_embeddings = False |
| mock_config.architectures = ["LlamaForCausalLM"] |
| mock_config.num_labels = 1 |
| mock_config_from_pretrained.return_value = mock_config |
|
|
| |
| test_cases = [ |
| ({}, "flash_attention_2"), |
| ({"attn_implementation": "eager"}, "eager"), |
| ({"attn_implementation": "sdpa"}, "sdpa"), |
| ] |
|
|
| for override_config, expected_attn_impl in test_cases: |
| |
| mock_config_from_pretrained.reset_mock() |
|
|
| |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| |
| AutoConfig.from_pretrained( |
| "test_path", |
| attn_implementation=attn_implementation, |
| trust_remote_code=False, |
| ) |
|
|
| |
| mock_config_from_pretrained.assert_called_once_with( |
| "test_path", |
| attn_implementation=expected_attn_impl, |
| trust_remote_code=False, |
| ) |
|
|
| def test_critic_override_config_integration(self): |
| """Test that override_config from Hydra configuration works correctly for CriticWorker.""" |
|
|
| |
| config_dict = { |
| "critic": { |
| "model": {"path": "/test/path", "override_config": {"attn_implementation": "eager", "dropout": 0.1}} |
| } |
| } |
|
|
| |
| omegaconf = OmegaConf.create(config_dict) |
|
|
| |
| override_model_config = OmegaConf.to_container( |
| OmegaConf.create(omegaconf.critic.model.get("override_config", {})) |
| ) |
|
|
| |
| attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "eager" |
|
|
| |
| assert override_model_config.get("dropout") == 0.1 |
|
|
| def test_critic_hydra_plus_prefix_config(self): |
| """Test that Hydra +prefix configurations work correctly for CriticWorker.""" |
|
|
| |
| |
|
|
| |
| config_dict = { |
| "critic": { |
| "model": { |
| "path": "/test/path", |
| "override_config": { |
| "attn_implementation": "eager" |
| }, |
| } |
| } |
| } |
|
|
| omegaconf = OmegaConf.create(config_dict) |
|
|
| |
| override_model_config = OmegaConf.to_container( |
| OmegaConf.create(omegaconf.critic.model.get("override_config", {})) |
| ) |
|
|
| |
| attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "eager" |
|
|
| def test_both_actor_and_critic_configuration(self): |
| """Test that both actor and critic can have different attn_implementation overrides simultaneously.""" |
|
|
| |
| config_dict = { |
| "actor_rollout_ref": {"model": {"override_config": {"attn_implementation": "eager"}}}, |
| "critic": {"model": {"override_config": {"attn_implementation": "sdpa"}}}, |
| } |
|
|
| omegaconf = OmegaConf.create(config_dict) |
|
|
| |
| actor_override_config = OmegaConf.to_container( |
| OmegaConf.create(omegaconf.actor_rollout_ref.model.get("override_config", {})) |
| ) |
| actor_attn_implementation = actor_override_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| critic_override_config = OmegaConf.to_container( |
| OmegaConf.create(omegaconf.critic.model.get("override_config", {})) |
| ) |
| critic_attn_implementation = critic_override_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| assert actor_attn_implementation == "eager" |
| assert critic_attn_implementation == "sdpa" |
|
|
| def test_critic_backward_compatibility(self): |
| """Test that the CriticWorker fix maintains backward compatibility.""" |
|
|
| |
| config_without_override = {} |
| attn_implementation = config_without_override.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "flash_attention_2" |
|
|
| |
| config_with_empty_override = {"override_config": {}} |
| override_config = config_with_empty_override.get("override_config", {}) |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "flash_attention_2" |
|
|
| |
| config_with_other_overrides = {"override_config": {"dropout": 0.1, "num_labels": 1}} |
| override_config = config_with_other_overrides.get("override_config", {}) |
| attn_implementation = override_config.get("attn_implementation", "flash_attention_2") |
| assert attn_implementation == "flash_attention_2" |
|
|
|
|
| def test_attn_implementation_fix_integration(): |
| """Integration test to verify the entire fix works as expected.""" |
|
|
| |
|
|
| |
| |
|
|
| |
| config_dict = {"actor_rollout_ref": {"model": {"override_config": {"attn_implementation": "eager"}}}} |
|
|
| |
| omegaconf = OmegaConf.create(config_dict) |
| override_model_config = OmegaConf.to_container( |
| OmegaConf.create(omegaconf.actor_rollout_ref.model.get("override_config", {})) |
| ) |
|
|
| |
| attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| assert attn_implementation == "eager" |
|
|
| |
| |
| config_params = {"attn_implementation": attn_implementation} |
| model_params = {"attn_implementation": attn_implementation} |
|
|
| assert config_params["attn_implementation"] == "eager" |
| assert model_params["attn_implementation"] == "eager" |
|
|
|
|
| def test_critic_attn_implementation_fix_integration(): |
| """Integration test to verify the entire fix works as expected for CriticWorker.""" |
|
|
| |
|
|
| |
| |
|
|
| |
| config_dict = {"critic": {"model": {"override_config": {"attn_implementation": "sdpa"}}}} |
|
|
| |
| omegaconf = OmegaConf.create(config_dict) |
| override_model_config = OmegaConf.to_container(OmegaConf.create(omegaconf.critic.model.get("override_config", {}))) |
|
|
| |
| attn_implementation = override_model_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| assert attn_implementation == "sdpa" |
|
|
| |
| config_params = {"attn_implementation": attn_implementation} |
|
|
| assert config_params["attn_implementation"] == "sdpa" |
|
|
|
|
| def test_complete_training_configuration(): |
| """Integration test for a complete training configuration with both actor and critic overrides.""" |
|
|
| |
| |
| config_dict = { |
| "actor_rollout_ref": { |
| "model": { |
| "path": "/shared/models/llama-7b", |
| "override_config": {"attn_implementation": "eager", "torch_dtype": "bfloat16"}, |
| } |
| }, |
| "critic": { |
| "model": { |
| "path": "/shared/models/llama-7b", |
| "override_config": {"attn_implementation": "sdpa", "num_labels": 1}, |
| } |
| }, |
| } |
|
|
| omegaconf = OmegaConf.create(config_dict) |
|
|
| |
| actor_override_config = OmegaConf.to_container( |
| OmegaConf.create(omegaconf.actor_rollout_ref.model.get("override_config", {})) |
| ) |
| critic_override_config = OmegaConf.to_container(OmegaConf.create(omegaconf.critic.model.get("override_config", {}))) |
|
|
| |
| actor_attn_implementation = actor_override_config.get("attn_implementation", "flash_attention_2") |
| critic_attn_implementation = critic_override_config.get("attn_implementation", "flash_attention_2") |
|
|
| |
| assert actor_attn_implementation == "eager" |
| assert critic_attn_implementation == "sdpa" |
|
|
| |
| assert actor_override_config.get("torch_dtype") == "bfloat16" |
| assert critic_override_config.get("num_labels") == 1 |
|
|
|
|
| if __name__ == "__main__": |
| |
| test_attn_implementation_fix_integration() |
| test_critic_attn_implementation_fix_integration() |
| test_complete_training_configuration() |
|
|
| if VERL_AVAILABLE: |
| |
| test_class = TestFSDPAttnImplementation() |
| test_class.test_attn_implementation_extraction_logic() |
| test_class.test_override_config_integration() |
| test_class.test_hydra_plus_prefix_config() |
| test_class.test_backward_compatibility() |
|
|
| |
| test_class.test_critic_attn_implementation_extraction_logic() |
| test_class.test_critic_override_config_integration() |
| test_class.test_critic_hydra_plus_prefix_config() |
| test_class.test_both_actor_and_critic_configuration() |
| test_class.test_critic_backward_compatibility() |
|
|
| print("✓ All FSDP attn_implementation tests passed!") |
| print("✓ All CriticWorker attn_implementation tests passed!") |
| else: |
| print("⚠ VERL components not available, skipping VERL-specific tests") |
|
|
| print("✓ Integration tests passed!") |
| print("✓ Critic integration tests passed!") |
|
|