File size: 2,475 Bytes
24c2665
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
"""
No-sync sharding manager for TTRLVR that disables weight synchronization like original AZR.
"""
import logging
import os
from torch.distributed.device_mesh import DeviceMesh
from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage
from verl.utils.device import get_torch_device

logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))


class NoSyncFSDPVLLMShardingManager(FSDPVLLMShardingManager):
    """
    A custom sharding manager that disables weight synchronization between FSDP and VLLM.
    This mimics the behavior of original AZR where VLLM weights are not updated during training.
    """
    
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.sync_weights = False  # Disable weight sync by default
        logger.info("🚫 NoSyncFSDPVLLMShardingManager initialized - weight sync disabled")
    
    @GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger)
    def __enter__(self):
        """
        Enter the sharding manager context without syncing weights.
        This keeps VLLM using the initial weights throughout the epoch.
        """
        # Just empty cache and set random states
        get_torch_device().empty_cache()
        log_gpu_memory_usage("After empty_cache in no-sync sharding manager", logger=logger)
        
        # Important: need to manually set the random states of each tp to be identical
        if self.device_mesh is not None:
            self.torch_random_states = get_torch_device().get_rng_state()
            get_torch_device().set_rng_state(self.gen_random_states)
        
        logger.info("✅ Entered no-sync sharding manager - skipping weight synchronization")
        
    @GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger)
    def __exit__(self, exc_type, exc_value, traceback):
        """
        Exit the sharding manager context.
        """
        # Set module back to train mode
        self.module.train()
        
        # Empty cache after compute
        get_torch_device().empty_cache()
        
        # Restore random states
        if self.device_mesh is not None:
            self.gen_random_states = get_torch_device().get_rng_state()
            get_torch_device().set_rng_state(self.torch_random_states)
        
        logger.info("✅ Exited no-sync sharding manager")