File size: 5,751 Bytes
a402b9b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
import asyncio
import os
import unittest

import torch
import torch.distributed as dist
from torch.distributed.device_mesh import init_device_mesh
from transformers import AutoModelForCausalLM

from sglang.srt.entrypoints.engine import Engine
from sglang.srt.weight_sync.utils import update_weights
from sglang.test.ci.ci_register import register_cuda_ci
from sglang.test.test_utils import DEFAULT_SMALL_MODEL_NAME_FOR_TEST

register_cuda_ci(est_time=29, suite="stage-b-test-large-1-gpu")


class AsyncEngine(Engine):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    async def update_weights_from_tensor(self, update_weights_request):
        return await self.tokenizer_manager.update_weights_from_tensor(
            update_weights_request, None
        )


def is_distributed_available():
    """Check if distributed training environment is available"""
    required_vars = ["RANK", "WORLD_SIZE", "MASTER_ADDR", "MASTER_PORT"]
    return all(var in os.environ for var in required_vars)


def setup_single_process_distributed():
    """Setup distributed environment for single process testing"""
    if not is_distributed_available():
        os.environ["RANK"] = "0"
        os.environ["WORLD_SIZE"] = "1"
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = "12356"
        os.environ["LOCAL_RANK"] = "0"


class TestUtilsUpdateWeights(unittest.TestCase):
    """Test class for utils.update_weights function"""

    @classmethod
    def setUpClass(cls):
        """Setup distributed environment and test fixtures for the entire test class"""
        cls.setup_distributed()
        cls.setup_test_engine()
        cls.setup_test_model()
        cls.setup_device_mesh()

    @classmethod
    def tearDownClass(cls):
        """Cleanup after all tests"""
        if hasattr(cls, "engine") and cls.engine:
            cls.engine.shutdown()

        # Cleanup distributed
        if dist.is_initialized():
            dist.destroy_process_group()

    @classmethod
    def setup_distributed(cls):
        """Setup distributed environment for testing"""
        setup_single_process_distributed()

        if not dist.is_initialized():
            try:
                dist.init_process_group(
                    backend="nccl" if torch.cuda.is_available() else "gloo"
                )
            except Exception as e:
                raise unittest.SkipTest(
                    f"Could not initialize distributed backend: {e}"
                )

        cls.rank = dist.get_rank()
        cls.world_size = dist.get_world_size()

        if torch.cuda.is_available():
            torch.cuda.set_device(cls.rank % torch.cuda.device_count())

        # Set up environment variables
        os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
        os.environ["NCCL_CUMEM_ENABLE"] = "0"
        os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
        os.environ["CUDA_MODULE_LOADING"] = "AUTO"

    @classmethod
    def setup_test_engine(cls):
        """Setup test engine"""
        if cls.rank == 0:
            cls.engine = AsyncEngine(
                model_path=DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                dtype="bfloat16",
                mem_fraction_static=0.3,
                enable_memory_saver=True,
                tp_size=cls.world_size,
                disable_cuda_graph=False,
            )
        else:
            cls.engine = None

    @classmethod
    def setup_test_model(cls):
        """Load test model"""
        try:
            cls.model = AutoModelForCausalLM.from_pretrained(
                DEFAULT_SMALL_MODEL_NAME_FOR_TEST,
                device_map="cpu",
                trust_remote_code=True,
                low_cpu_mem_usage=True,
                torch_dtype=(
                    torch.float16 if torch.cuda.is_available() else torch.float32
                ),
            )
        except Exception as e:
            raise unittest.SkipTest(f"Could not load test model: {e}")

    @classmethod
    def setup_device_mesh(cls):
        """Create device mesh for testing"""
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available for device mesh")

        cls.device_mesh_key = "tp"
        cls.mesh = init_device_mesh(
            "cuda", (cls.world_size,), mesh_dim_names=(cls.device_mesh_key,)
        )

    def create_test_params_batch(self, model, num_params=64):
        """Create a batch of test parameters from the model"""
        param_names = []
        test_tensors = []

        # Get first few parameters from the model for testing
        for i, (name, tensor) in enumerate(model.named_parameters()):
            if i >= num_params:
                break
            param_names.append(name)
            # Create test tensor with known values, matching original shape and dtype
            test_tensor = torch.full_like(tensor, 1.5, dtype=tensor.dtype).cuda()
            test_tensors.append(test_tensor)

        return list(zip(param_names, test_tensors))

    def test_utils_update_weights(self):
        """Test basic functionality of utils.update_weights"""

        async def async_test():
            # Create test parameters batch
            params_batch = self.create_test_params_batch(self.model, num_params=2)

            # Test the utils.update_weights function
            result = await update_weights(
                engine=self.engine,
                params_batch=params_batch,
                device_mesh_key=self.device_mesh_key,
                device_mesh=self.mesh,
                load_format=None,
            )

            self.assertIn("Success", result)

        # Run the async test
        asyncio.run(async_test())


if __name__ == "__main__":
    unittest.main()