| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| Test script to verify TiledMLP accuracy by comparing logits and gradients |
| between regular MLP and TiledMLP under FSDP2. |
| Run with: torchrun --nproc_per_node=2 tests/test_tiled_mlp_accuracy.py |
| """ |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.distributed.device_mesh import init_device_mesh |
| from torch.distributed.fsdp import fully_shard |
|
|
|
|
| def setup_distributed(): |
| dist.init_process_group(backend="nccl") |
| rank = dist.get_rank() |
| world_size = dist.get_world_size() |
| torch.cuda.set_device(rank) |
| return rank, world_size |
|
|
|
|
| def create_model(model_name="Qwen/Qwen3-1.7B", num_layers=2): |
| """Load a Qwen3-1.7B model with only 2 layers from pretrained weights.""" |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) |
| config.num_hidden_layers = num_layers |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| config=config, |
| torch_dtype=torch.bfloat16, |
| trust_remote_code=True, |
| attn_implementation="flash_attention_2", |
| ) |
| return model |
|
|
|
|
| def apply_fsdp2(model, device_mesh): |
| """Apply FSDP2 sharding to model.""" |
| for layer in model.model.layers: |
| fully_shard(layer, mesh=device_mesh) |
| fully_shard(model, mesh=device_mesh) |
| return model |
|
|
|
|
| def run_forward_backward(model, input_ids, labels): |
| """Run forward and backward pass, return logits and gradients.""" |
| model.zero_grad() |
|
|
| outputs = model(input_ids=input_ids, labels=labels) |
| logits = outputs.logits.clone().detach() |
| loss = outputs.loss |
|
|
| loss.backward() |
|
|
| |
| gradients = {} |
| for name, param in model.named_parameters(): |
| if "mlp" in name and param.grad is not None: |
| gradients[name] = param.grad.clone().detach() |
|
|
| return logits, gradients, loss.item() |
|
|
|
|
| def compare_results(logits1, grads1, logits2, grads2, rank): |
| """Compare logits and gradients between two runs.""" |
| |
| logits_diff = (logits1 - logits2).abs() |
| logits_max_diff = logits_diff.max().item() |
| logits_mean_diff = logits_diff.mean().item() |
|
|
| |
| all_pass = True |
| grad_results = [] |
| for name in sorted(grads1.keys()): |
| if name in grads2: |
| g1, g2 = grads1[name], grads2[name] |
| diff = (g1 - g2).abs() |
| max_diff = diff.max().item() |
| mean_diff = diff.mean().item() |
|
|
| |
| passed = max_diff < 1e-2 |
| if not passed: |
| all_pass = False |
| grad_results.append((name, max_diff, mean_diff, passed)) |
|
|
| |
| if rank == 0: |
| print("\n=== Comparison Results ===") |
| print("\nLogits:") |
| print(f" Max diff: {logits_max_diff:.2e}") |
| print(f" Mean diff: {logits_mean_diff:.2e}") |
|
|
| print("\nMLP Parameter Gradients:") |
| if grad_results: |
| for name, max_diff, mean_diff, passed in grad_results: |
| status = "✓" if passed else "✗" |
| print(f" {name}: max={max_diff:.2e}, mean={mean_diff:.2e} {status}") |
| else: |
| print(" (Gradients sharded to other ranks under FSDP2)") |
|
|
| return all_pass |
|
|
|
|
| def main(): |
| rank, world_size = setup_distributed() |
| device_mesh = init_device_mesh("cuda", (world_size,)) |
|
|
| model_name = "Qwen/Qwen3-1.7B" |
| num_layers = 2 |
|
|
| if rank == 0: |
| print(f"Running TiledMLP accuracy test with {world_size} GPUs") |
| print(f"Model: {model_name} ({num_layers} layers, from pretrained)") |
|
|
| dist.barrier() |
|
|
| |
| if rank == 0: |
| print("\n" + "=" * 60) |
| print("Creating Model 1 (without TiledMLP)") |
| print("=" * 60) |
|
|
| model1 = create_model(model_name, num_layers) |
| model1 = apply_fsdp2(model1, device_mesh) |
| model1 = model1.cuda() |
|
|
| |
| torch.manual_seed(42) |
| batch_size, seq_len = 2, 256 |
| vocab_size = model1.config.vocab_size |
| input_ids = torch.randint(0, vocab_size, (batch_size, seq_len), device="cuda") |
| labels = input_ids.clone() |
|
|
| |
| if rank == 0: |
| print("\n" + "=" * 60) |
| print("Running forward/backward on Model 1 (without TiledMLP)") |
| print("=" * 60) |
|
|
| logits1, grads1, loss1 = run_forward_backward(model1, input_ids, labels) |
| if rank == 0: |
| print(f"Loss: {loss1:.4f}") |
|
|
| |
| del model1 |
| torch.cuda.empty_cache() |
|
|
| dist.barrier() |
|
|
| |
| if rank == 0: |
| print("\n" + "=" * 60) |
| print("Creating Model 2 (with TiledMLP, patch before FSDP2)") |
| print("=" * 60) |
|
|
| model2 = create_model(model_name, num_layers) |
|
|
| |
| if rank == 0: |
| print("Applying TiledMLP monkey patch before FSDP2...") |
|
|
| from verl.models.transformers.tiled_mlp import apply_tiled_mlp_monkey_patch |
|
|
| apply_tiled_mlp_monkey_patch(num_shards=4, model_type="qwen3") |
|
|
| model2 = apply_fsdp2(model2, device_mesh) |
| model2 = model2.cuda() |
|
|
| dist.barrier() |
|
|
| |
| if rank == 0: |
| print("\n" + "=" * 60) |
| print("Running forward/backward on Model 2 (with TiledMLP)") |
| print("=" * 60) |
|
|
| logits2, grads2, loss2 = run_forward_backward(model2, input_ids, labels) |
| if rank == 0: |
| print(f"Loss: {loss2:.4f}") |
|
|
| dist.barrier() |
|
|
| |
| all_pass = compare_results(logits1, grads1, logits2, grads2, rank) |
|
|
| dist.barrier() |
|
|
| if rank == 0: |
| print("\n" + "=" * 60) |
| print("SUMMARY") |
| print("=" * 60) |
| print(f"Loss diff: {abs(loss1 - loss2):.2e}") |
| print(f"All gradient checks: {'PASS' if all_pass else 'FAIL'}") |
|
|
| |
| del model2 |
| torch.cuda.empty_cache() |
|
|
| dist.destroy_process_group() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|