File size: 19,155 Bytes
a9bd396 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 | # Copyright 2024 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Run all tests: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py
# Run specific config: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc"
# Run multiple configs: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py -k "2Proc or 4Proc"
# Run spefic test: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc::test_model_dense_forward_train
# Run tests with a specific prefix: RUN_SLOW=1 pytest -v tests/tensor_parallel/test_tensor_parallel.py::TestTensorParallel2Proc -k "forward"
import os
import tempfile
import warnings
from safetensors import safe_open
from transformers import AutoModelForCausalLM, AutoTokenizer, is_torch_available
from transformers.integrations.tensor_parallel import get_packed_weights, get_tensor_shard, repack_weights
from transformers.testing_utils import (
TestCasePlus,
backend_device_count,
get_torch_dist_unique_port,
require_huggingface_hub_greater_or_equal,
require_torch_multi_accelerator,
torch_device,
)
if is_torch_available():
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
def global_wrapper(rank, func, tp, port, func_args, func_kwargs):
def setup_dist_env(rank, world_size, port):
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
world_size = tp
setup_dist_env(rank, world_size, port)
if torch.cuda.is_available():
torch.cuda.set_device(rank)
dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
else:
dist.init_process_group(backend="gloo", rank=rank, world_size=world_size)
func(rank, *func_args, **func_kwargs)
dist.barrier()
dist.destroy_process_group()
def init_distributed(tp: int):
def _init_distributed(func):
def wrapper(*args, **kwargs):
world_size = tp
port = get_torch_dist_unique_port()
spawn_args = (func, tp, port, args, kwargs)
mp.spawn(global_wrapper, args=spawn_args, nprocs=world_size)
return wrapper
return _init_distributed
class TestTensorParallelUtils(TestCasePlus):
def test_packed_unpacked_conversion(self):
WORLD_SIZE = 2
PACKED_BLOCK_SIZE = 800
SHARDING_DIM = 2
NUM_BLOCKS = 2
original_packed_weights = torch.randn(4, 512, 2 * PACKED_BLOCK_SIZE)
original_packed_weights.get_dtype = lambda: "F32" # get_packed_weights expects PySlice object
empty_param = torch.empty(4, 512, 2 * PACKED_BLOCK_SIZE)
class MockDeviceMesh:
def size(self):
return WORLD_SIZE
mock_mesh = (
MockDeviceMesh()
) # get_packed_weights only calls `.size()`, do this to avoid doing actual distributed run
packed_weights_0 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 0, SHARDING_DIM)
packed_weights_1 = get_packed_weights(original_packed_weights, empty_param, mock_mesh, 1, SHARDING_DIM)
# simulate all gather of sharded weights
packed_weights = torch.cat([packed_weights_0, packed_weights_1], dim=SHARDING_DIM)
unpacked_weights = repack_weights(packed_weights, SHARDING_DIM, WORLD_SIZE, NUM_BLOCKS)
assert torch.allclose(unpacked_weights, original_packed_weights)
class TestTensorParallelProperties(TestCasePlus):
def test_tp_plan_property_setter_getter(self):
"""Test that tp_plan property can be set and retrieved correctly."""
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
# Test setting empty plan
model.tp_plan = {}
self.assertEqual(model.tp_plan, {})
# Test setting a valid plan
valid_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
model.tp_plan = valid_plan
self.assertEqual(model.tp_plan, valid_plan)
# Test updating the plan
model.tp_plan.update({"model.layers.*.self_attn.k_proj": "colwise"})
expected_plan = {"model.layers.*.self_attn.q_proj": "colwise", "model.layers.*.self_attn.k_proj": "colwise"}
self.assertEqual(model.tp_plan, expected_plan)
# Test overriding existing entry
model.tp_plan.update({"model.layers.*.self_attn.q_proj": "colwise_rep"})
expected_plan = {
"model.layers.*.self_attn.q_proj": "colwise_rep",
"model.layers.*.self_attn.k_proj": "colwise",
}
self.assertEqual(model.tp_plan, expected_plan)
def test_tp_plan_validation_invalid_style(self):
"""Test that invalid parallel styles are rejected."""
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
# Test invalid parallel style
with self.assertRaises(ValueError) as context:
model.tp_plan = {"layers.*.self_attn.q_proj": "invalid_style"}
self.assertIn("Unsupported tensor parallel style 'invalid_style'", str(context.exception))
self.assertIn("Supported styles are", str(context.exception))
def test_tp_plan_validation_nonexistent_layer_warning(self):
"""Test that warnings are issued for non-existent layer patterns."""
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
# Test warning for non-existent layer pattern
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.tp_plan = {"nonexistent.*.layer": "colwise"}
# Check that a warning was issued
self.assertTrue(len(w) > 0)
warning_message = str(w[0].message)
self.assertIn("Layer pattern 'nonexistent.*.layer' does not match any parameters", warning_message)
def test_tp_plan_valid_layer_patterns(self):
"""Test that valid layer patterns are accepted without warnings."""
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
# Test valid layer patterns that should match the model structure
valid_plans = [
{"model.layers.*.self_attn.q_proj": "colwise"},
{"model.layers.*.self_attn.k_proj": "rowwise"},
{"model.layers.*.mlp.gate_proj": "colwise_rep"},
]
for plan in valid_plans:
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
model.tp_plan = plan
# Filter out any warnings that are not about layer patterns
layer_warnings = [
warning
for warning in w
if "Layer pattern" in str(warning.message)
and "does not match any parameters" in str(warning.message)
]
# Should not have layer pattern warnings for valid patterns
self.assertEqual(
len(layer_warnings),
0,
f"Unexpected warning for valid pattern {plan}: {[str(w.message) for w in layer_warnings]}",
)
# Verify the final plan was set correctly
self.assertEqual(model.tp_plan, valid_plans[-1])
def test_tp_plan_none_handling(self):
"""Test that None values are handled correctly."""
model_id = "JackFram/llama-68m"
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
# Test setting None
model.tp_plan = None
self.assertEqual(model.tp_plan, {})
# Test setting a plan after None
model.tp_plan = {"model.layers.*.self_attn.q_proj": "colwise"}
self.assertEqual(model.tp_plan, {"model.layers.*.self_attn.q_proj": "colwise"})
# ====== TEST FUNCTIONS ======
def _test_model_dense_forward_impl(rank, mode):
"""Implementation for comparing TP and non-TP model outputs."""
model_id = "JackFram/llama-68m"
# Ensure same random seed for reproducibility
torch.manual_seed(0)
# Load tokenizer and prepare inputs - same for both models
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt")
# Load TP model first to determine device
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
dist.barrier()
if mode == "eval":
model_tp.eval()
else:
model_tp.train()
# Load non-TP model and move to same device as TP model
device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
model = model.to(device)
if mode == "eval":
model.eval()
else:
model.train()
# Prepare inputs on the same device
input_ids = inputs.input_ids.to(device)
# Run forward pass on both models
with torch.no_grad():
# Non-TP model output
outputs = model(input_ids)
logits = outputs.logits
# TP model output
outputs_tp = model_tp(input_ids)
logits_tp = outputs_tp.logits
# Compare outputs - they should match
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
)
dist.barrier()
def _test_model_dense_backward_pass_impl(rank):
"""Implementation for comparing TP and non-TP model backward passes."""
model_id = "JackFram/llama-68m"
torch.manual_seed(0)
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32, tp_plan="auto")
dist.barrier()
model_tp.train()
device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype=torch.float32)
model = model.to(device)
model.train()
batch_size, seq_length = 2, 10
torch.manual_seed(42) # Different seed for inputs to ensure they're deterministic
input_ids = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
labels = torch.randint(0, model.config.vocab_size, (batch_size, seq_length), device=device)
outputs = model(input_ids, labels=labels)
loss = outputs.loss
loss.backward()
outputs_tp = model_tp(input_ids, labels=labels)
loss_tp = outputs_tp.loss
loss_tp.backward()
assert torch.allclose(loss, loss_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP model losses differ. Non-TP loss: {loss.item()}, TP loss: {loss_tp.item()}, Diff: {(loss - loss_tp).abs().item()}"
)
# Compare gradients for matching parameters
# Note: TP model may have sharded parameters (DTensors), so we slice the reference gradient to match
for (name, param), (name_tp, param_tp) in zip(model.named_parameters(), model_tp.named_parameters()):
if param.grad is not None and param_tp.grad is not None:
grad = param.grad
grad_tp = param_tp.grad
if isinstance(param_tp.data, dist.tensor.DTensor):
placement = param_tp.data.placements[0]
if hasattr(placement, "dim") and placement.dim is not None:
grad_shard = get_tensor_shard(grad, grad, param_tp.data.device_mesh, rank, placement.dim)
else:
grad_shard = grad
else:
grad_shard = grad
grad_tp_local = grad_tp.to_local() if isinstance(grad_tp, dist.tensor.DTensor) else grad_tp
assert torch.allclose(grad_shard.cpu(), grad_tp_local.cpu(), atol=1e-5, rtol=1e-5), (
f"Gradients differ for parameter {name}. Max diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().max().item()} | Min diff: {(grad_shard.cpu() - grad_tp_local.cpu()).abs().min().item()}"
)
dist.barrier()
def _test_model_dense_forward_compile_impl(rank, mode):
"""Implementation for comparing TP and non-TP model outputs with torch.compile."""
model_id = "JackFram/llama-68m"
torch.manual_seed(0)
tokenizer = AutoTokenizer.from_pretrained(model_id, use_fast=False)
prompt = "Can I help"
inputs = tokenizer(prompt, return_tensors="pt")
model_tp = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto", tp_plan="auto")
dist.barrier()
if mode == "eval":
model_tp.eval()
else:
model_tp.train()
device = model_tp.device
model = AutoModelForCausalLM.from_pretrained(model_id, dtype="auto")
model = model.to(device)
if mode == "eval":
model.eval()
else:
model.train()
# Compile both models
model.forward = torch.compile(model.forward)
model_tp.forward = torch.compile(model_tp.forward)
input_ids = inputs.input_ids.to(device)
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
outputs_tp = model_tp(input_ids)
logits_tp = outputs_tp.logits
assert torch.allclose(logits, logits_tp, atol=1e-5, rtol=1e-5), (
f"TP and non-TP model outputs differ. Max diff: {(logits - logits_tp).abs().max().item()} | Min diff: {(logits - logits_tp).abs().min().item()}"
)
dist.barrier()
def _test_model_dense_save_impl(rank, tmp_dir):
"""Implementation of test_model_save for distributed execution."""
model_id = "JackFram/llama-68m"
if dist.is_initialized():
kwargs = {"tp_plan": "auto"}
result_dir = f"{tmp_dir}/tp"
else:
kwargs = {}
result_dir = f"{tmp_dir}/nontp"
model = AutoModelForCausalLM.from_pretrained(model_id, **kwargs)
model.save_pretrained(result_dir)
class TestTensorParallelBase(TestCasePlus):
"""Base class for tensor parallel tests. Subclasses must set nproc_per_node."""
nproc_per_node = None
@require_torch_multi_accelerator
def test_model_dense_forward_eval(self):
"""Test that TP and non-TP models produce the same outputs in eval mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("eval")
@require_torch_multi_accelerator
def test_model_dense_forward_train(self):
"""Test that TP and non-TP models produce the same outputs in train mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_impl)("train")
@require_torch_multi_accelerator
def test_model_dense_backward_pass(self):
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
init_distributed(tp=self.nproc_per_node)(_test_model_dense_backward_pass_impl)()
@require_torch_multi_accelerator
def test_model_dense_forward_compile_eval(self):
"""Test that TP and non-TP models produce the same outputs with torch.compile in eval mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("eval")
@require_torch_multi_accelerator
def test_model_dense_forward_compile_train(self):
"""Test that TP and non-TP models produce the same outputs with torch.compile in train mode."""
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
init_distributed(tp=self.nproc_per_node)(_test_model_dense_forward_compile_impl)("train")
@require_huggingface_hub_greater_or_equal("0.31.4")
@require_torch_multi_accelerator
def test_model_dense_save(self):
if self.nproc_per_node is None:
self.skipTest("nproc_per_node not set")
if backend_device_count(torch_device) < self.nproc_per_node:
self.skipTest(f"Need at least {self.nproc_per_node} devices, have {backend_device_count(torch_device)}")
with tempfile.TemporaryDirectory() as tmp_dir:
# First run with TP (distributed)
init_distributed(tp=self.nproc_per_node)(_test_model_dense_save_impl)(tmp_dir)
# Then run without TP (non-distributed)
_test_model_dense_save_impl(0, tmp_dir)
non_tp_model_path = os.path.join(tmp_dir, "nontp")
tp_model_path = os.path.join(tmp_dir, "tp")
for filename in os.listdir(non_tp_model_path):
if not filename.endswith(".safetensors"):
continue
non_tp_model = safe_open(os.path.join(non_tp_model_path, filename), device="cpu", framework="pt")
tp_model = safe_open(os.path.join(tp_model_path, filename), device="cpu", framework="pt")
for non_tp_key in non_tp_model.keys():
non_tp_tensor = non_tp_model.get_tensor(non_tp_key)
tp_tensor = tp_model.get_tensor(non_tp_key)
assert torch.allclose(non_tp_tensor, tp_tensor), f"Tensor with key: {non_tp_key} does not match"
del non_tp_tensor, tp_tensor
class TestTensorParallel2Proc(TestTensorParallelBase):
"""Test tensor parallel with 2 processes."""
nproc_per_node = 2
class TestTensorParallel4Proc(TestTensorParallelBase):
"""Test tensor parallel with 4 processes."""
nproc_per_node = 4
|