File size: 2,475 Bytes
24c2665 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
"""
No-sync sharding manager for TTRLVR that disables weight synchronization like original AZR.
"""
import logging
import os
from torch.distributed.device_mesh import DeviceMesh
from verl.workers.sharding_manager.fsdp_vllm import FSDPVLLMShardingManager
from verl.utils.debug import GPUMemoryLogger, log_gpu_memory_usage
from verl.utils.device import get_torch_device
logger = logging.getLogger(__file__)
logger.setLevel(os.getenv("VERL_LOGGING_LEVEL", "WARN"))
class NoSyncFSDPVLLMShardingManager(FSDPVLLMShardingManager):
"""
A custom sharding manager that disables weight synchronization between FSDP and VLLM.
This mimics the behavior of original AZR where VLLM weights are not updated during training.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.sync_weights = False # Disable weight sync by default
logger.info("🚫 NoSyncFSDPVLLMShardingManager initialized - weight sync disabled")
@GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger)
def __enter__(self):
"""
Enter the sharding manager context without syncing weights.
This keeps VLLM using the initial weights throughout the epoch.
"""
# Just empty cache and set random states
get_torch_device().empty_cache()
log_gpu_memory_usage("After empty_cache in no-sync sharding manager", logger=logger)
# Important: need to manually set the random states of each tp to be identical
if self.device_mesh is not None:
self.torch_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.gen_random_states)
logger.info("✅ Entered no-sync sharding manager - skipping weight synchronization")
@GPUMemoryLogger(role="no_sync_fsdp_vllm_sharding_manager", logger=logger)
def __exit__(self, exc_type, exc_value, traceback):
"""
Exit the sharding manager context.
"""
# Set module back to train mode
self.module.train()
# Empty cache after compute
get_torch_device().empty_cache()
# Restore random states
if self.device_mesh is not None:
self.gen_random_states = get_torch_device().get_rng_state()
get_torch_device().set_rng_state(self.torch_random_states)
logger.info("✅ Exited no-sync sharding manager") |