Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +37 -0
- ml-stable-diffusion/mlx/.circleci/config.yml +579 -0
- ml-stable-diffusion/mlx/.clang-format +87 -0
- ml-stable-diffusion/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
- ml-stable-diffusion/mlx/.github/pull_request_template.md +12 -0
- ml-stable-diffusion/mlx/.github/workflows/pull_request.yml +20 -0
- ml-stable-diffusion/mlx/.gitignore +88 -0
- ml-stable-diffusion/mlx/.pre-commit-config.yaml +21 -0
- ml-stable-diffusion/mlx/ACKNOWLEDGMENTS.md +268 -0
- ml-stable-diffusion/mlx/CITATION.cff +24 -0
- ml-stable-diffusion/mlx/CMakeLists.txt +353 -0
- ml-stable-diffusion/mlx/CODE_OF_CONDUCT.md +132 -0
- ml-stable-diffusion/mlx/CONTRIBUTING.md +38 -0
- ml-stable-diffusion/mlx/LICENSE +21 -0
- ml-stable-diffusion/mlx/MANIFEST.in +6 -0
- ml-stable-diffusion/mlx/README.md +121 -0
- ml-stable-diffusion/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
- ml-stable-diffusion/mlx/benchmarks/cpp/autograd.cpp +39 -0
- ml-stable-diffusion/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
- ml-stable-diffusion/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
- ml-stable-diffusion/mlx/benchmarks/cpp/single_ops.cpp +288 -0
- ml-stable-diffusion/mlx/benchmarks/cpp/time_utils.h +39 -0
- ml-stable-diffusion/mlx/benchmarks/numpy/single_ops.py +39 -0
- ml-stable-diffusion/mlx/benchmarks/numpy/time_utils.py +20 -0
- ml-stable-diffusion/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
- ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
- ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemv.py +221 -0
- ml-stable-diffusion/mlx/benchmarks/python/comparative/README.md +15 -0
- ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
- ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
- ml-stable-diffusion/mlx/benchmarks/python/comparative/compare.py +284 -0
- ml-stable-diffusion/mlx/benchmarks/python/compile_bench.py +107 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv1d_bench.py +123 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv_bench.py +135 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
- ml-stable-diffusion/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
- ml-stable-diffusion/mlx/benchmarks/python/distributed_bench.py +66 -0
- ml-stable-diffusion/mlx/benchmarks/python/einsum_bench.py +84 -0
- ml-stable-diffusion/mlx/benchmarks/python/fft_bench.py +118 -0
- ml-stable-diffusion/mlx/benchmarks/python/gather_bench.py +52 -0
- ml-stable-diffusion/mlx/benchmarks/python/gather_mm_bench.py +74 -0
- ml-stable-diffusion/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
- ml-stable-diffusion/mlx/benchmarks/python/hadamard_bench.py +70 -0
- ml-stable-diffusion/mlx/benchmarks/python/layer_norm_bench.py +82 -0
.gitattributes
CHANGED
|
@@ -120,3 +120,40 @@ ml-stable-diffusion/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_
|
|
| 120 |
ml-stable-diffusion/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_float16_original.png filter=lfs diff=lfs merge=lfs -text
|
| 121 |
ml-stable-diffusion/assets/palette6_cpuandne_readmereel.png filter=lfs diff=lfs merge=lfs -text
|
| 122 |
ml-stable-diffusion/assets/readme_reel.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
ml-stable-diffusion/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_float16_original.png filter=lfs diff=lfs merge=lfs -text
|
| 121 |
ml-stable-diffusion/assets/palette6_cpuandne_readmereel.png filter=lfs diff=lfs merge=lfs -text
|
| 122 |
ml-stable-diffusion/assets/readme_reel.png filter=lfs diff=lfs merge=lfs -text
|
| 123 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/arg_reduce.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 124 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/binary.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/compiled.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 126 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/compiled_preamble.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 127 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/conv.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 128 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/copy.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 129 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/fft.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 130 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/indexing.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 131 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/jit_compiler.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 132 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/masked_mm.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 133 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/matmul.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 134 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/primitives.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 135 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/quantized.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 136 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/reduce.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 137 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/scan.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 138 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/select.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/sort.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/unary.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/no_gpu/primitives.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/compile.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/distributed/mpi/mpi.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 144 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/distributed/ring/ring.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 145 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/einsum.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 146 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/export.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 147 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/fast.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 148 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/io/gguf.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 149 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/io/load.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 150 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/io/safetensors.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 151 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/linalg.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 152 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/primitives.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 153 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/random.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 154 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/transforms.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 155 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/python/src/CMakeFiles/nanobind-static.dir/proj/cvl/users/x_fahkh2/caches/pip-build-env-nyl54h73/overlay/lib/python3.10/site-packages/nanobind/src/nb_type.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 156 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/python/src/CMakeFiles/nanobind-static.dir/proj/cvl/users/x_fahkh2/caches/pip-build-env-vont0ixn/overlay/lib/python3.10/site-packages/nanobind/src/nb_type.cpp.o filter=lfs diff=lfs merge=lfs -text
|
| 157 |
+
ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/python/src/libnanobind-static.a filter=lfs diff=lfs merge=lfs -text
|
| 158 |
+
ml-stable-diffusion/mlx/docs/src/_static/metal_debugger/capture.png filter=lfs diff=lfs merge=lfs -text
|
| 159 |
+
ml-stable-diffusion/mlx/docs/src/_static/metal_debugger/schema.png filter=lfs diff=lfs merge=lfs -text
|
ml-stable-diffusion/mlx/.circleci/config.yml
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version: 2.1
|
| 2 |
+
|
| 3 |
+
orbs:
|
| 4 |
+
apple: ml-explore/pr-approval@0.1.0
|
| 5 |
+
|
| 6 |
+
parameters:
|
| 7 |
+
nightly_build:
|
| 8 |
+
type: boolean
|
| 9 |
+
default: false
|
| 10 |
+
test_release:
|
| 11 |
+
type: boolean
|
| 12 |
+
default: false
|
| 13 |
+
|
| 14 |
+
jobs:
|
| 15 |
+
build_documentation:
|
| 16 |
+
parameters:
|
| 17 |
+
upload-docs:
|
| 18 |
+
type: boolean
|
| 19 |
+
default: false
|
| 20 |
+
macos:
|
| 21 |
+
xcode: "26.0.0"
|
| 22 |
+
resource_class: m4pro.medium
|
| 23 |
+
steps:
|
| 24 |
+
- checkout
|
| 25 |
+
- run:
|
| 26 |
+
name: Install
|
| 27 |
+
command: |
|
| 28 |
+
xcodebuild -downloadComponent MetalToolchain
|
| 29 |
+
brew install python@3.9
|
| 30 |
+
brew install doxygen
|
| 31 |
+
python3.9 -m venv env
|
| 32 |
+
source env/bin/activate
|
| 33 |
+
pip install --upgrade pip
|
| 34 |
+
pip install --upgrade cmake
|
| 35 |
+
pip install -r docs/requirements.txt
|
| 36 |
+
pip install . -v
|
| 37 |
+
- when:
|
| 38 |
+
condition:
|
| 39 |
+
not: << parameters.upload-docs >>
|
| 40 |
+
steps:
|
| 41 |
+
- run:
|
| 42 |
+
name: Build documentation
|
| 43 |
+
command: |
|
| 44 |
+
source env/bin/activate
|
| 45 |
+
cd docs && doxygen && make html O=-W
|
| 46 |
+
- when:
|
| 47 |
+
condition: << parameters.upload-docs >>
|
| 48 |
+
steps:
|
| 49 |
+
- add_ssh_keys:
|
| 50 |
+
fingerprints:
|
| 51 |
+
- "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
|
| 52 |
+
- run:
|
| 53 |
+
name: Upload documentation
|
| 54 |
+
command: |
|
| 55 |
+
source env/bin/activate
|
| 56 |
+
git config user.email "mlx@group.apple.com"
|
| 57 |
+
git config user.name "CircleCI Docs"
|
| 58 |
+
git checkout gh-pages
|
| 59 |
+
git rebase main
|
| 60 |
+
cd docs
|
| 61 |
+
git rm -rf build/html
|
| 62 |
+
doxygen && make html O=-W
|
| 63 |
+
git add -f build/html
|
| 64 |
+
git commit -m "rebase"
|
| 65 |
+
git push -f origin gh-pages
|
| 66 |
+
|
| 67 |
+
linux_build_and_test:
|
| 68 |
+
machine:
|
| 69 |
+
image: ubuntu-2204:current
|
| 70 |
+
resource_class: large
|
| 71 |
+
steps:
|
| 72 |
+
- checkout
|
| 73 |
+
- run:
|
| 74 |
+
name: Run style checks
|
| 75 |
+
command: |
|
| 76 |
+
pip install pre-commit
|
| 77 |
+
pre-commit run --all
|
| 78 |
+
if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
|
| 79 |
+
- run:
|
| 80 |
+
name: Install dependencies
|
| 81 |
+
command: |
|
| 82 |
+
export DEBIAN_FRONTEND=noninteractive
|
| 83 |
+
export NEEDRESTART_MODE=a
|
| 84 |
+
sudo apt-get update
|
| 85 |
+
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
| 86 |
+
sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
|
| 87 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 88 |
+
- run:
|
| 89 |
+
name: Install Python package
|
| 90 |
+
command: |
|
| 91 |
+
uv venv
|
| 92 |
+
uv pip install cmake
|
| 93 |
+
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
| 94 |
+
uv pip install -e ".[dev]" -v
|
| 95 |
+
- run:
|
| 96 |
+
name: Generate package stubs
|
| 97 |
+
command: |
|
| 98 |
+
uv pip install typing_extensions
|
| 99 |
+
uv run --no-project setup.py generate_stubs
|
| 100 |
+
- run:
|
| 101 |
+
name: Run Python tests
|
| 102 |
+
command: |
|
| 103 |
+
source .venv/bin/activate
|
| 104 |
+
python -m unittest discover python/tests -v
|
| 105 |
+
mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
|
| 106 |
+
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
| 107 |
+
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
| 108 |
+
- run:
|
| 109 |
+
name: Build CPP only
|
| 110 |
+
command: |
|
| 111 |
+
source .venv/bin/activate
|
| 112 |
+
mkdir -p build && cd build
|
| 113 |
+
cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
|
| 114 |
+
make -j `nproc`
|
| 115 |
+
- run:
|
| 116 |
+
name: Run CPP tests
|
| 117 |
+
command: ./build/tests/tests
|
| 118 |
+
|
| 119 |
+
mac_build_and_test:
|
| 120 |
+
parameters:
|
| 121 |
+
xcode_version:
|
| 122 |
+
type: string
|
| 123 |
+
default: "26.0.0"
|
| 124 |
+
macosx_deployment_target:
|
| 125 |
+
type: string
|
| 126 |
+
default: ""
|
| 127 |
+
macos:
|
| 128 |
+
xcode: << parameters.xcode_version >>
|
| 129 |
+
environment:
|
| 130 |
+
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
| 131 |
+
resource_class: m4pro.medium
|
| 132 |
+
steps:
|
| 133 |
+
- checkout
|
| 134 |
+
- run:
|
| 135 |
+
name: Install dependencies
|
| 136 |
+
command: |
|
| 137 |
+
xcodebuild -downloadComponent MetalToolchain
|
| 138 |
+
HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
|
| 139 |
+
brew install openmpi uv
|
| 140 |
+
- run:
|
| 141 |
+
name: Install Python package
|
| 142 |
+
command: |
|
| 143 |
+
uv venv --python 3.9
|
| 144 |
+
uv pip install \
|
| 145 |
+
nanobind==2.4.0 \
|
| 146 |
+
cmake \
|
| 147 |
+
numpy \
|
| 148 |
+
torch \
|
| 149 |
+
tensorflow \
|
| 150 |
+
unittest-xml-reporting
|
| 151 |
+
DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
|
| 152 |
+
uv pip install -e . -v
|
| 153 |
+
- run:
|
| 154 |
+
name: Generate package stubs
|
| 155 |
+
command: |
|
| 156 |
+
uv pip install typing_extensions
|
| 157 |
+
uv run --no-project setup.py generate_stubs
|
| 158 |
+
- run:
|
| 159 |
+
name: Run Python tests
|
| 160 |
+
command: |
|
| 161 |
+
source .venv/bin/activate
|
| 162 |
+
LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
|
| 163 |
+
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
|
| 164 |
+
mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
|
| 165 |
+
mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
|
| 166 |
+
if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
|
| 167 |
+
- run:
|
| 168 |
+
name: Build example extension
|
| 169 |
+
command: |
|
| 170 |
+
source .venv/bin/activate
|
| 171 |
+
cd examples/extensions
|
| 172 |
+
uv pip install -r requirements.txt
|
| 173 |
+
uv run --no-project setup.py build_ext --inplace
|
| 174 |
+
uv run --no-project python test.py
|
| 175 |
+
- store_test_results:
|
| 176 |
+
path: test-results
|
| 177 |
+
- run:
|
| 178 |
+
name: Build CPP only
|
| 179 |
+
command: |
|
| 180 |
+
source .venv/bin/activate
|
| 181 |
+
mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
|
| 182 |
+
- run:
|
| 183 |
+
name: Run CPP tests
|
| 184 |
+
command: |
|
| 185 |
+
DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
|
| 186 |
+
- run:
|
| 187 |
+
name: Build small binary
|
| 188 |
+
command: |
|
| 189 |
+
source .venv/bin/activate
|
| 190 |
+
cd build/
|
| 191 |
+
cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
|
| 192 |
+
-DBUILD_SHARED_LIBS=ON \
|
| 193 |
+
-DMLX_BUILD_CPU=OFF \
|
| 194 |
+
-DMLX_BUILD_SAFETENSORS=OFF \
|
| 195 |
+
-DMLX_BUILD_GGUF=OFF \
|
| 196 |
+
-DMLX_METAL_JIT=ON
|
| 197 |
+
make -j `sysctl -n hw.ncpu`
|
| 198 |
+
- run:
|
| 199 |
+
name: Run Python tests with JIT
|
| 200 |
+
command: |
|
| 201 |
+
CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
|
| 202 |
+
uv pip install -e . -v
|
| 203 |
+
LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
|
| 204 |
+
METAL_DEBUG_ERROR_MODE=0 \
|
| 205 |
+
uv run --no-project python -m xmlrunner discover \
|
| 206 |
+
-v python/tests \
|
| 207 |
+
-o test-results/gpu_jit
|
| 208 |
+
|
| 209 |
+
cuda_build_and_test:
|
| 210 |
+
parameters:
|
| 211 |
+
image_date:
|
| 212 |
+
type: string
|
| 213 |
+
default: "2023.11.1"
|
| 214 |
+
machine:
|
| 215 |
+
image: "linux-cuda-12:<< parameters.image_date >>"
|
| 216 |
+
resource_class: gpu.nvidia.small.gen2
|
| 217 |
+
steps:
|
| 218 |
+
- checkout
|
| 219 |
+
- restore_cache:
|
| 220 |
+
keys:
|
| 221 |
+
- cuda-<< parameters.image_date >>-{{ arch }}-
|
| 222 |
+
- run:
|
| 223 |
+
name: Install dependencies
|
| 224 |
+
command: |
|
| 225 |
+
sudo apt-get update
|
| 226 |
+
sudo apt-get install libcudnn9-dev-cuda-12
|
| 227 |
+
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
| 228 |
+
sudo apt-get install libnccl2 libnccl-dev
|
| 229 |
+
curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
|
| 230 |
+
sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
|
| 231 |
+
rm -rf ccache-4.11.3-linux-x86_64
|
| 232 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 233 |
+
- run:
|
| 234 |
+
name: Set CCache size
|
| 235 |
+
command: ccache --max-size 1G
|
| 236 |
+
- run:
|
| 237 |
+
name: Install Python package
|
| 238 |
+
command: |
|
| 239 |
+
uv venv
|
| 240 |
+
uv pip install cmake
|
| 241 |
+
DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
| 242 |
+
uv pip install -e ".[dev]" -v
|
| 243 |
+
- run:
|
| 244 |
+
name: Run Python tests
|
| 245 |
+
command: |
|
| 246 |
+
source .venv/bin/activate
|
| 247 |
+
LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
|
| 248 |
+
LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
|
| 249 |
+
- run:
|
| 250 |
+
name: Build CPP only
|
| 251 |
+
command: |
|
| 252 |
+
source .venv/bin/activate
|
| 253 |
+
cmake . -B build \
|
| 254 |
+
-DMLX_BUILD_CUDA=ON \
|
| 255 |
+
-DCMAKE_CUDA_COMPILER=`which nvcc` \
|
| 256 |
+
-DCMAKE_BUILD_TYPE=DEBUG
|
| 257 |
+
cmake --build build -j `nproc`
|
| 258 |
+
- run:
|
| 259 |
+
name: Run CPP tests
|
| 260 |
+
command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
|
| 261 |
+
- run:
|
| 262 |
+
name: CCache report
|
| 263 |
+
command: |
|
| 264 |
+
ccache --show-stats
|
| 265 |
+
ccache --zero-stats
|
| 266 |
+
ccache --cleanup
|
| 267 |
+
- save_cache:
|
| 268 |
+
key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
|
| 269 |
+
paths:
|
| 270 |
+
- /home/circleci/.cache/ccache
|
| 271 |
+
|
| 272 |
+
build_release:
|
| 273 |
+
parameters:
|
| 274 |
+
python_version:
|
| 275 |
+
type: string
|
| 276 |
+
default: "3.9"
|
| 277 |
+
xcode_version:
|
| 278 |
+
type: string
|
| 279 |
+
default: "26.0.0"
|
| 280 |
+
build_env:
|
| 281 |
+
type: string
|
| 282 |
+
default: ""
|
| 283 |
+
macosx_deployment_target:
|
| 284 |
+
type: string
|
| 285 |
+
default: ""
|
| 286 |
+
macos:
|
| 287 |
+
xcode: << parameters.xcode_version >>
|
| 288 |
+
resource_class: m4pro.medium
|
| 289 |
+
environment:
|
| 290 |
+
MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
|
| 291 |
+
steps:
|
| 292 |
+
- checkout
|
| 293 |
+
- run:
|
| 294 |
+
name: Install dependencies
|
| 295 |
+
command: |
|
| 296 |
+
xcodebuild -downloadComponent MetalToolchain
|
| 297 |
+
mkdir -p ~/miniconda3
|
| 298 |
+
curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
|
| 299 |
+
bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
|
| 300 |
+
rm ~/miniconda3/miniconda.sh
|
| 301 |
+
source ~/miniconda3/bin/activate
|
| 302 |
+
conda init --all
|
| 303 |
+
conda create -n env python=<< parameters.python_version >> -y
|
| 304 |
+
conda activate env
|
| 305 |
+
pip install --upgrade cmake
|
| 306 |
+
pip install nanobind==2.4.0
|
| 307 |
+
pip install --upgrade setuptools
|
| 308 |
+
pip install numpy
|
| 309 |
+
pip install twine
|
| 310 |
+
pip install build
|
| 311 |
+
- run:
|
| 312 |
+
name: Install Python package
|
| 313 |
+
command: |
|
| 314 |
+
conda activate env
|
| 315 |
+
env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
|
| 316 |
+
pip install . -v
|
| 317 |
+
- run:
|
| 318 |
+
name: Generate package stubs
|
| 319 |
+
command: |
|
| 320 |
+
conda activate env
|
| 321 |
+
pip install typing_extensions
|
| 322 |
+
python setup.py generate_stubs
|
| 323 |
+
- run:
|
| 324 |
+
name: Build Python package
|
| 325 |
+
command: |
|
| 326 |
+
conda activate env
|
| 327 |
+
python setup.py clean --all
|
| 328 |
+
<< parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
|
| 329 |
+
- when:
|
| 330 |
+
condition:
|
| 331 |
+
equal: ["3.9", << parameters.python_version >>]
|
| 332 |
+
steps:
|
| 333 |
+
- run:
|
| 334 |
+
name: Build common package
|
| 335 |
+
command: |
|
| 336 |
+
conda activate env
|
| 337 |
+
python setup.py clean --all
|
| 338 |
+
<< parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
|
| 339 |
+
- when:
|
| 340 |
+
condition: << parameters.build_env >>
|
| 341 |
+
steps:
|
| 342 |
+
- run:
|
| 343 |
+
name: Upload package
|
| 344 |
+
command: |
|
| 345 |
+
conda activate env
|
| 346 |
+
twine upload dist/*
|
| 347 |
+
- store_artifacts:
|
| 348 |
+
path: dist/
|
| 349 |
+
|
| 350 |
+
build_linux_release:
|
| 351 |
+
parameters:
|
| 352 |
+
python_version:
|
| 353 |
+
type: string
|
| 354 |
+
default: "3.9"
|
| 355 |
+
build_env:
|
| 356 |
+
type: string
|
| 357 |
+
default: ""
|
| 358 |
+
machine:
|
| 359 |
+
image: ubuntu-2204:current
|
| 360 |
+
resource_class: large
|
| 361 |
+
steps:
|
| 362 |
+
- checkout
|
| 363 |
+
- run:
|
| 364 |
+
name: Build wheel
|
| 365 |
+
command: |
|
| 366 |
+
PYTHON=python<< parameters.python_version >>
|
| 367 |
+
export DEBIAN_FRONTEND=noninteractive
|
| 368 |
+
export NEEDRESTART_MODE=a
|
| 369 |
+
sudo apt-get update
|
| 370 |
+
TZ=Etc/UTC sudo apt-get -y install tzdata
|
| 371 |
+
sudo add-apt-repository -y ppa:deadsnakes/ppa
|
| 372 |
+
sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
|
| 373 |
+
sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
|
| 374 |
+
$PYTHON -m venv env
|
| 375 |
+
source env/bin/activate
|
| 376 |
+
pip install --upgrade pip
|
| 377 |
+
pip install --upgrade cmake
|
| 378 |
+
pip install auditwheel
|
| 379 |
+
pip install patchelf
|
| 380 |
+
pip install build
|
| 381 |
+
pip install twine
|
| 382 |
+
<< parameters.build_env >> pip install ".[dev]" -v
|
| 383 |
+
pip install typing_extensions
|
| 384 |
+
python setup.py generate_stubs
|
| 385 |
+
python setup.py clean --all
|
| 386 |
+
MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
|
| 387 |
+
bash python/scripts/repair_linux.sh
|
| 388 |
+
- when:
|
| 389 |
+
condition:
|
| 390 |
+
equal: ["3.9", << parameters.python_version >>]
|
| 391 |
+
steps:
|
| 392 |
+
- run:
|
| 393 |
+
name: Build common package
|
| 394 |
+
command: |
|
| 395 |
+
source env/bin/activate
|
| 396 |
+
python setup.py clean --all
|
| 397 |
+
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
| 398 |
+
python -m build -w
|
| 399 |
+
auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
|
| 400 |
+
- when:
|
| 401 |
+
condition: << parameters.build_env >>
|
| 402 |
+
steps:
|
| 403 |
+
- run:
|
| 404 |
+
name: Upload packages
|
| 405 |
+
command: |
|
| 406 |
+
source env/bin/activate
|
| 407 |
+
twine upload wheelhouse/*.whl
|
| 408 |
+
- store_artifacts:
|
| 409 |
+
path: wheelhouse/
|
| 410 |
+
|
| 411 |
+
build_cuda_release:
|
| 412 |
+
parameters:
|
| 413 |
+
build_env:
|
| 414 |
+
type: string
|
| 415 |
+
default: ""
|
| 416 |
+
machine:
|
| 417 |
+
image: ubuntu-2204:current
|
| 418 |
+
resource_class: xlarge
|
| 419 |
+
steps:
|
| 420 |
+
- checkout
|
| 421 |
+
- run:
|
| 422 |
+
name: Build wheel
|
| 423 |
+
command: |
|
| 424 |
+
export DEBIAN_FRONTEND=noninteractive
|
| 425 |
+
export NEEDRESTART_MODE=a
|
| 426 |
+
wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
|
| 427 |
+
sudo dpkg -i cuda-keyring_1.1-1_all.deb
|
| 428 |
+
sudo apt-get update
|
| 429 |
+
sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
|
| 430 |
+
sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
|
| 431 |
+
sudo apt-get install zip
|
| 432 |
+
pip install auditwheel
|
| 433 |
+
pip install patchelf
|
| 434 |
+
pip install build
|
| 435 |
+
pip install twine
|
| 436 |
+
export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
|
| 437 |
+
export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
|
| 438 |
+
<< parameters.build_env >> MLX_BUILD_STAGE=2 \
|
| 439 |
+
CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
|
| 440 |
+
python -m build -w
|
| 441 |
+
bash python/scripts/repair_cuda.sh
|
| 442 |
+
- when:
|
| 443 |
+
condition: << parameters.build_env >>
|
| 444 |
+
steps:
|
| 445 |
+
- run:
|
| 446 |
+
name: Upload package
|
| 447 |
+
command: |
|
| 448 |
+
twine upload wheelhouse/*.whl
|
| 449 |
+
- store_artifacts:
|
| 450 |
+
path: wheelhouse/
|
| 451 |
+
|
| 452 |
+
workflows:
|
| 453 |
+
build_and_test:
|
| 454 |
+
when:
|
| 455 |
+
and:
|
| 456 |
+
- matches:
|
| 457 |
+
pattern: "^(?!pull/)[-\\w]+$"
|
| 458 |
+
value: << pipeline.git.branch >>
|
| 459 |
+
- not: << pipeline.parameters.nightly_build >>
|
| 460 |
+
- not: << pipeline.parameters.test_release >>
|
| 461 |
+
jobs:
|
| 462 |
+
- mac_build_and_test:
|
| 463 |
+
matrix:
|
| 464 |
+
parameters:
|
| 465 |
+
macosx_deployment_target: ["13.5", "15.0"]
|
| 466 |
+
- linux_build_and_test
|
| 467 |
+
- cuda_build_and_test:
|
| 468 |
+
matrix:
|
| 469 |
+
parameters:
|
| 470 |
+
image_date: ["2023.11.1", "2025.05.1"]
|
| 471 |
+
- build_documentation
|
| 472 |
+
|
| 473 |
+
build_pypi_release:
|
| 474 |
+
when:
|
| 475 |
+
and:
|
| 476 |
+
- not: << pipeline.parameters.nightly_build >>
|
| 477 |
+
- not: << pipeline.parameters.test_release >>
|
| 478 |
+
jobs:
|
| 479 |
+
- build_release:
|
| 480 |
+
filters:
|
| 481 |
+
tags:
|
| 482 |
+
only: /^v.*/
|
| 483 |
+
branches:
|
| 484 |
+
ignore: /.*/
|
| 485 |
+
matrix:
|
| 486 |
+
parameters:
|
| 487 |
+
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
| 488 |
+
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
| 489 |
+
build_env: ["PYPI_RELEASE=1"]
|
| 490 |
+
xcode_version: ["26.0.0"]
|
| 491 |
+
- build_documentation:
|
| 492 |
+
filters:
|
| 493 |
+
tags:
|
| 494 |
+
only: /^v.*/
|
| 495 |
+
branches:
|
| 496 |
+
ignore: /.*/
|
| 497 |
+
upload-docs: true
|
| 498 |
+
- build_linux_release:
|
| 499 |
+
filters:
|
| 500 |
+
tags:
|
| 501 |
+
only: /^v.*/
|
| 502 |
+
branches:
|
| 503 |
+
ignore: /.*/
|
| 504 |
+
matrix:
|
| 505 |
+
parameters:
|
| 506 |
+
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
| 507 |
+
build_env: ["PYPI_RELEASE=1"]
|
| 508 |
+
- build_cuda_release:
|
| 509 |
+
filters:
|
| 510 |
+
tags:
|
| 511 |
+
only: /^v.*/
|
| 512 |
+
branches:
|
| 513 |
+
ignore: /.*/
|
| 514 |
+
matrix:
|
| 515 |
+
parameters:
|
| 516 |
+
build_env: ["PYPI_RELEASE=1"]
|
| 517 |
+
|
| 518 |
+
prb:
|
| 519 |
+
when:
|
| 520 |
+
matches:
|
| 521 |
+
pattern: "^pull/\\d+(/head)?$"
|
| 522 |
+
value: << pipeline.git.branch >>
|
| 523 |
+
jobs:
|
| 524 |
+
- hold:
|
| 525 |
+
type: approval
|
| 526 |
+
- apple/authenticate:
|
| 527 |
+
context: pr-approval
|
| 528 |
+
- mac_build_and_test:
|
| 529 |
+
requires: [ hold ]
|
| 530 |
+
matrix:
|
| 531 |
+
parameters:
|
| 532 |
+
macosx_deployment_target: ["13.5", "15.0"]
|
| 533 |
+
- linux_build_and_test:
|
| 534 |
+
requires: [ hold ]
|
| 535 |
+
- cuda_build_and_test:
|
| 536 |
+
requires: [ hold ]
|
| 537 |
+
matrix:
|
| 538 |
+
parameters:
|
| 539 |
+
image_date: ["2023.11.1", "2025.05.1"]
|
| 540 |
+
nightly_build:
|
| 541 |
+
when:
|
| 542 |
+
and:
|
| 543 |
+
- equal: [ main, << pipeline.git.branch >> ]
|
| 544 |
+
- << pipeline.parameters.nightly_build >>
|
| 545 |
+
jobs:
|
| 546 |
+
- build_release:
|
| 547 |
+
matrix:
|
| 548 |
+
parameters:
|
| 549 |
+
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
| 550 |
+
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
| 551 |
+
xcode_version: ["26.0.0"]
|
| 552 |
+
- build_linux_release:
|
| 553 |
+
matrix:
|
| 554 |
+
parameters:
|
| 555 |
+
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
| 556 |
+
- build_cuda_release
|
| 557 |
+
|
| 558 |
+
build_dev_release:
|
| 559 |
+
when:
|
| 560 |
+
and:
|
| 561 |
+
- equal: [ main, << pipeline.git.branch >> ]
|
| 562 |
+
- << pipeline.parameters.test_release >>
|
| 563 |
+
jobs:
|
| 564 |
+
- build_release:
|
| 565 |
+
matrix:
|
| 566 |
+
parameters:
|
| 567 |
+
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
| 568 |
+
macosx_deployment_target: ["13.5", "14.0", "15.0"]
|
| 569 |
+
build_env: ["DEV_RELEASE=1"]
|
| 570 |
+
xcode_version: ["26.0.0"]
|
| 571 |
+
- build_linux_release:
|
| 572 |
+
matrix:
|
| 573 |
+
parameters:
|
| 574 |
+
python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
|
| 575 |
+
build_env: ["DEV_RELEASE=1"]
|
| 576 |
+
- build_cuda_release:
|
| 577 |
+
matrix:
|
| 578 |
+
parameters:
|
| 579 |
+
build_env: ["DEV_RELEASE=1"]
|
ml-stable-diffusion/mlx/.clang-format
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
AccessModifierOffset: -1
|
| 3 |
+
AlignAfterOpenBracket: AlwaysBreak
|
| 4 |
+
AlignConsecutiveAssignments: false
|
| 5 |
+
AlignConsecutiveDeclarations: false
|
| 6 |
+
AlignEscapedNewlinesLeft: true
|
| 7 |
+
AlignOperands: false
|
| 8 |
+
AlignTrailingComments: false
|
| 9 |
+
AllowAllParametersOfDeclarationOnNextLine: false
|
| 10 |
+
AllowShortBlocksOnASingleLine: false
|
| 11 |
+
AllowShortCaseLabelsOnASingleLine: false
|
| 12 |
+
AllowShortFunctionsOnASingleLine: Empty
|
| 13 |
+
AllowShortIfStatementsOnASingleLine: false
|
| 14 |
+
AllowShortLoopsOnASingleLine: false
|
| 15 |
+
AlwaysBreakAfterReturnType: None
|
| 16 |
+
AlwaysBreakBeforeMultilineStrings: true
|
| 17 |
+
AlwaysBreakTemplateDeclarations: true
|
| 18 |
+
BinPackArguments: false
|
| 19 |
+
BinPackParameters: false
|
| 20 |
+
BraceWrapping:
|
| 21 |
+
AfterClass: false
|
| 22 |
+
AfterControlStatement: false
|
| 23 |
+
AfterEnum: false
|
| 24 |
+
AfterFunction: false
|
| 25 |
+
AfterNamespace: false
|
| 26 |
+
AfterObjCDeclaration: false
|
| 27 |
+
AfterStruct: false
|
| 28 |
+
AfterUnion: false
|
| 29 |
+
BeforeCatch: false
|
| 30 |
+
BeforeElse: false
|
| 31 |
+
IndentBraces: false
|
| 32 |
+
BreakBeforeBinaryOperators: None
|
| 33 |
+
BreakBeforeBraces: Attach
|
| 34 |
+
BreakBeforeTernaryOperators: true
|
| 35 |
+
BreakConstructorInitializersBeforeComma: false
|
| 36 |
+
BreakAfterJavaFieldAnnotations: false
|
| 37 |
+
BreakStringLiterals: false
|
| 38 |
+
ColumnLimit: 80
|
| 39 |
+
CommentPragmas: '^ IWYU pragma:'
|
| 40 |
+
ConstructorInitializerAllOnOneLineOrOnePerLine: true
|
| 41 |
+
ConstructorInitializerIndentWidth: 4
|
| 42 |
+
ContinuationIndentWidth: 4
|
| 43 |
+
Cpp11BracedListStyle: true
|
| 44 |
+
DerivePointerAlignment: false
|
| 45 |
+
DisableFormat: false
|
| 46 |
+
ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
|
| 47 |
+
IncludeCategories:
|
| 48 |
+
- Regex: '^<.*\.h(pp)?>'
|
| 49 |
+
Priority: 1
|
| 50 |
+
- Regex: '^<.*'
|
| 51 |
+
Priority: 2
|
| 52 |
+
- Regex: '.*'
|
| 53 |
+
Priority: 3
|
| 54 |
+
IndentCaseLabels: true
|
| 55 |
+
IndentWidth: 2
|
| 56 |
+
IndentWrappedFunctionNames: false
|
| 57 |
+
KeepEmptyLinesAtTheStartOfBlocks: false
|
| 58 |
+
MacroBlockBegin: ''
|
| 59 |
+
MacroBlockEnd: ''
|
| 60 |
+
MaxEmptyLinesToKeep: 1
|
| 61 |
+
NamespaceIndentation: None
|
| 62 |
+
ObjCBlockIndentWidth: 2
|
| 63 |
+
ObjCSpaceAfterProperty: false
|
| 64 |
+
ObjCSpaceBeforeProtocolList: false
|
| 65 |
+
PenaltyBreakBeforeFirstCallParameter: 1
|
| 66 |
+
PenaltyBreakComment: 300
|
| 67 |
+
PenaltyBreakFirstLessLess: 120
|
| 68 |
+
PenaltyBreakString: 1000
|
| 69 |
+
PenaltyExcessCharacter: 1000000
|
| 70 |
+
PenaltyReturnTypeOnItsOwnLine: 200
|
| 71 |
+
PointerAlignment: Left
|
| 72 |
+
ReflowComments: true
|
| 73 |
+
SortIncludes: true
|
| 74 |
+
SpaceAfterCStyleCast: false
|
| 75 |
+
SpaceBeforeAssignmentOperators: true
|
| 76 |
+
SpaceBeforeParens: ControlStatements
|
| 77 |
+
SpaceInEmptyParentheses: false
|
| 78 |
+
SpacesBeforeTrailingComments: 1
|
| 79 |
+
SpacesInAngles: false
|
| 80 |
+
SpacesInContainerLiterals: true
|
| 81 |
+
SpacesInCStyleCastParentheses: false
|
| 82 |
+
SpacesInParentheses: false
|
| 83 |
+
SpacesInSquareBrackets: false
|
| 84 |
+
Standard: Cpp11
|
| 85 |
+
TabWidth: 8
|
| 86 |
+
UseTab: Never
|
| 87 |
+
...
|
ml-stable-diffusion/mlx/.github/ISSUE_TEMPLATE/bug_report.md
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Bug report
|
| 3 |
+
about: Create a report about an issue you've encountered
|
| 4 |
+
title: "[BUG] "
|
| 5 |
+
labels: ''
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Describe the bug**
|
| 11 |
+
A clear and concise description of what the bug is.
|
| 12 |
+
|
| 13 |
+
**To Reproduce**
|
| 14 |
+
|
| 15 |
+
Include code snippet
|
| 16 |
+
```python
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
**Expected behavior**
|
| 21 |
+
A clear and concise description of what you expected to happen.
|
| 22 |
+
|
| 23 |
+
**Desktop (please complete the following information):**
|
| 24 |
+
- OS Version: [e.g. MacOS 14.1.2]
|
| 25 |
+
- Version [e.g. 0.7.0]
|
| 26 |
+
|
| 27 |
+
**Additional context**
|
| 28 |
+
Add any other context about the problem here.
|
ml-stable-diffusion/mlx/.github/pull_request_template.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Proposed changes
|
| 2 |
+
|
| 3 |
+
Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
|
| 4 |
+
|
| 5 |
+
## Checklist
|
| 6 |
+
|
| 7 |
+
Put an `x` in the boxes that apply.
|
| 8 |
+
|
| 9 |
+
- [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
|
| 10 |
+
- [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
|
| 11 |
+
- [ ] I have added tests that prove my fix is effective or that my feature works
|
| 12 |
+
- [ ] I have updated the necessary documentation (if needed)
|
ml-stable-diffusion/mlx/.github/workflows/pull_request.yml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
on:
|
| 2 |
+
pull_request:
|
| 3 |
+
branches:
|
| 4 |
+
- main
|
| 5 |
+
|
| 6 |
+
jobs:
|
| 7 |
+
check_lint:
|
| 8 |
+
runs-on: ubuntu-latest
|
| 9 |
+
steps:
|
| 10 |
+
- uses: actions/checkout@v4
|
| 11 |
+
- uses: actions/setup-python@v4
|
| 12 |
+
with:
|
| 13 |
+
python-version: 3.8
|
| 14 |
+
- name: Install dependencies
|
| 15 |
+
run: |
|
| 16 |
+
python -m pip install --upgrade pip
|
| 17 |
+
pip install pre-commit black isort clang-format
|
| 18 |
+
- name: Run lint
|
| 19 |
+
run: |
|
| 20 |
+
pre-commit run --all-files
|
ml-stable-diffusion/mlx/.gitignore
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# tensor files
|
| 10 |
+
*.safe
|
| 11 |
+
*.safetensors
|
| 12 |
+
|
| 13 |
+
# Metal libraries
|
| 14 |
+
*.metallib
|
| 15 |
+
venv/
|
| 16 |
+
|
| 17 |
+
# Distribution / packaging
|
| 18 |
+
python/mlx/core
|
| 19 |
+
python/mlx/share
|
| 20 |
+
python/mlx/include
|
| 21 |
+
.Python
|
| 22 |
+
build/
|
| 23 |
+
develop-eggs/
|
| 24 |
+
dist/
|
| 25 |
+
downloads/
|
| 26 |
+
eggs/
|
| 27 |
+
.eggs/
|
| 28 |
+
lib/
|
| 29 |
+
lib64/
|
| 30 |
+
parts/
|
| 31 |
+
sdist/
|
| 32 |
+
var/
|
| 33 |
+
wheels/
|
| 34 |
+
share/python-wheels/
|
| 35 |
+
*.egg-info/
|
| 36 |
+
.installed.cfg
|
| 37 |
+
*.egg
|
| 38 |
+
MANIFEST
|
| 39 |
+
uv.lock
|
| 40 |
+
|
| 41 |
+
# vim
|
| 42 |
+
*.swp
|
| 43 |
+
|
| 44 |
+
# Ignore build dir
|
| 45 |
+
build/
|
| 46 |
+
|
| 47 |
+
# Prerequisites
|
| 48 |
+
*.d
|
| 49 |
+
|
| 50 |
+
# Compiled Object files
|
| 51 |
+
*.slo
|
| 52 |
+
*.lo
|
| 53 |
+
*.o
|
| 54 |
+
*.obj
|
| 55 |
+
|
| 56 |
+
# Precompiled Headers
|
| 57 |
+
*.gch
|
| 58 |
+
*.pch
|
| 59 |
+
|
| 60 |
+
# Compiled Dynamic libraries
|
| 61 |
+
*.so
|
| 62 |
+
*.dylib
|
| 63 |
+
*.dll
|
| 64 |
+
|
| 65 |
+
# Fortran module files
|
| 66 |
+
*.mod
|
| 67 |
+
*.smod
|
| 68 |
+
|
| 69 |
+
# Compiled Static libraries
|
| 70 |
+
*.lai
|
| 71 |
+
*.la
|
| 72 |
+
*.a
|
| 73 |
+
*.lib
|
| 74 |
+
|
| 75 |
+
# Executables
|
| 76 |
+
*.exe
|
| 77 |
+
*.out
|
| 78 |
+
*.app
|
| 79 |
+
|
| 80 |
+
# Debug symbols
|
| 81 |
+
*.pdb
|
| 82 |
+
|
| 83 |
+
# VSCode
|
| 84 |
+
.vscode/
|
| 85 |
+
.DS_Store
|
| 86 |
+
|
| 87 |
+
# Jetbrains
|
| 88 |
+
.cache
|
ml-stable-diffusion/mlx/.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pre-commit/mirrors-clang-format
|
| 3 |
+
rev: v19.1.7
|
| 4 |
+
hooks:
|
| 5 |
+
- id: clang-format
|
| 6 |
+
# Using this mirror lets us use mypyc-compiled black, which is about 2x faster
|
| 7 |
+
- repo: https://github.com/psf/black-pre-commit-mirror
|
| 8 |
+
rev: 25.1.0
|
| 9 |
+
hooks:
|
| 10 |
+
- id: black
|
| 11 |
+
|
| 12 |
+
- repo: https://github.com/pycqa/isort
|
| 13 |
+
rev: 6.0.0
|
| 14 |
+
hooks:
|
| 15 |
+
- id: isort
|
| 16 |
+
args:
|
| 17 |
+
- --profile=black
|
| 18 |
+
- repo: https://github.com/cheshirekow/cmake-format-precommit
|
| 19 |
+
rev: v0.6.13
|
| 20 |
+
hooks:
|
| 21 |
+
- id: cmake-format
|
ml-stable-diffusion/mlx/ACKNOWLEDGMENTS.md
ADDED
|
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Individual Contributors
|
| 2 |
+
|
| 3 |
+
If you wish to be acknowledged for your contributions, please list your name
|
| 4 |
+
with a short description of your contribution(s) below. For example:
|
| 5 |
+
|
| 6 |
+
- Jane Smith: Added the `foo` and `bar` ops.
|
| 7 |
+
|
| 8 |
+
MLX was developed with contributions from the following individuals:
|
| 9 |
+
|
| 10 |
+
- Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
|
| 11 |
+
- Juarez Bochi: Fixed bug in cross attention.
|
| 12 |
+
- Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
|
| 13 |
+
- Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
|
| 14 |
+
- Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
|
| 15 |
+
- Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
|
| 16 |
+
- Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
|
| 17 |
+
- Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
|
| 18 |
+
- AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
|
| 19 |
+
- Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
|
| 20 |
+
- Paul Paczuski: Improved stability of BCE loss calculation
|
| 21 |
+
- Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
|
| 22 |
+
- Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
|
| 23 |
+
|
| 24 |
+
<a href="https://github.com/ml-explore/mlx/graphs/contributors">
|
| 25 |
+
<img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
|
| 26 |
+
</a>
|
| 27 |
+
|
| 28 |
+
# Organizations
|
| 29 |
+
|
| 30 |
+
MLX has received contributions from the following companies:
|
| 31 |
+
- NVIDIA Corporation & Affiliates
|
| 32 |
+
|
| 33 |
+
# Third-Party Software
|
| 34 |
+
|
| 35 |
+
MLX leverages several third-party software, listed here together with
|
| 36 |
+
their license copied verbatim.
|
| 37 |
+
|
| 38 |
+
## PocketFFT
|
| 39 |
+
|
| 40 |
+
Copyright (C) 2010-2018 Max-Planck-Society
|
| 41 |
+
All rights reserved.
|
| 42 |
+
|
| 43 |
+
Redistribution and use in source and binary forms, with or without modification,
|
| 44 |
+
are permitted provided that the following conditions are met:
|
| 45 |
+
|
| 46 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 47 |
+
list of conditions and the following disclaimer.
|
| 48 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
| 49 |
+
list of conditions and the following disclaimer in the documentation and/or
|
| 50 |
+
other materials provided with the distribution.
|
| 51 |
+
* Neither the name of the copyright holder nor the names of its contributors may
|
| 52 |
+
be used to endorse or promote products derived from this software without
|
| 53 |
+
specific prior written permission.
|
| 54 |
+
|
| 55 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 56 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 57 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 58 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
| 59 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 60 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 61 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
| 62 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 63 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 64 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 65 |
+
|
| 66 |
+
## metal-cpp
|
| 67 |
+
|
| 68 |
+
Apache License
|
| 69 |
+
Version 2.0, January 2004
|
| 70 |
+
http://www.apache.org/licenses/
|
| 71 |
+
|
| 72 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 73 |
+
|
| 74 |
+
1. Definitions.
|
| 75 |
+
|
| 76 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 77 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 78 |
+
|
| 79 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 80 |
+
the copyright owner that is granting the License.
|
| 81 |
+
|
| 82 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 83 |
+
other entities that control, are controlled by, or are under common
|
| 84 |
+
control with that entity. For the purposes of this definition,
|
| 85 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 86 |
+
direction or management of such entity, whether by contract or
|
| 87 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 88 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 89 |
+
|
| 90 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 91 |
+
exercising permissions granted by this License.
|
| 92 |
+
|
| 93 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 94 |
+
including but not limited to software source code, documentation
|
| 95 |
+
source, and configuration files.
|
| 96 |
+
|
| 97 |
+
"Object" form shall mean any form resulting from mechanical
|
| 98 |
+
transformation or translation of a Source form, including but
|
| 99 |
+
not limited to compiled object code, generated documentation,
|
| 100 |
+
and conversions to other media types.
|
| 101 |
+
|
| 102 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 103 |
+
Object form, made available under the License, as indicated by a
|
| 104 |
+
copyright notice that is included in or attached to the work
|
| 105 |
+
(an example is provided in the Appendix below).
|
| 106 |
+
|
| 107 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 108 |
+
form, that is based on (or derived from) the Work and for which the
|
| 109 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 110 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 111 |
+
of this License, Derivative Works shall not include works that remain
|
| 112 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 113 |
+
the Work and Derivative Works thereof.
|
| 114 |
+
|
| 115 |
+
"Contribution" shall mean any work of authorship, including
|
| 116 |
+
the original version of the Work and any modifications or additions
|
| 117 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 118 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 119 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 120 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 121 |
+
means any form of electronic, verbal, or written communication sent
|
| 122 |
+
to the Licensor or its representatives, including but not limited to
|
| 123 |
+
communication on electronic mailing lists, source code control systems,
|
| 124 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 125 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 126 |
+
excluding communication that is conspicuously marked or otherwise
|
| 127 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 128 |
+
|
| 129 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 130 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 131 |
+
subsequently incorporated within the Work.
|
| 132 |
+
|
| 133 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 134 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 135 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 136 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 137 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 138 |
+
Work and such Derivative Works in Source or Object form.
|
| 139 |
+
|
| 140 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 141 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 142 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 143 |
+
(except as stated in this section) patent license to make, have made,
|
| 144 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 145 |
+
where such license applies only to those patent claims licensable
|
| 146 |
+
by such Contributor that are necessarily infringed by their
|
| 147 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 148 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 149 |
+
institute patent litigation against any entity (including a
|
| 150 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 151 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 152 |
+
or contributory patent infringement, then any patent licenses
|
| 153 |
+
granted to You under this License for that Work shall terminate
|
| 154 |
+
as of the date such litigation is filed.
|
| 155 |
+
|
| 156 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 157 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 158 |
+
modifications, and in Source or Object form, provided that You
|
| 159 |
+
meet the following conditions:
|
| 160 |
+
|
| 161 |
+
(a) You must give any other recipients of the Work or
|
| 162 |
+
Derivative Works a copy of this License; and
|
| 163 |
+
|
| 164 |
+
(b) You must cause any modified files to carry prominent notices
|
| 165 |
+
stating that You changed the files; and
|
| 166 |
+
|
| 167 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 168 |
+
that You distribute, all copyright, patent, trademark, and
|
| 169 |
+
attribution notices from the Source form of the Work,
|
| 170 |
+
excluding those notices that do not pertain to any part of
|
| 171 |
+
the Derivative Works; and
|
| 172 |
+
|
| 173 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 174 |
+
distribution, then any Derivative Works that You distribute must
|
| 175 |
+
include a readable copy of the attribution notices contained
|
| 176 |
+
within such NOTICE file, excluding those notices that do not
|
| 177 |
+
pertain to any part of the Derivative Works, in at least one
|
| 178 |
+
of the following places: within a NOTICE text file distributed
|
| 179 |
+
as part of the Derivative Works; within the Source form or
|
| 180 |
+
documentation, if provided along with the Derivative Works; or,
|
| 181 |
+
within a display generated by the Derivative Works, if and
|
| 182 |
+
wherever such third-party notices normally appear. The contents
|
| 183 |
+
of the NOTICE file are for informational purposes only and
|
| 184 |
+
do not modify the License. You may add Your own attribution
|
| 185 |
+
notices within Derivative Works that You distribute, alongside
|
| 186 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 187 |
+
that such additional attribution notices cannot be construed
|
| 188 |
+
as modifying the License.
|
| 189 |
+
|
| 190 |
+
You may add Your own copyright statement to Your modifications and
|
| 191 |
+
may provide additional or different license terms and conditions
|
| 192 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 193 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 194 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 195 |
+
the conditions stated in this License.
|
| 196 |
+
|
| 197 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 198 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 199 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 200 |
+
this License, without any additional terms or conditions.
|
| 201 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 202 |
+
the terms of any separate license agreement you may have executed
|
| 203 |
+
with Licensor regarding such Contributions.
|
| 204 |
+
|
| 205 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 206 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 207 |
+
except as required for reasonable and customary use in describing the
|
| 208 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 209 |
+
|
| 210 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 211 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 212 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 213 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 214 |
+
implied, including, without limitation, any warranties or conditions
|
| 215 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 216 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 217 |
+
appropriateness of using or redistributing the Work and assume any
|
| 218 |
+
risks associated with Your exercise of permissions under this License.
|
| 219 |
+
|
| 220 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 221 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 222 |
+
unless required by applicable law (such as deliberate and grossly
|
| 223 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 224 |
+
liable to You for damages, including any direct, indirect, special,
|
| 225 |
+
incidental, or consequential damages of any character arising as a
|
| 226 |
+
result of this License or out of the use or inability to use the
|
| 227 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 228 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 229 |
+
other commercial damages or losses), even if such Contributor
|
| 230 |
+
has been advised of the possibility of such damages.
|
| 231 |
+
|
| 232 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 233 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 234 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 235 |
+
or other liability obligations and/or rights consistent with this
|
| 236 |
+
License. However, in accepting such obligations, You may act only
|
| 237 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 238 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 239 |
+
defend, and hold each Contributor harmless for any liability
|
| 240 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 241 |
+
of your accepting any such warranty or additional liability.
|
| 242 |
+
|
| 243 |
+
END OF TERMS AND CONDITIONS
|
| 244 |
+
|
| 245 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 246 |
+
|
| 247 |
+
To apply the Apache License to your work, attach the following
|
| 248 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 249 |
+
replaced with your own identifying information. (Don't include
|
| 250 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 251 |
+
comment syntax for the file format. We also recommend that a
|
| 252 |
+
file or class name and description of purpose be included on the
|
| 253 |
+
same "printed page" as the copyright notice for easier
|
| 254 |
+
identification within third-party archives.
|
| 255 |
+
|
| 256 |
+
Copyright © 2023 Apple Inc.
|
| 257 |
+
|
| 258 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 259 |
+
you may not use this file except in compliance with the License.
|
| 260 |
+
You may obtain a copy of the License at
|
| 261 |
+
|
| 262 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 263 |
+
|
| 264 |
+
Unless required by applicable law or agreed to in writing, software
|
| 265 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 266 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 267 |
+
See the License for the specific language governing permissions and
|
| 268 |
+
limitations under the License.
|
ml-stable-diffusion/mlx/CITATION.cff
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cff-version: 1.2.0
|
| 2 |
+
title: mlx
|
| 3 |
+
message: >-
|
| 4 |
+
If you use this software, please cite it using the
|
| 5 |
+
metadata from this file.
|
| 6 |
+
type: software
|
| 7 |
+
authors:
|
| 8 |
+
- given-names: Awni
|
| 9 |
+
family-names: Hannun
|
| 10 |
+
affiliation: Apple
|
| 11 |
+
- given-names: Jagrit
|
| 12 |
+
family-names: Digani
|
| 13 |
+
affiliation: Apple
|
| 14 |
+
- given-names: Angelos
|
| 15 |
+
family-names: Katharopoulos
|
| 16 |
+
affiliation: Apple
|
| 17 |
+
- given-names: Ronan
|
| 18 |
+
family-names: Collobert
|
| 19 |
+
affiliation: Apple
|
| 20 |
+
repository-code: 'https://github.com/ml-explore'
|
| 21 |
+
abstract: >-
|
| 22 |
+
MLX: efficient and flexible machine learning on Apple
|
| 23 |
+
silicon
|
| 24 |
+
license: MIT
|
ml-stable-diffusion/mlx/CMakeLists.txt
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cmake_minimum_required(VERSION 3.25)
|
| 2 |
+
|
| 3 |
+
if(NOT MLX_VERSION)
|
| 4 |
+
file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
|
| 5 |
+
string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
|
| 6 |
+
set(_major ${CMAKE_MATCH_1})
|
| 7 |
+
string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
|
| 8 |
+
set(_minor ${CMAKE_MATCH_1})
|
| 9 |
+
string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
|
| 10 |
+
set(_patch ${CMAKE_MATCH_1})
|
| 11 |
+
set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
|
| 12 |
+
set(MLX_VERSION ${MLX_PROJECT_VERSION})
|
| 13 |
+
else()
|
| 14 |
+
string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
|
| 15 |
+
${MLX_VERSION})
|
| 16 |
+
endif()
|
| 17 |
+
|
| 18 |
+
project(
|
| 19 |
+
mlx
|
| 20 |
+
LANGUAGES C CXX
|
| 21 |
+
VERSION ${MLX_PROJECT_VERSION})
|
| 22 |
+
|
| 23 |
+
# ----------------------------- Setup -----------------------------
|
| 24 |
+
set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
|
| 25 |
+
set(CMAKE_CXX_STANDARD 17)
|
| 26 |
+
set(CMAKE_CXX_STANDARD_REQUIRED ON)
|
| 27 |
+
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
|
| 28 |
+
set(CMAKE_INSTALL_MESSAGE NEVER)
|
| 29 |
+
set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
|
| 30 |
+
|
| 31 |
+
# ----------------------------- Configuration -----------------------------
|
| 32 |
+
option(MLX_BUILD_TESTS "Build tests for mlx" ON)
|
| 33 |
+
option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
|
| 34 |
+
option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
|
| 35 |
+
option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
|
| 36 |
+
option(MLX_BUILD_METAL "Build metal backend" ON)
|
| 37 |
+
option(MLX_BUILD_CPU "Build cpu backend" ON)
|
| 38 |
+
option(MLX_BUILD_CUDA "Build cuda backend" OFF)
|
| 39 |
+
option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
|
| 40 |
+
option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
|
| 41 |
+
option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
|
| 42 |
+
option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
|
| 43 |
+
option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
|
| 44 |
+
option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
|
| 45 |
+
option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
|
| 46 |
+
option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
|
| 47 |
+
option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
|
| 48 |
+
|
| 49 |
+
# --------------------- Processor tests -------------------------
|
| 50 |
+
message(
|
| 51 |
+
STATUS
|
| 52 |
+
"Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
|
| 56 |
+
if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
|
| 57 |
+
if(NOT MLX_ENABLE_X64_MAC)
|
| 58 |
+
message(
|
| 59 |
+
FATAL_ERROR
|
| 60 |
+
"Building for x86_64 on macOS is not supported."
|
| 61 |
+
" If you are on an Apple silicon system, check the build"
|
| 62 |
+
" documentation for possible fixes: "
|
| 63 |
+
"https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
|
| 64 |
+
)
|
| 65 |
+
else()
|
| 66 |
+
set(MLX_BUILD_METAL OFF)
|
| 67 |
+
message(WARNING "Building for x86_64 arch is not officially supported.")
|
| 68 |
+
endif()
|
| 69 |
+
endif()
|
| 70 |
+
else()
|
| 71 |
+
set(MLX_BUILD_METAL OFF)
|
| 72 |
+
endif()
|
| 73 |
+
|
| 74 |
+
if(MLX_USE_CCACHE)
|
| 75 |
+
find_program(CCACHE_PROGRAM ccache)
|
| 76 |
+
if(CCACHE_PROGRAM)
|
| 77 |
+
set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
| 78 |
+
set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
| 79 |
+
set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
|
| 80 |
+
endif()
|
| 81 |
+
endif()
|
| 82 |
+
|
| 83 |
+
# ----------------------------- Lib -----------------------------
|
| 84 |
+
|
| 85 |
+
include(FetchContent)
|
| 86 |
+
# Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
|
| 87 |
+
cmake_policy(SET CMP0135 NEW)
|
| 88 |
+
|
| 89 |
+
add_library(mlx)
|
| 90 |
+
|
| 91 |
+
if(MLX_BUILD_CUDA)
|
| 92 |
+
enable_language(CUDA)
|
| 93 |
+
endif()
|
| 94 |
+
|
| 95 |
+
if(MLX_BUILD_METAL)
|
| 96 |
+
find_library(METAL_LIB Metal)
|
| 97 |
+
find_library(FOUNDATION_LIB Foundation)
|
| 98 |
+
find_library(QUARTZ_LIB QuartzCore)
|
| 99 |
+
if(METAL_LIB)
|
| 100 |
+
message(STATUS "Metal found ${METAL_LIB}")
|
| 101 |
+
else()
|
| 102 |
+
message(
|
| 103 |
+
FATAL_ERROR
|
| 104 |
+
"Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
|
| 105 |
+
endif()
|
| 106 |
+
|
| 107 |
+
if(MLX_METAL_DEBUG)
|
| 108 |
+
add_compile_definitions(MLX_METAL_DEBUG)
|
| 109 |
+
endif()
|
| 110 |
+
|
| 111 |
+
# Throw an error if xcrun not found
|
| 112 |
+
execute_process(
|
| 113 |
+
COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
|
| 114 |
+
OUTPUT_VARIABLE MACOS_SDK_VERSION
|
| 115 |
+
OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
|
| 116 |
+
|
| 117 |
+
if(${MACOS_SDK_VERSION} LESS 14.0)
|
| 118 |
+
message(
|
| 119 |
+
FATAL_ERROR
|
| 120 |
+
"MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
|
| 121 |
+
endif()
|
| 122 |
+
message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
|
| 123 |
+
|
| 124 |
+
set(METAL_CPP_URL
|
| 125 |
+
https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
|
| 126 |
+
|
| 127 |
+
if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
|
| 128 |
+
set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
|
| 129 |
+
endif()
|
| 130 |
+
execute_process(
|
| 131 |
+
COMMAND
|
| 132 |
+
zsh "-c"
|
| 133 |
+
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
| 134 |
+
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
| 135 |
+
FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
|
| 136 |
+
|
| 137 |
+
FetchContent_MakeAvailable(metal_cpp)
|
| 138 |
+
target_include_directories(
|
| 139 |
+
mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
|
| 140 |
+
$<INSTALL_INTERFACE:include/metal_cpp>)
|
| 141 |
+
target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
|
| 142 |
+
endif()
|
| 143 |
+
|
| 144 |
+
if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
|
| 145 |
+
# With newer clang/gcc versions following libs are implicitly linked, but when
|
| 146 |
+
# building on old distributions they need to be explicitly listed.
|
| 147 |
+
target_link_libraries(mlx PRIVATE dl pthread)
|
| 148 |
+
endif()
|
| 149 |
+
|
| 150 |
+
if(WIN32)
|
| 151 |
+
if(MSVC)
|
| 152 |
+
# GGUF does not build with MSVC.
|
| 153 |
+
set(MLX_BUILD_GGUF OFF)
|
| 154 |
+
# There is no prebuilt OpenBLAS distribution for MSVC.
|
| 155 |
+
set(MLX_BUILD_BLAS_FROM_SOURCE ON)
|
| 156 |
+
endif()
|
| 157 |
+
# Windows implementation of dlfcn.h APIs.
|
| 158 |
+
FetchContent_Declare(
|
| 159 |
+
dlfcn-win32
|
| 160 |
+
GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
|
| 161 |
+
GIT_TAG v1.4.1
|
| 162 |
+
EXCLUDE_FROM_ALL)
|
| 163 |
+
block()
|
| 164 |
+
set(BUILD_SHARED_LIBS OFF)
|
| 165 |
+
FetchContent_MakeAvailable(dlfcn-win32)
|
| 166 |
+
endblock()
|
| 167 |
+
target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
|
| 168 |
+
target_link_libraries(mlx PRIVATE dl)
|
| 169 |
+
endif()
|
| 170 |
+
|
| 171 |
+
if(MLX_BUILD_CPU)
|
| 172 |
+
find_library(ACCELERATE_LIBRARY Accelerate)
|
| 173 |
+
if(ACCELERATE_LIBRARY)
|
| 174 |
+
message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
|
| 175 |
+
set(MLX_BUILD_ACCELERATE ON)
|
| 176 |
+
else()
|
| 177 |
+
message(STATUS "Accelerate not found, using default backend.")
|
| 178 |
+
set(MLX_BUILD_ACCELERATE OFF)
|
| 179 |
+
endif()
|
| 180 |
+
|
| 181 |
+
if(MLX_BUILD_ACCELERATE)
|
| 182 |
+
target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
|
| 183 |
+
add_compile_definitions(MLX_USE_ACCELERATE)
|
| 184 |
+
add_compile_definitions(ACCELERATE_NEW_LAPACK)
|
| 185 |
+
elseif(MLX_BUILD_BLAS_FROM_SOURCE)
|
| 186 |
+
# Download and build OpenBLAS from source code.
|
| 187 |
+
FetchContent_Declare(
|
| 188 |
+
openblas
|
| 189 |
+
GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
|
| 190 |
+
GIT_TAG v0.3.28
|
| 191 |
+
EXCLUDE_FROM_ALL)
|
| 192 |
+
set(BUILD_STATIC_LIBS ON) # link statically
|
| 193 |
+
set(NOFORTRAN ON) # msvc has no fortran compiler
|
| 194 |
+
FetchContent_MakeAvailable(openblas)
|
| 195 |
+
target_link_libraries(mlx PRIVATE openblas)
|
| 196 |
+
target_include_directories(
|
| 197 |
+
mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
|
| 198 |
+
"${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
|
| 199 |
+
else()
|
| 200 |
+
if(${CMAKE_HOST_APPLE})
|
| 201 |
+
# The blas shipped in macOS SDK is not supported, search homebrew for
|
| 202 |
+
# openblas instead.
|
| 203 |
+
set(BLA_VENDOR OpenBLAS)
|
| 204 |
+
set(LAPACK_ROOT
|
| 205 |
+
"${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
|
| 206 |
+
endif()
|
| 207 |
+
# Search and link with lapack.
|
| 208 |
+
find_package(LAPACK REQUIRED)
|
| 209 |
+
if(NOT LAPACK_FOUND)
|
| 210 |
+
message(FATAL_ERROR "Must have LAPACK installed")
|
| 211 |
+
endif()
|
| 212 |
+
find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
|
| 213 |
+
/usr/local/opt/openblas/include)
|
| 214 |
+
message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
|
| 215 |
+
message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
|
| 216 |
+
target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
|
| 217 |
+
target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
|
| 218 |
+
# List blas after lapack otherwise we may accidentally incldue an old
|
| 219 |
+
# version of lapack.h from the include dirs of blas.
|
| 220 |
+
find_package(BLAS REQUIRED)
|
| 221 |
+
if(NOT BLAS_FOUND)
|
| 222 |
+
message(FATAL_ERROR "Must have BLAS installed")
|
| 223 |
+
endif()
|
| 224 |
+
# TODO find a cleaner way to do this
|
| 225 |
+
find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
|
| 226 |
+
$ENV{BLAS_HOME}/include)
|
| 227 |
+
message(STATUS "Blas lib " ${BLAS_LIBRARIES})
|
| 228 |
+
message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
|
| 229 |
+
target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
|
| 230 |
+
target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
|
| 231 |
+
endif()
|
| 232 |
+
else()
|
| 233 |
+
set(MLX_BUILD_ACCELERATE OFF)
|
| 234 |
+
endif()
|
| 235 |
+
|
| 236 |
+
message(STATUS "Downloading json")
|
| 237 |
+
FetchContent_Declare(
|
| 238 |
+
json
|
| 239 |
+
URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
|
| 240 |
+
FetchContent_MakeAvailable(json)
|
| 241 |
+
target_include_directories(
|
| 242 |
+
mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
|
| 243 |
+
|
| 244 |
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
|
| 245 |
+
|
| 246 |
+
target_include_directories(
|
| 247 |
+
mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
|
| 248 |
+
$<INSTALL_INTERFACE:include>)
|
| 249 |
+
|
| 250 |
+
# Do not add mlx_EXPORTS define for shared library.
|
| 251 |
+
set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
|
| 252 |
+
|
| 253 |
+
if(USE_SYSTEM_FMT)
|
| 254 |
+
find_package(fmt REQUIRED)
|
| 255 |
+
else()
|
| 256 |
+
FetchContent_Declare(
|
| 257 |
+
fmt
|
| 258 |
+
GIT_REPOSITORY https://github.com/fmtlib/fmt.git
|
| 259 |
+
GIT_TAG 10.2.1
|
| 260 |
+
EXCLUDE_FROM_ALL)
|
| 261 |
+
FetchContent_MakeAvailable(fmt)
|
| 262 |
+
endif()
|
| 263 |
+
target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
|
| 264 |
+
|
| 265 |
+
if(MLX_BUILD_PYTHON_BINDINGS)
|
| 266 |
+
message(STATUS "Building Python bindings.")
|
| 267 |
+
find_package(
|
| 268 |
+
Python 3.8
|
| 269 |
+
COMPONENTS Interpreter Development.Module
|
| 270 |
+
REQUIRED)
|
| 271 |
+
execute_process(
|
| 272 |
+
COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
|
| 273 |
+
OUTPUT_STRIP_TRAILING_WHITESPACE
|
| 274 |
+
OUTPUT_VARIABLE nanobind_ROOT)
|
| 275 |
+
find_package(nanobind CONFIG REQUIRED)
|
| 276 |
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
|
| 277 |
+
endif()
|
| 278 |
+
|
| 279 |
+
if(MLX_BUILD_TESTS)
|
| 280 |
+
include(CTest)
|
| 281 |
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
|
| 282 |
+
endif()
|
| 283 |
+
|
| 284 |
+
if(MLX_BUILD_EXAMPLES)
|
| 285 |
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
|
| 286 |
+
endif()
|
| 287 |
+
|
| 288 |
+
if(MLX_BUILD_BENCHMARKS)
|
| 289 |
+
add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
|
| 290 |
+
endif()
|
| 291 |
+
|
| 292 |
+
# ----------------------------- Installation -----------------------------
|
| 293 |
+
include(GNUInstallDirs)
|
| 294 |
+
|
| 295 |
+
# Install library
|
| 296 |
+
install(
|
| 297 |
+
TARGETS mlx
|
| 298 |
+
EXPORT MLXTargets
|
| 299 |
+
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
| 300 |
+
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
|
| 301 |
+
RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
|
| 302 |
+
INCLUDES
|
| 303 |
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
|
| 304 |
+
|
| 305 |
+
# Install headers
|
| 306 |
+
install(
|
| 307 |
+
DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
|
| 308 |
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
|
| 309 |
+
COMPONENT headers
|
| 310 |
+
FILES_MATCHING
|
| 311 |
+
PATTERN "*.h"
|
| 312 |
+
PATTERN "backend/metal/kernels.h" EXCLUDE)
|
| 313 |
+
|
| 314 |
+
# Install metal dependencies
|
| 315 |
+
if(MLX_BUILD_METAL)
|
| 316 |
+
|
| 317 |
+
# Install metal cpp
|
| 318 |
+
install(
|
| 319 |
+
DIRECTORY ${metal_cpp_SOURCE_DIR}/
|
| 320 |
+
DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
|
| 321 |
+
COMPONENT metal_cpp_source)
|
| 322 |
+
|
| 323 |
+
endif()
|
| 324 |
+
|
| 325 |
+
# Install cmake config
|
| 326 |
+
set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
|
| 327 |
+
set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
|
| 328 |
+
set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
|
| 329 |
+
|
| 330 |
+
install(
|
| 331 |
+
EXPORT MLXTargets
|
| 332 |
+
FILE MLXTargets.cmake
|
| 333 |
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
| 334 |
+
|
| 335 |
+
include(CMakePackageConfigHelpers)
|
| 336 |
+
|
| 337 |
+
write_basic_package_version_file(
|
| 338 |
+
${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
| 339 |
+
COMPATIBILITY SameMajorVersion
|
| 340 |
+
VERSION ${MLX_VERSION})
|
| 341 |
+
|
| 342 |
+
configure_package_config_file(
|
| 343 |
+
${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
|
| 344 |
+
INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
|
| 345 |
+
NO_CHECK_REQUIRED_COMPONENTS_MACRO
|
| 346 |
+
PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
|
| 347 |
+
MLX_CMAKE_INSTALL_MODULE_DIR)
|
| 348 |
+
|
| 349 |
+
install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
|
| 350 |
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
| 351 |
+
|
| 352 |
+
install(DIRECTORY ${CMAKE_MODULE_PATH}/
|
| 353 |
+
DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
|
ml-stable-diffusion/mlx/CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, caste, color, religion, or sexual
|
| 10 |
+
identity and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the overall
|
| 26 |
+
community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or advances of
|
| 31 |
+
any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email address,
|
| 35 |
+
without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 |
+
[opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).
|
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series of
|
| 86 |
+
actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or permanent
|
| 93 |
+
ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within the
|
| 113 |
+
community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.1, available at
|
| 119 |
+
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by
|
| 122 |
+
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
| 123 |
+
|
| 124 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 125 |
+
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
| 126 |
+
[https://www.contributor-covenant.org/translations][translations].
|
| 127 |
+
|
| 128 |
+
[homepage]: https://www.contributor-covenant.org
|
| 129 |
+
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
| 130 |
+
[Mozilla CoC]: https://github.com/mozilla/diversity
|
| 131 |
+
[FAQ]: https://www.contributor-covenant.org/faq
|
| 132 |
+
[translations]: https://www.contributor-covenant.org/translations
|
ml-stable-diffusion/mlx/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to MLX
|
| 2 |
+
|
| 3 |
+
We want to make contributing to this project as easy and transparent as
|
| 4 |
+
possible.
|
| 5 |
+
|
| 6 |
+
## Pull Requests
|
| 7 |
+
|
| 8 |
+
1. Fork and submit pull requests to the repo.
|
| 9 |
+
2. If you've added code that should be tested, add tests.
|
| 10 |
+
3. If a change is likely to impact efficiency, run some of the benchmarks before
|
| 11 |
+
and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
|
| 12 |
+
4. If you've changed APIs, update the documentation.
|
| 13 |
+
5. Every PR should have passing tests and at least one review.
|
| 14 |
+
6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
|
| 15 |
+
This should install hooks for running `black` and `clang-format` to ensure
|
| 16 |
+
consistent style for C++ and python code.
|
| 17 |
+
|
| 18 |
+
You can also run the formatters manually as follows:
|
| 19 |
+
|
| 20 |
+
```shell
|
| 21 |
+
clang-format -i file.cpp
|
| 22 |
+
```
|
| 23 |
+
|
| 24 |
+
```shell
|
| 25 |
+
black file.py
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
or run `pre-commit run --all-files` to check all files in the repo.
|
| 29 |
+
|
| 30 |
+
## Issues
|
| 31 |
+
|
| 32 |
+
We use GitHub issues to track public bugs. Please ensure your description is
|
| 33 |
+
clear and has sufficient instructions to be able to reproduce the issue.
|
| 34 |
+
|
| 35 |
+
## License
|
| 36 |
+
|
| 37 |
+
By contributing to MLX, you agree that your contributions will be licensed
|
| 38 |
+
under the LICENSE file in the root directory of this source tree.
|
ml-stable-diffusion/mlx/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright © 2023 Apple Inc.
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
ml-stable-diffusion/mlx/MANIFEST.in
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
include CMakeLists.txt
|
| 2 |
+
include mlx.pc.in
|
| 3 |
+
recursive-include mlx/ *
|
| 4 |
+
include cmake/*
|
| 5 |
+
include python/src/*
|
| 6 |
+
include python/mlx/py.typed # support type hinting as in PEP-561
|
ml-stable-diffusion/mlx/README.md
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# MLX
|
| 2 |
+
|
| 3 |
+
[**Quickstart**](#quickstart) | [**Installation**](#installation) |
|
| 4 |
+
[**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
|
| 5 |
+
[**Examples**](#examples)
|
| 6 |
+
|
| 7 |
+
[](https://circleci.com/gh/ml-explore/mlx)
|
| 8 |
+
|
| 9 |
+
MLX is an array framework for machine learning on Apple silicon,
|
| 10 |
+
brought to you by Apple machine learning research.
|
| 11 |
+
|
| 12 |
+
Some key features of MLX include:
|
| 13 |
+
|
| 14 |
+
- **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
|
| 15 |
+
also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
|
| 16 |
+
[Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
|
| 17 |
+
the Python API. MLX has higher-level packages like `mlx.nn` and
|
| 18 |
+
`mlx.optimizers` with APIs that closely follow PyTorch to simplify building
|
| 19 |
+
more complex models.
|
| 20 |
+
|
| 21 |
+
- **Composable function transformations**: MLX supports composable function
|
| 22 |
+
transformations for automatic differentiation, automatic vectorization,
|
| 23 |
+
and computation graph optimization.
|
| 24 |
+
|
| 25 |
+
- **Lazy computation**: Computations in MLX are lazy. Arrays are only
|
| 26 |
+
materialized when needed.
|
| 27 |
+
|
| 28 |
+
- **Dynamic graph construction**: Computation graphs in MLX are constructed
|
| 29 |
+
dynamically. Changing the shapes of function arguments does not trigger
|
| 30 |
+
slow compilations, and debugging is simple and intuitive.
|
| 31 |
+
|
| 32 |
+
- **Multi-device**: Operations can run on any of the supported devices
|
| 33 |
+
(currently the CPU and the GPU).
|
| 34 |
+
|
| 35 |
+
- **Unified memory**: A notable difference from MLX and other frameworks
|
| 36 |
+
is the *unified memory model*. Arrays in MLX live in shared memory.
|
| 37 |
+
Operations on MLX arrays can be performed on any of the supported
|
| 38 |
+
device types without transferring data.
|
| 39 |
+
|
| 40 |
+
MLX is designed by machine learning researchers for machine learning
|
| 41 |
+
researchers. The framework is intended to be user-friendly, but still efficient
|
| 42 |
+
to train and deploy models. The design of the framework itself is also
|
| 43 |
+
conceptually simple. We intend to make it easy for researchers to extend and
|
| 44 |
+
improve MLX with the goal of quickly exploring new ideas.
|
| 45 |
+
|
| 46 |
+
The design of MLX is inspired by frameworks like
|
| 47 |
+
[NumPy](https://numpy.org/doc/stable/index.html),
|
| 48 |
+
[PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and
|
| 49 |
+
[ArrayFire](https://arrayfire.org/).
|
| 50 |
+
|
| 51 |
+
## Examples
|
| 52 |
+
|
| 53 |
+
The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
|
| 54 |
+
variety of examples, including:
|
| 55 |
+
|
| 56 |
+
- [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
|
| 57 |
+
- Large-scale text generation with
|
| 58 |
+
[LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
|
| 59 |
+
finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
|
| 60 |
+
- Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
|
| 61 |
+
- Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
|
| 62 |
+
|
| 63 |
+
## Quickstart
|
| 64 |
+
|
| 65 |
+
See the [quick start
|
| 66 |
+
guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
|
| 67 |
+
in the documentation.
|
| 68 |
+
|
| 69 |
+
## Installation
|
| 70 |
+
|
| 71 |
+
MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
|
| 72 |
+
macOS, run:
|
| 73 |
+
|
| 74 |
+
```bash
|
| 75 |
+
pip install mlx
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
To install the CUDA backend on Linux, run:
|
| 79 |
+
|
| 80 |
+
```bash
|
| 81 |
+
pip install mlx[cuda]
|
| 82 |
+
```
|
| 83 |
+
|
| 84 |
+
To install a CPU-only Linux package, run:
|
| 85 |
+
|
| 86 |
+
```bash
|
| 87 |
+
pip install mlx[cpu]
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
Checkout the
|
| 91 |
+
[documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
|
| 92 |
+
for more information on building the C++ and Python APIs from source.
|
| 93 |
+
|
| 94 |
+
## Contributing
|
| 95 |
+
|
| 96 |
+
Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
|
| 97 |
+
on contributing to MLX. See the
|
| 98 |
+
[docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
|
| 99 |
+
information on building from source, and running tests.
|
| 100 |
+
|
| 101 |
+
We are grateful for all of [our
|
| 102 |
+
contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
|
| 103 |
+
to MLX and wish to be acknowledged, please add your name to the list in your
|
| 104 |
+
pull request.
|
| 105 |
+
|
| 106 |
+
## Citing MLX
|
| 107 |
+
|
| 108 |
+
The MLX software suite was initially developed with equal contribution by Awni
|
| 109 |
+
Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
|
| 110 |
+
MLX useful in your research and wish to cite it, please use the following
|
| 111 |
+
BibTex entry:
|
| 112 |
+
|
| 113 |
+
```
|
| 114 |
+
@software{mlx2023,
|
| 115 |
+
author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
|
| 116 |
+
title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
|
| 117 |
+
url = {https://github.com/ml-explore},
|
| 118 |
+
version = {0.0},
|
| 119 |
+
year = {2023},
|
| 120 |
+
}
|
| 121 |
+
```
|
ml-stable-diffusion/mlx/benchmarks/cpp/CMakeLists.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
function(build_benchmark SRCFILE)
|
| 2 |
+
get_filename_component(src_name ${SRCFILE} NAME_WE)
|
| 3 |
+
set(target "${src_name}")
|
| 4 |
+
add_executable(${target} ${SRCFILE})
|
| 5 |
+
target_link_libraries(${target} PRIVATE mlx)
|
| 6 |
+
endfunction(build_benchmark)
|
| 7 |
+
|
| 8 |
+
build_benchmark(single_ops.cpp)
|
| 9 |
+
build_benchmark(irregular_strides.cpp)
|
| 10 |
+
build_benchmark(compare_devices.cpp)
|
| 11 |
+
build_benchmark(autograd.cpp)
|
ml-stable-diffusion/mlx/benchmarks/cpp/autograd.cpp
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
|
| 5 |
+
#include "mlx/mlx.h"
|
| 6 |
+
#include "time_utils.h"
|
| 7 |
+
|
| 8 |
+
namespace mx = mlx::core;
|
| 9 |
+
|
| 10 |
+
void time_value_and_grad() {
|
| 11 |
+
auto x = mx::ones({200, 1000});
|
| 12 |
+
mx::eval(x);
|
| 13 |
+
auto fn = [](mx::array x) {
|
| 14 |
+
for (int i = 0; i < 20; ++i) {
|
| 15 |
+
x = mx::log(mx::exp(x));
|
| 16 |
+
}
|
| 17 |
+
return mx::sum(x);
|
| 18 |
+
};
|
| 19 |
+
|
| 20 |
+
auto grad_fn = mx::grad(fn);
|
| 21 |
+
auto independent_value_and_grad = [&]() {
|
| 22 |
+
auto value = fn(x);
|
| 23 |
+
auto dfdx = grad_fn(x);
|
| 24 |
+
return std::vector<mx::array>{value, dfdx};
|
| 25 |
+
};
|
| 26 |
+
TIME(independent_value_and_grad);
|
| 27 |
+
|
| 28 |
+
auto value_and_grad_fn = mx::value_and_grad(fn);
|
| 29 |
+
auto combined_value_and_grad = [&]() {
|
| 30 |
+
auto [value, dfdx] = value_and_grad_fn(x);
|
| 31 |
+
return std::vector<mx::array>{value, dfdx};
|
| 32 |
+
};
|
| 33 |
+
TIME(combined_value_and_grad);
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
int main() {
|
| 37 |
+
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
| 38 |
+
time_value_and_grad();
|
| 39 |
+
}
|
ml-stable-diffusion/mlx/benchmarks/cpp/compare_devices.cpp
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#include <iostream>
|
| 4 |
+
#include "mlx/mlx.h"
|
| 5 |
+
#include "time_utils.h"
|
| 6 |
+
|
| 7 |
+
namespace mx = mlx::core;
|
| 8 |
+
|
| 9 |
+
void time_add_op() {
|
| 10 |
+
std::vector<int> sizes(1, 1);
|
| 11 |
+
for (int i = 0; i < 9; ++i) {
|
| 12 |
+
sizes.push_back(10 * sizes.back());
|
| 13 |
+
}
|
| 14 |
+
set_default_device(mx::Device::cpu);
|
| 15 |
+
for (auto size : sizes) {
|
| 16 |
+
auto a = mx::random::uniform({size});
|
| 17 |
+
auto b = mx::random::uniform({size});
|
| 18 |
+
mx::eval(a, b);
|
| 19 |
+
std::cout << "Size " << size << std::endl;
|
| 20 |
+
TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
|
| 21 |
+
TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
|
| 22 |
+
}
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
int main() {
|
| 26 |
+
time_add_op();
|
| 27 |
+
}
|
ml-stable-diffusion/mlx/benchmarks/cpp/irregular_strides.cpp
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#include <cstring>
|
| 4 |
+
#include <iostream>
|
| 5 |
+
#include <sstream>
|
| 6 |
+
|
| 7 |
+
#include "mlx/mlx.h"
|
| 8 |
+
#include "time_utils.h"
|
| 9 |
+
|
| 10 |
+
namespace mx = mlx::core;
|
| 11 |
+
|
| 12 |
+
void time_irregular_binary_ops_1D() {
|
| 13 |
+
auto device = mx::default_device();
|
| 14 |
+
int size = 1000000;
|
| 15 |
+
int step = 2;
|
| 16 |
+
auto a = mx::random::uniform({size});
|
| 17 |
+
auto b = mx::random::uniform({size});
|
| 18 |
+
mx::eval(a, b);
|
| 19 |
+
a = slice(a, {0}, {size}, {step});
|
| 20 |
+
b = slice(b, {0}, {size}, {step});
|
| 21 |
+
TIMEM("1D strided", mx::add, a, b, device);
|
| 22 |
+
}
|
| 23 |
+
|
| 24 |
+
void time_irregular_binary_ops_2D() {
|
| 25 |
+
auto device = mx::default_device();
|
| 26 |
+
int size = 2048;
|
| 27 |
+
auto a = mx::random::uniform({size, size});
|
| 28 |
+
auto b = mx::random::uniform({size, size});
|
| 29 |
+
mx::eval(a, b);
|
| 30 |
+
TIMEM("2D regular", mx::add, a, b, device);
|
| 31 |
+
|
| 32 |
+
b = mx::transpose(b);
|
| 33 |
+
mx::eval(b);
|
| 34 |
+
TIMEM("2D mx::transpose", mx::add, a, b, device);
|
| 35 |
+
|
| 36 |
+
b = mx::random::uniform({size});
|
| 37 |
+
mx::eval(b);
|
| 38 |
+
TIMEM("2D broadcast dim 0", mx::add, a, b, device);
|
| 39 |
+
|
| 40 |
+
b = mx::reshape(b, {size, 1});
|
| 41 |
+
mx::eval(b);
|
| 42 |
+
TIMEM("2D broadcast dim 1", mx::add, a, b, device);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
void time_irregular_binary_ops_3D() {
|
| 46 |
+
auto device = mx::default_device();
|
| 47 |
+
int d0 = 32;
|
| 48 |
+
int d1 = 512;
|
| 49 |
+
int d2 = 512;
|
| 50 |
+
auto a = mx::random::uniform({d0, d1, d2});
|
| 51 |
+
auto b = mx::random::uniform({d0, d1, d2});
|
| 52 |
+
TIMEM("3D regular", mx::add, a, b, device);
|
| 53 |
+
|
| 54 |
+
b = mx::transpose(b, {0, 2, 1});
|
| 55 |
+
TIMEM("3D mx::transpose", mx::add, a, b, device);
|
| 56 |
+
|
| 57 |
+
b = mx::random::uniform({d1, d2});
|
| 58 |
+
TIMEM("3D broadcast dim 0", mx::add, a, b, device);
|
| 59 |
+
|
| 60 |
+
b = mx::random::uniform({d0, 1, d2});
|
| 61 |
+
TIMEM("3D broadcast dim 1", mx::add, a, b, device);
|
| 62 |
+
|
| 63 |
+
b = mx::random::uniform({d0, d1, 1});
|
| 64 |
+
TIMEM("3D broadcast dim 2", mx::add, a, b, device);
|
| 65 |
+
|
| 66 |
+
b = mx::random::uniform({d2});
|
| 67 |
+
TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
|
| 68 |
+
|
| 69 |
+
b = mx::random::uniform({d1, 1});
|
| 70 |
+
TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
|
| 71 |
+
|
| 72 |
+
b = mx::random::uniform({d0, 1, 1});
|
| 73 |
+
TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
void time_irregular_binary_ops_4D() {
|
| 77 |
+
auto device = mx::default_device();
|
| 78 |
+
std::vector<int> shape = {8, 8, 512, 512};
|
| 79 |
+
auto a = mx::random::uniform(shape);
|
| 80 |
+
auto b = mx::random::uniform(shape);
|
| 81 |
+
|
| 82 |
+
TIMEM("4D regular", mx::add, a, b, device);
|
| 83 |
+
|
| 84 |
+
b = mx::transpose(b, {0, 1, 3, 2});
|
| 85 |
+
TIMEM("4D mx::transpose", mx::add, a, b, device);
|
| 86 |
+
|
| 87 |
+
std::string om = "4D broadcast dims ";
|
| 88 |
+
for (int i = 0; i < shape.size(); ++i) {
|
| 89 |
+
shape[i] = 1;
|
| 90 |
+
b = mx::random::uniform(shape);
|
| 91 |
+
std::ostringstream msg;
|
| 92 |
+
msg << om << i;
|
| 93 |
+
TIMEM(msg.str(), mx::add, a, b, device);
|
| 94 |
+
|
| 95 |
+
for (int j = i + 1; j < shape.size(); ++j) {
|
| 96 |
+
shape[j] = 1;
|
| 97 |
+
std::ostringstream msg;
|
| 98 |
+
msg << om << i << ", " << j;
|
| 99 |
+
b = mx::random::uniform(shape);
|
| 100 |
+
TIMEM(msg.str(), mx::add, a, b, device);
|
| 101 |
+
shape[j] = a.shape(j);
|
| 102 |
+
|
| 103 |
+
for (int k = j + 1; k < shape.size(); ++k) {
|
| 104 |
+
shape[k] = 1;
|
| 105 |
+
std::ostringstream msg;
|
| 106 |
+
msg << om << i << ", " << j << ", " << k;
|
| 107 |
+
b = mx::random::uniform(shape);
|
| 108 |
+
TIMEM(msg.str(), mx::add, a, b, device);
|
| 109 |
+
shape[k] = a.shape(k);
|
| 110 |
+
}
|
| 111 |
+
}
|
| 112 |
+
shape[i] = a.shape(i);
|
| 113 |
+
}
|
| 114 |
+
}
|
| 115 |
+
|
| 116 |
+
void time_irregular_reshape() {
|
| 117 |
+
auto device = mx::default_device();
|
| 118 |
+
std::vector<int> shape;
|
| 119 |
+
auto reshape_fn = [&shape, device](const mx::array& a) {
|
| 120 |
+
return mx::reshape(a, shape, device);
|
| 121 |
+
};
|
| 122 |
+
|
| 123 |
+
int size = 64;
|
| 124 |
+
int d = 2 * size;
|
| 125 |
+
|
| 126 |
+
auto a = mx::random::uniform({d, d, d});
|
| 127 |
+
|
| 128 |
+
shape = {8 * size, size, size};
|
| 129 |
+
TIMEM("3D contiguous", reshape_fn, a);
|
| 130 |
+
|
| 131 |
+
a = mx::transpose(a);
|
| 132 |
+
shape = {8 * size, size, size};
|
| 133 |
+
TIMEM("3D mx::transpose", reshape_fn, a);
|
| 134 |
+
|
| 135 |
+
a = mx::transpose(a, {1, 2, 0});
|
| 136 |
+
shape = {8 * size, size, size};
|
| 137 |
+
TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
|
| 138 |
+
|
| 139 |
+
a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
|
| 140 |
+
TIMEM("3D broadcast dim 0", reshape_fn, a);
|
| 141 |
+
|
| 142 |
+
a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
|
| 143 |
+
TIMEM("3D broadcast dim 1", reshape_fn, a);
|
| 144 |
+
|
| 145 |
+
a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
|
| 146 |
+
TIMEM("3D broadcast dim 2", reshape_fn, a);
|
| 147 |
+
|
| 148 |
+
a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
|
| 149 |
+
TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
|
| 150 |
+
|
| 151 |
+
a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
|
| 152 |
+
TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
|
| 153 |
+
|
| 154 |
+
a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
|
| 155 |
+
TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
|
| 156 |
+
|
| 157 |
+
a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
|
| 158 |
+
TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
|
| 159 |
+
}
|
| 160 |
+
|
| 161 |
+
void time_irregular_astype_1D() {
|
| 162 |
+
auto device = mx::default_device();
|
| 163 |
+
int size = 1000000;
|
| 164 |
+
int step = 2;
|
| 165 |
+
auto a = mx::random::uniform({size});
|
| 166 |
+
a = slice(a, {0}, {size}, {step});
|
| 167 |
+
TIMEM("1D strided", mx::astype, a, mx::int32, device);
|
| 168 |
+
}
|
| 169 |
+
|
| 170 |
+
void time_irregular_astype_2D() {
|
| 171 |
+
auto device = mx::default_device();
|
| 172 |
+
int size = 2048;
|
| 173 |
+
std::vector<int> shape = {size, size};
|
| 174 |
+
|
| 175 |
+
auto a = mx::random::uniform(shape);
|
| 176 |
+
TIMEM("2D regular", mx::astype, a, mx::int32, device);
|
| 177 |
+
|
| 178 |
+
a = mx::transpose(a);
|
| 179 |
+
TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
|
| 180 |
+
|
| 181 |
+
a = mx::broadcast_to(mx::random::uniform({size}), shape);
|
| 182 |
+
TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
|
| 183 |
+
|
| 184 |
+
a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
|
| 185 |
+
TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
|
| 186 |
+
}
|
| 187 |
+
|
| 188 |
+
int main(int argc, char** argv) {
|
| 189 |
+
if (argc > 1) {
|
| 190 |
+
bool use_gpu = !strcmp(argv[1], "gpu");
|
| 191 |
+
set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
|
| 192 |
+
}
|
| 193 |
+
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
| 194 |
+
time_irregular_binary_ops_1D();
|
| 195 |
+
time_irregular_binary_ops_2D();
|
| 196 |
+
time_irregular_binary_ops_3D();
|
| 197 |
+
time_irregular_binary_ops_4D();
|
| 198 |
+
time_irregular_reshape();
|
| 199 |
+
time_irregular_astype_1D();
|
| 200 |
+
time_irregular_astype_2D();
|
| 201 |
+
}
|
ml-stable-diffusion/mlx/benchmarks/cpp/single_ops.cpp
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#include "mlx/mlx.h"
|
| 4 |
+
#include "time_utils.h"
|
| 5 |
+
|
| 6 |
+
namespace mx = mlx::core;
|
| 7 |
+
|
| 8 |
+
void time_creation_ops() {
|
| 9 |
+
int M = 2000;
|
| 10 |
+
int N = 500;
|
| 11 |
+
auto shape = {M, N};
|
| 12 |
+
auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
|
| 13 |
+
TIME(full_fp32);
|
| 14 |
+
auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
|
| 15 |
+
TIME(zeros_fp32);
|
| 16 |
+
auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
|
| 17 |
+
TIME(ones_fp32);
|
| 18 |
+
|
| 19 |
+
auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
|
| 20 |
+
TIME(arange_fp32);
|
| 21 |
+
}
|
| 22 |
+
|
| 23 |
+
void time_type_conversions() {
|
| 24 |
+
int M = 2000;
|
| 25 |
+
int N = 500;
|
| 26 |
+
auto shape = {M, N};
|
| 27 |
+
auto device = mx::default_device();
|
| 28 |
+
|
| 29 |
+
auto a = mx::zeros(shape, mx::float32);
|
| 30 |
+
mx::eval(a);
|
| 31 |
+
TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
|
| 32 |
+
TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
|
| 33 |
+
|
| 34 |
+
a = mx::zeros(shape, mx::int32);
|
| 35 |
+
mx::eval(a);
|
| 36 |
+
TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
|
| 37 |
+
|
| 38 |
+
a = mx::zeros(shape, mx::bool_);
|
| 39 |
+
mx::eval(a);
|
| 40 |
+
TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
|
| 41 |
+
TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
|
| 42 |
+
TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
void time_random_generation() {
|
| 46 |
+
int M = 2000;
|
| 47 |
+
int N = 500;
|
| 48 |
+
|
| 49 |
+
auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
|
| 50 |
+
TIME(uniform);
|
| 51 |
+
auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
|
| 52 |
+
TIME(normal);
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
void time_unary_ops() {
|
| 56 |
+
int M = 2000;
|
| 57 |
+
int N = 500;
|
| 58 |
+
auto device = mx::default_device();
|
| 59 |
+
|
| 60 |
+
auto a = mx::random::normal({M, N});
|
| 61 |
+
mx::eval(a);
|
| 62 |
+
TIME(mlx::core::abs, a, device);
|
| 63 |
+
TIME(mx::negative, a, device);
|
| 64 |
+
TIME(mx::sign, a, device);
|
| 65 |
+
TIME(mx::square, a, device);
|
| 66 |
+
TIME(mlx::core::sqrt, a, device);
|
| 67 |
+
TIME(mx::rsqrt, a, device);
|
| 68 |
+
TIME(mlx::core::exp, a, device);
|
| 69 |
+
|
| 70 |
+
a = mx::random::uniform({M, N});
|
| 71 |
+
TIME(mlx::core::log, a, device);
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
void time_binary_ops() {
|
| 75 |
+
int M = 1000, N = 100, K = 10;
|
| 76 |
+
auto condition = mx::random::randint(0, 2, {M, N, K});
|
| 77 |
+
auto a = mx::random::uniform({M, N, K});
|
| 78 |
+
auto b = mx::random::uniform({M, N, K});
|
| 79 |
+
auto device = mx::default_device();
|
| 80 |
+
mx::eval(a, b);
|
| 81 |
+
|
| 82 |
+
TIME(mx::add, a, b, device);
|
| 83 |
+
TIME(mx::subtract, a, b, device);
|
| 84 |
+
TIME(mx::multiply, a, b, device);
|
| 85 |
+
TIME(mx::divide, a, b, device);
|
| 86 |
+
TIME(mx::maximum, a, b, device);
|
| 87 |
+
TIME(mx::minimum, a, b, device);
|
| 88 |
+
TIME(mx::where, condition, a, b, device);
|
| 89 |
+
|
| 90 |
+
condition = mx::array({true});
|
| 91 |
+
b = mx::random::uniform({1});
|
| 92 |
+
mx::eval(b);
|
| 93 |
+
TIMEM("scalar", mx::add, a, b, device);
|
| 94 |
+
TIMEM("vector-scalar", mx::subtract, a, b, device);
|
| 95 |
+
TIMEM("scalar-vector", mx::subtract, b, a, device);
|
| 96 |
+
TIMEM("scalar", mx::multiply, a, b, device);
|
| 97 |
+
TIMEM("vector-scalar", mx::divide, a, b, device);
|
| 98 |
+
TIMEM("scalar-vector", mx::divide, b, a, device);
|
| 99 |
+
TIMEM("scalar-vector", mx::where, condition, a, b, device);
|
| 100 |
+
|
| 101 |
+
condition = mx::broadcast_to(mx::array({true}), {1000, 100});
|
| 102 |
+
a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
| 103 |
+
b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
|
| 104 |
+
mx::eval(a, b);
|
| 105 |
+
TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
|
| 106 |
+
TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
|
| 107 |
+
TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
|
| 108 |
+
TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
|
| 109 |
+
TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
|
| 110 |
+
}
|
| 111 |
+
|
| 112 |
+
void time_strided_ops() {
|
| 113 |
+
int M = 50, N = 50, O = 50, P = 50;
|
| 114 |
+
auto a = mx::random::uniform({M, N, O, P});
|
| 115 |
+
auto b = mx::random::uniform({M, N, O, P});
|
| 116 |
+
auto device = mx::default_device();
|
| 117 |
+
mx::eval(a, b);
|
| 118 |
+
TIMEM("non-strided", mx::add, a, b, device);
|
| 119 |
+
a = mx::transpose(a, {1, 0, 2, 3});
|
| 120 |
+
b = mx::transpose(b, {3, 2, 0, 1});
|
| 121 |
+
mx::eval(a, b);
|
| 122 |
+
TIMEM("strided", mx::add, a, b, device);
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
void time_comparisons() {
|
| 126 |
+
int M = 1000, N = 100, K = 10;
|
| 127 |
+
auto a = mx::random::uniform({M, N, K});
|
| 128 |
+
auto b = mx::random::uniform({M, N, K});
|
| 129 |
+
auto device = mx::default_device();
|
| 130 |
+
mx::eval(a, b);
|
| 131 |
+
TIME(mx::equal, a, b, device);
|
| 132 |
+
TIME(mx::greater, a, b, device);
|
| 133 |
+
TIME(mx::greater_equal, a, b, device);
|
| 134 |
+
TIME(mx::less, a, b, device);
|
| 135 |
+
TIME(mx::less_equal, a, b, device);
|
| 136 |
+
}
|
| 137 |
+
|
| 138 |
+
void time_matvec() {
|
| 139 |
+
int M = 2000, N = 200;
|
| 140 |
+
auto a = mx::random::uniform({M, N});
|
| 141 |
+
auto b = mx::random::uniform({N});
|
| 142 |
+
auto c = mx::random::uniform({M});
|
| 143 |
+
mx::eval(a, b, c);
|
| 144 |
+
auto matvec = [&]() { return mx::matmul(a, b); };
|
| 145 |
+
TIME(matvec);
|
| 146 |
+
|
| 147 |
+
auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
|
| 148 |
+
TIME(matvec_transpose);
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
void time_matmul() {
|
| 152 |
+
int M = 1000, N = 1000, K = 1000;
|
| 153 |
+
auto a = mx::random::uniform({M, K});
|
| 154 |
+
auto b = mx::random::uniform({K, N});
|
| 155 |
+
auto device = mx::default_device();
|
| 156 |
+
mx::eval(a, b);
|
| 157 |
+
TIME(mx::matmul, a, b, device);
|
| 158 |
+
|
| 159 |
+
auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
|
| 160 |
+
TIME(transpose_matmul);
|
| 161 |
+
}
|
| 162 |
+
|
| 163 |
+
void time_reductions() {
|
| 164 |
+
auto a = mx::random::normal({10000, 1000});
|
| 165 |
+
mx::eval(a);
|
| 166 |
+
auto sum_all = [&a]() { return mx::sum(a, false); };
|
| 167 |
+
TIME(sum_all);
|
| 168 |
+
|
| 169 |
+
auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
|
| 170 |
+
TIME(sum_along_0);
|
| 171 |
+
|
| 172 |
+
auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
|
| 173 |
+
TIME(sum_along_1);
|
| 174 |
+
|
| 175 |
+
auto prod_all = [&a]() { return mx::prod(a, false); };
|
| 176 |
+
TIME(prod_all);
|
| 177 |
+
|
| 178 |
+
auto all_true = [&a]() { return mx::all(a, false); };
|
| 179 |
+
TIME(all_true);
|
| 180 |
+
|
| 181 |
+
auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
|
| 182 |
+
TIME(all_along_0);
|
| 183 |
+
|
| 184 |
+
auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
|
| 185 |
+
TIME(all_along_1);
|
| 186 |
+
|
| 187 |
+
auto any_true = [&a]() { return mx::any(a, false); };
|
| 188 |
+
TIME(any_true);
|
| 189 |
+
|
| 190 |
+
auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
|
| 191 |
+
TIME(argmin_along_0);
|
| 192 |
+
|
| 193 |
+
auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
|
| 194 |
+
TIME(argmin_along_1);
|
| 195 |
+
|
| 196 |
+
auto indices = mx::array({1});
|
| 197 |
+
auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
|
| 198 |
+
std::vector<int> axes{0};
|
| 199 |
+
auto b = scatter(a, {indices}, updates, axes);
|
| 200 |
+
mx::eval(b);
|
| 201 |
+
|
| 202 |
+
auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
|
| 203 |
+
TIME(max_along_0);
|
| 204 |
+
auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
|
| 205 |
+
TIME(max_along_1);
|
| 206 |
+
|
| 207 |
+
auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
|
| 208 |
+
TIME(min_along_0);
|
| 209 |
+
auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
|
| 210 |
+
TIME(min_along_1);
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
void time_gather_scatter() {
|
| 214 |
+
auto a = mx::random::normal({1000, 768});
|
| 215 |
+
mx::eval(a);
|
| 216 |
+
auto indices = mx::random::randint(0, 1000, {256});
|
| 217 |
+
mx::eval(indices);
|
| 218 |
+
|
| 219 |
+
auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
|
| 220 |
+
TIME(embedding_lookup);
|
| 221 |
+
|
| 222 |
+
indices = mx::random::randint(0, 768 * 1000, {256 * 768});
|
| 223 |
+
mx::eval(indices);
|
| 224 |
+
|
| 225 |
+
auto single_element_lookup = [&a, &indices]() {
|
| 226 |
+
return mx::take(a, indices);
|
| 227 |
+
};
|
| 228 |
+
TIME(single_element_lookup);
|
| 229 |
+
|
| 230 |
+
indices = mx::random::randint(0, 1000, {256});
|
| 231 |
+
auto updates = mx::random::normal({256, 1, 768});
|
| 232 |
+
mx::eval(indices, updates);
|
| 233 |
+
|
| 234 |
+
auto embedding_update = [&a, &indices, &updates]() {
|
| 235 |
+
return scatter(a, indices, updates, 0);
|
| 236 |
+
};
|
| 237 |
+
TIME(embedding_update);
|
| 238 |
+
|
| 239 |
+
auto embedding_add = [&a, &indices, &updates]() {
|
| 240 |
+
return scatter_add(a, indices, updates, 0);
|
| 241 |
+
};
|
| 242 |
+
TIME(embedding_add);
|
| 243 |
+
|
| 244 |
+
a = mx::reshape(a, {-1});
|
| 245 |
+
indices = mx::random::randint(0, 768 * 1000, {768 * 256});
|
| 246 |
+
updates = mx::random::normal({256 * 768, 1});
|
| 247 |
+
mx::eval(a, indices, updates);
|
| 248 |
+
|
| 249 |
+
auto single_element_update = [&a, &indices, &updates]() {
|
| 250 |
+
return scatter(a, indices, updates, 0);
|
| 251 |
+
};
|
| 252 |
+
TIME(single_element_update);
|
| 253 |
+
|
| 254 |
+
auto single_element_add = [&a, &indices, &updates]() {
|
| 255 |
+
return scatter_add(a, indices, updates, 0);
|
| 256 |
+
};
|
| 257 |
+
TIME(single_element_add);
|
| 258 |
+
}
|
| 259 |
+
|
| 260 |
+
void time_divmod() {
|
| 261 |
+
auto a = mx::random::normal({1000});
|
| 262 |
+
auto b = mx::random::normal({1000});
|
| 263 |
+
mx::eval({a, b});
|
| 264 |
+
|
| 265 |
+
auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
|
| 266 |
+
TIME(divmod_fused);
|
| 267 |
+
|
| 268 |
+
auto divmod_separate = [&a, &b]() {
|
| 269 |
+
return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
|
| 270 |
+
};
|
| 271 |
+
TIME(divmod_separate);
|
| 272 |
+
}
|
| 273 |
+
|
| 274 |
+
int main() {
|
| 275 |
+
std::cout << "Benchmarks for " << mx::default_device() << std::endl;
|
| 276 |
+
time_creation_ops();
|
| 277 |
+
time_type_conversions();
|
| 278 |
+
time_unary_ops();
|
| 279 |
+
time_binary_ops();
|
| 280 |
+
time_strided_ops();
|
| 281 |
+
time_random_generation();
|
| 282 |
+
time_comparisons();
|
| 283 |
+
time_matvec();
|
| 284 |
+
time_matmul();
|
| 285 |
+
time_reductions();
|
| 286 |
+
time_gather_scatter();
|
| 287 |
+
time_divmod();
|
| 288 |
+
}
|
ml-stable-diffusion/mlx/benchmarks/cpp/time_utils.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#pragma once
|
| 4 |
+
|
| 5 |
+
#include <chrono>
|
| 6 |
+
#include <iomanip>
|
| 7 |
+
#include <iostream>
|
| 8 |
+
|
| 9 |
+
#include "mlx/mlx.h"
|
| 10 |
+
|
| 11 |
+
#define milliseconds(x) \
|
| 12 |
+
(std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
|
| 13 |
+
#define time_now() std::chrono::high_resolution_clock::now()
|
| 14 |
+
|
| 15 |
+
#define TIME(FUNC, ...) \
|
| 16 |
+
std::cout << "Timing " << #FUNC << " ... " << std::flush \
|
| 17 |
+
<< std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
|
| 18 |
+
<< std::endl;
|
| 19 |
+
|
| 20 |
+
#define TIMEM(MSG, FUNC, ...) \
|
| 21 |
+
std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
|
| 22 |
+
<< std::flush << std::setprecision(5) \
|
| 23 |
+
<< time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
|
| 24 |
+
|
| 25 |
+
template <typename F, typename... Args>
|
| 26 |
+
double time_fn(F fn, Args&&... args) {
|
| 27 |
+
// warmup
|
| 28 |
+
for (int i = 0; i < 5; ++i) {
|
| 29 |
+
eval(fn(std::forward<Args>(args)...));
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
int num_iters = 100;
|
| 33 |
+
auto start = time_now();
|
| 34 |
+
for (int i = 0; i < num_iters; i++) {
|
| 35 |
+
eval(fn(std::forward<Args>(args)...));
|
| 36 |
+
}
|
| 37 |
+
auto end = time_now();
|
| 38 |
+
return milliseconds(end - start) / static_cast<double>(num_iters);
|
| 39 |
+
}
|
ml-stable-diffusion/mlx/benchmarks/numpy/single_ops.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
from time_utils import time_fn
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def time_add():
|
| 8 |
+
a = np.ones((100, 100, 10), dtype=np.float32)
|
| 9 |
+
b = np.ones((100, 100, 10), dtype=np.float32)
|
| 10 |
+
time_fn(np.add, a, b)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def time_matmul():
|
| 14 |
+
a = np.random.rand(1000, 500).astype(np.float32)
|
| 15 |
+
b = np.random.rand(500, 1000).astype(np.float32)
|
| 16 |
+
time_fn(np.matmul, a, b)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def time_exp():
|
| 20 |
+
a = np.random.randn(1000, 100).astype(np.float32)
|
| 21 |
+
time_fn(np.exp, a)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def time_take():
|
| 25 |
+
a = np.random.rand(10000, 500)
|
| 26 |
+
ids = np.random.randint(0, 10000, (20, 10))
|
| 27 |
+
ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
|
| 28 |
+
|
| 29 |
+
def random_take():
|
| 30 |
+
return [np.take(a, idx, 0) for idx in ids]
|
| 31 |
+
|
| 32 |
+
time_fn(random_take)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
time_add()
|
| 37 |
+
time_matmul()
|
| 38 |
+
time_exp()
|
| 39 |
+
time_take()
|
ml-stable-diffusion/mlx/benchmarks/numpy/time_utils.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def time_fn(fn, *args):
|
| 7 |
+
print(f"Timing {fn.__name__} ...", end=" ")
|
| 8 |
+
|
| 9 |
+
# warmup
|
| 10 |
+
for _ in range(5):
|
| 11 |
+
fn(*args)
|
| 12 |
+
|
| 13 |
+
num_iters = 100
|
| 14 |
+
tic = time.perf_counter()
|
| 15 |
+
for _ in range(num_iters):
|
| 16 |
+
x = fn(*args)
|
| 17 |
+
toc = time.perf_counter()
|
| 18 |
+
|
| 19 |
+
msec = 1e3 * (toc - tic) / num_iters
|
| 20 |
+
print(f"{msec:.5f} msec")
|
ml-stable-diffusion/mlx/benchmarks/python/batch_matmul_bench.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
from time_utils import time_fn
|
| 7 |
+
|
| 8 |
+
B = 8
|
| 9 |
+
T = 1024
|
| 10 |
+
D = 512
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def time_batch_matmul():
|
| 14 |
+
mx.random.seed(3)
|
| 15 |
+
a = mx.random.uniform(shape=(B, T, D))
|
| 16 |
+
b = mx.random.uniform(shape=(D, D))
|
| 17 |
+
c = mx.random.uniform(shape=(B, T, D))
|
| 18 |
+
mx.eval(a, b, c)
|
| 19 |
+
|
| 20 |
+
time_fn(mx.matmul, a, b)
|
| 21 |
+
|
| 22 |
+
def batch_vjp_first():
|
| 23 |
+
return mx.vjp(mx.matmul, [a, b], [c])[1][0]
|
| 24 |
+
|
| 25 |
+
time_fn(batch_vjp_first)
|
| 26 |
+
|
| 27 |
+
def batch_vjp_second():
|
| 28 |
+
return mx.vjp(mx.matmul, [a, b], [c])[1][1]
|
| 29 |
+
|
| 30 |
+
time_fn(batch_vjp_second)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def time_unbatch_matmul():
|
| 34 |
+
mx.random.seed(3)
|
| 35 |
+
a = mx.random.uniform(shape=(B * T, D))
|
| 36 |
+
b = mx.random.uniform(shape=(D, D))
|
| 37 |
+
c = mx.random.uniform(shape=(B * T, D))
|
| 38 |
+
mx.eval(a, b, c)
|
| 39 |
+
time_fn(mx.matmul, a, b)
|
| 40 |
+
|
| 41 |
+
def unbatch_vjp_first():
|
| 42 |
+
return mx.matmul(c, mx.transpose(b))
|
| 43 |
+
|
| 44 |
+
time_fn(unbatch_vjp_first)
|
| 45 |
+
|
| 46 |
+
def unbatch_vjp_second():
|
| 47 |
+
return mx.matmul(mx.transpose(a), c)
|
| 48 |
+
|
| 49 |
+
time_fn(unbatch_vjp_second)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
if __name__ == "__main__":
|
| 53 |
+
parser = argparse.ArgumentParser("MLX benchmarks.")
|
| 54 |
+
parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
|
| 55 |
+
args = parser.parse_args()
|
| 56 |
+
if args.gpu:
|
| 57 |
+
mx.set_default_device(mx.gpu)
|
| 58 |
+
else:
|
| 59 |
+
mx.set_default_device(mx.cpu)
|
| 60 |
+
|
| 61 |
+
time_batch_matmul()
|
| 62 |
+
time_unbatch_matmul()
|
ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemm.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import subprocess
|
| 7 |
+
import time
|
| 8 |
+
|
| 9 |
+
import mlx.core as mx
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
| 14 |
+
device_name = device_name.decode("utf-8").strip("\n")
|
| 15 |
+
|
| 16 |
+
N_warmup = 8
|
| 17 |
+
N_iter_bench = 80
|
| 18 |
+
N_iter_func = 5
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def bench(f, a, b):
|
| 22 |
+
for i in range(N_warmup):
|
| 23 |
+
f(a, b)
|
| 24 |
+
torch.mps.synchronize()
|
| 25 |
+
|
| 26 |
+
s = time.perf_counter_ns()
|
| 27 |
+
for i in range(N_iter_bench):
|
| 28 |
+
f(a, b)
|
| 29 |
+
e = time.perf_counter_ns()
|
| 30 |
+
return (e - s) * 1e-9
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def gemm_nn_mlx(a, b):
|
| 34 |
+
ys = []
|
| 35 |
+
for i in range(N_iter_func):
|
| 36 |
+
y = a @ b
|
| 37 |
+
ys.append(y)
|
| 38 |
+
mx.eval(ys)
|
| 39 |
+
return ys
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def gemm_nt_mlx(a, b):
|
| 43 |
+
ys = []
|
| 44 |
+
for i in range(N_iter_func):
|
| 45 |
+
y = a @ b.transpose((0, 2, 1))
|
| 46 |
+
ys.append(y)
|
| 47 |
+
mx.eval(ys)
|
| 48 |
+
return ys
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def gemm_tn_mlx(a, b):
|
| 52 |
+
ys = []
|
| 53 |
+
for i in range(N_iter_func):
|
| 54 |
+
y = a.transpose((0, 2, 1)) @ b
|
| 55 |
+
ys.append(y)
|
| 56 |
+
mx.eval(ys)
|
| 57 |
+
return ys
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def gemm_tt_mlx(a, b):
|
| 61 |
+
ys = []
|
| 62 |
+
for i in range(N_iter_func):
|
| 63 |
+
y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
|
| 64 |
+
ys.append(y)
|
| 65 |
+
mx.eval(ys)
|
| 66 |
+
return ys
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
@torch.no_grad()
|
| 70 |
+
def gemm_nn_torch(a, b):
|
| 71 |
+
ys = []
|
| 72 |
+
for i in range(N_iter_func):
|
| 73 |
+
y = a @ b
|
| 74 |
+
ys.append(y)
|
| 75 |
+
torch.mps.synchronize()
|
| 76 |
+
return ys
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@torch.no_grad()
|
| 80 |
+
def gemm_nt_torch(a, b):
|
| 81 |
+
ys = []
|
| 82 |
+
for i in range(N_iter_func):
|
| 83 |
+
y = a @ b.transpose(-1, -2)
|
| 84 |
+
ys.append(y)
|
| 85 |
+
torch.mps.synchronize()
|
| 86 |
+
return ys
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@torch.no_grad()
|
| 90 |
+
def gemm_tn_torch(a, b):
|
| 91 |
+
ys = []
|
| 92 |
+
for i in range(N_iter_func):
|
| 93 |
+
y = a.transpose(-1, -2) @ b
|
| 94 |
+
ys.append(y)
|
| 95 |
+
torch.mps.synchronize()
|
| 96 |
+
return ys
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
@torch.no_grad()
|
| 100 |
+
def gemm_tt_torch(a, b):
|
| 101 |
+
ys = []
|
| 102 |
+
for i in range(N_iter_func):
|
| 103 |
+
y = a.transpose(-1, -2) @ b.transpose(-1, -2)
|
| 104 |
+
ys.append(y)
|
| 105 |
+
torch.mps.synchronize()
|
| 106 |
+
return ys
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
|
| 110 |
+
shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
|
| 111 |
+
shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
|
| 112 |
+
|
| 113 |
+
a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
|
| 114 |
+
b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
|
| 115 |
+
|
| 116 |
+
a_mx = mx.array(a_np)
|
| 117 |
+
b_mx = mx.array(b_np)
|
| 118 |
+
|
| 119 |
+
a_pt = torch.from_numpy(a_np).to("mps")
|
| 120 |
+
b_pt = torch.from_numpy(b_np).to("mps")
|
| 121 |
+
|
| 122 |
+
torch.mps.synchronize()
|
| 123 |
+
|
| 124 |
+
f_mx = {
|
| 125 |
+
"nn": gemm_nn_mlx,
|
| 126 |
+
"nt": gemm_nt_mlx,
|
| 127 |
+
"tn": gemm_tn_mlx,
|
| 128 |
+
"tt": gemm_tt_mlx,
|
| 129 |
+
}[transpose]
|
| 130 |
+
|
| 131 |
+
f_pt = {
|
| 132 |
+
"nn": gemm_nn_torch,
|
| 133 |
+
"nt": gemm_nt_torch,
|
| 134 |
+
"tn": gemm_tn_torch,
|
| 135 |
+
"tt": gemm_tt_torch,
|
| 136 |
+
}[transpose]
|
| 137 |
+
|
| 138 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 139 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 140 |
+
|
| 141 |
+
t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
|
| 142 |
+
t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
|
| 143 |
+
|
| 144 |
+
c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
|
| 145 |
+
c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
|
| 146 |
+
|
| 147 |
+
atol = 1e-5 if np_dtype == np.float32 else 1e-4
|
| 148 |
+
|
| 149 |
+
if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
|
| 150 |
+
print(
|
| 151 |
+
f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
| 152 |
+
)
|
| 153 |
+
|
| 154 |
+
return time_mlx, time_torch
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def get_gflop_count(B, M, N, K):
|
| 158 |
+
return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
if __name__ == "__main__":
|
| 162 |
+
parser = argparse.ArgumentParser(description="Run gemm benchmarks")
|
| 163 |
+
|
| 164 |
+
dtypes = ("float32", "float16", "complex64")
|
| 165 |
+
transposes = ("nn", "nt", "tn")
|
| 166 |
+
shapes = (
|
| 167 |
+
(16, 234, 768, 3072),
|
| 168 |
+
(1, 64, 64, 25344),
|
| 169 |
+
(16, 1024, 1024, 1024),
|
| 170 |
+
(1, 1024, 1024, 2048),
|
| 171 |
+
(4, 1024, 1024, 4096),
|
| 172 |
+
(4, 1024, 4096, 1024),
|
| 173 |
+
(1, 4096, 4096, 4096),
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
for dtype in dtypes:
|
| 177 |
+
for transpose in transposes:
|
| 178 |
+
for B, M, N, K in shapes:
|
| 179 |
+
np_dtype = getattr(np, dtype)
|
| 180 |
+
time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
|
| 181 |
+
|
| 182 |
+
gflop_count = get_gflop_count(B, M, N, K)
|
| 183 |
+
gflops_mx = gflop_count / (time_mlx)
|
| 184 |
+
gflops_pt = gflop_count / (time_torch)
|
| 185 |
+
diff = gflops_mx / gflops_pt - 1.0
|
| 186 |
+
|
| 187 |
+
print(
|
| 188 |
+
f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
|
| 189 |
+
)
|
| 190 |
+
if gflops_pt >= 2.0 * gflops_mx:
|
| 191 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemv.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import subprocess
|
| 6 |
+
import time
|
| 7 |
+
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import mlx.core as mx
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
results_dir = "./results"
|
| 14 |
+
|
| 15 |
+
if not os.path.isdir(results_dir):
|
| 16 |
+
os.mkdir(results_dir)
|
| 17 |
+
|
| 18 |
+
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
| 19 |
+
device_name = device_name.decode("utf-8").strip("\n")
|
| 20 |
+
|
| 21 |
+
N_warmup = 5
|
| 22 |
+
N_iter_bench = 50
|
| 23 |
+
N_iter_func = 20
|
| 24 |
+
|
| 25 |
+
out_vec_sizes = [128, 512, 2048, 4096]
|
| 26 |
+
in_vec_sizes = [128, 512, 2048, 4096]
|
| 27 |
+
|
| 28 |
+
benchmark_vector_lens = []
|
| 29 |
+
benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
|
| 30 |
+
benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
|
| 31 |
+
benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
|
| 32 |
+
benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
|
| 33 |
+
|
| 34 |
+
benchmark_vector_lens.sort()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def bench(f, m, v):
|
| 38 |
+
for i in range(N_warmup):
|
| 39 |
+
f(m, v)
|
| 40 |
+
torch.mps.synchronize()
|
| 41 |
+
|
| 42 |
+
s = time.perf_counter_ns()
|
| 43 |
+
for i in range(N_iter_bench):
|
| 44 |
+
f(m, v)
|
| 45 |
+
e = time.perf_counter_ns()
|
| 46 |
+
return (e - s) * 1e-9
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def gemv_mlx(m, v):
|
| 50 |
+
ys = []
|
| 51 |
+
for i in range(N_iter_func):
|
| 52 |
+
y = m @ v
|
| 53 |
+
ys.append(y)
|
| 54 |
+
mx.eval(ys)
|
| 55 |
+
return ys
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def gemv_t_mlx(m, v):
|
| 59 |
+
ys = []
|
| 60 |
+
for i in range(N_iter_func):
|
| 61 |
+
y = v @ m
|
| 62 |
+
ys.append(y)
|
| 63 |
+
mx.eval(ys)
|
| 64 |
+
return ys
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@torch.no_grad()
|
| 68 |
+
def gemv_torch(m, v):
|
| 69 |
+
ys = []
|
| 70 |
+
for i in range(N_iter_func):
|
| 71 |
+
y = m @ v
|
| 72 |
+
ys.append(y)
|
| 73 |
+
torch.mps.synchronize()
|
| 74 |
+
return ys
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
@torch.no_grad()
|
| 78 |
+
def gemv_t_torch(m, v):
|
| 79 |
+
ys = []
|
| 80 |
+
for i in range(N_iter_func):
|
| 81 |
+
y = v @ m
|
| 82 |
+
ys.append(y)
|
| 83 |
+
torch.mps.synchronize()
|
| 84 |
+
return ys
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
|
| 88 |
+
shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
|
| 89 |
+
shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
|
| 90 |
+
|
| 91 |
+
mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
|
| 92 |
+
vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
|
| 93 |
+
mat_mlx = mx.array(mat_npy)
|
| 94 |
+
vec_mlx = mx.array(vec_npy)
|
| 95 |
+
mat_trc = torch.from_numpy(mat_npy).to("mps")
|
| 96 |
+
vec_trc = torch.from_numpy(vec_npy).to("mps")
|
| 97 |
+
|
| 98 |
+
torch.mps.synchronize()
|
| 99 |
+
|
| 100 |
+
time_torch = (
|
| 101 |
+
bench(gemv_t_torch, mat_trc, vec_trc)
|
| 102 |
+
if transpose
|
| 103 |
+
else bench(gemv_torch, mat_trc, vec_trc)
|
| 104 |
+
)
|
| 105 |
+
time_mlx = (
|
| 106 |
+
bench(gemv_t_mlx, mat_mlx, vec_mlx)
|
| 107 |
+
if transpose
|
| 108 |
+
else bench(gemv_mlx, mat_mlx, vec_mlx)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
c_mlx = (
|
| 112 |
+
np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
|
| 113 |
+
)
|
| 114 |
+
c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
|
| 115 |
+
|
| 116 |
+
if not np.allclose(c_mlx, c_npy, atol=2e-5):
|
| 117 |
+
print(
|
| 118 |
+
f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
return time_mlx, time_torch
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
def get_gflop_count(in_vec_len, out_vec_len):
|
| 125 |
+
return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
|
| 126 |
+
1024**3
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
|
| 131 |
+
n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
|
| 132 |
+
item_size = 4 if np_dtype == np.float32 else 2
|
| 133 |
+
return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
|
| 137 |
+
np_dtype = getattr(np, dtype)
|
| 138 |
+
mlx_gb_s = []
|
| 139 |
+
mlx_gflops = []
|
| 140 |
+
pyt_gb_s = []
|
| 141 |
+
pyt_gflops = []
|
| 142 |
+
|
| 143 |
+
for out_vec_len in out_vector_lens:
|
| 144 |
+
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
| 145 |
+
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
| 146 |
+
|
| 147 |
+
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
| 148 |
+
|
| 149 |
+
mlx_gb_s.append(gbyte_size / time_mlx)
|
| 150 |
+
pyt_gb_s.append(gbyte_size / time_torch)
|
| 151 |
+
|
| 152 |
+
mlx_gflops.append(gflop_count / time_mlx)
|
| 153 |
+
pyt_gflops.append(gflop_count / time_torch)
|
| 154 |
+
|
| 155 |
+
if transpose:
|
| 156 |
+
title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
|
| 157 |
+
else:
|
| 158 |
+
title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
|
| 159 |
+
|
| 160 |
+
ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
| 161 |
+
ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
| 162 |
+
ax.set_title(title)
|
| 163 |
+
ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
|
| 164 |
+
ax.legend()
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
|
| 168 |
+
np_dtype = getattr(np, dtype)
|
| 169 |
+
mlx_gb_s = []
|
| 170 |
+
mlx_gflops = []
|
| 171 |
+
pyt_gb_s = []
|
| 172 |
+
pyt_gflops = []
|
| 173 |
+
|
| 174 |
+
for in_vec_len in in_vector_lens:
|
| 175 |
+
gflop_count = get_gflop_count(in_vec_len, out_vec_len)
|
| 176 |
+
gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
|
| 177 |
+
|
| 178 |
+
time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
|
| 179 |
+
|
| 180 |
+
mlx_gb_s.append(gbyte_size / time_mlx)
|
| 181 |
+
pyt_gb_s.append(gbyte_size / time_torch)
|
| 182 |
+
|
| 183 |
+
mlx_gflops.append(gflop_count / time_mlx)
|
| 184 |
+
pyt_gflops.append(gflop_count / time_torch)
|
| 185 |
+
|
| 186 |
+
if transpose:
|
| 187 |
+
title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
|
| 188 |
+
else:
|
| 189 |
+
title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
|
| 190 |
+
|
| 191 |
+
ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
|
| 192 |
+
ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
|
| 193 |
+
ax.set_title(title)
|
| 194 |
+
ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
|
| 195 |
+
ax.legend()
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
for transpose in (False, True):
|
| 199 |
+
for dtype in ("float32", "float16", "complex64"):
|
| 200 |
+
fig, axs = plt.subplots(
|
| 201 |
+
len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
for i, in_vec_len in enumerate(in_vec_sizes):
|
| 205 |
+
bench_with_in_len(
|
| 206 |
+
axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
for i, out_vec_len in enumerate(out_vec_sizes):
|
| 210 |
+
bench_with_out_len(
|
| 211 |
+
axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
op_name = "gemv_t" if transpose else "gemv"
|
| 215 |
+
fig.suptitle(f"{device_name}: {dtype} {op_name}")
|
| 216 |
+
fig.savefig(
|
| 217 |
+
os.path.join(
|
| 218 |
+
results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
|
| 219 |
+
)
|
| 220 |
+
)
|
| 221 |
+
plt.close(fig)
|
ml-stable-diffusion/mlx/benchmarks/python/comparative/README.md
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Microbenchmarks comparing MLX to PyTorch
|
| 2 |
+
========================================
|
| 3 |
+
|
| 4 |
+
Implement the same microbenchmarks in MLX and PyTorch to compare and make a
|
| 5 |
+
list of the biggest possible performance improvements and/or regressions.
|
| 6 |
+
|
| 7 |
+
Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
|
| 8 |
+
instance to measure the times it takes to sum across the 3rd axis of the above
|
| 9 |
+
tensor on the cpu.
|
| 10 |
+
|
| 11 |
+
`compare.py` runs several benchmarks and compares the speed-up or lack thereof
|
| 12 |
+
in comparison to PyTorch.
|
| 13 |
+
|
| 14 |
+
Each bench script can be run with `--print-pid` to print the PID and wait for a
|
| 15 |
+
key in order to ease attaching a debugger.
|
ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_mlx.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
import os
|
| 6 |
+
import time
|
| 7 |
+
from functools import partial
|
| 8 |
+
|
| 9 |
+
import mlx.core as mx
|
| 10 |
+
import mlx.nn as nn
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def int_or_list(x):
|
| 14 |
+
try:
|
| 15 |
+
return int(x)
|
| 16 |
+
except ValueError:
|
| 17 |
+
return [int(xi) for xi in x.split(",")]
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def none_or_list(x):
|
| 21 |
+
if x == "":
|
| 22 |
+
return None
|
| 23 |
+
else:
|
| 24 |
+
return [int(xi) for xi in x.split(",")]
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def dtype_from_str(x):
|
| 28 |
+
if x == "":
|
| 29 |
+
return mx.float32
|
| 30 |
+
else:
|
| 31 |
+
dt = getattr(mx, x)
|
| 32 |
+
if not isinstance(dt, mx.Dtype):
|
| 33 |
+
raise ValueError(f"{x} is not an mlx dtype")
|
| 34 |
+
return dt
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def bench(f, *args):
|
| 38 |
+
for i in range(10):
|
| 39 |
+
f(*args)
|
| 40 |
+
|
| 41 |
+
s = time.time()
|
| 42 |
+
for i in range(100):
|
| 43 |
+
f(*args)
|
| 44 |
+
e = time.time()
|
| 45 |
+
return e - s
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def matmul_square(x):
|
| 49 |
+
y = x
|
| 50 |
+
for i in range(10):
|
| 51 |
+
y = y @ x
|
| 52 |
+
mx.eval(y)
|
| 53 |
+
return y
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def matmul(x, y):
|
| 57 |
+
ys = []
|
| 58 |
+
for i in range(10):
|
| 59 |
+
ys.append(x @ y)
|
| 60 |
+
mx.eval(ys)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _quant_matmul(x, w, s, b, transpose, group_size, bits):
|
| 64 |
+
ys = []
|
| 65 |
+
for i in range(10):
|
| 66 |
+
ys.append(
|
| 67 |
+
mx.quantized_matmul(
|
| 68 |
+
x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
|
| 69 |
+
)
|
| 70 |
+
)
|
| 71 |
+
mx.eval(ys)
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
quant_matmul = {
|
| 75 |
+
"quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
|
| 76 |
+
"quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
|
| 77 |
+
"quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
|
| 78 |
+
"quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
|
| 79 |
+
"quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
|
| 80 |
+
"quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
|
| 81 |
+
"quant_matmul_128_2": partial(
|
| 82 |
+
_quant_matmul, transpose=False, group_size=128, bits=2
|
| 83 |
+
),
|
| 84 |
+
"quant_matmul_128_4": partial(
|
| 85 |
+
_quant_matmul, transpose=False, group_size=128, bits=4
|
| 86 |
+
),
|
| 87 |
+
"quant_matmul_128_8": partial(
|
| 88 |
+
_quant_matmul, transpose=False, group_size=128, bits=8
|
| 89 |
+
),
|
| 90 |
+
"quant_matmul_t_32_2": partial(
|
| 91 |
+
_quant_matmul, transpose=True, group_size=32, bits=2
|
| 92 |
+
),
|
| 93 |
+
"quant_matmul_t_32_4": partial(
|
| 94 |
+
_quant_matmul, transpose=True, group_size=32, bits=4
|
| 95 |
+
),
|
| 96 |
+
"quant_matmul_t_32_8": partial(
|
| 97 |
+
_quant_matmul, transpose=True, group_size=32, bits=8
|
| 98 |
+
),
|
| 99 |
+
"quant_matmul_t_64_2": partial(
|
| 100 |
+
_quant_matmul, transpose=True, group_size=64, bits=2
|
| 101 |
+
),
|
| 102 |
+
"quant_matmul_t_64_4": partial(
|
| 103 |
+
_quant_matmul, transpose=True, group_size=64, bits=4
|
| 104 |
+
),
|
| 105 |
+
"quant_matmul_t_64_8": partial(
|
| 106 |
+
_quant_matmul, transpose=True, group_size=64, bits=8
|
| 107 |
+
),
|
| 108 |
+
"quant_matmul_t_128_2": partial(
|
| 109 |
+
_quant_matmul, transpose=True, group_size=128, bits=2
|
| 110 |
+
),
|
| 111 |
+
"quant_matmul_t_128_4": partial(
|
| 112 |
+
_quant_matmul, transpose=True, group_size=128, bits=4
|
| 113 |
+
),
|
| 114 |
+
"quant_matmul_t_128_8": partial(
|
| 115 |
+
_quant_matmul, transpose=True, group_size=128, bits=8
|
| 116 |
+
),
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def conv1d(x, y):
|
| 121 |
+
ys = []
|
| 122 |
+
for i in range(10):
|
| 123 |
+
ys.append(mx.conv1d(x, y))
|
| 124 |
+
mx.eval(ys)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def conv2d(x, y):
|
| 128 |
+
ys = []
|
| 129 |
+
for i in range(10):
|
| 130 |
+
ys.append(mx.conv2d(x, y))
|
| 131 |
+
mx.eval(ys)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def binary(op, x, y):
|
| 135 |
+
for i in range(100):
|
| 136 |
+
y = getattr(mx, op)(x, y)
|
| 137 |
+
mx.eval(y)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def reduction(op, axis, x):
|
| 141 |
+
ys = []
|
| 142 |
+
for i in range(100):
|
| 143 |
+
ys.append(getattr(mx, op)(x, axis=axis))
|
| 144 |
+
mx.eval(ys)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
def sum_and_add(axis, x, y):
|
| 148 |
+
z = x.sum(axis=axis, keepdims=True)
|
| 149 |
+
for i in range(50):
|
| 150 |
+
z = (z + y).sum(axis=axis, keepdims=True)
|
| 151 |
+
mx.eval(z)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def softmax(axis, x):
|
| 155 |
+
ys = []
|
| 156 |
+
for i in range(100):
|
| 157 |
+
ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
|
| 158 |
+
y = ex / mx.sum(ex, axis=axis, keepdims=True)
|
| 159 |
+
ys.append(y)
|
| 160 |
+
mx.eval(ys)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def softmax_fused(axis, x):
|
| 164 |
+
ys = []
|
| 165 |
+
for i in range(100):
|
| 166 |
+
y = mx.softmax(x, axis=axis)
|
| 167 |
+
ys.append(y)
|
| 168 |
+
mx.eval(ys)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def relu(x):
|
| 172 |
+
y = x
|
| 173 |
+
for i in range(100):
|
| 174 |
+
y = nn.relu(y)
|
| 175 |
+
mx.eval(y)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def leaky_relu(x: mx.array):
|
| 179 |
+
y = x
|
| 180 |
+
for i in range(100):
|
| 181 |
+
y = nn.leaky_relu(y)
|
| 182 |
+
mx.eval(y)
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def prelu(x: mx.array):
|
| 186 |
+
y = x
|
| 187 |
+
for i in range(100):
|
| 188 |
+
y = nn.prelu(y, mx.ones(1))
|
| 189 |
+
mx.eval(y)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def softplus(x: mx.array):
|
| 193 |
+
y = x
|
| 194 |
+
for i in range(100):
|
| 195 |
+
y = nn.softplus(y)
|
| 196 |
+
mx.eval(y)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def mish(x: mx.array):
|
| 200 |
+
y = x
|
| 201 |
+
for i in range(100):
|
| 202 |
+
y = nn.mish(y)
|
| 203 |
+
mx.eval(y)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def leaky_relu(x):
|
| 207 |
+
y = x
|
| 208 |
+
for i in range(100):
|
| 209 |
+
y = nn.leaky_relu(y)
|
| 210 |
+
mx.eval(y)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def elu(x):
|
| 214 |
+
y = x
|
| 215 |
+
for i in range(100):
|
| 216 |
+
y = nn.elu(y)
|
| 217 |
+
mx.eval(y)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def relu6(x):
|
| 221 |
+
y = x
|
| 222 |
+
for i in range(100):
|
| 223 |
+
y = nn.relu6(y)
|
| 224 |
+
mx.eval(y)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def softplus(x):
|
| 228 |
+
y = x
|
| 229 |
+
for i in range(100):
|
| 230 |
+
y = nn.softplus(y)
|
| 231 |
+
mx.eval(y)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def celu(x):
|
| 235 |
+
y = x
|
| 236 |
+
for i in range(100):
|
| 237 |
+
y = nn.celu(y)
|
| 238 |
+
mx.eval(y)
|
| 239 |
+
|
| 240 |
+
|
| 241 |
+
def log_sigmoid(x):
|
| 242 |
+
y = x
|
| 243 |
+
for i in range(100):
|
| 244 |
+
y = nn.log_sigmoid(y)
|
| 245 |
+
mx.eval(y)
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def scalar_mult(x):
|
| 249 |
+
y = x
|
| 250 |
+
for i in range(100):
|
| 251 |
+
y = y * (1.0 / (1 + i))
|
| 252 |
+
mx.eval(y)
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def cross_entropy(targets, x):
|
| 256 |
+
ys = []
|
| 257 |
+
for i in range(100):
|
| 258 |
+
y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
|
| 259 |
+
x, mx.reshape(targets, (-1, 1)), axis=-1
|
| 260 |
+
)
|
| 261 |
+
ys.append(mx.mean(y))
|
| 262 |
+
mx.eval(ys)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def logsumexp(axis, x):
|
| 266 |
+
ys = []
|
| 267 |
+
for i in range(100):
|
| 268 |
+
ys.append(mx.logsumexp(x, axis=axis))
|
| 269 |
+
mx.eval(ys)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def linear(w, b, x):
|
| 273 |
+
ys = []
|
| 274 |
+
for i in range(10):
|
| 275 |
+
ys.append(x @ mx.transpose(w, (1, 0)) + b)
|
| 276 |
+
mx.eval(ys)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
def linear_fused(w, b, x):
|
| 280 |
+
ys = []
|
| 281 |
+
for i in range(10):
|
| 282 |
+
ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
|
| 283 |
+
mx.eval(ys)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
def rope(x):
|
| 287 |
+
*_, N, D = x.shape
|
| 288 |
+
ys = []
|
| 289 |
+
for i in range(10):
|
| 290 |
+
shape = x.shape
|
| 291 |
+
x = mx.reshape(x, (-1, N, D))
|
| 292 |
+
positions = mx.arange(N)
|
| 293 |
+
freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
|
| 294 |
+
theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
|
| 295 |
+
costheta = mx.cos(theta)
|
| 296 |
+
sintheta = mx.sin(theta)
|
| 297 |
+
x1 = x[..., ::2]
|
| 298 |
+
x2 = x[..., 1::2]
|
| 299 |
+
rx1 = x1 * costheta - x2 * sintheta
|
| 300 |
+
rx2 = x1 * sintheta + x2 * costheta
|
| 301 |
+
y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
|
| 302 |
+
y = mx.reshape(y, (-1, N, D))
|
| 303 |
+
ys.append(y)
|
| 304 |
+
mx.eval(ys)
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def concatenate(axis, x, y):
|
| 308 |
+
ys = []
|
| 309 |
+
for i in range(10):
|
| 310 |
+
ys.append(mx.concatenate([x, y], axis=axis))
|
| 311 |
+
mx.eval(ys)
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def cumsum(axis, x):
|
| 315 |
+
ys = []
|
| 316 |
+
for i in range(10):
|
| 317 |
+
ys.append(mx.cumsum(x, axis))
|
| 318 |
+
mx.eval(ys)
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
def sort(axis, x):
|
| 322 |
+
ys = []
|
| 323 |
+
for i in range(10):
|
| 324 |
+
ys.append(mx.sort(x, axis))
|
| 325 |
+
mx.eval(ys)
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def topk(axis, x):
|
| 329 |
+
k = x.shape[axis] // 3
|
| 330 |
+
ys = []
|
| 331 |
+
for i in range(10):
|
| 332 |
+
ys.append(mx.topk(x, k, axis))
|
| 333 |
+
mx.eval(ys)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
def step_function(x):
|
| 337 |
+
y = x
|
| 338 |
+
for i in range(100):
|
| 339 |
+
y = nn.step(x)
|
| 340 |
+
mx.eval(y)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def selu(x):
|
| 344 |
+
y = x
|
| 345 |
+
for i in range(100):
|
| 346 |
+
y = nn.selu(x)
|
| 347 |
+
mx.eval(y)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
if __name__ == "__main__":
|
| 351 |
+
parser = argparse.ArgumentParser()
|
| 352 |
+
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
| 353 |
+
parser.add_argument(
|
| 354 |
+
"--size",
|
| 355 |
+
default=[(1024, 1024)],
|
| 356 |
+
type=lambda x: list(map(int, x.split("x"))),
|
| 357 |
+
help="Set the matrix size",
|
| 358 |
+
action="append",
|
| 359 |
+
)
|
| 360 |
+
parser.add_argument(
|
| 361 |
+
"--axis",
|
| 362 |
+
default=[1],
|
| 363 |
+
type=int_or_list,
|
| 364 |
+
help="Set a reduction axis",
|
| 365 |
+
action="append",
|
| 366 |
+
)
|
| 367 |
+
parser.add_argument(
|
| 368 |
+
"--transpose",
|
| 369 |
+
type=none_or_list,
|
| 370 |
+
default=[],
|
| 371 |
+
help="Permute the matrix",
|
| 372 |
+
action="append",
|
| 373 |
+
)
|
| 374 |
+
parser.add_argument(
|
| 375 |
+
"--print-pid", action="store_true", help="Print the PID and pause"
|
| 376 |
+
)
|
| 377 |
+
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
|
| 378 |
+
parser.add_argument(
|
| 379 |
+
"--fused", action="store_true", help="Use fused functions where possible"
|
| 380 |
+
)
|
| 381 |
+
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
| 382 |
+
|
| 383 |
+
args = parser.parse_args()
|
| 384 |
+
|
| 385 |
+
if len(args.size) > 1:
|
| 386 |
+
args.size.pop(0)
|
| 387 |
+
if len(args.axis) > 1:
|
| 388 |
+
args.axis.pop(0)
|
| 389 |
+
|
| 390 |
+
if args.cpu:
|
| 391 |
+
mx.set_default_device(mx.cpu)
|
| 392 |
+
else:
|
| 393 |
+
mx.set_default_device(mx.gpu)
|
| 394 |
+
|
| 395 |
+
types = args.dtype
|
| 396 |
+
if not types:
|
| 397 |
+
types = [mx.float32]
|
| 398 |
+
if len(types) < len(args.size):
|
| 399 |
+
types = types + [types[0]] * (len(args.size) - len(types))
|
| 400 |
+
|
| 401 |
+
xs = []
|
| 402 |
+
for size, dtype in zip(args.size, types):
|
| 403 |
+
xs.append(mx.random.normal(size).astype(dtype))
|
| 404 |
+
for i, t in enumerate(args.transpose):
|
| 405 |
+
if t is None:
|
| 406 |
+
continue
|
| 407 |
+
xs[i] = mx.transpose(xs[i], t)
|
| 408 |
+
mx.eval(xs)
|
| 409 |
+
x = xs[0]
|
| 410 |
+
axis = args.axis[0]
|
| 411 |
+
|
| 412 |
+
if args.print_pid:
|
| 413 |
+
print(os.getpid())
|
| 414 |
+
input("Press enter to run")
|
| 415 |
+
|
| 416 |
+
if args.benchmark == "matmul_square":
|
| 417 |
+
print(bench(matmul_square, x))
|
| 418 |
+
|
| 419 |
+
elif args.benchmark == "matmul":
|
| 420 |
+
print(bench(matmul, *xs))
|
| 421 |
+
|
| 422 |
+
elif args.benchmark.startswith("quant_matmul"):
|
| 423 |
+
print(bench(quant_matmul[args.benchmark], *xs))
|
| 424 |
+
|
| 425 |
+
elif args.benchmark == "linear":
|
| 426 |
+
if args.fused:
|
| 427 |
+
print(bench(linear_fused, *xs))
|
| 428 |
+
else:
|
| 429 |
+
print(bench(linear, *xs))
|
| 430 |
+
|
| 431 |
+
elif args.benchmark == "sum_axis":
|
| 432 |
+
print(bench(reduction, "sum", axis, x))
|
| 433 |
+
|
| 434 |
+
elif args.benchmark == "sum_all":
|
| 435 |
+
print(bench(reduction, "sum", None, x))
|
| 436 |
+
|
| 437 |
+
elif args.benchmark == "argmax":
|
| 438 |
+
print(bench(reduction, "argmax", axis, x))
|
| 439 |
+
|
| 440 |
+
elif args.benchmark == "add":
|
| 441 |
+
print(bench(binary, "add", *xs))
|
| 442 |
+
|
| 443 |
+
elif args.benchmark == "mul":
|
| 444 |
+
print(bench(binary, "multiply", *xs))
|
| 445 |
+
|
| 446 |
+
elif args.benchmark == "softmax":
|
| 447 |
+
if args.fused:
|
| 448 |
+
print(bench(softmax_fused, axis, x))
|
| 449 |
+
else:
|
| 450 |
+
print(bench(softmax, axis, x))
|
| 451 |
+
|
| 452 |
+
elif args.benchmark == "relu":
|
| 453 |
+
print(bench(relu, x))
|
| 454 |
+
|
| 455 |
+
elif args.benchmark == "elu":
|
| 456 |
+
print(bench(elu, x))
|
| 457 |
+
|
| 458 |
+
elif args.benchmark == "relu6":
|
| 459 |
+
print(bench(relu6, x))
|
| 460 |
+
|
| 461 |
+
elif args.benchmark == "celu":
|
| 462 |
+
print(bench(celu, x))
|
| 463 |
+
|
| 464 |
+
elif args.benchmark == "log_sigmoid":
|
| 465 |
+
print(bench(log_sigmoid, x))
|
| 466 |
+
|
| 467 |
+
elif args.benchmark == "leaky_relu":
|
| 468 |
+
print(bench(leaky_relu, x))
|
| 469 |
+
elif args.benchmark == "prelu":
|
| 470 |
+
print(bench(prelu, x))
|
| 471 |
+
elif args.benchmark == "softplus":
|
| 472 |
+
print(bench(softplus, x))
|
| 473 |
+
elif args.benchmark == "mish":
|
| 474 |
+
print(bench(mish, x))
|
| 475 |
+
elif args.benchmark == "scalar_mul":
|
| 476 |
+
print(bench(scalar_mult, x))
|
| 477 |
+
|
| 478 |
+
elif args.benchmark == "cross_entropy":
|
| 479 |
+
if len(size) != 2:
|
| 480 |
+
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
|
| 481 |
+
|
| 482 |
+
targets = mx.zeros((len(x),), dtype=mx.uint32)
|
| 483 |
+
print(bench(cross_entropy, targets, x))
|
| 484 |
+
|
| 485 |
+
elif args.benchmark == "logsumexp":
|
| 486 |
+
print(bench(logsumexp, axis, x))
|
| 487 |
+
|
| 488 |
+
elif args.benchmark == "rope":
|
| 489 |
+
print(bench(rope, x))
|
| 490 |
+
|
| 491 |
+
elif args.benchmark == "concatenate":
|
| 492 |
+
print(bench(concatenate, axis, *xs))
|
| 493 |
+
|
| 494 |
+
elif args.benchmark == "cumsum":
|
| 495 |
+
print(bench(cumsum, axis, *xs))
|
| 496 |
+
|
| 497 |
+
elif args.benchmark == "conv1d":
|
| 498 |
+
print(bench(conv1d, *xs))
|
| 499 |
+
|
| 500 |
+
elif args.benchmark == "conv2d":
|
| 501 |
+
print(bench(conv2d, *xs))
|
| 502 |
+
|
| 503 |
+
elif args.benchmark == "sort":
|
| 504 |
+
print(bench(sort, axis, x))
|
| 505 |
+
|
| 506 |
+
elif args.benchmark == "topk":
|
| 507 |
+
print(bench(topk, axis, x))
|
| 508 |
+
|
| 509 |
+
elif args.benchmark == "step":
|
| 510 |
+
print(bench(step_function, x))
|
| 511 |
+
|
| 512 |
+
elif args.benchmark == "selu":
|
| 513 |
+
print(bench(selu, x))
|
| 514 |
+
|
| 515 |
+
elif args.benchmark == "sum_and_add":
|
| 516 |
+
print(bench(sum_and_add, axis, *xs))
|
| 517 |
+
|
| 518 |
+
else:
|
| 519 |
+
raise ValueError("Unknown benchmark")
|
ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_torch.py
ADDED
|
@@ -0,0 +1,482 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import os
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.cuda
|
| 9 |
+
import torch.mps
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def int_or_list(x):
|
| 13 |
+
try:
|
| 14 |
+
return int(x)
|
| 15 |
+
except ValueError:
|
| 16 |
+
return [int(xi) for xi in x.split(",")]
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def none_or_list(x):
|
| 20 |
+
if x == "":
|
| 21 |
+
return None
|
| 22 |
+
else:
|
| 23 |
+
return [int(xi) for xi in x.split(",")]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def dtype_from_str(x):
|
| 27 |
+
if x == "":
|
| 28 |
+
return torch.float32
|
| 29 |
+
else:
|
| 30 |
+
dt = getattr(torch, x)
|
| 31 |
+
if not isinstance(dt, torch.dtype):
|
| 32 |
+
raise ValueError(f"{x} is not a torch dtype")
|
| 33 |
+
return dt
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def bench(f, *args):
|
| 37 |
+
for i in range(10):
|
| 38 |
+
f(*args)
|
| 39 |
+
|
| 40 |
+
s = time.time()
|
| 41 |
+
for i in range(100):
|
| 42 |
+
f(*args)
|
| 43 |
+
e = time.time()
|
| 44 |
+
return e - s
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def sync_if_needed(x):
|
| 48 |
+
if x.device == torch.device("mps"):
|
| 49 |
+
torch.mps.synchronize()
|
| 50 |
+
elif x.device == torch.device("cuda"):
|
| 51 |
+
torch.cuda.synchronize()
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@torch.no_grad()
|
| 55 |
+
def matmul_square(x):
|
| 56 |
+
y = x
|
| 57 |
+
for i in range(10):
|
| 58 |
+
y = y @ x
|
| 59 |
+
sync_if_needed(x)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@torch.no_grad()
|
| 63 |
+
def matmul(x, y):
|
| 64 |
+
ys = []
|
| 65 |
+
for i in range(10):
|
| 66 |
+
ys.append(x @ y)
|
| 67 |
+
sync_if_needed(x)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
@torch.no_grad()
|
| 71 |
+
def conv1d(x, y):
|
| 72 |
+
x = torch.transpose(x, -1, -2)
|
| 73 |
+
y = torch.transpose(y, -1, -2)
|
| 74 |
+
ys = []
|
| 75 |
+
for i in range(10):
|
| 76 |
+
ys.append(torch.nn.functional.conv1d(x, y))
|
| 77 |
+
sync_if_needed(x)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@torch.no_grad()
|
| 81 |
+
def conv2d(x, y):
|
| 82 |
+
x = torch.permute(x, (0, 3, 1, 2))
|
| 83 |
+
y = torch.permute(y, (0, 3, 1, 2))
|
| 84 |
+
ys = []
|
| 85 |
+
for i in range(10):
|
| 86 |
+
ys.append(torch.nn.functional.conv2d(x, y))
|
| 87 |
+
sync_if_needed(x)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
@torch.no_grad()
|
| 91 |
+
def binary(op, x, y):
|
| 92 |
+
for i in range(100):
|
| 93 |
+
y = getattr(torch, op)(x, y)
|
| 94 |
+
sync_if_needed(x)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@torch.no_grad()
|
| 98 |
+
def reduction(op, axis, x):
|
| 99 |
+
ys = []
|
| 100 |
+
for i in range(100):
|
| 101 |
+
ys.append(getattr(x, op)(axis))
|
| 102 |
+
sync_if_needed(x)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@torch.no_grad()
|
| 106 |
+
def sum_and_add(axis, x, y):
|
| 107 |
+
z = x.sum(axis=axis, keepdims=True)
|
| 108 |
+
for i in range(50):
|
| 109 |
+
z = (z + y).sum(axis=axis, keepdims=True)
|
| 110 |
+
sync_if_needed(x)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
@torch.no_grad()
|
| 114 |
+
def softmax(axis, x):
|
| 115 |
+
ys = []
|
| 116 |
+
for i in range(100):
|
| 117 |
+
ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)
|
| 118 |
+
y = ex / torch.sum(ex, dim=axis, keepdims=True)
|
| 119 |
+
ys.append(y)
|
| 120 |
+
sync_if_needed(x)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def softmax_fused(axis, x):
|
| 125 |
+
ys = []
|
| 126 |
+
for i in range(100):
|
| 127 |
+
ys.append(torch.nn.functional.softmax(x, dim=axis))
|
| 128 |
+
sync_if_needed(x)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
@torch.no_grad()
|
| 132 |
+
def relu(x):
|
| 133 |
+
y = x
|
| 134 |
+
for i in range(100):
|
| 135 |
+
y = torch.nn.functional.relu(y)
|
| 136 |
+
sync_if_needed(x)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
@torch.no_grad()
|
| 140 |
+
def leaky_relu(x):
|
| 141 |
+
y = x
|
| 142 |
+
for i in range(100):
|
| 143 |
+
y = torch.nn.functional.leaky_relu(y)
|
| 144 |
+
sync_if_needed(x)
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
@torch.no_grad()
|
| 148 |
+
def elu(x):
|
| 149 |
+
y = x
|
| 150 |
+
for i in range(100):
|
| 151 |
+
y = torch.nn.functional.elu(y)
|
| 152 |
+
sync_if_needed(x)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@torch.no_grad()
|
| 156 |
+
def celu(x):
|
| 157 |
+
y = x
|
| 158 |
+
for i in range(100):
|
| 159 |
+
y = torch.nn.functional.celu(y)
|
| 160 |
+
sync_if_needed(x)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
@torch.no_grad()
|
| 164 |
+
def relu6(x):
|
| 165 |
+
y = x
|
| 166 |
+
for i in range(100):
|
| 167 |
+
y = torch.nn.functional.relu6(y)
|
| 168 |
+
sync_if_needed(x)
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
@torch.no_grad()
|
| 172 |
+
def softplus(x):
|
| 173 |
+
y = x
|
| 174 |
+
for i in range(100):
|
| 175 |
+
y = torch.nn.functional.softplus(y)
|
| 176 |
+
sync_if_needed(x)
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
@torch.no_grad()
|
| 180 |
+
def log_sigmoid(x):
|
| 181 |
+
y = x
|
| 182 |
+
for i in range(100):
|
| 183 |
+
y = torch.nn.functional.logsigmoid(y)
|
| 184 |
+
sync_if_needed(x)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
@torch.no_grad()
|
| 188 |
+
def prelu(x: torch.Tensor) -> torch.Tensor:
|
| 189 |
+
y = x
|
| 190 |
+
for _ in range(100):
|
| 191 |
+
y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
|
| 192 |
+
sync_if_needed(x)
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def mish(x: torch.Tensor) -> torch.Tensor:
|
| 197 |
+
y = x
|
| 198 |
+
for _ in range(100):
|
| 199 |
+
y = torch.nn.functional.mish(y)
|
| 200 |
+
sync_if_needed(x)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@torch.no_grad()
|
| 204 |
+
def scalar_mult(x):
|
| 205 |
+
y = x
|
| 206 |
+
for i in range(100):
|
| 207 |
+
y = y * (1.0 / (1 + i))
|
| 208 |
+
sync_if_needed(x)
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
@torch.no_grad()
|
| 212 |
+
def cross_entropy(targets, x):
|
| 213 |
+
ys = []
|
| 214 |
+
for i in range(100):
|
| 215 |
+
ys.append(torch.nn.functional.cross_entropy(x, targets))
|
| 216 |
+
sync_if_needed(x)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@torch.no_grad()
|
| 220 |
+
def logsumexp(axis, x):
|
| 221 |
+
ys = []
|
| 222 |
+
for i in range(100):
|
| 223 |
+
ys.append(torch.logsumexp(x, dim=axis))
|
| 224 |
+
sync_if_needed(x)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
@torch.no_grad()
|
| 228 |
+
def linear_fused(w, b, x):
|
| 229 |
+
ys = []
|
| 230 |
+
for i in range(10):
|
| 231 |
+
ys.append(torch.nn.functional.linear(x, w, b))
|
| 232 |
+
sync_if_needed(x)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
@torch.no_grad()
|
| 236 |
+
def linear(w, b, x):
|
| 237 |
+
ys = []
|
| 238 |
+
for i in range(10):
|
| 239 |
+
ys.append((x @ torch.transpose(w, -2, -1)) + b)
|
| 240 |
+
sync_if_needed(x)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@torch.no_grad()
|
| 244 |
+
def rope(x):
|
| 245 |
+
*_, N, D = x.shape
|
| 246 |
+
ys = []
|
| 247 |
+
for i in range(10):
|
| 248 |
+
x = x.view(-1, N, D)
|
| 249 |
+
positions = torch.arange(N, device=x.device)
|
| 250 |
+
freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)
|
| 251 |
+
theta = positions[:, None] * freqs[None]
|
| 252 |
+
costheta = torch.cos(theta)
|
| 253 |
+
sintheta = torch.sin(theta)
|
| 254 |
+
x1 = x[..., ::2]
|
| 255 |
+
x2 = x[..., 1::2]
|
| 256 |
+
rx1 = x1 * costheta - x2 * sintheta
|
| 257 |
+
rx2 = x1 * sintheta + x2 * costheta
|
| 258 |
+
y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
|
| 259 |
+
y = y.reshape(-1, N, D)
|
| 260 |
+
ys.append(y)
|
| 261 |
+
sync_if_needed(x)
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
@torch.no_grad()
|
| 265 |
+
def concatenate(axis, x, y):
|
| 266 |
+
ys = []
|
| 267 |
+
for i in range(10):
|
| 268 |
+
ys.append(torch.cat([x, y], dim=axis))
|
| 269 |
+
sync_if_needed(x)
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
@torch.no_grad()
|
| 273 |
+
def cumsum(axis, x):
|
| 274 |
+
ys = []
|
| 275 |
+
for i in range(10):
|
| 276 |
+
ys.append(x.cumsum(axis))
|
| 277 |
+
sync_if_needed(x)
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@torch.no_grad()
|
| 281 |
+
def sort(axis, x):
|
| 282 |
+
ys = []
|
| 283 |
+
for i in range(10):
|
| 284 |
+
ys.append(torch.sort(x, dim=axis)[0])
|
| 285 |
+
sync_if_needed(x)
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
@torch.no_grad()
|
| 289 |
+
def topk(axis, x):
|
| 290 |
+
k = x.shape[axis] // 3
|
| 291 |
+
ys = []
|
| 292 |
+
for i in range(10):
|
| 293 |
+
ys.append(torch.topk(x, k, dim=axis)[0])
|
| 294 |
+
sync_if_needed(x)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
@torch.no_grad()
|
| 298 |
+
def step_function(x):
|
| 299 |
+
y = x
|
| 300 |
+
for i in range(100):
|
| 301 |
+
y = torch.where(y < 0, 0, 1)
|
| 302 |
+
sync_if_needed(x)
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
@torch.no_grad()
|
| 306 |
+
def selu(x):
|
| 307 |
+
y = x
|
| 308 |
+
for i in range(100):
|
| 309 |
+
y = torch.nn.functional.selu(y)
|
| 310 |
+
sync_if_needed(x)
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
if __name__ == "__main__":
|
| 314 |
+
parser = argparse.ArgumentParser()
|
| 315 |
+
parser.add_argument("benchmark", help="Choose the benchmark to run")
|
| 316 |
+
parser.add_argument(
|
| 317 |
+
"--size",
|
| 318 |
+
default=[(1024, 1024)],
|
| 319 |
+
type=lambda x: list(map(int, x.split("x"))),
|
| 320 |
+
help="Set the matrix size",
|
| 321 |
+
action="append",
|
| 322 |
+
)
|
| 323 |
+
parser.add_argument(
|
| 324 |
+
"--axis",
|
| 325 |
+
default=[1],
|
| 326 |
+
type=int_or_list,
|
| 327 |
+
help="Set a reduction axis",
|
| 328 |
+
action="append",
|
| 329 |
+
)
|
| 330 |
+
parser.add_argument(
|
| 331 |
+
"--transpose",
|
| 332 |
+
type=none_or_list,
|
| 333 |
+
default=[],
|
| 334 |
+
help="Permute the matrix",
|
| 335 |
+
action="append",
|
| 336 |
+
)
|
| 337 |
+
parser.add_argument(
|
| 338 |
+
"--print-pid", action="store_true", help="Print the PID and pause"
|
| 339 |
+
)
|
| 340 |
+
parser.add_argument("--cpu", action="store_true", help="Use the CPU")
|
| 341 |
+
parser.add_argument(
|
| 342 |
+
"--fused", action="store_true", help="Use fused functions where possible"
|
| 343 |
+
)
|
| 344 |
+
parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
|
| 345 |
+
|
| 346 |
+
args = parser.parse_args()
|
| 347 |
+
|
| 348 |
+
if len(args.size) > 1:
|
| 349 |
+
args.size.pop(0)
|
| 350 |
+
if len(args.axis) > 1:
|
| 351 |
+
args.axis.pop(0)
|
| 352 |
+
|
| 353 |
+
torch.set_num_threads(1)
|
| 354 |
+
device = "mps"
|
| 355 |
+
if torch.cuda.is_available():
|
| 356 |
+
device = "cuda"
|
| 357 |
+
if args.cpu:
|
| 358 |
+
device = "cpu"
|
| 359 |
+
|
| 360 |
+
types = args.dtype
|
| 361 |
+
if not types:
|
| 362 |
+
types = [torch.float32]
|
| 363 |
+
if len(types) < len(args.size):
|
| 364 |
+
types = types + [types[0]] * (len(args.size) - len(types))
|
| 365 |
+
|
| 366 |
+
xs = []
|
| 367 |
+
for size, dtype in zip(args.size, types):
|
| 368 |
+
xs.append(torch.randn(*size).to(device).to(dtype))
|
| 369 |
+
for i, t in enumerate(args.transpose):
|
| 370 |
+
if t is None:
|
| 371 |
+
continue
|
| 372 |
+
xs[i] = xs[i].permute(*t)
|
| 373 |
+
x = xs[0]
|
| 374 |
+
axis = args.axis[0]
|
| 375 |
+
|
| 376 |
+
if args.print_pid:
|
| 377 |
+
print(os.getpid())
|
| 378 |
+
input("Press enter to run")
|
| 379 |
+
|
| 380 |
+
if args.benchmark == "matmul_square":
|
| 381 |
+
print(bench(matmul_square, x))
|
| 382 |
+
|
| 383 |
+
elif args.benchmark == "matmul":
|
| 384 |
+
print(bench(matmul, *xs))
|
| 385 |
+
|
| 386 |
+
elif args.benchmark == "linear":
|
| 387 |
+
if args.fused:
|
| 388 |
+
print(bench(linear_fused, *xs))
|
| 389 |
+
else:
|
| 390 |
+
print(bench(linear, *xs))
|
| 391 |
+
|
| 392 |
+
elif args.benchmark == "sum_axis":
|
| 393 |
+
print(bench(reduction, "sum", axis, x))
|
| 394 |
+
|
| 395 |
+
elif args.benchmark == "sum_all":
|
| 396 |
+
print(bench(reduction, "sum", None, x))
|
| 397 |
+
|
| 398 |
+
elif args.benchmark == "argmax":
|
| 399 |
+
print(bench(reduction, "argmax", axis, x))
|
| 400 |
+
|
| 401 |
+
elif args.benchmark == "add":
|
| 402 |
+
print(bench(binary, "add", *xs))
|
| 403 |
+
|
| 404 |
+
elif args.benchmark == "mul":
|
| 405 |
+
print(bench(binary, "mul", *xs))
|
| 406 |
+
|
| 407 |
+
elif args.benchmark == "softmax":
|
| 408 |
+
if args.fused:
|
| 409 |
+
print(bench(softmax_fused, axis, x))
|
| 410 |
+
else:
|
| 411 |
+
print(bench(softmax, axis, x))
|
| 412 |
+
|
| 413 |
+
elif args.benchmark == "relu":
|
| 414 |
+
print(bench(relu, x))
|
| 415 |
+
|
| 416 |
+
elif args.benchmark == "leaky_relu":
|
| 417 |
+
print(bench(leaky_relu, x))
|
| 418 |
+
|
| 419 |
+
elif args.benchmark == "elu":
|
| 420 |
+
print(bench(elu, x))
|
| 421 |
+
|
| 422 |
+
elif args.benchmark == "relu6":
|
| 423 |
+
print(bench(relu6, x))
|
| 424 |
+
|
| 425 |
+
elif args.benchmark == "softplus":
|
| 426 |
+
print(bench(softplus, x))
|
| 427 |
+
|
| 428 |
+
elif args.benchmark == "celu":
|
| 429 |
+
print(bench(celu, x))
|
| 430 |
+
|
| 431 |
+
elif args.benchmark == "log_sigmoid":
|
| 432 |
+
print(bench(log_sigmoid, x))
|
| 433 |
+
|
| 434 |
+
elif args.benchmark == "prelu":
|
| 435 |
+
print(bench(prelu, x))
|
| 436 |
+
elif args.benchmark == "mish":
|
| 437 |
+
print(bench(mish, x))
|
| 438 |
+
elif args.benchmark == "scalar_mul":
|
| 439 |
+
print(bench(scalar_mult, x))
|
| 440 |
+
|
| 441 |
+
elif args.benchmark == "cross_entropy":
|
| 442 |
+
if len(size) != 2:
|
| 443 |
+
raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
|
| 444 |
+
|
| 445 |
+
targets = torch.zeros(len(x), dtype=torch.long).to(x.device)
|
| 446 |
+
print(bench(cross_entropy, targets, x))
|
| 447 |
+
|
| 448 |
+
elif args.benchmark == "logsumexp":
|
| 449 |
+
print(bench(logsumexp, axis, x))
|
| 450 |
+
|
| 451 |
+
elif args.benchmark == "rope":
|
| 452 |
+
print(bench(rope, x))
|
| 453 |
+
|
| 454 |
+
elif args.benchmark == "concatenate":
|
| 455 |
+
print(bench(concatenate, axis, *xs))
|
| 456 |
+
|
| 457 |
+
elif args.benchmark == "cumsum":
|
| 458 |
+
print(bench(cumsum, axis, *xs))
|
| 459 |
+
|
| 460 |
+
elif args.benchmark == "conv1d":
|
| 461 |
+
print(bench(conv1d, *xs))
|
| 462 |
+
|
| 463 |
+
elif args.benchmark == "conv2d":
|
| 464 |
+
print(bench(conv2d, *xs))
|
| 465 |
+
|
| 466 |
+
elif args.benchmark == "sort":
|
| 467 |
+
print(bench(sort, axis, x))
|
| 468 |
+
|
| 469 |
+
elif args.benchmark == "topk":
|
| 470 |
+
print(bench(topk, axis, x))
|
| 471 |
+
|
| 472 |
+
elif args.benchmark == "step":
|
| 473 |
+
print(bench(step_function, x))
|
| 474 |
+
|
| 475 |
+
elif args.benchmark == "selu":
|
| 476 |
+
print(bench(selu, x))
|
| 477 |
+
|
| 478 |
+
elif args.benchmark == "sum_and_add":
|
| 479 |
+
print(bench(sum_and_add, axis, *xs))
|
| 480 |
+
|
| 481 |
+
else:
|
| 482 |
+
raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
|
ml-stable-diffusion/mlx/benchmarks/python/comparative/compare.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023 Apple Inc.
|
| 2 |
+
|
| 3 |
+
#!/usr/bin/env python
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import re
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from subprocess import run
|
| 9 |
+
|
| 10 |
+
BENCH_MLX = Path(__file__).parent / "bench_mlx.py"
|
| 11 |
+
BENCH_TORCH = Path(__file__).parent / "bench_torch.py"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def run_or_raise(*args, **kwargs):
|
| 15 |
+
try:
|
| 16 |
+
result = run(*args, capture_output=True, **kwargs)
|
| 17 |
+
return float(result.stdout)
|
| 18 |
+
except ValueError:
|
| 19 |
+
raise ValueError(
|
| 20 |
+
f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def compare(args):
|
| 25 |
+
t_mlx = run_or_raise(["python", BENCH_MLX] + args)
|
| 26 |
+
t_torch = run_or_raise(["python", BENCH_TORCH] + args)
|
| 27 |
+
|
| 28 |
+
print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def compare_mlx_dtypes(args, dt1, dt2):
|
| 32 |
+
t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1])
|
| 33 |
+
t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2])
|
| 34 |
+
|
| 35 |
+
print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t")
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def make_regex_search(regexes):
|
| 39 |
+
compiled_regexes = list(map(re.compile, regexes))
|
| 40 |
+
|
| 41 |
+
def search(x):
|
| 42 |
+
return (c.search(x) is not None for c in compiled_regexes)
|
| 43 |
+
|
| 44 |
+
return search
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def make_predicate(positive_filter, negative_filter):
|
| 48 |
+
if positive_filter is not None:
|
| 49 |
+
positive_filter_search = make_regex_search(positive_filter)
|
| 50 |
+
positive_filter = lambda x: all(positive_filter_search(x))
|
| 51 |
+
else:
|
| 52 |
+
positive_filter = lambda x: True
|
| 53 |
+
|
| 54 |
+
if negative_filter is not None:
|
| 55 |
+
negative_filter_search = make_regex_search(negative_filter)
|
| 56 |
+
negative_filter = lambda x: not any(negative_filter_search(x))
|
| 57 |
+
else:
|
| 58 |
+
negative_filter = lambda x: True
|
| 59 |
+
|
| 60 |
+
def predicate(x):
|
| 61 |
+
return positive_filter(x) and negative_filter(x)
|
| 62 |
+
|
| 63 |
+
return predicate
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
if __name__ == "__main__":
|
| 67 |
+
parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
|
| 68 |
+
parser.add_argument(
|
| 69 |
+
"--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
|
| 70 |
+
)
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+"
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--mlx_dtypes",
|
| 76 |
+
"-d",
|
| 77 |
+
help="Compare mlx benchmarks between the 2 provided data types",
|
| 78 |
+
nargs=2,
|
| 79 |
+
)
|
| 80 |
+
args, rest = parser.parse_known_args()
|
| 81 |
+
|
| 82 |
+
_filter = make_predicate(args.filter, args.negative_filter)
|
| 83 |
+
|
| 84 |
+
if args.mlx_dtypes:
|
| 85 |
+
compare_filtered = lambda x: (
|
| 86 |
+
compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
|
| 87 |
+
if _filter(x)
|
| 88 |
+
else None
|
| 89 |
+
)
|
| 90 |
+
else:
|
| 91 |
+
compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None
|
| 92 |
+
|
| 93 |
+
# Binary ops
|
| 94 |
+
compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu")
|
| 95 |
+
compare_filtered("add --size 10x1024x128 --size 1x1024x128")
|
| 96 |
+
compare_filtered("add --size 1024x128 --size 1x128 --cpu")
|
| 97 |
+
compare_filtered("add --size 1024x128 --size 1x128")
|
| 98 |
+
compare_filtered("add --size 1024x4096 --size 1x4096 --cpu")
|
| 99 |
+
compare_filtered("add --size 1024x4096 --size 1x4096")
|
| 100 |
+
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu")
|
| 101 |
+
compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0")
|
| 102 |
+
compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu")
|
| 103 |
+
compare_filtered("add --size 1024x1024 --size 1024x1024")
|
| 104 |
+
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu")
|
| 105 |
+
compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0")
|
| 106 |
+
compare_filtered(
|
| 107 |
+
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu"
|
| 108 |
+
)
|
| 109 |
+
compare_filtered(
|
| 110 |
+
"add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0"
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# Reduction ops
|
| 114 |
+
compare_filtered("sum_all --size 10x1024x128 --cpu")
|
| 115 |
+
compare_filtered("sum_all --size 10x1024x128")
|
| 116 |
+
compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu")
|
| 117 |
+
compare_filtered("sum_axis --size 16x1024x128 --axis 2")
|
| 118 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu")
|
| 119 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 2")
|
| 120 |
+
compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu")
|
| 121 |
+
compare_filtered("sum_axis --size 1024x1024 --axis 1")
|
| 122 |
+
compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu")
|
| 123 |
+
compare_filtered("sum_axis --size 1024x1024 --axis 0")
|
| 124 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu")
|
| 125 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 1")
|
| 126 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
|
| 127 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0")
|
| 128 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
|
| 129 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
|
| 130 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
|
| 131 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
|
| 132 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
|
| 133 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
|
| 134 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
|
| 135 |
+
compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
|
| 136 |
+
compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
|
| 137 |
+
compare_filtered("argmax --size 10x1024x128 --axis 1")
|
| 138 |
+
compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
|
| 139 |
+
compare_filtered("argmax --size 10x1024x128 --axis 2")
|
| 140 |
+
compare_filtered("argmax --size 1024x1024 --axis 1 --cpu")
|
| 141 |
+
compare_filtered("argmax --size 1024x1024 --axis 1")
|
| 142 |
+
|
| 143 |
+
# Matmul ops
|
| 144 |
+
compare_filtered("matmul_square --size 1024x1024")
|
| 145 |
+
compare_filtered("matmul_square --size 1024x1024 --cpu")
|
| 146 |
+
compare_filtered("matmul_square --size 16x1024x1024")
|
| 147 |
+
compare_filtered("matmul_square --size 16x1024x1024 --cpu")
|
| 148 |
+
compare_filtered(
|
| 149 |
+
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1"
|
| 150 |
+
)
|
| 151 |
+
compare_filtered(
|
| 152 |
+
"matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu"
|
| 153 |
+
)
|
| 154 |
+
compare_filtered(
|
| 155 |
+
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1"
|
| 156 |
+
)
|
| 157 |
+
compare_filtered(
|
| 158 |
+
"matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu"
|
| 159 |
+
)
|
| 160 |
+
compare_filtered("matmul --size 512x8192 --size 8192x512")
|
| 161 |
+
compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu")
|
| 162 |
+
# compare_filtered("matmul --size 512x131072 --size 131072x512")
|
| 163 |
+
# compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu")
|
| 164 |
+
compare_filtered("matmul --size 8192x512 --size 512x8192")
|
| 165 |
+
compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu")
|
| 166 |
+
# compare_filtered("matmul --size 131072x512 --size 512x512")
|
| 167 |
+
# compare_filtered("matmul --size 131072x512 --size 512x512 --cpu")
|
| 168 |
+
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024")
|
| 169 |
+
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu")
|
| 170 |
+
compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused")
|
| 171 |
+
compare_filtered(
|
| 172 |
+
"linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# Matvec ops
|
| 176 |
+
compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu")
|
| 177 |
+
compare_filtered("matmul --size 1x1x4096 --size 4096x4096")
|
| 178 |
+
compare_filtered(
|
| 179 |
+
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu"
|
| 180 |
+
)
|
| 181 |
+
compare_filtered(
|
| 182 |
+
"matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0"
|
| 183 |
+
)
|
| 184 |
+
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu")
|
| 185 |
+
compare_filtered("matmul --size 32x1x1000 --size 32x1000x128")
|
| 186 |
+
compare_filtered(
|
| 187 |
+
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu"
|
| 188 |
+
)
|
| 189 |
+
compare_filtered(
|
| 190 |
+
"matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1"
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# Various ops
|
| 194 |
+
compare_filtered("softmax --size 32x16x1024 --axis 2")
|
| 195 |
+
compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu")
|
| 196 |
+
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused")
|
| 197 |
+
compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu")
|
| 198 |
+
compare_filtered("softmax --size 2x1024x1024 --axis 1")
|
| 199 |
+
compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu")
|
| 200 |
+
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused")
|
| 201 |
+
compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
|
| 202 |
+
compare_filtered("relu --size 32x16x1024")
|
| 203 |
+
compare_filtered("relu --size 32x16x1024 --cpu")
|
| 204 |
+
compare_filtered("leaky_relu --size 32x16x1024")
|
| 205 |
+
compare_filtered("leaky_relu --size 32x16x1024 --cpu")
|
| 206 |
+
compare_filtered("elu --size 32x16x1024")
|
| 207 |
+
compare_filtered("elu --size 32x16x1024 --cpu")
|
| 208 |
+
compare_filtered("relu6 --size 32x16x1024")
|
| 209 |
+
compare_filtered("relu6 --size 32x16x1024 --cpu")
|
| 210 |
+
compare_filtered("softplus --size 32x16x1024")
|
| 211 |
+
compare_filtered("softplus --size 32x16x1024 --cpu")
|
| 212 |
+
compare_filtered("celu --size 32x16x1024")
|
| 213 |
+
compare_filtered("celu --size 32x16x1024 --cpu")
|
| 214 |
+
compare_filtered("log_sigmoid --size 32x16x1024")
|
| 215 |
+
compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
|
| 216 |
+
compare_filtered("step --size 32x16x1024")
|
| 217 |
+
compare_filtered("step --size 32x16x1024 --cpu")
|
| 218 |
+
compare_filtered("selu --size 32x16x1024")
|
| 219 |
+
compare_filtered("selu --size 32x16x1024 --cpu")
|
| 220 |
+
# compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
|
| 221 |
+
compare_filtered("mish --size 32x16x1024 --cpu")
|
| 222 |
+
compare_filtered("prelu --size 32x16x1024")
|
| 223 |
+
compare_filtered("prelu --size 32x16x1024 --cpu")
|
| 224 |
+
|
| 225 |
+
compare_filtered("scalar_mul --size 32x16x1024")
|
| 226 |
+
compare_filtered("scalar_mul --size 32x16x1024 --cpu")
|
| 227 |
+
compare_filtered("cross_entropy --size 256x1024")
|
| 228 |
+
compare_filtered("cross_entropy --size 256x1024 --cpu")
|
| 229 |
+
compare_filtered("logsumexp --size 1024x1024 --axis 1")
|
| 230 |
+
compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu")
|
| 231 |
+
compare_filtered("logsumexp --size 1024x1024 --axis 0")
|
| 232 |
+
compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu")
|
| 233 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2")
|
| 234 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu")
|
| 235 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1")
|
| 236 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu")
|
| 237 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0")
|
| 238 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu")
|
| 239 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1")
|
| 240 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu")
|
| 241 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1")
|
| 242 |
+
compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu")
|
| 243 |
+
compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2")
|
| 244 |
+
compare_filtered(
|
| 245 |
+
"concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu"
|
| 246 |
+
)
|
| 247 |
+
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80")
|
| 248 |
+
compare_filtered("conv1d --size 1x1000x80 --size 128x11x80 --cpu")
|
| 249 |
+
compare_filtered("conv1d --size 16x1000x80 --size 128x11x80")
|
| 250 |
+
compare_filtered("conv1d --size 4x1000x80 --size 128x11x80 --cpu")
|
| 251 |
+
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3")
|
| 252 |
+
compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu")
|
| 253 |
+
compare_filtered("conv2d --size 16x256x256x3 --size 8x3x3x3")
|
| 254 |
+
compare_filtered("conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu")
|
| 255 |
+
compare_filtered("cumsum --size 1024x1024 --axis 1 --cpu")
|
| 256 |
+
compare_filtered("cumsum --size 1024x1024 --axis 0 --cpu")
|
| 257 |
+
compare_filtered("cumsum --size 1024x1024 --axis 1")
|
| 258 |
+
compare_filtered("cumsum --size 1024x1024 --axis 0")
|
| 259 |
+
compare_filtered("cumsum --size 128x1024 --axis 1")
|
| 260 |
+
compare_filtered("cumsum --size 128x1024 --axis 0")
|
| 261 |
+
compare_filtered("cumsum --size 1024x4096 --axis 1")
|
| 262 |
+
compare_filtered("cumsum --size 1024x4096 --axis 0")
|
| 263 |
+
compare_filtered("cumsum --size 128x4096 --axis 1")
|
| 264 |
+
compare_filtered("cumsum --size 128x4096 --axis 0")
|
| 265 |
+
compare_filtered("cumsum --size 1024x7777 --axis 1")
|
| 266 |
+
compare_filtered("cumsum --size 1024x7777 --axis 0")
|
| 267 |
+
compare_filtered("cumsum --size 128x7777 --axis 1")
|
| 268 |
+
compare_filtered("cumsum --size 128x7777 --axis 0")
|
| 269 |
+
compare_filtered("cumsum --size 32768x128 --axis 1")
|
| 270 |
+
compare_filtered("cumsum --size 32768x128 --axis 0")
|
| 271 |
+
|
| 272 |
+
compare_filtered("sort --size 1024x1024 --axis 0")
|
| 273 |
+
compare_filtered("sort --size 1024x1024 --axis 1")
|
| 274 |
+
compare_filtered("sort --size 32768x128 --axis 0")
|
| 275 |
+
compare_filtered("sort --size 32768x128 --axis 1")
|
| 276 |
+
compare_filtered("sort --size 128x128 --axis 0 --cpu")
|
| 277 |
+
compare_filtered("sort --size 128x128 --axis 1 --cpu")
|
| 278 |
+
|
| 279 |
+
compare_filtered("topk --size 1024x1024 --axis 0")
|
| 280 |
+
compare_filtered("topk --size 1024x1024 --axis 1")
|
| 281 |
+
compare_filtered("topk --size 32768x128 --axis 0")
|
| 282 |
+
compare_filtered("topk --size 32768x128 --axis 1")
|
| 283 |
+
compare_filtered("topk --size 128x128 --axis 0 --cpu")
|
| 284 |
+
compare_filtered("topk --size 128x128 --axis 1 --cpu")
|
ml-stable-diffusion/mlx/benchmarks/python/compile_bench.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023-2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
|
| 7 |
+
import mlx.core as mx
|
| 8 |
+
from time_utils import time_fn
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def bench_gelu():
|
| 12 |
+
def gelu(x):
|
| 13 |
+
return x * (1 + mx.erf(x / math.sqrt(2))) / 2
|
| 14 |
+
|
| 15 |
+
x = mx.random.uniform(shape=(1000, 1024))
|
| 16 |
+
|
| 17 |
+
def gen_fun(fun):
|
| 18 |
+
def bench_fun(x):
|
| 19 |
+
for _ in range(10):
|
| 20 |
+
x = fun(x)
|
| 21 |
+
return x
|
| 22 |
+
|
| 23 |
+
return bench_fun
|
| 24 |
+
|
| 25 |
+
time_fn(gen_fun(gelu), x, msg="fixed gelu")
|
| 26 |
+
time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
|
| 27 |
+
|
| 28 |
+
def randint():
|
| 29 |
+
return random.randint(1, x.shape[0])
|
| 30 |
+
|
| 31 |
+
def gen_fun(fun):
|
| 32 |
+
def bench_fun(x, y):
|
| 33 |
+
x = x[: randint()]
|
| 34 |
+
for _ in range(10):
|
| 35 |
+
x = fun(x)
|
| 36 |
+
y = fun(y)
|
| 37 |
+
return x, y
|
| 38 |
+
|
| 39 |
+
return bench_fun
|
| 40 |
+
|
| 41 |
+
y = mx.random.uniform(shape=(1000, 1024))
|
| 42 |
+
time_fn(gen_fun(gelu), x, y, msg="variable gelu")
|
| 43 |
+
time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
|
| 44 |
+
time_fn(
|
| 45 |
+
gen_fun(mx.compile(gelu, shapeless=True)),
|
| 46 |
+
x,
|
| 47 |
+
y,
|
| 48 |
+
msg="shapeless variable gelu",
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def bench_layernorm():
|
| 53 |
+
weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
| 54 |
+
bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
|
| 55 |
+
mx.eval(weight, bias)
|
| 56 |
+
|
| 57 |
+
def layernorm(x):
|
| 58 |
+
x = x.astype(mx.float32)
|
| 59 |
+
means = mx.mean(x, axis=-1, keepdims=True)
|
| 60 |
+
var = mx.var(x, axis=-1, keepdims=True)
|
| 61 |
+
x = (x - means) * mx.rsqrt(var + 1e-4)
|
| 62 |
+
x = x.astype(mx.float16)
|
| 63 |
+
return weight * x + bias
|
| 64 |
+
|
| 65 |
+
x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
|
| 66 |
+
|
| 67 |
+
def gen_fun(fun):
|
| 68 |
+
def bench_fun(x):
|
| 69 |
+
for _ in range(10):
|
| 70 |
+
x = fun(x)
|
| 71 |
+
return x
|
| 72 |
+
|
| 73 |
+
return bench_fun
|
| 74 |
+
|
| 75 |
+
time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
|
| 76 |
+
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
|
| 77 |
+
|
| 78 |
+
def randint():
|
| 79 |
+
return random.randint(1, x.shape[0])
|
| 80 |
+
|
| 81 |
+
def gen_fun(fun):
|
| 82 |
+
def bench_fun(x):
|
| 83 |
+
x = x[: randint()]
|
| 84 |
+
for _ in range(10):
|
| 85 |
+
x = fun(x)
|
| 86 |
+
return x
|
| 87 |
+
|
| 88 |
+
return bench_fun
|
| 89 |
+
|
| 90 |
+
random.seed(0)
|
| 91 |
+
time_fn(gen_fun(layernorm), x, msg="variable layernorm")
|
| 92 |
+
random.seed(0)
|
| 93 |
+
time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
|
| 94 |
+
random.seed(0)
|
| 95 |
+
time_fn(
|
| 96 |
+
gen_fun(mx.compile(layernorm, shapeless=True)),
|
| 97 |
+
x,
|
| 98 |
+
msg="shapeless variable layernorm",
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
parser = argparse.ArgumentParser("Compile benchmarks.")
|
| 104 |
+
args = parser.parse_args()
|
| 105 |
+
|
| 106 |
+
bench_gelu()
|
| 107 |
+
bench_layernorm()
|
ml-stable-diffusion/mlx/benchmarks/python/conv1d_bench.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import mlx.core as mx
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
| 12 |
+
device_name = device_name.decode("utf-8").strip("\n")
|
| 13 |
+
|
| 14 |
+
N_warmup = 10
|
| 15 |
+
N_iter_bench = 100
|
| 16 |
+
N_iter_func = 5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def bench(f, a, b):
|
| 20 |
+
for i in range(N_warmup):
|
| 21 |
+
f(a, b)
|
| 22 |
+
torch.mps.synchronize()
|
| 23 |
+
|
| 24 |
+
s = time.perf_counter_ns()
|
| 25 |
+
for i in range(N_iter_bench):
|
| 26 |
+
f(a, b)
|
| 27 |
+
e = time.perf_counter_ns()
|
| 28 |
+
return (e - s) * 1e-9
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def make_mx_conv_1D(strides=1, padding=0, groups=1):
|
| 32 |
+
def mx_conv_1D(a, b):
|
| 33 |
+
ys = []
|
| 34 |
+
for _ in range(N_iter_func):
|
| 35 |
+
y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
| 36 |
+
ys.append(y)
|
| 37 |
+
mx.eval(ys)
|
| 38 |
+
return ys
|
| 39 |
+
|
| 40 |
+
return mx_conv_1D
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def make_pt_conv_1D(strides=1, padding=0, groups=1):
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def pt_conv_1D(a, b):
|
| 46 |
+
ys = []
|
| 47 |
+
for _ in range(N_iter_func):
|
| 48 |
+
y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
|
| 49 |
+
ys.append(y)
|
| 50 |
+
torch.mps.synchronize()
|
| 51 |
+
return ys
|
| 52 |
+
|
| 53 |
+
return pt_conv_1D
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
|
| 57 |
+
scale = 1.0 / math.sqrt(wH * C)
|
| 58 |
+
a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
|
| 59 |
+
b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
|
| 60 |
+
|
| 61 |
+
a_mx = mx.array(a_np)
|
| 62 |
+
b_mx = mx.array(b_np)
|
| 63 |
+
|
| 64 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
|
| 65 |
+
b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
|
| 66 |
+
|
| 67 |
+
torch.mps.synchronize()
|
| 68 |
+
|
| 69 |
+
f_mx = make_mx_conv_1D(strides, padding, groups)
|
| 70 |
+
f_pt = make_pt_conv_1D(strides, padding, groups)
|
| 71 |
+
|
| 72 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 73 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 74 |
+
|
| 75 |
+
out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
| 76 |
+
out_pt = torch.conv1d(
|
| 77 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 78 |
+
)
|
| 79 |
+
out_pt = torch.permute(out_pt, (0, 2, 1))
|
| 80 |
+
out_pt = out_pt.numpy(force=True)
|
| 81 |
+
|
| 82 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 83 |
+
|
| 84 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 85 |
+
print(
|
| 86 |
+
f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return time_mlx, time_torch
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
| 94 |
+
|
| 95 |
+
dtypes = ("float32",)
|
| 96 |
+
shapes = (
|
| 97 |
+
(4, 32, 32, 5, 32, 1, 2, 1),
|
| 98 |
+
(4, 32, 32, 5, 32, 1, 2, 2),
|
| 99 |
+
(4, 32, 32, 5, 32, 1, 2, 4),
|
| 100 |
+
(4, 32, 32, 5, 32, 1, 2, 8),
|
| 101 |
+
(4, 32, 32, 5, 32, 1, 2, 8),
|
| 102 |
+
(4, 32, 32, 5, 32, 1, 2, 16),
|
| 103 |
+
(4, 32, 32, 5, 32, 1, 2, 32),
|
| 104 |
+
(4, 32, 256, 5, 512, 1, 2, 2),
|
| 105 |
+
(4, 32, 256, 5, 512, 1, 2, 128),
|
| 106 |
+
(4, 32, 256, 5, 512, 1, 2, 256),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
for dtype in dtypes:
|
| 110 |
+
print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
|
| 111 |
+
for N, iH, C, wH, O, strides, padding, groups in shapes:
|
| 112 |
+
np_dtype = getattr(np, dtype)
|
| 113 |
+
time_mlx, time_torch = bench_shape(
|
| 114 |
+
N, iH, C, wH, O, strides, padding, np_dtype, groups
|
| 115 |
+
)
|
| 116 |
+
diff = time_torch / time_mlx - 1.0
|
| 117 |
+
|
| 118 |
+
print(
|
| 119 |
+
f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if time_mlx >= 2.0 * time_torch:
|
| 123 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/conv2d_bench_cpu.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
N_warmup = 1
|
| 10 |
+
N_iter_bench = 10
|
| 11 |
+
N_iter_func = 5
|
| 12 |
+
mx.set_default_device(mx.cpu)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def bench(f, a, b):
|
| 16 |
+
for i in range(N_warmup):
|
| 17 |
+
f(a, b)
|
| 18 |
+
|
| 19 |
+
s = time.perf_counter_ns()
|
| 20 |
+
for i in range(N_iter_bench):
|
| 21 |
+
f(a, b)
|
| 22 |
+
e = time.perf_counter_ns()
|
| 23 |
+
return (e - s) * 1e-9
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 27 |
+
def mx_conv_2D(a, b):
|
| 28 |
+
ys = []
|
| 29 |
+
for i in range(N_iter_func):
|
| 30 |
+
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
| 31 |
+
ys.append(y)
|
| 32 |
+
mx.eval(ys)
|
| 33 |
+
return ys
|
| 34 |
+
|
| 35 |
+
return mx_conv_2D
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def pt_conv_2D(a, b):
|
| 41 |
+
ys = []
|
| 42 |
+
for i in range(N_iter_func):
|
| 43 |
+
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
| 44 |
+
ys.append(y)
|
| 45 |
+
return ys
|
| 46 |
+
|
| 47 |
+
return pt_conv_2D
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
| 51 |
+
scale = 1.0 / math.sqrt(kH * kH * C)
|
| 52 |
+
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
| 53 |
+
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
| 54 |
+
np_dtype
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
a_mx = mx.array(a_np)
|
| 58 |
+
b_mx = mx.array(b_np)
|
| 59 |
+
|
| 60 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
| 61 |
+
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
|
| 62 |
+
|
| 63 |
+
f_mx = make_mx_conv_2D(strides, padding, groups)
|
| 64 |
+
f_pt = make_pt_conv_2D(strides, padding, groups)
|
| 65 |
+
|
| 66 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 67 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 68 |
+
|
| 69 |
+
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
| 70 |
+
out_pt = torch.conv2d(
|
| 71 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 72 |
+
)
|
| 73 |
+
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
| 74 |
+
out_pt = out_pt.numpy(force=True)
|
| 75 |
+
|
| 76 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 77 |
+
|
| 78 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 79 |
+
print(
|
| 80 |
+
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return time_mlx, time_torch
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
| 88 |
+
|
| 89 |
+
dtypes = ("float32",)
|
| 90 |
+
shapes = (
|
| 91 |
+
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 92 |
+
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 93 |
+
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 94 |
+
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 95 |
+
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
| 96 |
+
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 97 |
+
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 98 |
+
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 99 |
+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 100 |
+
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
| 101 |
+
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
| 102 |
+
# (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
| 103 |
+
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 104 |
+
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 105 |
+
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 106 |
+
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 107 |
+
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 108 |
+
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 109 |
+
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
for dtype in dtypes:
|
| 113 |
+
print(
|
| 114 |
+
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
| 115 |
+
)
|
| 116 |
+
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
| 117 |
+
np_dtype = getattr(np, dtype)
|
| 118 |
+
time_mlx, time_torch = bench_shape(
|
| 119 |
+
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
| 120 |
+
)
|
| 121 |
+
diff = time_torch / time_mlx - 1.0
|
| 122 |
+
|
| 123 |
+
print(
|
| 124 |
+
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
| 125 |
+
)
|
| 126 |
+
if time_mlx >= 2.0 * time_torch:
|
| 127 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/conv2d_train_bench_cpu.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import mlx.core as mx
|
| 4 |
+
import mlx.nn
|
| 5 |
+
import mlx.optimizers as opt
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def bench_mlx(steps: int = 20) -> float:
|
| 10 |
+
mx.set_default_device(mx.cpu)
|
| 11 |
+
|
| 12 |
+
class BenchNetMLX(mlx.nn.Module):
|
| 13 |
+
# simple encoder-decoder net
|
| 14 |
+
|
| 15 |
+
def __init__(self, in_channels, hidden_channels=32):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.net = mlx.nn.Sequential(
|
| 19 |
+
mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
| 20 |
+
mlx.nn.ReLU(),
|
| 21 |
+
mlx.nn.Conv2d(
|
| 22 |
+
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
| 23 |
+
),
|
| 24 |
+
mlx.nn.ReLU(),
|
| 25 |
+
mlx.nn.ConvTranspose2d(
|
| 26 |
+
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
| 27 |
+
),
|
| 28 |
+
mlx.nn.ReLU(),
|
| 29 |
+
mlx.nn.ConvTranspose2d(
|
| 30 |
+
hidden_channels, in_channels, kernel_size=3, padding=1
|
| 31 |
+
),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def __call__(self, input):
|
| 35 |
+
return self.net(input)
|
| 36 |
+
|
| 37 |
+
benchNet = BenchNetMLX(3)
|
| 38 |
+
mx.eval(benchNet.parameters())
|
| 39 |
+
optim = opt.Adam(learning_rate=1e-3)
|
| 40 |
+
|
| 41 |
+
inputs = mx.random.normal([10, 256, 256, 3])
|
| 42 |
+
|
| 43 |
+
params = benchNet.parameters()
|
| 44 |
+
optim.init(params)
|
| 45 |
+
|
| 46 |
+
state = [benchNet.state, optim.state]
|
| 47 |
+
|
| 48 |
+
def loss_fn(params, image):
|
| 49 |
+
benchNet.update(params)
|
| 50 |
+
pred_image = benchNet(image)
|
| 51 |
+
return (pred_image - image).abs().mean()
|
| 52 |
+
|
| 53 |
+
def step(params, image):
|
| 54 |
+
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
| 55 |
+
optim.update(benchNet, grads)
|
| 56 |
+
return loss
|
| 57 |
+
|
| 58 |
+
total_time = 0.0
|
| 59 |
+
print("MLX:")
|
| 60 |
+
for i in range(steps):
|
| 61 |
+
start_time = time.perf_counter()
|
| 62 |
+
|
| 63 |
+
step(benchNet.parameters(), inputs)
|
| 64 |
+
mx.eval(state)
|
| 65 |
+
end_time = time.perf_counter()
|
| 66 |
+
|
| 67 |
+
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
| 68 |
+
total_time += (end_time - start_time) * 1000
|
| 69 |
+
|
| 70 |
+
return total_time
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def bench_torch(steps: int = 20) -> float:
|
| 74 |
+
device = torch.device("cpu")
|
| 75 |
+
|
| 76 |
+
class BenchNetTorch(torch.nn.Module):
|
| 77 |
+
# simple encoder-decoder net
|
| 78 |
+
|
| 79 |
+
def __init__(self, in_channels, hidden_channels=32):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
self.net = torch.nn.Sequential(
|
| 83 |
+
torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
| 84 |
+
torch.nn.ReLU(),
|
| 85 |
+
torch.nn.Conv2d(
|
| 86 |
+
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
| 87 |
+
),
|
| 88 |
+
torch.nn.ReLU(),
|
| 89 |
+
torch.nn.ConvTranspose2d(
|
| 90 |
+
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
| 91 |
+
),
|
| 92 |
+
torch.nn.ReLU(),
|
| 93 |
+
torch.nn.ConvTranspose2d(
|
| 94 |
+
hidden_channels, in_channels, kernel_size=3, padding=1
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, input):
|
| 99 |
+
return self.net(input)
|
| 100 |
+
|
| 101 |
+
benchNet = BenchNetTorch(3).to(device)
|
| 102 |
+
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
| 103 |
+
|
| 104 |
+
inputs = torch.randn(10, 3, 256, 256, device=device)
|
| 105 |
+
|
| 106 |
+
def loss_fn(pred_image, image):
|
| 107 |
+
return (pred_image - image).abs().mean()
|
| 108 |
+
|
| 109 |
+
total_time = 0.0
|
| 110 |
+
print("PyTorch:")
|
| 111 |
+
for i in range(steps):
|
| 112 |
+
start_time = time.perf_counter()
|
| 113 |
+
|
| 114 |
+
optim.zero_grad()
|
| 115 |
+
pred_image = benchNet(inputs)
|
| 116 |
+
loss = loss_fn(pred_image, inputs)
|
| 117 |
+
loss.backward()
|
| 118 |
+
optim.step()
|
| 119 |
+
|
| 120 |
+
end_time = time.perf_counter()
|
| 121 |
+
|
| 122 |
+
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
| 123 |
+
total_time += (end_time - start_time) * 1000
|
| 124 |
+
|
| 125 |
+
return total_time
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def main():
|
| 129 |
+
steps = 20
|
| 130 |
+
time_mlx = bench_mlx(steps)
|
| 131 |
+
time_torch = bench_torch(steps)
|
| 132 |
+
|
| 133 |
+
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
| 134 |
+
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
| 135 |
+
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
| 136 |
+
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
| 137 |
+
|
| 138 |
+
diff = time_torch / time_mlx - 1.0
|
| 139 |
+
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
ml-stable-diffusion/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
N_warmup = 1
|
| 10 |
+
N_iter_bench = 10
|
| 11 |
+
N_iter_func = 5
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def bench(f, a, b):
|
| 15 |
+
for i in range(N_warmup):
|
| 16 |
+
f(a, b)
|
| 17 |
+
|
| 18 |
+
s = time.perf_counter_ns()
|
| 19 |
+
for i in range(N_iter_bench):
|
| 20 |
+
f(a, b)
|
| 21 |
+
e = time.perf_counter_ns()
|
| 22 |
+
return (e - s) * 1e-9
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 26 |
+
def mx_conv_transpose_2D(a, b):
|
| 27 |
+
ys = []
|
| 28 |
+
for i in range(N_iter_func):
|
| 29 |
+
y = mx.conv_transpose2d(
|
| 30 |
+
a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
| 31 |
+
)
|
| 32 |
+
ys.append(y)
|
| 33 |
+
mx.eval(ys)
|
| 34 |
+
return ys
|
| 35 |
+
|
| 36 |
+
return mx_conv_transpose_2D
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 40 |
+
@torch.no_grad()
|
| 41 |
+
def pt_conv_transpose_2D(a, b):
|
| 42 |
+
ys = []
|
| 43 |
+
for i in range(N_iter_func):
|
| 44 |
+
y = torch.conv_transpose2d(
|
| 45 |
+
a, b, stride=strides, padding=padding, groups=groups
|
| 46 |
+
)
|
| 47 |
+
ys.append(y)
|
| 48 |
+
return ys
|
| 49 |
+
|
| 50 |
+
return pt_conv_transpose_2D
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
| 54 |
+
scale = 1.0 / math.sqrt(kH * kH * C)
|
| 55 |
+
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
| 56 |
+
b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
|
| 57 |
+
np_dtype
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
a_mx = mx.array(a_np)
|
| 61 |
+
b_mx = mx.array(b_np)
|
| 62 |
+
|
| 63 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
|
| 64 |
+
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
|
| 65 |
+
|
| 66 |
+
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
| 67 |
+
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
| 68 |
+
|
| 69 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 70 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 71 |
+
|
| 72 |
+
out_mx = mx.conv_transpose2d(
|
| 73 |
+
a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
|
| 74 |
+
)
|
| 75 |
+
out_pt = torch.conv_transpose2d(
|
| 76 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 77 |
+
)
|
| 78 |
+
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
| 79 |
+
out_pt = out_pt.numpy(force=True)
|
| 80 |
+
|
| 81 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 82 |
+
|
| 83 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 84 |
+
print(
|
| 85 |
+
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
return time_mlx, time_torch
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
if __name__ == "__main__":
|
| 92 |
+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
| 93 |
+
|
| 94 |
+
dtypes = ("float32",)
|
| 95 |
+
shapes = (
|
| 96 |
+
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 97 |
+
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 98 |
+
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 99 |
+
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 100 |
+
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
| 101 |
+
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 102 |
+
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 103 |
+
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 104 |
+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 105 |
+
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 106 |
+
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 107 |
+
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 108 |
+
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 109 |
+
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 110 |
+
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 111 |
+
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
for dtype in dtypes:
|
| 115 |
+
print(
|
| 116 |
+
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
| 117 |
+
)
|
| 118 |
+
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
| 119 |
+
np_dtype = getattr(np, dtype)
|
| 120 |
+
time_mlx, time_torch = bench_shape(
|
| 121 |
+
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
| 122 |
+
)
|
| 123 |
+
diff = time_torch / time_mlx - 1.0
|
| 124 |
+
|
| 125 |
+
print(
|
| 126 |
+
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
| 127 |
+
)
|
| 128 |
+
if time_mlx >= 2.0 * time_torch:
|
| 129 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/conv3d_bench_cpu.py
ADDED
|
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
N_warmup = 1
|
| 10 |
+
N_iter_bench = 10
|
| 11 |
+
N_iter_func = 5
|
| 12 |
+
mx.set_default_device(mx.cpu)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def bench(f, a, b):
|
| 16 |
+
for i in range(N_warmup):
|
| 17 |
+
f(a, b)
|
| 18 |
+
|
| 19 |
+
s = time.perf_counter_ns()
|
| 20 |
+
for i in range(N_iter_bench):
|
| 21 |
+
f(a, b)
|
| 22 |
+
e = time.perf_counter_ns()
|
| 23 |
+
return (e - s) * 1e-9
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 27 |
+
def mx_conv_3D(a, b):
|
| 28 |
+
ys = []
|
| 29 |
+
for i in range(N_iter_func):
|
| 30 |
+
y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
| 31 |
+
ys.append(y)
|
| 32 |
+
mx.eval(ys)
|
| 33 |
+
return ys
|
| 34 |
+
|
| 35 |
+
return mx_conv_3D
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def pt_conv_3D(a, b):
|
| 41 |
+
ys = []
|
| 42 |
+
for i in range(N_iter_func):
|
| 43 |
+
y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
|
| 44 |
+
ys.append(y)
|
| 45 |
+
return ys
|
| 46 |
+
|
| 47 |
+
return pt_conv_3D
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
| 51 |
+
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
| 52 |
+
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
| 53 |
+
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
| 54 |
+
np_dtype
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
a_mx = mx.array(a_np)
|
| 58 |
+
b_mx = mx.array(b_np)
|
| 59 |
+
|
| 60 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
| 61 |
+
b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
| 62 |
+
|
| 63 |
+
f_mx = make_mx_conv_3D(strides, padding, groups)
|
| 64 |
+
f_pt = make_pt_conv_3D(strides, padding, groups)
|
| 65 |
+
|
| 66 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 67 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 68 |
+
|
| 69 |
+
out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
| 70 |
+
out_pt = torch.conv3d(
|
| 71 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 72 |
+
)
|
| 73 |
+
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
| 74 |
+
out_pt = out_pt.numpy(force=True)
|
| 75 |
+
|
| 76 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 77 |
+
|
| 78 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 79 |
+
print(
|
| 80 |
+
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
return time_mlx, time_torch
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
if __name__ == "__main__":
|
| 87 |
+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
| 88 |
+
|
| 89 |
+
dtypes = ("float32",)
|
| 90 |
+
shapes = (
|
| 91 |
+
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
| 92 |
+
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
for dtype in dtypes:
|
| 96 |
+
print(
|
| 97 |
+
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
| 98 |
+
)
|
| 99 |
+
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
| 100 |
+
np_dtype = getattr(np, dtype)
|
| 101 |
+
time_mlx, time_torch = bench_shape(
|
| 102 |
+
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
| 103 |
+
)
|
| 104 |
+
diff = time_torch / time_mlx - 1.0
|
| 105 |
+
|
| 106 |
+
print(
|
| 107 |
+
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
| 108 |
+
)
|
| 109 |
+
if time_mlx >= 2.0 * time_torch:
|
| 110 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/conv3d_train_bench_cpu.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
|
| 3 |
+
import mlx.core as mx
|
| 4 |
+
import mlx.nn
|
| 5 |
+
import mlx.optimizers as opt
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
|
| 10 |
+
mx.set_default_device(mx.cpu)
|
| 11 |
+
|
| 12 |
+
class BenchNetMLX(mlx.nn.Module):
|
| 13 |
+
# simple encoder-decoder net
|
| 14 |
+
|
| 15 |
+
def __init__(self, in_channels, hidden_channels=16):
|
| 16 |
+
super().__init__()
|
| 17 |
+
|
| 18 |
+
self.net = mlx.nn.Sequential(
|
| 19 |
+
mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
| 20 |
+
mlx.nn.ReLU(),
|
| 21 |
+
mlx.nn.Conv3d(
|
| 22 |
+
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
| 23 |
+
),
|
| 24 |
+
mlx.nn.ReLU(),
|
| 25 |
+
mlx.nn.ConvTranspose3d(
|
| 26 |
+
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
| 27 |
+
),
|
| 28 |
+
mlx.nn.ReLU(),
|
| 29 |
+
mlx.nn.ConvTranspose3d(
|
| 30 |
+
hidden_channels, in_channels, kernel_size=3, padding=1
|
| 31 |
+
),
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def __call__(self, input):
|
| 35 |
+
return self.net(input)
|
| 36 |
+
|
| 37 |
+
benchNet = BenchNetMLX(3)
|
| 38 |
+
mx.eval(benchNet.parameters())
|
| 39 |
+
optim = opt.Adam(learning_rate=1e-3)
|
| 40 |
+
|
| 41 |
+
inputs = mx.random.normal(shape)
|
| 42 |
+
|
| 43 |
+
params = benchNet.parameters()
|
| 44 |
+
optim.init(params)
|
| 45 |
+
|
| 46 |
+
state = [benchNet.state, optim.state]
|
| 47 |
+
|
| 48 |
+
def loss_fn(params, image):
|
| 49 |
+
benchNet.update(params)
|
| 50 |
+
pred_image = benchNet(image)
|
| 51 |
+
return (pred_image - image).abs().mean()
|
| 52 |
+
|
| 53 |
+
def step(params, image):
|
| 54 |
+
loss, grads = mx.value_and_grad(loss_fn)(params, image)
|
| 55 |
+
optim.update(benchNet, grads)
|
| 56 |
+
return loss
|
| 57 |
+
|
| 58 |
+
total_time = 0.0
|
| 59 |
+
print("MLX:")
|
| 60 |
+
for i in range(steps):
|
| 61 |
+
start_time = time.perf_counter()
|
| 62 |
+
|
| 63 |
+
step(benchNet.parameters(), inputs)
|
| 64 |
+
mx.eval(state)
|
| 65 |
+
end_time = time.perf_counter()
|
| 66 |
+
|
| 67 |
+
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
| 68 |
+
total_time += (end_time - start_time) * 1000
|
| 69 |
+
|
| 70 |
+
return total_time
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
|
| 74 |
+
device = torch.device("cpu")
|
| 75 |
+
|
| 76 |
+
class BenchNetTorch(torch.nn.Module):
|
| 77 |
+
# simple encoder-decoder net
|
| 78 |
+
|
| 79 |
+
def __init__(self, in_channels, hidden_channels=16):
|
| 80 |
+
super().__init__()
|
| 81 |
+
|
| 82 |
+
self.net = torch.nn.Sequential(
|
| 83 |
+
torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
|
| 84 |
+
torch.nn.ReLU(),
|
| 85 |
+
torch.nn.Conv3d(
|
| 86 |
+
hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
|
| 87 |
+
),
|
| 88 |
+
torch.nn.ReLU(),
|
| 89 |
+
torch.nn.ConvTranspose3d(
|
| 90 |
+
2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
|
| 91 |
+
),
|
| 92 |
+
torch.nn.ReLU(),
|
| 93 |
+
torch.nn.ConvTranspose3d(
|
| 94 |
+
hidden_channels, in_channels, kernel_size=3, padding=1
|
| 95 |
+
),
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
def forward(self, input):
|
| 99 |
+
return self.net(input)
|
| 100 |
+
|
| 101 |
+
benchNet = BenchNetTorch(3).to(device)
|
| 102 |
+
optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
|
| 103 |
+
|
| 104 |
+
inputs = torch.randn(*shape, device=device)
|
| 105 |
+
|
| 106 |
+
def loss_fn(pred_image, image):
|
| 107 |
+
return (pred_image - image).abs().mean()
|
| 108 |
+
|
| 109 |
+
total_time = 0.0
|
| 110 |
+
print("PyTorch:")
|
| 111 |
+
for i in range(steps):
|
| 112 |
+
start_time = time.perf_counter()
|
| 113 |
+
|
| 114 |
+
optim.zero_grad()
|
| 115 |
+
pred_image = benchNet(inputs)
|
| 116 |
+
loss = loss_fn(pred_image, inputs)
|
| 117 |
+
loss.backward()
|
| 118 |
+
optim.step()
|
| 119 |
+
|
| 120 |
+
end_time = time.perf_counter()
|
| 121 |
+
|
| 122 |
+
print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
|
| 123 |
+
total_time += (end_time - start_time) * 1000
|
| 124 |
+
|
| 125 |
+
return total_time
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def main():
|
| 129 |
+
steps = 10
|
| 130 |
+
time_mlx = bench_mlx(steps)
|
| 131 |
+
time_torch = bench_torch(steps)
|
| 132 |
+
|
| 133 |
+
print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
|
| 134 |
+
print(f"total time of MLX: {time_mlx:9.2f} ms")
|
| 135 |
+
print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
|
| 136 |
+
print(f"total time of PyTorch: {time_torch:9.2f} ms")
|
| 137 |
+
|
| 138 |
+
diff = time_torch / time_mlx - 1.0
|
| 139 |
+
print(f"torch/mlx diff: {100. * diff:+5.2f}%")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
if __name__ == "__main__":
|
| 143 |
+
main()
|
ml-stable-diffusion/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
N_warmup = 1
|
| 10 |
+
N_iter_bench = 10
|
| 11 |
+
N_iter_func = 5
|
| 12 |
+
mx.set_default_device(mx.cpu)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def bench(f, a, b):
|
| 16 |
+
for i in range(N_warmup):
|
| 17 |
+
f(a, b)
|
| 18 |
+
|
| 19 |
+
s = time.perf_counter_ns()
|
| 20 |
+
for i in range(N_iter_bench):
|
| 21 |
+
f(a, b)
|
| 22 |
+
e = time.perf_counter_ns()
|
| 23 |
+
return (e - s) * 1e-9
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
| 27 |
+
def mx_conv_3D(a, b):
|
| 28 |
+
ys = []
|
| 29 |
+
for i in range(N_iter_func):
|
| 30 |
+
y = mx.conv_transpose3d(
|
| 31 |
+
a, b, stride=strides, padding=padding, groups=groups
|
| 32 |
+
)
|
| 33 |
+
ys.append(y)
|
| 34 |
+
mx.eval(ys)
|
| 35 |
+
return ys
|
| 36 |
+
|
| 37 |
+
return mx_conv_3D
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
|
| 41 |
+
@torch.no_grad()
|
| 42 |
+
def pt_conv_3D(a, b):
|
| 43 |
+
ys = []
|
| 44 |
+
for i in range(N_iter_func):
|
| 45 |
+
y = torch.conv_transpose3d(
|
| 46 |
+
a, b, stride=strides, padding=padding, groups=groups
|
| 47 |
+
)
|
| 48 |
+
ys.append(y)
|
| 49 |
+
return ys
|
| 50 |
+
|
| 51 |
+
return pt_conv_3D
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
|
| 55 |
+
scale = 1.0 / math.sqrt(kD * kH * kW * C)
|
| 56 |
+
a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
|
| 57 |
+
b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
|
| 58 |
+
np_dtype
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
a_mx = mx.array(a_np)
|
| 62 |
+
b_mx = mx.array(b_np)
|
| 63 |
+
|
| 64 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
|
| 65 |
+
b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
|
| 66 |
+
|
| 67 |
+
f_mx = make_mx_conv_3D(strides, padding, groups)
|
| 68 |
+
f_pt = make_pt_conv_3D(strides, padding, groups)
|
| 69 |
+
|
| 70 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 71 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 72 |
+
|
| 73 |
+
out_mx = mx.conv_transpose3d(
|
| 74 |
+
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
| 75 |
+
)
|
| 76 |
+
out_pt = torch.conv_transpose3d(
|
| 77 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 78 |
+
)
|
| 79 |
+
out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
|
| 80 |
+
out_pt = out_pt.numpy(force=True)
|
| 81 |
+
|
| 82 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 83 |
+
|
| 84 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 85 |
+
print(
|
| 86 |
+
f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
return time_mlx, time_torch
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
if __name__ == "__main__":
|
| 93 |
+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
| 94 |
+
|
| 95 |
+
dtypes = ("float32",)
|
| 96 |
+
shapes = (
|
| 97 |
+
(4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
|
| 98 |
+
(4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
for dtype in dtypes:
|
| 102 |
+
print(
|
| 103 |
+
"(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
|
| 104 |
+
)
|
| 105 |
+
for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
|
| 106 |
+
np_dtype = getattr(np, dtype)
|
| 107 |
+
time_mlx, time_torch = bench_shape(
|
| 108 |
+
N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
|
| 109 |
+
)
|
| 110 |
+
diff = time_torch / time_mlx - 1.0
|
| 111 |
+
|
| 112 |
+
print(
|
| 113 |
+
f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
| 114 |
+
)
|
| 115 |
+
if time_mlx >= 2.0 * time_torch:
|
| 116 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/conv_bench.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import mlx.core as mx
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
|
| 12 |
+
device_name = device_name.decode("utf-8").strip("\n")
|
| 13 |
+
|
| 14 |
+
N_warmup = 10
|
| 15 |
+
N_iter_bench = 100
|
| 16 |
+
N_iter_func = 5
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def bench(f, a, b):
|
| 20 |
+
for i in range(N_warmup):
|
| 21 |
+
f(a, b)
|
| 22 |
+
torch.mps.synchronize()
|
| 23 |
+
|
| 24 |
+
s = time.perf_counter_ns()
|
| 25 |
+
for i in range(N_iter_bench):
|
| 26 |
+
f(a, b)
|
| 27 |
+
e = time.perf_counter_ns()
|
| 28 |
+
return (e - s) * 1e-9
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 32 |
+
def mx_conv_2D(a, b):
|
| 33 |
+
ys = []
|
| 34 |
+
for i in range(N_iter_func):
|
| 35 |
+
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
| 36 |
+
ys.append(y)
|
| 37 |
+
mx.eval(ys)
|
| 38 |
+
return ys
|
| 39 |
+
|
| 40 |
+
return mx_conv_2D
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def pt_conv_2D(a, b):
|
| 46 |
+
ys = []
|
| 47 |
+
for i in range(N_iter_func):
|
| 48 |
+
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
| 49 |
+
ys.append(y)
|
| 50 |
+
torch.mps.synchronize()
|
| 51 |
+
return ys
|
| 52 |
+
|
| 53 |
+
return pt_conv_2D
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
| 57 |
+
scale = 1.0 / math.sqrt(kH * kH * C)
|
| 58 |
+
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
| 59 |
+
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
| 60 |
+
np_dtype
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
a_mx = mx.array(a_np)
|
| 64 |
+
b_mx = mx.array(b_np)
|
| 65 |
+
|
| 66 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
| 67 |
+
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
| 68 |
+
|
| 69 |
+
torch.mps.synchronize()
|
| 70 |
+
|
| 71 |
+
f_mx = make_mx_conv_2D(strides, padding, groups)
|
| 72 |
+
f_pt = make_pt_conv_2D(strides, padding, groups)
|
| 73 |
+
|
| 74 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 75 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 76 |
+
|
| 77 |
+
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
| 78 |
+
out_pt = torch.conv2d(
|
| 79 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 80 |
+
)
|
| 81 |
+
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
| 82 |
+
out_pt = out_pt.numpy(force=True)
|
| 83 |
+
|
| 84 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 85 |
+
|
| 86 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 87 |
+
print(
|
| 88 |
+
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
return time_mlx, time_torch
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
| 96 |
+
|
| 97 |
+
dtypes = ("float32",)
|
| 98 |
+
shapes = (
|
| 99 |
+
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 100 |
+
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 101 |
+
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 102 |
+
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 103 |
+
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
| 104 |
+
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 105 |
+
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 106 |
+
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 107 |
+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 108 |
+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
|
| 109 |
+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
|
| 110 |
+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
|
| 111 |
+
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 112 |
+
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 113 |
+
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 114 |
+
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 115 |
+
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 116 |
+
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 117 |
+
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
for dtype in dtypes:
|
| 121 |
+
print(
|
| 122 |
+
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
| 123 |
+
)
|
| 124 |
+
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
| 125 |
+
np_dtype = getattr(np, dtype)
|
| 126 |
+
time_mlx, time_torch = bench_shape(
|
| 127 |
+
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
| 128 |
+
)
|
| 129 |
+
diff = time_torch / time_mlx - 1.0
|
| 130 |
+
|
| 131 |
+
print(
|
| 132 |
+
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
| 133 |
+
)
|
| 134 |
+
if time_mlx >= 2.0 * time_torch:
|
| 135 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/conv_transpose_bench.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import math
|
| 3 |
+
import os
|
| 4 |
+
import subprocess
|
| 5 |
+
import time
|
| 6 |
+
|
| 7 |
+
import mlx.core as mx
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
N_warmup = 10
|
| 12 |
+
N_iter_bench = 100
|
| 13 |
+
N_iter_func = 5
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def bench(f, a, b):
|
| 17 |
+
for i in range(N_warmup):
|
| 18 |
+
f(a, b)
|
| 19 |
+
torch.mps.synchronize()
|
| 20 |
+
|
| 21 |
+
s = time.perf_counter_ns()
|
| 22 |
+
for i in range(N_iter_bench):
|
| 23 |
+
f(a, b)
|
| 24 |
+
e = time.perf_counter_ns()
|
| 25 |
+
return (e - s) * 1e-9
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 29 |
+
def mx_conv_transpose_2D(a, b):
|
| 30 |
+
ys = []
|
| 31 |
+
for i in range(N_iter_func):
|
| 32 |
+
y = mx.conv_transpose2d(
|
| 33 |
+
a, b, stride=strides, padding=padding, groups=groups
|
| 34 |
+
)
|
| 35 |
+
ys.append(y)
|
| 36 |
+
mx.eval(ys)
|
| 37 |
+
return ys
|
| 38 |
+
|
| 39 |
+
return mx_conv_transpose_2D
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 43 |
+
@torch.no_grad()
|
| 44 |
+
def pt_conv_transpose_2D(a, b):
|
| 45 |
+
ys = []
|
| 46 |
+
for i in range(N_iter_func):
|
| 47 |
+
y = torch.conv_transpose2d(
|
| 48 |
+
a, b, stride=strides, padding=padding, groups=groups
|
| 49 |
+
)
|
| 50 |
+
ys.append(y)
|
| 51 |
+
torch.mps.synchronize()
|
| 52 |
+
return ys
|
| 53 |
+
|
| 54 |
+
return pt_conv_transpose_2D
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
| 58 |
+
scale = 1.0 / math.sqrt(kH * kH * C)
|
| 59 |
+
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
| 60 |
+
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
| 61 |
+
np_dtype
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
a_mx = mx.array(a_np)
|
| 65 |
+
b_mx = mx.array(b_np)
|
| 66 |
+
|
| 67 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
| 68 |
+
b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
|
| 69 |
+
|
| 70 |
+
torch.mps.synchronize()
|
| 71 |
+
|
| 72 |
+
f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
|
| 73 |
+
f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
|
| 74 |
+
|
| 75 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 76 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 77 |
+
|
| 78 |
+
out_mx = mx.conv_transpose2d(
|
| 79 |
+
a_mx, b_mx, stride=strides, padding=padding, groups=groups
|
| 80 |
+
)
|
| 81 |
+
out_pt = torch.conv_transpose2d(
|
| 82 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 83 |
+
)
|
| 84 |
+
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
| 85 |
+
out_pt = out_pt.numpy(force=True)
|
| 86 |
+
|
| 87 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 88 |
+
|
| 89 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 90 |
+
print(
|
| 91 |
+
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
return time_mlx, time_torch
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
if __name__ == "__main__":
|
| 98 |
+
parser = argparse.ArgumentParser(description="Run conv benchmarks")
|
| 99 |
+
|
| 100 |
+
dtypes = ("float32",)
|
| 101 |
+
shapes = (
|
| 102 |
+
(4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 103 |
+
(4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 104 |
+
(4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 105 |
+
(4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 106 |
+
(4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
|
| 107 |
+
(4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 108 |
+
(4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 109 |
+
(4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 110 |
+
(4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
|
| 111 |
+
(4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 112 |
+
(4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 113 |
+
(4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
|
| 114 |
+
(4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 115 |
+
(4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
|
| 116 |
+
(4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
|
| 117 |
+
(4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
for dtype in dtypes:
|
| 121 |
+
print(
|
| 122 |
+
"(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
|
| 123 |
+
)
|
| 124 |
+
for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
|
| 125 |
+
np_dtype = getattr(np, dtype)
|
| 126 |
+
time_mlx, time_torch = bench_shape(
|
| 127 |
+
N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
|
| 128 |
+
)
|
| 129 |
+
diff = time_torch / time_mlx - 1.0
|
| 130 |
+
|
| 131 |
+
print(
|
| 132 |
+
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
|
| 133 |
+
)
|
| 134 |
+
if time_mlx >= 2.0 * time_torch:
|
| 135 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/conv_unaligned_bench.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import mlx.core as mx
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
|
| 8 |
+
N_warmup = 10
|
| 9 |
+
N_iter_bench = 100
|
| 10 |
+
N_iter_func = 5
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def bench(f, a, b):
|
| 14 |
+
for i in range(N_warmup):
|
| 15 |
+
f(a, b)
|
| 16 |
+
torch.mps.synchronize()
|
| 17 |
+
|
| 18 |
+
s = time.perf_counter_ns()
|
| 19 |
+
for i in range(N_iter_bench):
|
| 20 |
+
f(a, b)
|
| 21 |
+
e = time.perf_counter_ns()
|
| 22 |
+
return (e - s) * 1e-9
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 26 |
+
def mx_conv_2D(a, b):
|
| 27 |
+
ys = []
|
| 28 |
+
for i in range(N_iter_func):
|
| 29 |
+
y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
| 30 |
+
ys.append(y)
|
| 31 |
+
mx.eval(ys)
|
| 32 |
+
return ys
|
| 33 |
+
|
| 34 |
+
return mx_conv_2D
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
|
| 38 |
+
@torch.no_grad()
|
| 39 |
+
def pt_conv_2D(a, b):
|
| 40 |
+
ys = []
|
| 41 |
+
for i in range(N_iter_func):
|
| 42 |
+
y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
|
| 43 |
+
ys.append(y)
|
| 44 |
+
torch.mps.synchronize()
|
| 45 |
+
return ys
|
| 46 |
+
|
| 47 |
+
return pt_conv_2D
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
|
| 51 |
+
scale = 1.0 / math.sqrt(kH * kH * C)
|
| 52 |
+
a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
|
| 53 |
+
b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
|
| 54 |
+
np_dtype
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
a_mx = mx.array(a_np)
|
| 58 |
+
b_mx = mx.array(b_np)
|
| 59 |
+
|
| 60 |
+
a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
|
| 61 |
+
b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
|
| 62 |
+
|
| 63 |
+
torch.mps.synchronize()
|
| 64 |
+
|
| 65 |
+
f_mx = make_mx_conv_2D(strides, padding, groups)
|
| 66 |
+
f_pt = make_pt_conv_2D(strides, padding, groups)
|
| 67 |
+
|
| 68 |
+
time_torch = bench(f_pt, a_pt, b_pt)
|
| 69 |
+
time_mlx = bench(f_mx, a_mx, b_mx)
|
| 70 |
+
|
| 71 |
+
out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
|
| 72 |
+
out_pt = torch.conv2d(
|
| 73 |
+
a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
|
| 74 |
+
)
|
| 75 |
+
out_pt = torch.permute(out_pt, (0, 2, 3, 1))
|
| 76 |
+
out_pt = out_pt.numpy(force=True)
|
| 77 |
+
|
| 78 |
+
atol = 2e-5 if np_dtype == np.float32 else 1e-4
|
| 79 |
+
|
| 80 |
+
if not np.allclose(out_pt, out_mx, atol=atol):
|
| 81 |
+
print(
|
| 82 |
+
f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
return time_mlx, time_torch
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == "__main__":
|
| 89 |
+
dtype = "float32"
|
| 90 |
+
shapes = (
|
| 91 |
+
(4, 32, 32, 21, 3, 3, 128),
|
| 92 |
+
(4, 32, 32, 21, 3, 3, 37),
|
| 93 |
+
(4, 32, 32, 370, 3, 3, 370),
|
| 94 |
+
(4, 32, 32, 370, 7, 7, 128),
|
| 95 |
+
(2, 320, 640, 21, 7, 7, 21),
|
| 96 |
+
)
|
| 97 |
+
for N, H, W, C, kh, kw, O in shapes:
|
| 98 |
+
time_mlx, time_torch = bench_shape(
|
| 99 |
+
N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
|
| 100 |
+
)
|
| 101 |
+
diff = time_torch / time_mlx - 1.0
|
| 102 |
+
|
| 103 |
+
print(
|
| 104 |
+
f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
|
| 105 |
+
)
|
| 106 |
+
if time_mlx >= 2.0 * time_torch:
|
| 107 |
+
print("ATTENTION ^^^^^^^")
|
ml-stable-diffusion/mlx/benchmarks/python/distributed_bench.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
Run with:
|
| 5 |
+
mpirun -n 2 python /path/to/distributed_bench.py
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
import mlx.core as mx
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def time_fn(fn, *args, **kwargs):
|
| 14 |
+
msg = kwargs.pop("msg", None)
|
| 15 |
+
world = mx.distributed.init()
|
| 16 |
+
if world.rank() == 0:
|
| 17 |
+
if msg:
|
| 18 |
+
print(f"Timing {msg} ...", end=" ")
|
| 19 |
+
else:
|
| 20 |
+
print(f"Timing {fn.__name__} ...", end=" ")
|
| 21 |
+
|
| 22 |
+
# warmup
|
| 23 |
+
for _ in range(5):
|
| 24 |
+
mx.eval(fn(*args, **kwargs))
|
| 25 |
+
|
| 26 |
+
num_iters = 100
|
| 27 |
+
tic = time.perf_counter()
|
| 28 |
+
for _ in range(num_iters):
|
| 29 |
+
x = mx.eval(fn(*args, **kwargs))
|
| 30 |
+
toc = time.perf_counter()
|
| 31 |
+
|
| 32 |
+
msec = 1e3 * (toc - tic) / num_iters
|
| 33 |
+
if world.rank() == 0:
|
| 34 |
+
print(f"{msec:.5f} msec")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def time_all_sum():
|
| 38 |
+
shape = (4096,)
|
| 39 |
+
x = mx.random.uniform(shape=shape)
|
| 40 |
+
mx.eval(x)
|
| 41 |
+
|
| 42 |
+
def sine(x):
|
| 43 |
+
for _ in range(20):
|
| 44 |
+
x = mx.sin(x)
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
time_fn(sine, x)
|
| 48 |
+
|
| 49 |
+
def all_sum_plain(x):
|
| 50 |
+
for _ in range(20):
|
| 51 |
+
x = mx.distributed.all_sum(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
time_fn(all_sum_plain, x)
|
| 55 |
+
|
| 56 |
+
def all_sum_with_sine(x):
|
| 57 |
+
for _ in range(20):
|
| 58 |
+
x = mx.sin(x)
|
| 59 |
+
x = mx.distributed.all_sum(x)
|
| 60 |
+
return x
|
| 61 |
+
|
| 62 |
+
time_fn(all_sum_with_sine, x)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
time_all_sum()
|
ml-stable-diffusion/mlx/benchmarks/python/einsum_bench.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def timeit(fn, its=100, args=[]):
|
| 10 |
+
for _ in range(5):
|
| 11 |
+
fn(*args)
|
| 12 |
+
tic = time.perf_counter()
|
| 13 |
+
for _ in range(its):
|
| 14 |
+
fn(*args)
|
| 15 |
+
toc = time.perf_counter()
|
| 16 |
+
return 1e3 * (toc - tic) / its
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def time_little_einsum_path():
|
| 20 |
+
subscripts = "ik,kj->ij"
|
| 21 |
+
x = mx.ones((32, 32))
|
| 22 |
+
y = mx.ones((32, 32))
|
| 23 |
+
mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
|
| 24 |
+
|
| 25 |
+
x = np.array(x)
|
| 26 |
+
y = np.array(y)
|
| 27 |
+
np_time = timeit(np.einsum_path, args=(subscripts, x, y))
|
| 28 |
+
print("Timing little einsum path...")
|
| 29 |
+
print(f"MLX ... {mx_time:.3f} ms")
|
| 30 |
+
print(f"NumPy... {np_time:.3f} ms")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def time_big_einsum_path():
|
| 34 |
+
chars = list("abcdefgh")
|
| 35 |
+
char_to_dim = {c: v for v, c in enumerate(chars)}
|
| 36 |
+
|
| 37 |
+
num_inputs = 10
|
| 38 |
+
inputs = []
|
| 39 |
+
subscripts = []
|
| 40 |
+
for _ in range(num_inputs):
|
| 41 |
+
subscript = np.random.choice(chars, size=5, replace=False).tolist()
|
| 42 |
+
subscripts.append("".join(subscript))
|
| 43 |
+
inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
|
| 44 |
+
subscripts = ",".join(subscripts)
|
| 45 |
+
|
| 46 |
+
np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
|
| 47 |
+
|
| 48 |
+
inputs = [mx.array(x) for x in inputs]
|
| 49 |
+
mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
|
| 50 |
+
print("Timing big einsum path...")
|
| 51 |
+
print(f"MLX ... {mx_time:.3f} ms")
|
| 52 |
+
print(f"NumPy... {np_time:.3f} ms")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def time_attention():
|
| 56 |
+
def regular_attention(x):
|
| 57 |
+
# shape [batch, sequence, num_heads, head_dim]
|
| 58 |
+
queries, keys, values = x, x, x
|
| 59 |
+
scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
|
| 60 |
+
scores = mx.softmax(scores, axis=-1)
|
| 61 |
+
output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
|
| 62 |
+
mx.eval(output)
|
| 63 |
+
|
| 64 |
+
def einsum_attention(x):
|
| 65 |
+
# shape [batch, sequence, num_heads, head_dim]
|
| 66 |
+
queries, keys, values = x, x, x
|
| 67 |
+
scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
|
| 68 |
+
scores = mx.softmax(scores, axis=-1)
|
| 69 |
+
output = mx.einsum("ijtu,iujk->itjk", scores, values)
|
| 70 |
+
mx.eval(output)
|
| 71 |
+
|
| 72 |
+
x = mx.random.uniform(shape=(8, 512, 32, 128))
|
| 73 |
+
|
| 74 |
+
regular_time = timeit(regular_attention, args=(x,))
|
| 75 |
+
ein_time = timeit(einsum_attention, args=(x,))
|
| 76 |
+
print("Timing einsum attention...")
|
| 77 |
+
print(f"Regular ... {regular_time:.3f} ms")
|
| 78 |
+
print(f"Einsum ... {ein_time:.3f} ms")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
time_little_einsum_path()
|
| 83 |
+
time_big_einsum_path()
|
| 84 |
+
time_attention()
|
ml-stable-diffusion/mlx/benchmarks/python/fft_bench.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import matplotlib
|
| 4 |
+
import mlx.core as mx
|
| 5 |
+
import numpy as np
|
| 6 |
+
import sympy
|
| 7 |
+
import torch
|
| 8 |
+
from time_utils import measure_runtime
|
| 9 |
+
|
| 10 |
+
matplotlib.use("Agg")
|
| 11 |
+
import matplotlib.pyplot as plt
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def bandwidth_gb(runtime_ms, system_size):
|
| 15 |
+
bytes_per_fft = np.dtype(np.complex64).itemsize * 2
|
| 16 |
+
bytes_per_gb = 1e9
|
| 17 |
+
ms_per_s = 1e3
|
| 18 |
+
return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
|
| 22 |
+
def fft_mlx(x):
|
| 23 |
+
if dim == 1:
|
| 24 |
+
out = mx.fft.fft(x)
|
| 25 |
+
elif dim == 2:
|
| 26 |
+
out = mx.fft.fft2(x)
|
| 27 |
+
mx.eval(out)
|
| 28 |
+
return out
|
| 29 |
+
|
| 30 |
+
def fft_mps(x):
|
| 31 |
+
if dim == 1:
|
| 32 |
+
out = torch.fft.fft(x)
|
| 33 |
+
elif dim == 2:
|
| 34 |
+
out = torch.fft.fft2(x)
|
| 35 |
+
torch.mps.synchronize()
|
| 36 |
+
return out
|
| 37 |
+
|
| 38 |
+
bandwidths = []
|
| 39 |
+
for n in fft_sizes:
|
| 40 |
+
batch_size = system_size // n**dim
|
| 41 |
+
shape = [batch_size] + [n for _ in range(dim)]
|
| 42 |
+
if backend == "mlx":
|
| 43 |
+
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
| 44 |
+
x = mx.array(x_np)
|
| 45 |
+
mx.eval(x)
|
| 46 |
+
fft = fft_mlx
|
| 47 |
+
elif backend == "mps":
|
| 48 |
+
x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
|
| 49 |
+
x = torch.tensor(x_np, device="mps")
|
| 50 |
+
torch.mps.synchronize()
|
| 51 |
+
fft = fft_mps
|
| 52 |
+
else:
|
| 53 |
+
raise NotImplementedError()
|
| 54 |
+
runtime_ms = measure_runtime(fft, x=x)
|
| 55 |
+
bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
|
| 56 |
+
print(n, bandwidth)
|
| 57 |
+
bandwidths.append(bandwidth)
|
| 58 |
+
|
| 59 |
+
return np.array(bandwidths)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def time_fft():
|
| 63 |
+
x = np.array(range(2, 512))
|
| 64 |
+
system_size = int(2**26)
|
| 65 |
+
|
| 66 |
+
print("MLX GPU")
|
| 67 |
+
with mx.stream(mx.gpu):
|
| 68 |
+
gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
| 69 |
+
|
| 70 |
+
print("MPS GPU")
|
| 71 |
+
mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
|
| 72 |
+
|
| 73 |
+
print("CPU")
|
| 74 |
+
system_size = int(2**20)
|
| 75 |
+
with mx.stream(mx.cpu):
|
| 76 |
+
cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
|
| 77 |
+
|
| 78 |
+
x = np.array(x)
|
| 79 |
+
|
| 80 |
+
all_indices = x - x[0]
|
| 81 |
+
radix_2to13 = (
|
| 82 |
+
np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
|
| 83 |
+
)
|
| 84 |
+
bluesteins = (
|
| 85 |
+
np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
for indices, name in [
|
| 89 |
+
(all_indices, "All"),
|
| 90 |
+
(radix_2to13, "Radix 2-13"),
|
| 91 |
+
(bluesteins, "Bluestein's"),
|
| 92 |
+
]:
|
| 93 |
+
# plot bandwidths
|
| 94 |
+
print(name)
|
| 95 |
+
plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
|
| 96 |
+
plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
|
| 97 |
+
plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
|
| 98 |
+
plt.title(f"MLX FFT Benchmark -- {name}")
|
| 99 |
+
plt.xlabel("N")
|
| 100 |
+
plt.ylabel("Bandwidth (GB/s)")
|
| 101 |
+
plt.legend()
|
| 102 |
+
plt.savefig(f"{name}.png")
|
| 103 |
+
plt.clf()
|
| 104 |
+
|
| 105 |
+
av_gpu_bandwidth = np.mean(gpu_bandwidths)
|
| 106 |
+
av_mps_bandwidth = np.mean(mps_bandwidths)
|
| 107 |
+
av_cpu_bandwidth = np.mean(cpu_bandwidths)
|
| 108 |
+
print("Average bandwidths:")
|
| 109 |
+
print("GPU:", av_gpu_bandwidth)
|
| 110 |
+
print("MPS:", av_mps_bandwidth)
|
| 111 |
+
print("CPU:", av_cpu_bandwidth)
|
| 112 |
+
|
| 113 |
+
portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
|
| 114 |
+
print("Percent MLX faster than MPS: ", portion_faster * 100)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
if __name__ == "__main__":
|
| 118 |
+
time_fft()
|
ml-stable-diffusion/mlx/benchmarks/python/gather_bench.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023-2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import torch
|
| 7 |
+
from time_utils import measure_runtime
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def benchmark_gather_mlx(x_shape, idx_shape):
|
| 11 |
+
def gather(x, idx):
|
| 12 |
+
mx.eval(x[idx])
|
| 13 |
+
|
| 14 |
+
idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
|
| 15 |
+
x = mx.random.normal(x_shape).astype(mx.float32)
|
| 16 |
+
|
| 17 |
+
runtime = measure_runtime(gather, x=x, idx=idx)
|
| 18 |
+
print(f"MLX: {runtime:.3f}ms")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def benchmark_gather_torch(x_shape, idx_shape, device):
|
| 22 |
+
def gather(x, idx, device):
|
| 23 |
+
_ = x[idx]
|
| 24 |
+
if device == torch.device("mps"):
|
| 25 |
+
torch.mps.synchronize()
|
| 26 |
+
|
| 27 |
+
idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
|
| 28 |
+
x = torch.randn(x_shape, dtype=torch.float32).to(device)
|
| 29 |
+
|
| 30 |
+
runtime = measure_runtime(gather, x=x, idx=idx, device=device)
|
| 31 |
+
print(f"PyTorch: {runtime:.3f}ms")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
if __name__ == "__main__":
|
| 35 |
+
parser = argparse.ArgumentParser("Gather benchmarks.")
|
| 36 |
+
parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
|
| 39 |
+
if args.cpu:
|
| 40 |
+
mx.set_default_device(mx.cpu)
|
| 41 |
+
device = torch.device("cpu")
|
| 42 |
+
else:
|
| 43 |
+
device = torch.device("mps")
|
| 44 |
+
|
| 45 |
+
idx_shapes = [(1_000_000,), (100_000,), ()]
|
| 46 |
+
x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
|
| 47 |
+
|
| 48 |
+
for x_shape, idx_shape in zip(x_shapes, idx_shapes):
|
| 49 |
+
print("=" * 20)
|
| 50 |
+
print(f"X {x_shape}, Indices {idx_shape}")
|
| 51 |
+
benchmark_gather_mlx(x_shape, idx_shape)
|
| 52 |
+
benchmark_gather_torch(x_shape, idx_shape, device=device)
|
ml-stable-diffusion/mlx/benchmarks/python/gather_mm_bench.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2025 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import mlx.core as mx
|
| 4 |
+
from time_utils import time_fn
|
| 5 |
+
|
| 6 |
+
N = 1024
|
| 7 |
+
D = 1024
|
| 8 |
+
M = 1024
|
| 9 |
+
E = 32
|
| 10 |
+
I = 4
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def gather_sort(x, indices):
|
| 14 |
+
N, M = indices.shape
|
| 15 |
+
indices = indices.flatten()
|
| 16 |
+
order = mx.argsort(indices)
|
| 17 |
+
inv_order = mx.argsort(order)
|
| 18 |
+
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def scatter_unsort(x, inv_order, shape=None):
|
| 22 |
+
x = x[inv_order]
|
| 23 |
+
if shape is not None:
|
| 24 |
+
x = mx.unflatten(x, 0, shape)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def gather_mm_simulate(x, w, indices):
|
| 29 |
+
x, idx, inv_order = gather_sort(x, indices)
|
| 30 |
+
for i in range(2):
|
| 31 |
+
y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
|
| 32 |
+
x = y[:, None]
|
| 33 |
+
x = scatter_unsort(x, inv_order, indices.shape)
|
| 34 |
+
return x
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def time_gather_mm():
|
| 38 |
+
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
| 39 |
+
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
| 40 |
+
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
| 41 |
+
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
| 42 |
+
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
| 43 |
+
mx.eval(x, w1, w2, indices, sorted_indices)
|
| 44 |
+
|
| 45 |
+
def gather_mm(x, w1, w2, indices, sort):
|
| 46 |
+
idx = indices
|
| 47 |
+
inv_order = None
|
| 48 |
+
if sort:
|
| 49 |
+
x, idx, inv_order = gather_sort(x, indices)
|
| 50 |
+
x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
| 51 |
+
x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
|
| 52 |
+
if sort:
|
| 53 |
+
x = scatter_unsort(x, inv_order, indices.shape)
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
time_fn(gather_mm, x, w1, w2, indices, False)
|
| 57 |
+
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
| 58 |
+
time_fn(gather_mm, x, w1, w2, indices, True)
|
| 59 |
+
|
| 60 |
+
x = mx.random.normal((N * I, D)) / 1024**0.5
|
| 61 |
+
w1 = mx.random.normal((M, D)) / 1024**0.5
|
| 62 |
+
w2 = mx.random.normal((D, M)) / 1024**0.5
|
| 63 |
+
mx.eval(x, w1, w2)
|
| 64 |
+
|
| 65 |
+
def equivalent_matmul(x, w1, w2):
|
| 66 |
+
x = x @ w1.T
|
| 67 |
+
x = x @ w2.T
|
| 68 |
+
return x
|
| 69 |
+
|
| 70 |
+
time_fn(equivalent_matmul, x, w1, w2)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
if __name__ == "__main__":
|
| 74 |
+
time_gather_mm()
|
ml-stable-diffusion/mlx/benchmarks/python/gather_qmm_bench.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2025 Apple Inc.
|
| 2 |
+
|
| 3 |
+
import mlx.core as mx
|
| 4 |
+
from time_utils import time_fn
|
| 5 |
+
|
| 6 |
+
N = 1024
|
| 7 |
+
D = 1024
|
| 8 |
+
M = 1024
|
| 9 |
+
E = 32
|
| 10 |
+
I = 4
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def gather_sort(x, indices):
|
| 14 |
+
N, M = indices.shape
|
| 15 |
+
indices = indices.flatten()
|
| 16 |
+
order = mx.argsort(indices)
|
| 17 |
+
inv_order = mx.argsort(order)
|
| 18 |
+
return x.flatten(0, -3)[order // M], indices[order], inv_order
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def scatter_unsort(x, inv_order, shape=None):
|
| 22 |
+
x = x[inv_order]
|
| 23 |
+
if shape is not None:
|
| 24 |
+
x = mx.unflatten(x, 0, shape)
|
| 25 |
+
return x
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def gather_mm_simulate(x, w, indices):
|
| 29 |
+
x, idx, inv_order = gather_sort(x, indices)
|
| 30 |
+
for i in range(2):
|
| 31 |
+
y = mx.concatenate(
|
| 32 |
+
[
|
| 33 |
+
mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
|
| 34 |
+
for i, j in enumerate(idx.tolist())
|
| 35 |
+
],
|
| 36 |
+
axis=0,
|
| 37 |
+
)
|
| 38 |
+
x = y[:, None]
|
| 39 |
+
x = scatter_unsort(x, inv_order, indices.shape)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def time_gather_qmm():
|
| 44 |
+
x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
|
| 45 |
+
w1 = mx.random.normal((E, M, D)) / 1024**0.5
|
| 46 |
+
w2 = mx.random.normal((E, D, M)) / 1024**0.5
|
| 47 |
+
w1 = mx.quantize(w1)
|
| 48 |
+
w2 = mx.quantize(w2)
|
| 49 |
+
indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
|
| 50 |
+
sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
|
| 51 |
+
mx.eval(x, w1, w2, indices, sorted_indices)
|
| 52 |
+
|
| 53 |
+
def gather_mm(x, w1, w2, indices, sort):
|
| 54 |
+
idx = indices
|
| 55 |
+
inv_order = None
|
| 56 |
+
if sort:
|
| 57 |
+
x, idx, inv_order = gather_sort(x, indices)
|
| 58 |
+
x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
| 59 |
+
x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
|
| 60 |
+
if sort:
|
| 61 |
+
x = scatter_unsort(x, inv_order, indices.shape)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
time_fn(gather_mm, x, w1, w2, indices, False)
|
| 65 |
+
time_fn(gather_mm, x, w1, w2, sorted_indices, False)
|
| 66 |
+
time_fn(gather_mm, x, w1, w2, indices, True)
|
| 67 |
+
|
| 68 |
+
x = mx.random.normal((N * I, D)) / 1024**0.5
|
| 69 |
+
w1 = mx.random.normal((M, D)) / 1024**0.5
|
| 70 |
+
w2 = mx.random.normal((D, M)) / 1024**0.5
|
| 71 |
+
w1 = mx.quantize(w1)
|
| 72 |
+
w2 = mx.quantize(w2)
|
| 73 |
+
mx.eval(x, w1, w2)
|
| 74 |
+
|
| 75 |
+
def equivalent_matmul(x, w1, w2):
|
| 76 |
+
x = mx.quantized_matmul(x, *w1, transpose=True)
|
| 77 |
+
x = mx.quantized_matmul(x, *w2, transpose=True)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
time_fn(equivalent_matmul, x, w1, w2)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if __name__ == "__main__":
|
| 84 |
+
time_gather_qmm()
|
ml-stable-diffusion/mlx/benchmarks/python/hadamard_bench.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
import matplotlib
|
| 4 |
+
import mlx.core as mx
|
| 5 |
+
import numpy as np
|
| 6 |
+
from time_utils import measure_runtime
|
| 7 |
+
|
| 8 |
+
matplotlib.use("Agg")
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def had(x):
|
| 13 |
+
y = mx.hadamard_transform(x)
|
| 14 |
+
mx.eval(y)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def copy(x):
|
| 18 |
+
y = x + 1.0
|
| 19 |
+
mx.eval(y)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def run(dtype):
|
| 23 |
+
system_size = 2**26
|
| 24 |
+
outputs = {}
|
| 25 |
+
for test_fn in (had, copy):
|
| 26 |
+
for m in [1, 12, 20, 28]:
|
| 27 |
+
if test_fn == copy:
|
| 28 |
+
key = "copy"
|
| 29 |
+
elif m == 1:
|
| 30 |
+
key = "had_2^k"
|
| 31 |
+
else:
|
| 32 |
+
key = "had_m*2^k"
|
| 33 |
+
outputs.setdefault(key, {})
|
| 34 |
+
for k in range(7, 14):
|
| 35 |
+
n = m * 2**k
|
| 36 |
+
if n > 2**15:
|
| 37 |
+
continue
|
| 38 |
+
x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
|
| 39 |
+
x = mx.array(x_np)
|
| 40 |
+
runtime_ms = measure_runtime(test_fn, x=x)
|
| 41 |
+
bytes_per_gb = 1e9
|
| 42 |
+
ms_per_s = 1e3
|
| 43 |
+
bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
|
| 44 |
+
bandwidth_gb = (
|
| 45 |
+
system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
|
| 46 |
+
)
|
| 47 |
+
print(n, bandwidth_gb)
|
| 48 |
+
outputs[key][n] = bandwidth_gb
|
| 49 |
+
|
| 50 |
+
colors = {
|
| 51 |
+
"copy": "black",
|
| 52 |
+
"had_2^k": "steelblue",
|
| 53 |
+
"had_m*2^k": "skyblue",
|
| 54 |
+
}
|
| 55 |
+
for key, output in outputs.items():
|
| 56 |
+
plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
|
| 57 |
+
plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
|
| 58 |
+
plt.xlabel("N")
|
| 59 |
+
plt.ylabel("Bandwidth (GB/s)")
|
| 60 |
+
plt.legend()
|
| 61 |
+
plt.savefig(f"bench_{dtype.__name__}.png")
|
| 62 |
+
plt.clf()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
if __name__ == "__main__":
|
| 66 |
+
parser = argparse.ArgumentParser()
|
| 67 |
+
parser.add_argument("--fp16", action="store_true")
|
| 68 |
+
args = parser.parse_args()
|
| 69 |
+
dtype = np.float16 if args.fp16 else np.float32
|
| 70 |
+
run(dtype)
|
ml-stable-diffusion/mlx/benchmarks/python/layer_norm_bench.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright © 2023-2024 Apple Inc.
|
| 2 |
+
|
| 3 |
+
from functools import partial
|
| 4 |
+
|
| 5 |
+
import mlx.core as mx
|
| 6 |
+
import mlx.nn as nn
|
| 7 |
+
from time_utils import time_fn
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def layer_norm(x, w, b, eps):
|
| 11 |
+
ot = x.dtype
|
| 12 |
+
x = x.astype(mx.float32)
|
| 13 |
+
mu = mx.mean(x, -1, keepdims=True)
|
| 14 |
+
v = mx.var(x, -1, keepdims=True)
|
| 15 |
+
y = (x - mu) * mx.rsqrt(v + eps)
|
| 16 |
+
if w is not None:
|
| 17 |
+
y = y * w
|
| 18 |
+
if b is not None:
|
| 19 |
+
y = y + b
|
| 20 |
+
return y
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def time_layer_norm(N, dt):
|
| 24 |
+
L = 1024
|
| 25 |
+
f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
|
| 26 |
+
f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
|
| 27 |
+
g1 = mx.grad(f1, argnums=(0, 1, 2))
|
| 28 |
+
g2 = mx.grad(f2, argnums=(0, 1, 2))
|
| 29 |
+
|
| 30 |
+
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
| 31 |
+
w = mx.random.uniform(shape=(N,)).astype(dt)
|
| 32 |
+
b = mx.random.uniform(shape=(N,)).astype(dt)
|
| 33 |
+
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
| 34 |
+
mx.eval(x, w, b, y)
|
| 35 |
+
|
| 36 |
+
def layer_norm_loop(f, x, w, b):
|
| 37 |
+
for _ in range(32):
|
| 38 |
+
x = f(x, w, b)
|
| 39 |
+
return x
|
| 40 |
+
|
| 41 |
+
time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
|
| 42 |
+
time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
|
| 43 |
+
|
| 44 |
+
def layer_norm_grad_loop(g, x, w, b):
|
| 45 |
+
gx, gw, gb = x, w, b
|
| 46 |
+
for _ in range(32):
|
| 47 |
+
gx, gw, gb = g(gx, gw, gb, y)
|
| 48 |
+
return gx, gw, gb
|
| 49 |
+
|
| 50 |
+
time_fn(layer_norm_grad_loop, g1, x, w, b)
|
| 51 |
+
time_fn(layer_norm_grad_loop, g2, x, w, b)
|
| 52 |
+
time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
|
| 53 |
+
time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
|
| 54 |
+
|
| 55 |
+
f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
|
| 56 |
+
f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
|
| 57 |
+
g1 = mx.grad(f1, argnums=(0,))
|
| 58 |
+
g2 = mx.grad(f2, argnums=(0,))
|
| 59 |
+
|
| 60 |
+
x = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
| 61 |
+
w = mx.random.uniform(shape=(N,)).astype(dt)
|
| 62 |
+
b = mx.random.uniform(shape=(N,)).astype(dt)
|
| 63 |
+
y = mx.random.uniform(shape=(8, L, N)).astype(dt)
|
| 64 |
+
mx.eval(x, w, b, y)
|
| 65 |
+
|
| 66 |
+
def layer_norm_grad_x_loop(g, x):
|
| 67 |
+
gx = x
|
| 68 |
+
for _ in range(32):
|
| 69 |
+
gx = g(gx, y)
|
| 70 |
+
return gx
|
| 71 |
+
|
| 72 |
+
time_fn(layer_norm_grad_x_loop, g1, x)
|
| 73 |
+
time_fn(layer_norm_grad_x_loop, g2, x)
|
| 74 |
+
time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
|
| 75 |
+
time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
for dt in [mx.float32, mx.float16, mx.bfloat16]:
|
| 80 |
+
for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
|
| 81 |
+
print(dt, n)
|
| 82 |
+
time_layer_norm(n, dt)
|