| | |
| | """ |
| | Integration test for NVFP4 model loading. |
| | |
| | This tests that the model can be loaded from sharded safetensors |
| | and that all weights have correct shapes and flags. |
| | """ |
| |
|
| | import os |
| | import sys |
| | import json |
| | import torch |
| |
|
| | |
| | from model import Transformer, ModelArgs |
| | from generate import load_sharded_model |
| |
|
| |
|
| | def clear_cache(): |
| | """Clear system cache to free memory.""" |
| | print("Clearing system cache...") |
| | try: |
| | import subprocess |
| | subprocess.run( |
| | ['sudo', 'sh', '-c', 'echo 3 > /proc/sys/vm/drop_caches'], |
| | check=False, capture_output=True, text=True |
| | ) |
| | print(" PASS: Cache cleared\n") |
| | except Exception as e: |
| | print(f" WARN: Could not clear cache: {e}\n") |
| |
|
| |
|
| | def check_memory(): |
| | """Check available memory.""" |
| | try: |
| | import psutil |
| | mem = psutil.virtual_memory() |
| | print(f"Memory: {mem.available / 1e9:.1f}GB available / {mem.total / 1e9:.1f}GB total") |
| | print(f" {mem.percent:.1f}% used\n") |
| | return mem.available / 1e9 |
| | except ImportError: |
| | print("psutil not available, skipping memory check\n") |
| | return None |
| |
|
| |
|
| | def test_config_loading(): |
| | """Test 1: Load and validate config.""" |
| | print("=" * 70) |
| | print("Test 1: Load Model Config") |
| | print("=" * 70) |
| |
|
| | config_path = "/mnt/models/deepseek-v3.2-nvfp4/inference/config_671B_nvfp4.json" |
| |
|
| | print(f" Loading config from: {config_path}") |
| | with open(config_path) as f: |
| | config_dict = json.load(f) |
| |
|
| | args = ModelArgs(**config_dict) |
| |
|
| | print(f" Model parameters:") |
| | print(f" - vocab_size: {args.vocab_size:,}") |
| | print(f" - dim: {args.dim}") |
| | print(f" - n_layers: {args.n_layers}") |
| | print(f" - n_routed_experts: {args.n_routed_experts}") |
| | print(f" - dtype: {args.dtype}") |
| |
|
| | assert args.dtype == "nvfp4", f"Expected dtype='nvfp4', got '{args.dtype}'" |
| | assert args.n_layers == 61, f"Expected 61 layers, got {args.n_layers}" |
| |
|
| | print(f" PASS: Config loaded successfully") |
| | print(f" PASS: Test 1 PASSED\n") |
| |
|
| | return args |
| |
|
| |
|
| | def test_model_creation(args): |
| | """Test 2: Create model instance.""" |
| | print("=" * 70) |
| | print("Test 2: Create Model Instance") |
| | print("=" * 70) |
| |
|
| | print(f" Creating Transformer model with dtype={args.dtype}...") |
| | print(f" (This may take 1-2 minutes)") |
| |
|
| | torch.set_default_dtype(torch.bfloat16) |
| |
|
| | with torch.device("cpu"): |
| | model = Transformer(args) |
| |
|
| | total_params = sum(p.numel() for p in model.parameters()) |
| | total_buffers = sum(b.numel() for b in model.buffers()) |
| |
|
| | print(f" Model created:") |
| | print(f" - Parameters: {total_params / 1e9:.2f}B") |
| | print(f" - Buffers: {total_buffers / 1e9:.2f}B") |
| | print(f" - Total: {(total_params + total_buffers) / 1e9:.2f}B elements") |
| |
|
| | |
| | assert hasattr(model, 'embed'), "Model missing embed layer" |
| | assert hasattr(model, 'layers'), "Model missing layers" |
| | assert len(model.layers) == args.n_layers, f"Expected {args.n_layers} layers, got {len(model.layers)}" |
| |
|
| | print(f" PASS: Model structure correct") |
| | print(f" PASS: Test 2 PASSED\n") |
| |
|
| | return model |
| |
|
| |
|
| | def test_weight_loading(model): |
| | """Test 3: Load weights from sharded checkpoint.""" |
| | print("=" * 70) |
| | print("Test 3: Load Weights from Checkpoint") |
| | print("=" * 70) |
| |
|
| | ckpt_path = "/mnt/models/deepseek-v3.2-nvfp4" |
| |
|
| | print(f" Loading from: {ckpt_path}") |
| | print(f" (This will take 5-15 minutes for the full model)") |
| | print(f" Progress will be shown shard-by-shard...\n") |
| |
|
| | load_sharded_model(model, ckpt_path) |
| |
|
| | print(f"\n PASS: Weights loaded successfully") |
| | print(f" PASS: Test 3 PASSED\n") |
| |
|
| | return model |
| |
|
| |
|
| | def test_nvfp4_layers(model): |
| | """Test 4: Verify NVFP4 layers have correct structure.""" |
| | print("=" * 70) |
| | print("Test 4: Verify NVFP4 Layer Structure") |
| | print("=" * 70) |
| |
|
| | nvfp4_layers = [] |
| | total_layers = 0 |
| |
|
| | for name, module in model.named_modules(): |
| | |
| | if hasattr(module, '_nvfp4_mode') and hasattr(module, 'weight'): |
| | total_layers += 1 |
| | if getattr(module, '_nvfp4_mode', False): |
| | nvfp4_layers.append((name, module)) |
| |
|
| | print(f" Found {len(nvfp4_layers)} NVFP4 layers out of {total_layers} total linear layers") |
| |
|
| | if len(nvfp4_layers) == 0: |
| | print(f" WARN: WARNING: No NVFP4 layers found!") |
| | print(f" This might indicate dtype configuration issue") |
| | return |
| |
|
| | |
| | print(f"\n Inspecting first 5 NVFP4 layers:") |
| | for i, (name, module) in enumerate(nvfp4_layers[:5]): |
| | weight = module.weight |
| | weight_scale = module.weight_scale if hasattr(module, 'weight_scale') else None |
| | weight_scale_2 = module.weight_scale_2 if hasattr(module, 'weight_scale_2') else None |
| |
|
| | print(f"\n [{i+1}] {name}:") |
| | print(f" weight: {weight.shape}, dtype={weight.dtype}") |
| |
|
| | |
| | N, K_half = weight.shape |
| | K = K_half * 2 |
| |
|
| | if weight_scale is not None: |
| | print(f" weight_scale: {weight_scale.shape}, dtype={weight_scale.dtype}") |
| | expected_scale_shape = (N, K // 16) |
| | if weight_scale.shape != expected_scale_shape: |
| | print(f" WARN: WARNING: Expected scale shape {expected_scale_shape}, got {weight_scale.shape}") |
| | else: |
| | print(f" PASS: Scale shape correct") |
| | else: |
| | print(f" WARN: WARNING: weight_scale not found!") |
| |
|
| | if weight_scale_2 is not None: |
| | print(f" weight_scale_2: {weight_scale_2.shape}, dtype={weight_scale_2.dtype}, value={weight_scale_2.item():.6e}") |
| | if weight_scale_2.shape != torch.Size([1]): |
| | print(f" WARN: WARNING: Expected scale_2 shape [1], got {weight_scale_2.shape}") |
| | else: |
| | print(f" PASS: Scale_2 shape correct") |
| | else: |
| | print(f" WARN: WARNING: weight_scale_2 not found!") |
| |
|
| | |
| | assert weight.dtype == torch.uint8, f"Weight should be uint8, got {weight.dtype}" |
| |
|
| | print(f"\n PASS: NVFP4 layers have correct structure") |
| | print(f" PASS: Test 4 PASSED\n") |
| |
|
| |
|
| | def test_weight_statistics(model): |
| | """Test 5: Check weight statistics to verify they're not zeros or corrupted.""" |
| | print("=" * 70) |
| | print("Test 5: Weight Statistics") |
| | print("=" * 70) |
| |
|
| | |
| | nvfp4_count = 0 |
| | zero_count = 0 |
| | checked = 0 |
| |
|
| | for name, module in model.named_modules(): |
| | if hasattr(module, '_nvfp4_mode') and getattr(module, '_nvfp4_mode', False): |
| | nvfp4_count += 1 |
| |
|
| | |
| | if checked < 10: |
| | weight = module.weight |
| | weight_scale = module.weight_scale if hasattr(module, 'weight_scale') else None |
| | weight_scale_2 = module.weight_scale_2 if hasattr(module, 'weight_scale_2') else None |
| |
|
| | |
| | num_zeros = (weight == 0).sum().item() |
| | total_elems = weight.numel() |
| | zero_percent = 100.0 * num_zeros / total_elems |
| |
|
| | if checked == 0: |
| | print(f"\n Sample layer: {name}") |
| | print(f" Weight zeros: {zero_percent:.1f}%") |
| | if weight_scale is not None: |
| | scale_min = weight_scale.to(torch.float32).min().item() |
| | scale_max = weight_scale.to(torch.float32).max().item() |
| | print(f" Scale range: [{scale_min:.6e}, {scale_max:.6e}]") |
| | if weight_scale_2 is not None: |
| | print(f" Scale_2: {weight_scale_2.item():.6e}") |
| |
|
| | |
| | if zero_percent > 95: |
| | zero_count += 1 |
| | print(f" WARN: WARNING: {name} has {zero_percent:.1f}% zeros (possibly corrupted)") |
| |
|
| | checked += 1 |
| |
|
| | print(f"\n Checked {checked} NVFP4 layers:") |
| | print(f" - Total NVFP4 layers: {nvfp4_count}") |
| | print(f" - Layers with >95% zeros: {zero_count}") |
| |
|
| | if zero_count > checked // 2: |
| | print(f" WARN: WARNING: Many layers appear corrupted (too many zeros)") |
| | else: |
| | print(f" PASS: Weight statistics look reasonable") |
| |
|
| | print(f" PASS: Test 5 PASSED\n") |
| |
|
| |
|
| | def main(): |
| | """Run all model loading tests.""" |
| | print("\n" + "=" * 70) |
| | print("NVFP4 Model Loading Integration Test") |
| | print("=" * 70) |
| | print("This test will load the full 671B parameter model") |
| | print("Expected runtime: 5-20 minutes") |
| | print("Memory required: ~400GB") |
| | print("=" * 70 + "\n") |
| |
|
| | |
| | available_gb = check_memory() |
| | if available_gb is not None and available_gb < 350: |
| | print(f"WARN: WARNING: Only {available_gb:.1f}GB available") |
| | print(f" Model may not fit in memory. Consider clearing cache.") |
| | user_input = input(" Continue anyway? (y/n): ") |
| | if user_input.lower() != 'y': |
| | print(" Aborted by user") |
| | return 1 |
| |
|
| | |
| | user_input = input("Clear system cache before loading? (recommended) (y/n): ") |
| | if user_input.lower() == 'y': |
| | clear_cache() |
| | check_memory() |
| |
|
| | try: |
| | |
| | args = test_config_loading() |
| | model = test_model_creation(args) |
| | model = test_weight_loading(model) |
| | test_nvfp4_layers(model) |
| | test_weight_statistics(model) |
| |
|
| | |
| | print("=" * 70) |
| | print("PASS: ALL TESTS PASSED") |
| | print("=" * 70) |
| | print("Model loaded successfully with correct NVFP4 structure!") |
| | print("Ready for forward pass testing.") |
| | print("=" * 70) |
| |
|
| | |
| | print("\nKeeping model in memory for next test...") |
| | print("Run test_forward_pass.py in the same Python session to reuse loaded model") |
| |
|
| | return 0 |
| |
|
| | except Exception as e: |
| | print(f"\nFAIL: TEST FAILED: {e}") |
| | import traceback |
| | traceback.print_exc() |
| | return 1 |
| |
|
| |
|
| | if __name__ == "__main__": |
| | sys.exit(main()) |
| |
|