Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| """ | |
| Test script for CPU-optimized training | |
| """ | |
| import torch | |
| import os | |
| import sys | |
| import time | |
| def test_cpu_optimizations(): | |
| """Test CPU optimizations""" | |
| print("=== Testing CPU Optimizations ===") | |
| # Test PyTorch thread settings | |
| print(f"CPU cores: {os.cpu_count()}") | |
| print(f"PyTorch threads: {torch.get_num_threads()}") | |
| print(f"PyTorch interop threads: {torch.get_num_interop_threads()}") | |
| # Test environment variables | |
| print(f"OMP_NUM_THREADS: {os.environ.get('OMP_NUM_THREADS', 'Not set')}") | |
| print(f"MKL_NUM_THREADS: {os.environ.get('MKL_NUM_THREADS', 'Not set')}") | |
| print(f"NUMEXPR_NUM_THREADS: {os.environ.get('NUMEXPR_NUM_THREADS', 'Not set')}") | |
| # Test basic tensor operations | |
| print("\n=== Testing Tensor Operations ===") | |
| # Create test tensors | |
| x = torch.randn(1000, 1000) | |
| y = torch.randn(1000, 1000) | |
| # Test matrix multiplication with simple timing | |
| start_time = time.time() | |
| for _ in range(10): | |
| z = torch.mm(x, y) | |
| end_time = time.time() | |
| avg_time = (end_time - start_time) / 10 | |
| print(f"Matrix multiplication (1000x1000): {avg_time:.4f}s average over 10 runs") | |
| # Test model compilation if available | |
| if hasattr(torch, 'compile'): | |
| print("\n=== Testing Model Compilation ===") | |
| # Create a simple model | |
| model = torch.nn.Sequential( | |
| torch.nn.Linear(100, 200), | |
| torch.nn.ReLU(), | |
| torch.nn.Linear(200, 100) | |
| ) | |
| try: | |
| compiled_model = torch.compile(model, mode="max-autotune") | |
| print("✓ Model compilation successful") | |
| # Test compiled model | |
| test_input = torch.randn(32, 100) | |
| with torch.no_grad(): | |
| output = compiled_model(test_input) | |
| print(f"✓ Compiled model forward pass successful, output shape: {output.shape}") | |
| except Exception as e: | |
| print(f"⚠ Model compilation failed: {e}") | |
| else: | |
| print("⚠ torch.compile not available") | |
| print("\n=== CPU Optimization Test Complete ===") | |
| return True | |
| if __name__ == '__main__': | |
| test_cpu_optimizations() | |