Fahad-S commited on
Commit
712dbf0
·
verified ·
1 Parent(s): 1e8e70c

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. ml-stable-diffusion/mlx/.circleci/config.yml +579 -0
  3. ml-stable-diffusion/mlx/.clang-format +87 -0
  4. ml-stable-diffusion/mlx/.github/ISSUE_TEMPLATE/bug_report.md +28 -0
  5. ml-stable-diffusion/mlx/.github/pull_request_template.md +12 -0
  6. ml-stable-diffusion/mlx/.github/workflows/pull_request.yml +20 -0
  7. ml-stable-diffusion/mlx/.gitignore +88 -0
  8. ml-stable-diffusion/mlx/.pre-commit-config.yaml +21 -0
  9. ml-stable-diffusion/mlx/ACKNOWLEDGMENTS.md +268 -0
  10. ml-stable-diffusion/mlx/CITATION.cff +24 -0
  11. ml-stable-diffusion/mlx/CMakeLists.txt +353 -0
  12. ml-stable-diffusion/mlx/CODE_OF_CONDUCT.md +132 -0
  13. ml-stable-diffusion/mlx/CONTRIBUTING.md +38 -0
  14. ml-stable-diffusion/mlx/LICENSE +21 -0
  15. ml-stable-diffusion/mlx/MANIFEST.in +6 -0
  16. ml-stable-diffusion/mlx/README.md +121 -0
  17. ml-stable-diffusion/mlx/benchmarks/cpp/CMakeLists.txt +11 -0
  18. ml-stable-diffusion/mlx/benchmarks/cpp/autograd.cpp +39 -0
  19. ml-stable-diffusion/mlx/benchmarks/cpp/compare_devices.cpp +27 -0
  20. ml-stable-diffusion/mlx/benchmarks/cpp/irregular_strides.cpp +201 -0
  21. ml-stable-diffusion/mlx/benchmarks/cpp/single_ops.cpp +288 -0
  22. ml-stable-diffusion/mlx/benchmarks/cpp/time_utils.h +39 -0
  23. ml-stable-diffusion/mlx/benchmarks/numpy/single_ops.py +39 -0
  24. ml-stable-diffusion/mlx/benchmarks/numpy/time_utils.py +20 -0
  25. ml-stable-diffusion/mlx/benchmarks/python/batch_matmul_bench.py +62 -0
  26. ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemm.py +191 -0
  27. ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemv.py +221 -0
  28. ml-stable-diffusion/mlx/benchmarks/python/comparative/README.md +15 -0
  29. ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_mlx.py +519 -0
  30. ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_torch.py +482 -0
  31. ml-stable-diffusion/mlx/benchmarks/python/comparative/compare.py +284 -0
  32. ml-stable-diffusion/mlx/benchmarks/python/compile_bench.py +107 -0
  33. ml-stable-diffusion/mlx/benchmarks/python/conv1d_bench.py +123 -0
  34. ml-stable-diffusion/mlx/benchmarks/python/conv2d_bench_cpu.py +127 -0
  35. ml-stable-diffusion/mlx/benchmarks/python/conv2d_train_bench_cpu.py +143 -0
  36. ml-stable-diffusion/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py +129 -0
  37. ml-stable-diffusion/mlx/benchmarks/python/conv3d_bench_cpu.py +110 -0
  38. ml-stable-diffusion/mlx/benchmarks/python/conv3d_train_bench_cpu.py +143 -0
  39. ml-stable-diffusion/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py +116 -0
  40. ml-stable-diffusion/mlx/benchmarks/python/conv_bench.py +135 -0
  41. ml-stable-diffusion/mlx/benchmarks/python/conv_transpose_bench.py +135 -0
  42. ml-stable-diffusion/mlx/benchmarks/python/conv_unaligned_bench.py +107 -0
  43. ml-stable-diffusion/mlx/benchmarks/python/distributed_bench.py +66 -0
  44. ml-stable-diffusion/mlx/benchmarks/python/einsum_bench.py +84 -0
  45. ml-stable-diffusion/mlx/benchmarks/python/fft_bench.py +118 -0
  46. ml-stable-diffusion/mlx/benchmarks/python/gather_bench.py +52 -0
  47. ml-stable-diffusion/mlx/benchmarks/python/gather_mm_bench.py +74 -0
  48. ml-stable-diffusion/mlx/benchmarks/python/gather_qmm_bench.py +84 -0
  49. ml-stable-diffusion/mlx/benchmarks/python/hadamard_bench.py +70 -0
  50. ml-stable-diffusion/mlx/benchmarks/python/layer_norm_bench.py +82 -0
.gitattributes CHANGED
@@ -120,3 +120,40 @@ ml-stable-diffusion/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_
120
  ml-stable-diffusion/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_float16_original.png filter=lfs diff=lfs merge=lfs -text
121
  ml-stable-diffusion/assets/palette6_cpuandne_readmereel.png filter=lfs diff=lfs merge=lfs -text
122
  ml-stable-diffusion/assets/readme_reel.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  ml-stable-diffusion/assets/mbp/a_high_quality_photo_of_a_surfing_dog.7667.final_float16_original.png filter=lfs diff=lfs merge=lfs -text
121
  ml-stable-diffusion/assets/palette6_cpuandne_readmereel.png filter=lfs diff=lfs merge=lfs -text
122
  ml-stable-diffusion/assets/readme_reel.png filter=lfs diff=lfs merge=lfs -text
123
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/arg_reduce.cpp.o filter=lfs diff=lfs merge=lfs -text
124
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/binary.cpp.o filter=lfs diff=lfs merge=lfs -text
125
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/compiled.cpp.o filter=lfs diff=lfs merge=lfs -text
126
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/compiled_preamble.cpp.o filter=lfs diff=lfs merge=lfs -text
127
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/conv.cpp.o filter=lfs diff=lfs merge=lfs -text
128
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/copy.cpp.o filter=lfs diff=lfs merge=lfs -text
129
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/fft.cpp.o filter=lfs diff=lfs merge=lfs -text
130
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/indexing.cpp.o filter=lfs diff=lfs merge=lfs -text
131
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/jit_compiler.cpp.o filter=lfs diff=lfs merge=lfs -text
132
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/masked_mm.cpp.o filter=lfs diff=lfs merge=lfs -text
133
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/matmul.cpp.o filter=lfs diff=lfs merge=lfs -text
134
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/primitives.cpp.o filter=lfs diff=lfs merge=lfs -text
135
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/quantized.cpp.o filter=lfs diff=lfs merge=lfs -text
136
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/reduce.cpp.o filter=lfs diff=lfs merge=lfs -text
137
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/scan.cpp.o filter=lfs diff=lfs merge=lfs -text
138
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/select.cpp.o filter=lfs diff=lfs merge=lfs -text
139
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/sort.cpp.o filter=lfs diff=lfs merge=lfs -text
140
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/cpu/unary.cpp.o filter=lfs diff=lfs merge=lfs -text
141
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/backend/no_gpu/primitives.cpp.o filter=lfs diff=lfs merge=lfs -text
142
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/compile.cpp.o filter=lfs diff=lfs merge=lfs -text
143
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/distributed/mpi/mpi.cpp.o filter=lfs diff=lfs merge=lfs -text
144
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/distributed/ring/ring.cpp.o filter=lfs diff=lfs merge=lfs -text
145
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/einsum.cpp.o filter=lfs diff=lfs merge=lfs -text
146
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/export.cpp.o filter=lfs diff=lfs merge=lfs -text
147
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/fast.cpp.o filter=lfs diff=lfs merge=lfs -text
148
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/io/gguf.cpp.o filter=lfs diff=lfs merge=lfs -text
149
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/io/load.cpp.o filter=lfs diff=lfs merge=lfs -text
150
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/io/safetensors.cpp.o filter=lfs diff=lfs merge=lfs -text
151
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/linalg.cpp.o filter=lfs diff=lfs merge=lfs -text
152
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/primitives.cpp.o filter=lfs diff=lfs merge=lfs -text
153
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/random.cpp.o filter=lfs diff=lfs merge=lfs -text
154
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/CMakeFiles/mlx.dir/mlx/transforms.cpp.o filter=lfs diff=lfs merge=lfs -text
155
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/python/src/CMakeFiles/nanobind-static.dir/proj/cvl/users/x_fahkh2/caches/pip-build-env-nyl54h73/overlay/lib/python3.10/site-packages/nanobind/src/nb_type.cpp.o filter=lfs diff=lfs merge=lfs -text
156
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/python/src/CMakeFiles/nanobind-static.dir/proj/cvl/users/x_fahkh2/caches/pip-build-env-vont0ixn/overlay/lib/python3.10/site-packages/nanobind/src/nb_type.cpp.o filter=lfs diff=lfs merge=lfs -text
157
+ ml-stable-diffusion/mlx/build/temp.linux-x86_64-cpython-310/mlx.core/python/src/libnanobind-static.a filter=lfs diff=lfs merge=lfs -text
158
+ ml-stable-diffusion/mlx/docs/src/_static/metal_debugger/capture.png filter=lfs diff=lfs merge=lfs -text
159
+ ml-stable-diffusion/mlx/docs/src/_static/metal_debugger/schema.png filter=lfs diff=lfs merge=lfs -text
ml-stable-diffusion/mlx/.circleci/config.yml ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: 2.1
2
+
3
+ orbs:
4
+ apple: ml-explore/pr-approval@0.1.0
5
+
6
+ parameters:
7
+ nightly_build:
8
+ type: boolean
9
+ default: false
10
+ test_release:
11
+ type: boolean
12
+ default: false
13
+
14
+ jobs:
15
+ build_documentation:
16
+ parameters:
17
+ upload-docs:
18
+ type: boolean
19
+ default: false
20
+ macos:
21
+ xcode: "26.0.0"
22
+ resource_class: m4pro.medium
23
+ steps:
24
+ - checkout
25
+ - run:
26
+ name: Install
27
+ command: |
28
+ xcodebuild -downloadComponent MetalToolchain
29
+ brew install python@3.9
30
+ brew install doxygen
31
+ python3.9 -m venv env
32
+ source env/bin/activate
33
+ pip install --upgrade pip
34
+ pip install --upgrade cmake
35
+ pip install -r docs/requirements.txt
36
+ pip install . -v
37
+ - when:
38
+ condition:
39
+ not: << parameters.upload-docs >>
40
+ steps:
41
+ - run:
42
+ name: Build documentation
43
+ command: |
44
+ source env/bin/activate
45
+ cd docs && doxygen && make html O=-W
46
+ - when:
47
+ condition: << parameters.upload-docs >>
48
+ steps:
49
+ - add_ssh_keys:
50
+ fingerprints:
51
+ - "SHA256:OhcVVMovbT0pkgMeiVRyxMnjV9R2t+hKBsNcuxq9h+0"
52
+ - run:
53
+ name: Upload documentation
54
+ command: |
55
+ source env/bin/activate
56
+ git config user.email "mlx@group.apple.com"
57
+ git config user.name "CircleCI Docs"
58
+ git checkout gh-pages
59
+ git rebase main
60
+ cd docs
61
+ git rm -rf build/html
62
+ doxygen && make html O=-W
63
+ git add -f build/html
64
+ git commit -m "rebase"
65
+ git push -f origin gh-pages
66
+
67
+ linux_build_and_test:
68
+ machine:
69
+ image: ubuntu-2204:current
70
+ resource_class: large
71
+ steps:
72
+ - checkout
73
+ - run:
74
+ name: Run style checks
75
+ command: |
76
+ pip install pre-commit
77
+ pre-commit run --all
78
+ if ! git diff --quiet; then echo 'Style checks failed, please install pre-commit and run pre-commit run --all and push the change'; exit 1; fi
79
+ - run:
80
+ name: Install dependencies
81
+ command: |
82
+ export DEBIAN_FRONTEND=noninteractive
83
+ export NEEDRESTART_MODE=a
84
+ sudo apt-get update
85
+ sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
86
+ sudo apt-get install openmpi-bin openmpi-common libopenmpi-dev
87
+ curl -LsSf https://astral.sh/uv/install.sh | sh
88
+ - run:
89
+ name: Install Python package
90
+ command: |
91
+ uv venv
92
+ uv pip install cmake
93
+ DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
94
+ uv pip install -e ".[dev]" -v
95
+ - run:
96
+ name: Generate package stubs
97
+ command: |
98
+ uv pip install typing_extensions
99
+ uv run --no-project setup.py generate_stubs
100
+ - run:
101
+ name: Run Python tests
102
+ command: |
103
+ source .venv/bin/activate
104
+ python -m unittest discover python/tests -v
105
+ mpirun --bind-to none -host localhost:8 -np 8 python python/tests/mpi_test_distributed.py
106
+ mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
107
+ if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
108
+ - run:
109
+ name: Build CPP only
110
+ command: |
111
+ source .venv/bin/activate
112
+ mkdir -p build && cd build
113
+ cmake .. -DMLX_BUILD_METAL=OFF -DCMAKE_BUILD_TYPE=DEBUG
114
+ make -j `nproc`
115
+ - run:
116
+ name: Run CPP tests
117
+ command: ./build/tests/tests
118
+
119
+ mac_build_and_test:
120
+ parameters:
121
+ xcode_version:
122
+ type: string
123
+ default: "26.0.0"
124
+ macosx_deployment_target:
125
+ type: string
126
+ default: ""
127
+ macos:
128
+ xcode: << parameters.xcode_version >>
129
+ environment:
130
+ MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
131
+ resource_class: m4pro.medium
132
+ steps:
133
+ - checkout
134
+ - run:
135
+ name: Install dependencies
136
+ command: |
137
+ xcodebuild -downloadComponent MetalToolchain
138
+ HOMEBREW_NO_AUTO_UPDATE=1 HOMEBREW_NO_INSTALL_CLEANUP=1 \
139
+ brew install openmpi uv
140
+ - run:
141
+ name: Install Python package
142
+ command: |
143
+ uv venv --python 3.9
144
+ uv pip install \
145
+ nanobind==2.4.0 \
146
+ cmake \
147
+ numpy \
148
+ torch \
149
+ tensorflow \
150
+ unittest-xml-reporting
151
+ DEBUG=1 CMAKE_ARGS="-DCMAKE_COMPILE_WARNING_AS_ERROR=ON" \
152
+ uv pip install -e . -v
153
+ - run:
154
+ name: Generate package stubs
155
+ command: |
156
+ uv pip install typing_extensions
157
+ uv run --no-project setup.py generate_stubs
158
+ - run:
159
+ name: Run Python tests
160
+ command: |
161
+ source .venv/bin/activate
162
+ LOW_MEMORY=1 DEVICE=cpu python -m xmlrunner discover -v python/tests -o test-results/cpu
163
+ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 python -m xmlrunner discover -v python/tests -o test-results/gpu
164
+ mpirun --bind-to none -host localhost:8 -np 8 -x DYLD_LIBRARY_PATH=/opt/homebrew/lib/ python python/tests/mpi_test_distributed.py
165
+ mlx.launch --verbose -n 8 python/tests/ring_test_distributed.py -v 2> >(tee -a stderr.log >&2)
166
+ if $(grep "\[WARN\]" stderr.log); then echo "Distributed ring test failed"; exit 1; fi
167
+ - run:
168
+ name: Build example extension
169
+ command: |
170
+ source .venv/bin/activate
171
+ cd examples/extensions
172
+ uv pip install -r requirements.txt
173
+ uv run --no-project setup.py build_ext --inplace
174
+ uv run --no-project python test.py
175
+ - store_test_results:
176
+ path: test-results
177
+ - run:
178
+ name: Build CPP only
179
+ command: |
180
+ source .venv/bin/activate
181
+ mkdir -p build && cd build && cmake .. && make -j `sysctl -n hw.ncpu`
182
+ - run:
183
+ name: Run CPP tests
184
+ command: |
185
+ DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 METAL_DEBUG_ERROR_MODE=0 ./build/tests/tests
186
+ - run:
187
+ name: Build small binary
188
+ command: |
189
+ source .venv/bin/activate
190
+ cd build/
191
+ cmake .. -DCMAKE_BUILD_TYPE=MinSizeRel \
192
+ -DBUILD_SHARED_LIBS=ON \
193
+ -DMLX_BUILD_CPU=OFF \
194
+ -DMLX_BUILD_SAFETENSORS=OFF \
195
+ -DMLX_BUILD_GGUF=OFF \
196
+ -DMLX_METAL_JIT=ON
197
+ make -j `sysctl -n hw.ncpu`
198
+ - run:
199
+ name: Run Python tests with JIT
200
+ command: |
201
+ CMAKE_ARGS="-DMLX_METAL_JIT=ON" \
202
+ uv pip install -e . -v
203
+ LOW_MEMORY=1 DEVICE=gpu METAL_DEVICE_WRAPPER_TYPE=1 \
204
+ METAL_DEBUG_ERROR_MODE=0 \
205
+ uv run --no-project python -m xmlrunner discover \
206
+ -v python/tests \
207
+ -o test-results/gpu_jit
208
+
209
+ cuda_build_and_test:
210
+ parameters:
211
+ image_date:
212
+ type: string
213
+ default: "2023.11.1"
214
+ machine:
215
+ image: "linux-cuda-12:<< parameters.image_date >>"
216
+ resource_class: gpu.nvidia.small.gen2
217
+ steps:
218
+ - checkout
219
+ - restore_cache:
220
+ keys:
221
+ - cuda-<< parameters.image_date >>-{{ arch }}-
222
+ - run:
223
+ name: Install dependencies
224
+ command: |
225
+ sudo apt-get update
226
+ sudo apt-get install libcudnn9-dev-cuda-12
227
+ sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
228
+ sudo apt-get install libnccl2 libnccl-dev
229
+ curl -sL https://github.com/ccache/ccache/releases/download/v4.11.3/ccache-4.11.3-linux-x86_64.tar.xz | tar xJf -
230
+ sudo mv ccache-4.11.3-linux-x86_64/ccache /usr/bin/ccache
231
+ rm -rf ccache-4.11.3-linux-x86_64
232
+ curl -LsSf https://astral.sh/uv/install.sh | sh
233
+ - run:
234
+ name: Set CCache size
235
+ command: ccache --max-size 1G
236
+ - run:
237
+ name: Install Python package
238
+ command: |
239
+ uv venv
240
+ uv pip install cmake
241
+ DEBUG=1 CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_COMPILE_WARNING_AS_ERROR=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
242
+ uv pip install -e ".[dev]" -v
243
+ - run:
244
+ name: Run Python tests
245
+ command: |
246
+ source .venv/bin/activate
247
+ LOW_MEMORY=1 DEVICE=cpu python -m unittest discover python/tests -v
248
+ LOW_MEMORY=1 DEVICE=gpu python -m tests discover python/tests -v
249
+ - run:
250
+ name: Build CPP only
251
+ command: |
252
+ source .venv/bin/activate
253
+ cmake . -B build \
254
+ -DMLX_BUILD_CUDA=ON \
255
+ -DCMAKE_CUDA_COMPILER=`which nvcc` \
256
+ -DCMAKE_BUILD_TYPE=DEBUG
257
+ cmake --build build -j `nproc`
258
+ - run:
259
+ name: Run CPP tests
260
+ command: ./build/tests/tests -sfe="*fft_tests.cpp,*linalg_tests.cpp"
261
+ - run:
262
+ name: CCache report
263
+ command: |
264
+ ccache --show-stats
265
+ ccache --zero-stats
266
+ ccache --cleanup
267
+ - save_cache:
268
+ key: cuda-<< parameters.image_date >>-{{ arch }}-{{ epoch }}
269
+ paths:
270
+ - /home/circleci/.cache/ccache
271
+
272
+ build_release:
273
+ parameters:
274
+ python_version:
275
+ type: string
276
+ default: "3.9"
277
+ xcode_version:
278
+ type: string
279
+ default: "26.0.0"
280
+ build_env:
281
+ type: string
282
+ default: ""
283
+ macosx_deployment_target:
284
+ type: string
285
+ default: ""
286
+ macos:
287
+ xcode: << parameters.xcode_version >>
288
+ resource_class: m4pro.medium
289
+ environment:
290
+ MACOSX_DEPLOYMENT_TARGET: << parameters.macosx_deployment_target >>
291
+ steps:
292
+ - checkout
293
+ - run:
294
+ name: Install dependencies
295
+ command: |
296
+ xcodebuild -downloadComponent MetalToolchain
297
+ mkdir -p ~/miniconda3
298
+ curl https://repo.anaconda.com/miniconda/Miniconda3-latest-MacOSX-arm64.sh -o ~/miniconda3/miniconda.sh
299
+ bash ~/miniconda3/miniconda.sh -b -u -p ~/miniconda3
300
+ rm ~/miniconda3/miniconda.sh
301
+ source ~/miniconda3/bin/activate
302
+ conda init --all
303
+ conda create -n env python=<< parameters.python_version >> -y
304
+ conda activate env
305
+ pip install --upgrade cmake
306
+ pip install nanobind==2.4.0
307
+ pip install --upgrade setuptools
308
+ pip install numpy
309
+ pip install twine
310
+ pip install build
311
+ - run:
312
+ name: Install Python package
313
+ command: |
314
+ conda activate env
315
+ env -u MACOSX_DEPLOYMENT_TARGET DEV_RELEASE=1 \
316
+ pip install . -v
317
+ - run:
318
+ name: Generate package stubs
319
+ command: |
320
+ conda activate env
321
+ pip install typing_extensions
322
+ python setup.py generate_stubs
323
+ - run:
324
+ name: Build Python package
325
+ command: |
326
+ conda activate env
327
+ python setup.py clean --all
328
+ << parameters.build_env >> MLX_BUILD_STAGE=1 python -m build -w
329
+ - when:
330
+ condition:
331
+ equal: ["3.9", << parameters.python_version >>]
332
+ steps:
333
+ - run:
334
+ name: Build common package
335
+ command: |
336
+ conda activate env
337
+ python setup.py clean --all
338
+ << parameters.build_env >> MLX_BUILD_STAGE=2 python -m build -w
339
+ - when:
340
+ condition: << parameters.build_env >>
341
+ steps:
342
+ - run:
343
+ name: Upload package
344
+ command: |
345
+ conda activate env
346
+ twine upload dist/*
347
+ - store_artifacts:
348
+ path: dist/
349
+
350
+ build_linux_release:
351
+ parameters:
352
+ python_version:
353
+ type: string
354
+ default: "3.9"
355
+ build_env:
356
+ type: string
357
+ default: ""
358
+ machine:
359
+ image: ubuntu-2204:current
360
+ resource_class: large
361
+ steps:
362
+ - checkout
363
+ - run:
364
+ name: Build wheel
365
+ command: |
366
+ PYTHON=python<< parameters.python_version >>
367
+ export DEBIAN_FRONTEND=noninteractive
368
+ export NEEDRESTART_MODE=a
369
+ sudo apt-get update
370
+ TZ=Etc/UTC sudo apt-get -y install tzdata
371
+ sudo add-apt-repository -y ppa:deadsnakes/ppa
372
+ sudo apt-get install -y $PYTHON $PYTHON-dev $PYTHON-full
373
+ sudo apt-get install -y libblas-dev liblapack-dev liblapacke-dev
374
+ $PYTHON -m venv env
375
+ source env/bin/activate
376
+ pip install --upgrade pip
377
+ pip install --upgrade cmake
378
+ pip install auditwheel
379
+ pip install patchelf
380
+ pip install build
381
+ pip install twine
382
+ << parameters.build_env >> pip install ".[dev]" -v
383
+ pip install typing_extensions
384
+ python setup.py generate_stubs
385
+ python setup.py clean --all
386
+ MLX_BUILD_STAGE=1 << parameters.build_env >> python -m build -w
387
+ bash python/scripts/repair_linux.sh
388
+ - when:
389
+ condition:
390
+ equal: ["3.9", << parameters.python_version >>]
391
+ steps:
392
+ - run:
393
+ name: Build common package
394
+ command: |
395
+ source env/bin/activate
396
+ python setup.py clean --all
397
+ << parameters.build_env >> MLX_BUILD_STAGE=2 \
398
+ python -m build -w
399
+ auditwheel repair dist/mlx_cpu*.whl --plat manylinux_2_35_x86_64
400
+ - when:
401
+ condition: << parameters.build_env >>
402
+ steps:
403
+ - run:
404
+ name: Upload packages
405
+ command: |
406
+ source env/bin/activate
407
+ twine upload wheelhouse/*.whl
408
+ - store_artifacts:
409
+ path: wheelhouse/
410
+
411
+ build_cuda_release:
412
+ parameters:
413
+ build_env:
414
+ type: string
415
+ default: ""
416
+ machine:
417
+ image: ubuntu-2204:current
418
+ resource_class: xlarge
419
+ steps:
420
+ - checkout
421
+ - run:
422
+ name: Build wheel
423
+ command: |
424
+ export DEBIAN_FRONTEND=noninteractive
425
+ export NEEDRESTART_MODE=a
426
+ wget https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2404/x86_64/cuda-keyring_1.1-1_all.deb
427
+ sudo dpkg -i cuda-keyring_1.1-1_all.deb
428
+ sudo apt-get update
429
+ sudo apt-get install cuda-toolkit-12-9 libcudnn9-dev-cuda-12
430
+ sudo apt-get install libblas-dev liblapack-dev liblapacke-dev
431
+ sudo apt-get install zip
432
+ pip install auditwheel
433
+ pip install patchelf
434
+ pip install build
435
+ pip install twine
436
+ export PATH=/usr/local/cuda/bin${PATH:+:${PATH}}
437
+ export LD_LIBRARY_PATH=/usr/local/cuda/lib64${LD_LIBRARY_PATH:+:${LD_LIBRARY_PATH}}
438
+ << parameters.build_env >> MLX_BUILD_STAGE=2 \
439
+ CMAKE_ARGS="-DMLX_BUILD_CUDA=ON -DCMAKE_CUDA_COMPILER=`which nvcc`" \
440
+ python -m build -w
441
+ bash python/scripts/repair_cuda.sh
442
+ - when:
443
+ condition: << parameters.build_env >>
444
+ steps:
445
+ - run:
446
+ name: Upload package
447
+ command: |
448
+ twine upload wheelhouse/*.whl
449
+ - store_artifacts:
450
+ path: wheelhouse/
451
+
452
+ workflows:
453
+ build_and_test:
454
+ when:
455
+ and:
456
+ - matches:
457
+ pattern: "^(?!pull/)[-\\w]+$"
458
+ value: << pipeline.git.branch >>
459
+ - not: << pipeline.parameters.nightly_build >>
460
+ - not: << pipeline.parameters.test_release >>
461
+ jobs:
462
+ - mac_build_and_test:
463
+ matrix:
464
+ parameters:
465
+ macosx_deployment_target: ["13.5", "15.0"]
466
+ - linux_build_and_test
467
+ - cuda_build_and_test:
468
+ matrix:
469
+ parameters:
470
+ image_date: ["2023.11.1", "2025.05.1"]
471
+ - build_documentation
472
+
473
+ build_pypi_release:
474
+ when:
475
+ and:
476
+ - not: << pipeline.parameters.nightly_build >>
477
+ - not: << pipeline.parameters.test_release >>
478
+ jobs:
479
+ - build_release:
480
+ filters:
481
+ tags:
482
+ only: /^v.*/
483
+ branches:
484
+ ignore: /.*/
485
+ matrix:
486
+ parameters:
487
+ python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
488
+ macosx_deployment_target: ["13.5", "14.0", "15.0"]
489
+ build_env: ["PYPI_RELEASE=1"]
490
+ xcode_version: ["26.0.0"]
491
+ - build_documentation:
492
+ filters:
493
+ tags:
494
+ only: /^v.*/
495
+ branches:
496
+ ignore: /.*/
497
+ upload-docs: true
498
+ - build_linux_release:
499
+ filters:
500
+ tags:
501
+ only: /^v.*/
502
+ branches:
503
+ ignore: /.*/
504
+ matrix:
505
+ parameters:
506
+ python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
507
+ build_env: ["PYPI_RELEASE=1"]
508
+ - build_cuda_release:
509
+ filters:
510
+ tags:
511
+ only: /^v.*/
512
+ branches:
513
+ ignore: /.*/
514
+ matrix:
515
+ parameters:
516
+ build_env: ["PYPI_RELEASE=1"]
517
+
518
+ prb:
519
+ when:
520
+ matches:
521
+ pattern: "^pull/\\d+(/head)?$"
522
+ value: << pipeline.git.branch >>
523
+ jobs:
524
+ - hold:
525
+ type: approval
526
+ - apple/authenticate:
527
+ context: pr-approval
528
+ - mac_build_and_test:
529
+ requires: [ hold ]
530
+ matrix:
531
+ parameters:
532
+ macosx_deployment_target: ["13.5", "15.0"]
533
+ - linux_build_and_test:
534
+ requires: [ hold ]
535
+ - cuda_build_and_test:
536
+ requires: [ hold ]
537
+ matrix:
538
+ parameters:
539
+ image_date: ["2023.11.1", "2025.05.1"]
540
+ nightly_build:
541
+ when:
542
+ and:
543
+ - equal: [ main, << pipeline.git.branch >> ]
544
+ - << pipeline.parameters.nightly_build >>
545
+ jobs:
546
+ - build_release:
547
+ matrix:
548
+ parameters:
549
+ python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
550
+ macosx_deployment_target: ["13.5", "14.0", "15.0"]
551
+ xcode_version: ["26.0.0"]
552
+ - build_linux_release:
553
+ matrix:
554
+ parameters:
555
+ python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
556
+ - build_cuda_release
557
+
558
+ build_dev_release:
559
+ when:
560
+ and:
561
+ - equal: [ main, << pipeline.git.branch >> ]
562
+ - << pipeline.parameters.test_release >>
563
+ jobs:
564
+ - build_release:
565
+ matrix:
566
+ parameters:
567
+ python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
568
+ macosx_deployment_target: ["13.5", "14.0", "15.0"]
569
+ build_env: ["DEV_RELEASE=1"]
570
+ xcode_version: ["26.0.0"]
571
+ - build_linux_release:
572
+ matrix:
573
+ parameters:
574
+ python_version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
575
+ build_env: ["DEV_RELEASE=1"]
576
+ - build_cuda_release:
577
+ matrix:
578
+ parameters:
579
+ build_env: ["DEV_RELEASE=1"]
ml-stable-diffusion/mlx/.clang-format ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ AccessModifierOffset: -1
3
+ AlignAfterOpenBracket: AlwaysBreak
4
+ AlignConsecutiveAssignments: false
5
+ AlignConsecutiveDeclarations: false
6
+ AlignEscapedNewlinesLeft: true
7
+ AlignOperands: false
8
+ AlignTrailingComments: false
9
+ AllowAllParametersOfDeclarationOnNextLine: false
10
+ AllowShortBlocksOnASingleLine: false
11
+ AllowShortCaseLabelsOnASingleLine: false
12
+ AllowShortFunctionsOnASingleLine: Empty
13
+ AllowShortIfStatementsOnASingleLine: false
14
+ AllowShortLoopsOnASingleLine: false
15
+ AlwaysBreakAfterReturnType: None
16
+ AlwaysBreakBeforeMultilineStrings: true
17
+ AlwaysBreakTemplateDeclarations: true
18
+ BinPackArguments: false
19
+ BinPackParameters: false
20
+ BraceWrapping:
21
+ AfterClass: false
22
+ AfterControlStatement: false
23
+ AfterEnum: false
24
+ AfterFunction: false
25
+ AfterNamespace: false
26
+ AfterObjCDeclaration: false
27
+ AfterStruct: false
28
+ AfterUnion: false
29
+ BeforeCatch: false
30
+ BeforeElse: false
31
+ IndentBraces: false
32
+ BreakBeforeBinaryOperators: None
33
+ BreakBeforeBraces: Attach
34
+ BreakBeforeTernaryOperators: true
35
+ BreakConstructorInitializersBeforeComma: false
36
+ BreakAfterJavaFieldAnnotations: false
37
+ BreakStringLiterals: false
38
+ ColumnLimit: 80
39
+ CommentPragmas: '^ IWYU pragma:'
40
+ ConstructorInitializerAllOnOneLineOrOnePerLine: true
41
+ ConstructorInitializerIndentWidth: 4
42
+ ContinuationIndentWidth: 4
43
+ Cpp11BracedListStyle: true
44
+ DerivePointerAlignment: false
45
+ DisableFormat: false
46
+ ForEachMacros: [ FOR_EACH, FOR_EACH_R, FOR_EACH_RANGE, ]
47
+ IncludeCategories:
48
+ - Regex: '^<.*\.h(pp)?>'
49
+ Priority: 1
50
+ - Regex: '^<.*'
51
+ Priority: 2
52
+ - Regex: '.*'
53
+ Priority: 3
54
+ IndentCaseLabels: true
55
+ IndentWidth: 2
56
+ IndentWrappedFunctionNames: false
57
+ KeepEmptyLinesAtTheStartOfBlocks: false
58
+ MacroBlockBegin: ''
59
+ MacroBlockEnd: ''
60
+ MaxEmptyLinesToKeep: 1
61
+ NamespaceIndentation: None
62
+ ObjCBlockIndentWidth: 2
63
+ ObjCSpaceAfterProperty: false
64
+ ObjCSpaceBeforeProtocolList: false
65
+ PenaltyBreakBeforeFirstCallParameter: 1
66
+ PenaltyBreakComment: 300
67
+ PenaltyBreakFirstLessLess: 120
68
+ PenaltyBreakString: 1000
69
+ PenaltyExcessCharacter: 1000000
70
+ PenaltyReturnTypeOnItsOwnLine: 200
71
+ PointerAlignment: Left
72
+ ReflowComments: true
73
+ SortIncludes: true
74
+ SpaceAfterCStyleCast: false
75
+ SpaceBeforeAssignmentOperators: true
76
+ SpaceBeforeParens: ControlStatements
77
+ SpaceInEmptyParentheses: false
78
+ SpacesBeforeTrailingComments: 1
79
+ SpacesInAngles: false
80
+ SpacesInContainerLiterals: true
81
+ SpacesInCStyleCastParentheses: false
82
+ SpacesInParentheses: false
83
+ SpacesInSquareBrackets: false
84
+ Standard: Cpp11
85
+ TabWidth: 8
86
+ UseTab: Never
87
+ ...
ml-stable-diffusion/mlx/.github/ISSUE_TEMPLATE/bug_report.md ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: Bug report
3
+ about: Create a report about an issue you've encountered
4
+ title: "[BUG] "
5
+ labels: ''
6
+ assignees: ''
7
+
8
+ ---
9
+
10
+ **Describe the bug**
11
+ A clear and concise description of what the bug is.
12
+
13
+ **To Reproduce**
14
+
15
+ Include code snippet
16
+ ```python
17
+
18
+ ```
19
+
20
+ **Expected behavior**
21
+ A clear and concise description of what you expected to happen.
22
+
23
+ **Desktop (please complete the following information):**
24
+ - OS Version: [e.g. MacOS 14.1.2]
25
+ - Version [e.g. 0.7.0]
26
+
27
+ **Additional context**
28
+ Add any other context about the problem here.
ml-stable-diffusion/mlx/.github/pull_request_template.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Proposed changes
2
+
3
+ Please include a description of the problem or feature this PR is addressing. If there is a corresponding issue, include the issue #.
4
+
5
+ ## Checklist
6
+
7
+ Put an `x` in the boxes that apply.
8
+
9
+ - [ ] I have read the [CONTRIBUTING](https://github.com/ml-explore/mlx/blob/main/CONTRIBUTING.md) document
10
+ - [ ] I have run `pre-commit run --all-files` to format my code / installed pre-commit prior to committing changes
11
+ - [ ] I have added tests that prove my fix is effective or that my feature works
12
+ - [ ] I have updated the necessary documentation (if needed)
ml-stable-diffusion/mlx/.github/workflows/pull_request.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ on:
2
+ pull_request:
3
+ branches:
4
+ - main
5
+
6
+ jobs:
7
+ check_lint:
8
+ runs-on: ubuntu-latest
9
+ steps:
10
+ - uses: actions/checkout@v4
11
+ - uses: actions/setup-python@v4
12
+ with:
13
+ python-version: 3.8
14
+ - name: Install dependencies
15
+ run: |
16
+ python -m pip install --upgrade pip
17
+ pip install pre-commit black isort clang-format
18
+ - name: Run lint
19
+ run: |
20
+ pre-commit run --all-files
ml-stable-diffusion/mlx/.gitignore ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # tensor files
10
+ *.safe
11
+ *.safetensors
12
+
13
+ # Metal libraries
14
+ *.metallib
15
+ venv/
16
+
17
+ # Distribution / packaging
18
+ python/mlx/core
19
+ python/mlx/share
20
+ python/mlx/include
21
+ .Python
22
+ build/
23
+ develop-eggs/
24
+ dist/
25
+ downloads/
26
+ eggs/
27
+ .eggs/
28
+ lib/
29
+ lib64/
30
+ parts/
31
+ sdist/
32
+ var/
33
+ wheels/
34
+ share/python-wheels/
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+ uv.lock
40
+
41
+ # vim
42
+ *.swp
43
+
44
+ # Ignore build dir
45
+ build/
46
+
47
+ # Prerequisites
48
+ *.d
49
+
50
+ # Compiled Object files
51
+ *.slo
52
+ *.lo
53
+ *.o
54
+ *.obj
55
+
56
+ # Precompiled Headers
57
+ *.gch
58
+ *.pch
59
+
60
+ # Compiled Dynamic libraries
61
+ *.so
62
+ *.dylib
63
+ *.dll
64
+
65
+ # Fortran module files
66
+ *.mod
67
+ *.smod
68
+
69
+ # Compiled Static libraries
70
+ *.lai
71
+ *.la
72
+ *.a
73
+ *.lib
74
+
75
+ # Executables
76
+ *.exe
77
+ *.out
78
+ *.app
79
+
80
+ # Debug symbols
81
+ *.pdb
82
+
83
+ # VSCode
84
+ .vscode/
85
+ .DS_Store
86
+
87
+ # Jetbrains
88
+ .cache
ml-stable-diffusion/mlx/.pre-commit-config.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/mirrors-clang-format
3
+ rev: v19.1.7
4
+ hooks:
5
+ - id: clang-format
6
+ # Using this mirror lets us use mypyc-compiled black, which is about 2x faster
7
+ - repo: https://github.com/psf/black-pre-commit-mirror
8
+ rev: 25.1.0
9
+ hooks:
10
+ - id: black
11
+
12
+ - repo: https://github.com/pycqa/isort
13
+ rev: 6.0.0
14
+ hooks:
15
+ - id: isort
16
+ args:
17
+ - --profile=black
18
+ - repo: https://github.com/cheshirekow/cmake-format-precommit
19
+ rev: v0.6.13
20
+ hooks:
21
+ - id: cmake-format
ml-stable-diffusion/mlx/ACKNOWLEDGMENTS.md ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Individual Contributors
2
+
3
+ If you wish to be acknowledged for your contributions, please list your name
4
+ with a short description of your contribution(s) below. For example:
5
+
6
+ - Jane Smith: Added the `foo` and `bar` ops.
7
+
8
+ MLX was developed with contributions from the following individuals:
9
+
10
+ - Nripesh Niketan: Added `softsign`, `softmax`, `hardswish`, `logsoftmax` activation functions. Added `dropout3d` ops. Added `LogicalAnd` and `LogicalOR` ops. Added `clip_grad_norm` along with `tree_reduce`. Added `cross`. Added `orthogonal` initializer.
11
+ - Juarez Bochi: Fixed bug in cross attention.
12
+ - Justin Deschenaux: Sine, Cosine, arange, randint, truncated normal, bernoulli, lion optimizer, Dropout2d, linear and logistic regression python example.
13
+ - Diogo Da Cruz: Added `tri`, `tril`, `triu`, `tensordot`, `inner`, `outer`, `tile`, `StreamContext`, `stream`, safetensors support, `einsum`, and `einsum_path`.
14
+ - Gabrijel Boduljak: Added `mlx.core.linalg`, implemented `norm` method and `InstanceNorm` layer. Implemented pooling layers and ``Upsample``.
15
+ - Hinrik Snær Guðmundsson: Added `atleast_1d`, `atleast_2d`, `atleast_3d` ops.
16
+ - Luca Arnaboldi: Added `Ceil` and `Floor` ops; implemented pickling, copy and deepcopy for mlx arrays.
17
+ - Brian Keene & Atila Orhon, with Argmax Inc.: Added `fast.scaled_dot_product_attention`
18
+ - AmirHossein Razlighi: Added chaining support for some of the ops in `nn.Module`. Comparison works for non array objects in `mlx.core.array`. Exception handling for invalid operations in `mlx.core.array`.
19
+ - Gleb Pobudzey: Added the `where` primitive, and groups in 1D and 2D convolutions.
20
+ - Paul Paczuski: Improved stability of BCE loss calculation
21
+ - Max-Heinrich Laves: Added `conv_transpose1d`, `conv_transpose2d`, and `conv_transpose3d` ops.
22
+ - Gökdeniz Gülmez: Added the `Muon (MomentUm Orthogonalized by Newton-schulz)` optimizer, and the `ReLU²` activation function.
23
+
24
+ <a href="https://github.com/ml-explore/mlx/graphs/contributors">
25
+ <img class="dark-light" src="https://contrib.rocks/image?repo=ml-explore/mlx&anon=0&columns=20&max=100&r=true" />
26
+ </a>
27
+
28
+ # Organizations
29
+
30
+ MLX has received contributions from the following companies:
31
+ - NVIDIA Corporation & Affiliates
32
+
33
+ # Third-Party Software
34
+
35
+ MLX leverages several third-party software, listed here together with
36
+ their license copied verbatim.
37
+
38
+ ## PocketFFT
39
+
40
+ Copyright (C) 2010-2018 Max-Planck-Society
41
+ All rights reserved.
42
+
43
+ Redistribution and use in source and binary forms, with or without modification,
44
+ are permitted provided that the following conditions are met:
45
+
46
+ * Redistributions of source code must retain the above copyright notice, this
47
+ list of conditions and the following disclaimer.
48
+ * Redistributions in binary form must reproduce the above copyright notice, this
49
+ list of conditions and the following disclaimer in the documentation and/or
50
+ other materials provided with the distribution.
51
+ * Neither the name of the copyright holder nor the names of its contributors may
52
+ be used to endorse or promote products derived from this software without
53
+ specific prior written permission.
54
+
55
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
56
+ ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
57
+ WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
58
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
59
+ ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
60
+ (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
61
+ LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
62
+ ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
63
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
64
+ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
65
+
66
+ ## metal-cpp
67
+
68
+ Apache License
69
+ Version 2.0, January 2004
70
+ http://www.apache.org/licenses/
71
+
72
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
73
+
74
+ 1. Definitions.
75
+
76
+ "License" shall mean the terms and conditions for use, reproduction,
77
+ and distribution as defined by Sections 1 through 9 of this document.
78
+
79
+ "Licensor" shall mean the copyright owner or entity authorized by
80
+ the copyright owner that is granting the License.
81
+
82
+ "Legal Entity" shall mean the union of the acting entity and all
83
+ other entities that control, are controlled by, or are under common
84
+ control with that entity. For the purposes of this definition,
85
+ "control" means (i) the power, direct or indirect, to cause the
86
+ direction or management of such entity, whether by contract or
87
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
88
+ outstanding shares, or (iii) beneficial ownership of such entity.
89
+
90
+ "You" (or "Your") shall mean an individual or Legal Entity
91
+ exercising permissions granted by this License.
92
+
93
+ "Source" form shall mean the preferred form for making modifications,
94
+ including but not limited to software source code, documentation
95
+ source, and configuration files.
96
+
97
+ "Object" form shall mean any form resulting from mechanical
98
+ transformation or translation of a Source form, including but
99
+ not limited to compiled object code, generated documentation,
100
+ and conversions to other media types.
101
+
102
+ "Work" shall mean the work of authorship, whether in Source or
103
+ Object form, made available under the License, as indicated by a
104
+ copyright notice that is included in or attached to the work
105
+ (an example is provided in the Appendix below).
106
+
107
+ "Derivative Works" shall mean any work, whether in Source or Object
108
+ form, that is based on (or derived from) the Work and for which the
109
+ editorial revisions, annotations, elaborations, or other modifications
110
+ represent, as a whole, an original work of authorship. For the purposes
111
+ of this License, Derivative Works shall not include works that remain
112
+ separable from, or merely link (or bind by name) to the interfaces of,
113
+ the Work and Derivative Works thereof.
114
+
115
+ "Contribution" shall mean any work of authorship, including
116
+ the original version of the Work and any modifications or additions
117
+ to that Work or Derivative Works thereof, that is intentionally
118
+ submitted to Licensor for inclusion in the Work by the copyright owner
119
+ or by an individual or Legal Entity authorized to submit on behalf of
120
+ the copyright owner. For the purposes of this definition, "submitted"
121
+ means any form of electronic, verbal, or written communication sent
122
+ to the Licensor or its representatives, including but not limited to
123
+ communication on electronic mailing lists, source code control systems,
124
+ and issue tracking systems that are managed by, or on behalf of, the
125
+ Licensor for the purpose of discussing and improving the Work, but
126
+ excluding communication that is conspicuously marked or otherwise
127
+ designated in writing by the copyright owner as "Not a Contribution."
128
+
129
+ "Contributor" shall mean Licensor and any individual or Legal Entity
130
+ on behalf of whom a Contribution has been received by Licensor and
131
+ subsequently incorporated within the Work.
132
+
133
+ 2. Grant of Copyright License. Subject to the terms and conditions of
134
+ this License, each Contributor hereby grants to You a perpetual,
135
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
136
+ copyright license to reproduce, prepare Derivative Works of,
137
+ publicly display, publicly perform, sublicense, and distribute the
138
+ Work and such Derivative Works in Source or Object form.
139
+
140
+ 3. Grant of Patent License. Subject to the terms and conditions of
141
+ this License, each Contributor hereby grants to You a perpetual,
142
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
143
+ (except as stated in this section) patent license to make, have made,
144
+ use, offer to sell, sell, import, and otherwise transfer the Work,
145
+ where such license applies only to those patent claims licensable
146
+ by such Contributor that are necessarily infringed by their
147
+ Contribution(s) alone or by combination of their Contribution(s)
148
+ with the Work to which such Contribution(s) was submitted. If You
149
+ institute patent litigation against any entity (including a
150
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
151
+ or a Contribution incorporated within the Work constitutes direct
152
+ or contributory patent infringement, then any patent licenses
153
+ granted to You under this License for that Work shall terminate
154
+ as of the date such litigation is filed.
155
+
156
+ 4. Redistribution. You may reproduce and distribute copies of the
157
+ Work or Derivative Works thereof in any medium, with or without
158
+ modifications, and in Source or Object form, provided that You
159
+ meet the following conditions:
160
+
161
+ (a) You must give any other recipients of the Work or
162
+ Derivative Works a copy of this License; and
163
+
164
+ (b) You must cause any modified files to carry prominent notices
165
+ stating that You changed the files; and
166
+
167
+ (c) You must retain, in the Source form of any Derivative Works
168
+ that You distribute, all copyright, patent, trademark, and
169
+ attribution notices from the Source form of the Work,
170
+ excluding those notices that do not pertain to any part of
171
+ the Derivative Works; and
172
+
173
+ (d) If the Work includes a "NOTICE" text file as part of its
174
+ distribution, then any Derivative Works that You distribute must
175
+ include a readable copy of the attribution notices contained
176
+ within such NOTICE file, excluding those notices that do not
177
+ pertain to any part of the Derivative Works, in at least one
178
+ of the following places: within a NOTICE text file distributed
179
+ as part of the Derivative Works; within the Source form or
180
+ documentation, if provided along with the Derivative Works; or,
181
+ within a display generated by the Derivative Works, if and
182
+ wherever such third-party notices normally appear. The contents
183
+ of the NOTICE file are for informational purposes only and
184
+ do not modify the License. You may add Your own attribution
185
+ notices within Derivative Works that You distribute, alongside
186
+ or as an addendum to the NOTICE text from the Work, provided
187
+ that such additional attribution notices cannot be construed
188
+ as modifying the License.
189
+
190
+ You may add Your own copyright statement to Your modifications and
191
+ may provide additional or different license terms and conditions
192
+ for use, reproduction, or distribution of Your modifications, or
193
+ for any such Derivative Works as a whole, provided Your use,
194
+ reproduction, and distribution of the Work otherwise complies with
195
+ the conditions stated in this License.
196
+
197
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
198
+ any Contribution intentionally submitted for inclusion in the Work
199
+ by You to the Licensor shall be under the terms and conditions of
200
+ this License, without any additional terms or conditions.
201
+ Notwithstanding the above, nothing herein shall supersede or modify
202
+ the terms of any separate license agreement you may have executed
203
+ with Licensor regarding such Contributions.
204
+
205
+ 6. Trademarks. This License does not grant permission to use the trade
206
+ names, trademarks, service marks, or product names of the Licensor,
207
+ except as required for reasonable and customary use in describing the
208
+ origin of the Work and reproducing the content of the NOTICE file.
209
+
210
+ 7. Disclaimer of Warranty. Unless required by applicable law or
211
+ agreed to in writing, Licensor provides the Work (and each
212
+ Contributor provides its Contributions) on an "AS IS" BASIS,
213
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
214
+ implied, including, without limitation, any warranties or conditions
215
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
216
+ PARTICULAR PURPOSE. You are solely responsible for determining the
217
+ appropriateness of using or redistributing the Work and assume any
218
+ risks associated with Your exercise of permissions under this License.
219
+
220
+ 8. Limitation of Liability. In no event and under no legal theory,
221
+ whether in tort (including negligence), contract, or otherwise,
222
+ unless required by applicable law (such as deliberate and grossly
223
+ negligent acts) or agreed to in writing, shall any Contributor be
224
+ liable to You for damages, including any direct, indirect, special,
225
+ incidental, or consequential damages of any character arising as a
226
+ result of this License or out of the use or inability to use the
227
+ Work (including but not limited to damages for loss of goodwill,
228
+ work stoppage, computer failure or malfunction, or any and all
229
+ other commercial damages or losses), even if such Contributor
230
+ has been advised of the possibility of such damages.
231
+
232
+ 9. Accepting Warranty or Additional Liability. While redistributing
233
+ the Work or Derivative Works thereof, You may choose to offer,
234
+ and charge a fee for, acceptance of support, warranty, indemnity,
235
+ or other liability obligations and/or rights consistent with this
236
+ License. However, in accepting such obligations, You may act only
237
+ on Your own behalf and on Your sole responsibility, not on behalf
238
+ of any other Contributor, and only if You agree to indemnify,
239
+ defend, and hold each Contributor harmless for any liability
240
+ incurred by, or claims asserted against, such Contributor by reason
241
+ of your accepting any such warranty or additional liability.
242
+
243
+ END OF TERMS AND CONDITIONS
244
+
245
+ APPENDIX: How to apply the Apache License to your work.
246
+
247
+ To apply the Apache License to your work, attach the following
248
+ boilerplate notice, with the fields enclosed by brackets "[]"
249
+ replaced with your own identifying information. (Don't include
250
+ the brackets!) The text should be enclosed in the appropriate
251
+ comment syntax for the file format. We also recommend that a
252
+ file or class name and description of purpose be included on the
253
+ same "printed page" as the copyright notice for easier
254
+ identification within third-party archives.
255
+
256
+ Copyright © 2023 Apple Inc.
257
+
258
+ Licensed under the Apache License, Version 2.0 (the "License");
259
+ you may not use this file except in compliance with the License.
260
+ You may obtain a copy of the License at
261
+
262
+ http://www.apache.org/licenses/LICENSE-2.0
263
+
264
+ Unless required by applicable law or agreed to in writing, software
265
+ distributed under the License is distributed on an "AS IS" BASIS,
266
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
267
+ See the License for the specific language governing permissions and
268
+ limitations under the License.
ml-stable-diffusion/mlx/CITATION.cff ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ title: mlx
3
+ message: >-
4
+ If you use this software, please cite it using the
5
+ metadata from this file.
6
+ type: software
7
+ authors:
8
+ - given-names: Awni
9
+ family-names: Hannun
10
+ affiliation: Apple
11
+ - given-names: Jagrit
12
+ family-names: Digani
13
+ affiliation: Apple
14
+ - given-names: Angelos
15
+ family-names: Katharopoulos
16
+ affiliation: Apple
17
+ - given-names: Ronan
18
+ family-names: Collobert
19
+ affiliation: Apple
20
+ repository-code: 'https://github.com/ml-explore'
21
+ abstract: >-
22
+ MLX: efficient and flexible machine learning on Apple
23
+ silicon
24
+ license: MIT
ml-stable-diffusion/mlx/CMakeLists.txt ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cmake_minimum_required(VERSION 3.25)
2
+
3
+ if(NOT MLX_VERSION)
4
+ file(STRINGS "mlx/version.h" _mlx_h_version REGEX "^#define MLX_VERSION_.*$")
5
+ string(REGEX MATCH "#define MLX_VERSION_MAJOR ([0-9]+)" _ "${_mlx_h_version}")
6
+ set(_major ${CMAKE_MATCH_1})
7
+ string(REGEX MATCH "#define MLX_VERSION_MINOR ([0-9]+)" _ "${_mlx_h_version}")
8
+ set(_minor ${CMAKE_MATCH_1})
9
+ string(REGEX MATCH "#define MLX_VERSION_PATCH ([0-9]+)" _ "${_mlx_h_version}")
10
+ set(_patch ${CMAKE_MATCH_1})
11
+ set(MLX_PROJECT_VERSION "${_major}.${_minor}.${_patch}")
12
+ set(MLX_VERSION ${MLX_PROJECT_VERSION})
13
+ else()
14
+ string(REGEX REPLACE "^([0-9]+\.[0-9]+\.[0-9]+).*" "\\1" MLX_PROJECT_VERSION
15
+ ${MLX_VERSION})
16
+ endif()
17
+
18
+ project(
19
+ mlx
20
+ LANGUAGES C CXX
21
+ VERSION ${MLX_PROJECT_VERSION})
22
+
23
+ # ----------------------------- Setup -----------------------------
24
+ set(CMAKE_MODULE_PATH "${PROJECT_SOURCE_DIR}/cmake")
25
+ set(CMAKE_CXX_STANDARD 17)
26
+ set(CMAKE_CXX_STANDARD_REQUIRED ON)
27
+ set(CMAKE_POSITION_INDEPENDENT_CODE ON)
28
+ set(CMAKE_INSTALL_MESSAGE NEVER)
29
+ set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
30
+
31
+ # ----------------------------- Configuration -----------------------------
32
+ option(MLX_BUILD_TESTS "Build tests for mlx" ON)
33
+ option(MLX_BUILD_EXAMPLES "Build examples for mlx" ON)
34
+ option(MLX_BUILD_BENCHMARKS "Build benchmarks for mlx" OFF)
35
+ option(MLX_BUILD_PYTHON_BINDINGS "Build python bindings for mlx" OFF)
36
+ option(MLX_BUILD_METAL "Build metal backend" ON)
37
+ option(MLX_BUILD_CPU "Build cpu backend" ON)
38
+ option(MLX_BUILD_CUDA "Build cuda backend" OFF)
39
+ option(MLX_METAL_DEBUG "Enhance metal debug workflow" OFF)
40
+ option(MLX_ENABLE_X64_MAC "Enable building for x64 macOS" OFF)
41
+ option(MLX_BUILD_GGUF "Include support for GGUF format" ON)
42
+ option(MLX_BUILD_SAFETENSORS "Include support for safetensors format" ON)
43
+ option(MLX_BUILD_BLAS_FROM_SOURCE "Build OpenBLAS from source code" OFF)
44
+ option(MLX_METAL_JIT "Use JIT compilation for Metal kernels" OFF)
45
+ option(MLX_USE_CCACHE "Use CCache for compilation cache when available" ON)
46
+ option(BUILD_SHARED_LIBS "Build mlx as a shared library" OFF)
47
+ option(USE_SYSTEM_FMT "Use system's provided fmt library" OFF)
48
+
49
+ # --------------------- Processor tests -------------------------
50
+ message(
51
+ STATUS
52
+ "Building MLX for ${CMAKE_SYSTEM_PROCESSOR} processor on ${CMAKE_SYSTEM_NAME}"
53
+ )
54
+
55
+ if(${CMAKE_SYSTEM_NAME} MATCHES "Darwin")
56
+ if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "x86_64")
57
+ if(NOT MLX_ENABLE_X64_MAC)
58
+ message(
59
+ FATAL_ERROR
60
+ "Building for x86_64 on macOS is not supported."
61
+ " If you are on an Apple silicon system, check the build"
62
+ " documentation for possible fixes: "
63
+ "https://ml-explore.github.io/mlx/build/html/install.html#build-from-source"
64
+ )
65
+ else()
66
+ set(MLX_BUILD_METAL OFF)
67
+ message(WARNING "Building for x86_64 arch is not officially supported.")
68
+ endif()
69
+ endif()
70
+ else()
71
+ set(MLX_BUILD_METAL OFF)
72
+ endif()
73
+
74
+ if(MLX_USE_CCACHE)
75
+ find_program(CCACHE_PROGRAM ccache)
76
+ if(CCACHE_PROGRAM)
77
+ set(CMAKE_C_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
78
+ set(CMAKE_CXX_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
79
+ set(CMAKE_CUDA_COMPILER_LAUNCHER "${CCACHE_PROGRAM}")
80
+ endif()
81
+ endif()
82
+
83
+ # ----------------------------- Lib -----------------------------
84
+
85
+ include(FetchContent)
86
+ # Avoid warning about DOWNLOAD_EXTRACT_TIMESTAMP in CMake 3.24:
87
+ cmake_policy(SET CMP0135 NEW)
88
+
89
+ add_library(mlx)
90
+
91
+ if(MLX_BUILD_CUDA)
92
+ enable_language(CUDA)
93
+ endif()
94
+
95
+ if(MLX_BUILD_METAL)
96
+ find_library(METAL_LIB Metal)
97
+ find_library(FOUNDATION_LIB Foundation)
98
+ find_library(QUARTZ_LIB QuartzCore)
99
+ if(METAL_LIB)
100
+ message(STATUS "Metal found ${METAL_LIB}")
101
+ else()
102
+ message(
103
+ FATAL_ERROR
104
+ "Metal not found. Set MLX_BUILD_METAL=OFF to build without GPU")
105
+ endif()
106
+
107
+ if(MLX_METAL_DEBUG)
108
+ add_compile_definitions(MLX_METAL_DEBUG)
109
+ endif()
110
+
111
+ # Throw an error if xcrun not found
112
+ execute_process(
113
+ COMMAND zsh "-c" "/usr/bin/xcrun -sdk macosx --show-sdk-version"
114
+ OUTPUT_VARIABLE MACOS_SDK_VERSION
115
+ OUTPUT_STRIP_TRAILING_WHITESPACE COMMAND_ERROR_IS_FATAL ANY)
116
+
117
+ if(${MACOS_SDK_VERSION} LESS 14.0)
118
+ message(
119
+ FATAL_ERROR
120
+ "MLX requires macOS SDK >= 14.0 to be built with MLX_BUILD_METAL=ON")
121
+ endif()
122
+ message(STATUS "Building with macOS SDK version ${MACOS_SDK_VERSION}")
123
+
124
+ set(METAL_CPP_URL
125
+ https://developer.apple.com/metal/cpp/files/metal-cpp_macOS15_iOS18.zip)
126
+
127
+ if(NOT CMAKE_OSX_DEPLOYMENT_TARGET STREQUAL "")
128
+ set(XCRUN_FLAGS "-mmacosx-version-min=${CMAKE_OSX_DEPLOYMENT_TARGET}")
129
+ endif()
130
+ execute_process(
131
+ COMMAND
132
+ zsh "-c"
133
+ "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
134
+ OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
135
+ FetchContent_Declare(metal_cpp URL ${METAL_CPP_URL})
136
+
137
+ FetchContent_MakeAvailable(metal_cpp)
138
+ target_include_directories(
139
+ mlx PUBLIC $<BUILD_INTERFACE:${metal_cpp_SOURCE_DIR}>
140
+ $<INSTALL_INTERFACE:include/metal_cpp>)
141
+ target_link_libraries(mlx PUBLIC ${METAL_LIB} ${FOUNDATION_LIB} ${QUARTZ_LIB})
142
+ endif()
143
+
144
+ if(CMAKE_SYSTEM_NAME STREQUAL "Linux")
145
+ # With newer clang/gcc versions following libs are implicitly linked, but when
146
+ # building on old distributions they need to be explicitly listed.
147
+ target_link_libraries(mlx PRIVATE dl pthread)
148
+ endif()
149
+
150
+ if(WIN32)
151
+ if(MSVC)
152
+ # GGUF does not build with MSVC.
153
+ set(MLX_BUILD_GGUF OFF)
154
+ # There is no prebuilt OpenBLAS distribution for MSVC.
155
+ set(MLX_BUILD_BLAS_FROM_SOURCE ON)
156
+ endif()
157
+ # Windows implementation of dlfcn.h APIs.
158
+ FetchContent_Declare(
159
+ dlfcn-win32
160
+ GIT_REPOSITORY https://github.com/dlfcn-win32/dlfcn-win32.git
161
+ GIT_TAG v1.4.1
162
+ EXCLUDE_FROM_ALL)
163
+ block()
164
+ set(BUILD_SHARED_LIBS OFF)
165
+ FetchContent_MakeAvailable(dlfcn-win32)
166
+ endblock()
167
+ target_include_directories(mlx PRIVATE "${dlfcn-win32_SOURCE_DIR}/src")
168
+ target_link_libraries(mlx PRIVATE dl)
169
+ endif()
170
+
171
+ if(MLX_BUILD_CPU)
172
+ find_library(ACCELERATE_LIBRARY Accelerate)
173
+ if(ACCELERATE_LIBRARY)
174
+ message(STATUS "Accelerate found ${ACCELERATE_LIBRARY}")
175
+ set(MLX_BUILD_ACCELERATE ON)
176
+ else()
177
+ message(STATUS "Accelerate not found, using default backend.")
178
+ set(MLX_BUILD_ACCELERATE OFF)
179
+ endif()
180
+
181
+ if(MLX_BUILD_ACCELERATE)
182
+ target_link_libraries(mlx PUBLIC ${ACCELERATE_LIBRARY})
183
+ add_compile_definitions(MLX_USE_ACCELERATE)
184
+ add_compile_definitions(ACCELERATE_NEW_LAPACK)
185
+ elseif(MLX_BUILD_BLAS_FROM_SOURCE)
186
+ # Download and build OpenBLAS from source code.
187
+ FetchContent_Declare(
188
+ openblas
189
+ GIT_REPOSITORY https://github.com/OpenMathLib/OpenBLAS.git
190
+ GIT_TAG v0.3.28
191
+ EXCLUDE_FROM_ALL)
192
+ set(BUILD_STATIC_LIBS ON) # link statically
193
+ set(NOFORTRAN ON) # msvc has no fortran compiler
194
+ FetchContent_MakeAvailable(openblas)
195
+ target_link_libraries(mlx PRIVATE openblas)
196
+ target_include_directories(
197
+ mlx PRIVATE "${openblas_SOURCE_DIR}/lapack-netlib/LAPACKE/include"
198
+ "${CMAKE_BINARY_DIR}/generated" "${CMAKE_BINARY_DIR}")
199
+ else()
200
+ if(${CMAKE_HOST_APPLE})
201
+ # The blas shipped in macOS SDK is not supported, search homebrew for
202
+ # openblas instead.
203
+ set(BLA_VENDOR OpenBLAS)
204
+ set(LAPACK_ROOT
205
+ "${LAPACK_ROOT};$ENV{LAPACK_ROOT};/usr/local/opt/openblas")
206
+ endif()
207
+ # Search and link with lapack.
208
+ find_package(LAPACK REQUIRED)
209
+ if(NOT LAPACK_FOUND)
210
+ message(FATAL_ERROR "Must have LAPACK installed")
211
+ endif()
212
+ find_path(LAPACK_INCLUDE_DIRS lapacke.h /usr/include /usr/local/include
213
+ /usr/local/opt/openblas/include)
214
+ message(STATUS "Lapack lib " ${LAPACK_LIBRARIES})
215
+ message(STATUS "Lapack include " ${LAPACK_INCLUDE_DIRS})
216
+ target_include_directories(mlx PRIVATE ${LAPACK_INCLUDE_DIRS})
217
+ target_link_libraries(mlx PRIVATE ${LAPACK_LIBRARIES})
218
+ # List blas after lapack otherwise we may accidentally incldue an old
219
+ # version of lapack.h from the include dirs of blas.
220
+ find_package(BLAS REQUIRED)
221
+ if(NOT BLAS_FOUND)
222
+ message(FATAL_ERROR "Must have BLAS installed")
223
+ endif()
224
+ # TODO find a cleaner way to do this
225
+ find_path(BLAS_INCLUDE_DIRS cblas.h /usr/include /usr/local/include
226
+ $ENV{BLAS_HOME}/include)
227
+ message(STATUS "Blas lib " ${BLAS_LIBRARIES})
228
+ message(STATUS "Blas include " ${BLAS_INCLUDE_DIRS})
229
+ target_include_directories(mlx PRIVATE ${BLAS_INCLUDE_DIRS})
230
+ target_link_libraries(mlx PRIVATE ${BLAS_LIBRARIES})
231
+ endif()
232
+ else()
233
+ set(MLX_BUILD_ACCELERATE OFF)
234
+ endif()
235
+
236
+ message(STATUS "Downloading json")
237
+ FetchContent_Declare(
238
+ json
239
+ URL https://github.com/nlohmann/json/releases/download/v3.11.3/json.tar.xz)
240
+ FetchContent_MakeAvailable(json)
241
+ target_include_directories(
242
+ mlx PRIVATE $<BUILD_INTERFACE:${json_SOURCE_DIR}/single_include/nlohmann>)
243
+
244
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/mlx)
245
+
246
+ target_include_directories(
247
+ mlx PUBLIC $<BUILD_INTERFACE:${CMAKE_CURRENT_LIST_DIR}>
248
+ $<INSTALL_INTERFACE:include>)
249
+
250
+ # Do not add mlx_EXPORTS define for shared library.
251
+ set_target_properties(mlx PROPERTIES DEFINE_SYMBOL "")
252
+
253
+ if(USE_SYSTEM_FMT)
254
+ find_package(fmt REQUIRED)
255
+ else()
256
+ FetchContent_Declare(
257
+ fmt
258
+ GIT_REPOSITORY https://github.com/fmtlib/fmt.git
259
+ GIT_TAG 10.2.1
260
+ EXCLUDE_FROM_ALL)
261
+ FetchContent_MakeAvailable(fmt)
262
+ endif()
263
+ target_link_libraries(mlx PRIVATE $<BUILD_INTERFACE:fmt::fmt-header-only>)
264
+
265
+ if(MLX_BUILD_PYTHON_BINDINGS)
266
+ message(STATUS "Building Python bindings.")
267
+ find_package(
268
+ Python 3.8
269
+ COMPONENTS Interpreter Development.Module
270
+ REQUIRED)
271
+ execute_process(
272
+ COMMAND "${Python_EXECUTABLE}" -m nanobind --cmake_dir
273
+ OUTPUT_STRIP_TRAILING_WHITESPACE
274
+ OUTPUT_VARIABLE nanobind_ROOT)
275
+ find_package(nanobind CONFIG REQUIRED)
276
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/python/src)
277
+ endif()
278
+
279
+ if(MLX_BUILD_TESTS)
280
+ include(CTest)
281
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/tests)
282
+ endif()
283
+
284
+ if(MLX_BUILD_EXAMPLES)
285
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/examples/cpp)
286
+ endif()
287
+
288
+ if(MLX_BUILD_BENCHMARKS)
289
+ add_subdirectory(${CMAKE_CURRENT_LIST_DIR}/benchmarks/cpp)
290
+ endif()
291
+
292
+ # ----------------------------- Installation -----------------------------
293
+ include(GNUInstallDirs)
294
+
295
+ # Install library
296
+ install(
297
+ TARGETS mlx
298
+ EXPORT MLXTargets
299
+ LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
300
+ ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
301
+ RUNTIME DESTINATION ${CMAKE_INSTALL_BINDIR}
302
+ INCLUDES
303
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR})
304
+
305
+ # Install headers
306
+ install(
307
+ DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/mlx
308
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}
309
+ COMPONENT headers
310
+ FILES_MATCHING
311
+ PATTERN "*.h"
312
+ PATTERN "backend/metal/kernels.h" EXCLUDE)
313
+
314
+ # Install metal dependencies
315
+ if(MLX_BUILD_METAL)
316
+
317
+ # Install metal cpp
318
+ install(
319
+ DIRECTORY ${metal_cpp_SOURCE_DIR}/
320
+ DESTINATION ${CMAKE_INSTALL_INCLUDEDIR}/metal_cpp
321
+ COMPONENT metal_cpp_source)
322
+
323
+ endif()
324
+
325
+ # Install cmake config
326
+ set(MLX_CMAKE_BUILD_CONFIG ${CMAKE_BINARY_DIR}/MLXConfig.cmake)
327
+ set(MLX_CMAKE_BUILD_VERSION_CONFIG ${CMAKE_BINARY_DIR}/MLXConfigVersion.cmake)
328
+ set(MLX_CMAKE_INSTALL_MODULE_DIR share/cmake/MLX)
329
+
330
+ install(
331
+ EXPORT MLXTargets
332
+ FILE MLXTargets.cmake
333
+ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
334
+
335
+ include(CMakePackageConfigHelpers)
336
+
337
+ write_basic_package_version_file(
338
+ ${MLX_CMAKE_BUILD_VERSION_CONFIG}
339
+ COMPATIBILITY SameMajorVersion
340
+ VERSION ${MLX_VERSION})
341
+
342
+ configure_package_config_file(
343
+ ${CMAKE_CURRENT_LIST_DIR}/mlx.pc.in ${MLX_CMAKE_BUILD_CONFIG}
344
+ INSTALL_DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR}
345
+ NO_CHECK_REQUIRED_COMPONENTS_MACRO
346
+ PATH_VARS CMAKE_INSTALL_LIBDIR CMAKE_INSTALL_INCLUDEDIR
347
+ MLX_CMAKE_INSTALL_MODULE_DIR)
348
+
349
+ install(FILES ${MLX_CMAKE_BUILD_CONFIG} ${MLX_CMAKE_BUILD_VERSION_CONFIG}
350
+ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
351
+
352
+ install(DIRECTORY ${CMAKE_MODULE_PATH}/
353
+ DESTINATION ${MLX_CMAKE_INSTALL_MODULE_DIR})
ml-stable-diffusion/mlx/CODE_OF_CONDUCT.md ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributor Covenant Code of Conduct
2
+
3
+ ## Our Pledge
4
+
5
+ We as members, contributors, and leaders pledge to make participation in our
6
+ community a harassment-free experience for everyone, regardless of age, body
7
+ size, visible or invisible disability, ethnicity, sex characteristics, gender
8
+ identity and expression, level of experience, education, socio-economic status,
9
+ nationality, personal appearance, race, caste, color, religion, or sexual
10
+ identity and orientation.
11
+
12
+ We pledge to act and interact in ways that contribute to an open, welcoming,
13
+ diverse, inclusive, and healthy community.
14
+
15
+ ## Our Standards
16
+
17
+ Examples of behavior that contributes to a positive environment for our
18
+ community include:
19
+
20
+ * Demonstrating empathy and kindness toward other people
21
+ * Being respectful of differing opinions, viewpoints, and experiences
22
+ * Giving and gracefully accepting constructive feedback
23
+ * Accepting responsibility and apologizing to those affected by our mistakes,
24
+ and learning from the experience
25
+ * Focusing on what is best not just for us as individuals, but for the overall
26
+ community
27
+
28
+ Examples of unacceptable behavior include:
29
+
30
+ * The use of sexualized language or imagery, and sexual attention or advances of
31
+ any kind
32
+ * Trolling, insulting or derogatory comments, and personal or political attacks
33
+ * Public or private harassment
34
+ * Publishing others' private information, such as a physical or email address,
35
+ without their explicit permission
36
+ * Other conduct which could reasonably be considered inappropriate in a
37
+ professional setting
38
+
39
+ ## Enforcement Responsibilities
40
+
41
+ Community leaders are responsible for clarifying and enforcing our standards of
42
+ acceptable behavior and will take appropriate and fair corrective action in
43
+ response to any behavior that they deem inappropriate, threatening, offensive,
44
+ or harmful.
45
+
46
+ Community leaders have the right and responsibility to remove, edit, or reject
47
+ comments, commits, code, wiki edits, issues, and other contributions that are
48
+ not aligned to this Code of Conduct, and will communicate reasons for moderation
49
+ decisions when appropriate.
50
+
51
+ ## Scope
52
+
53
+ This Code of Conduct applies within all community spaces, and also applies when
54
+ an individual is officially representing the community in public spaces.
55
+ Examples of representing our community include using an official e-mail address,
56
+ posting via an official social media account, or acting as an appointed
57
+ representative at an online or offline event.
58
+
59
+ ## Enforcement
60
+
61
+ Instances of abusive, harassing, or otherwise unacceptable behavior may be
62
+ reported to the community leaders responsible for enforcement at
63
+ [opensource-conduct@group.apple.com](mailto:opensource-conduct@group.apple.com).
64
+ All complaints will be reviewed and investigated promptly and fairly.
65
+
66
+ All community leaders are obligated to respect the privacy and security of the
67
+ reporter of any incident.
68
+
69
+ ## Enforcement Guidelines
70
+
71
+ Community leaders will follow these Community Impact Guidelines in determining
72
+ the consequences for any action they deem in violation of this Code of Conduct:
73
+
74
+ ### 1. Correction
75
+
76
+ **Community Impact**: Use of inappropriate language or other behavior deemed
77
+ unprofessional or unwelcome in the community.
78
+
79
+ **Consequence**: A private, written warning from community leaders, providing
80
+ clarity around the nature of the violation and an explanation of why the
81
+ behavior was inappropriate. A public apology may be requested.
82
+
83
+ ### 2. Warning
84
+
85
+ **Community Impact**: A violation through a single incident or series of
86
+ actions.
87
+
88
+ **Consequence**: A warning with consequences for continued behavior. No
89
+ interaction with the people involved, including unsolicited interaction with
90
+ those enforcing the Code of Conduct, for a specified period of time. This
91
+ includes avoiding interactions in community spaces as well as external channels
92
+ like social media. Violating these terms may lead to a temporary or permanent
93
+ ban.
94
+
95
+ ### 3. Temporary Ban
96
+
97
+ **Community Impact**: A serious violation of community standards, including
98
+ sustained inappropriate behavior.
99
+
100
+ **Consequence**: A temporary ban from any sort of interaction or public
101
+ communication with the community for a specified period of time. No public or
102
+ private interaction with the people involved, including unsolicited interaction
103
+ with those enforcing the Code of Conduct, is allowed during this period.
104
+ Violating these terms may lead to a permanent ban.
105
+
106
+ ### 4. Permanent Ban
107
+
108
+ **Community Impact**: Demonstrating a pattern of violation of community
109
+ standards, including sustained inappropriate behavior, harassment of an
110
+ individual, or aggression toward or disparagement of classes of individuals.
111
+
112
+ **Consequence**: A permanent ban from any sort of public interaction within the
113
+ community.
114
+
115
+ ## Attribution
116
+
117
+ This Code of Conduct is adapted from the [Contributor Covenant][homepage],
118
+ version 2.1, available at
119
+ [https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
120
+
121
+ Community Impact Guidelines were inspired by
122
+ [Mozilla's code of conduct enforcement ladder][Mozilla CoC].
123
+
124
+ For answers to common questions about this code of conduct, see the FAQ at
125
+ [https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
126
+ [https://www.contributor-covenant.org/translations][translations].
127
+
128
+ [homepage]: https://www.contributor-covenant.org
129
+ [v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
130
+ [Mozilla CoC]: https://github.com/mozilla/diversity
131
+ [FAQ]: https://www.contributor-covenant.org/faq
132
+ [translations]: https://www.contributor-covenant.org/translations
ml-stable-diffusion/mlx/CONTRIBUTING.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to MLX
2
+
3
+ We want to make contributing to this project as easy and transparent as
4
+ possible.
5
+
6
+ ## Pull Requests
7
+
8
+ 1. Fork and submit pull requests to the repo.
9
+ 2. If you've added code that should be tested, add tests.
10
+ 3. If a change is likely to impact efficiency, run some of the benchmarks before
11
+ and after the change. Examples of benchmarks can be found in `benchmarks/python/`.
12
+ 4. If you've changed APIs, update the documentation.
13
+ 5. Every PR should have passing tests and at least one review.
14
+ 6. For code formatting install `pre-commit` using something like `pip install pre-commit` and run `pre-commit install`.
15
+ This should install hooks for running `black` and `clang-format` to ensure
16
+ consistent style for C++ and python code.
17
+
18
+ You can also run the formatters manually as follows:
19
+
20
+ ```shell
21
+ clang-format -i file.cpp
22
+ ```
23
+
24
+ ```shell
25
+ black file.py
26
+ ```
27
+
28
+ or run `pre-commit run --all-files` to check all files in the repo.
29
+
30
+ ## Issues
31
+
32
+ We use GitHub issues to track public bugs. Please ensure your description is
33
+ clear and has sufficient instructions to be able to reproduce the issue.
34
+
35
+ ## License
36
+
37
+ By contributing to MLX, you agree that your contributions will be licensed
38
+ under the LICENSE file in the root directory of this source tree.
ml-stable-diffusion/mlx/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright © 2023 Apple Inc.
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
ml-stable-diffusion/mlx/MANIFEST.in ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ include CMakeLists.txt
2
+ include mlx.pc.in
3
+ recursive-include mlx/ *
4
+ include cmake/*
5
+ include python/src/*
6
+ include python/mlx/py.typed # support type hinting as in PEP-561
ml-stable-diffusion/mlx/README.md ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MLX
2
+
3
+ [**Quickstart**](#quickstart) | [**Installation**](#installation) |
4
+ [**Documentation**](https://ml-explore.github.io/mlx/build/html/index.html) |
5
+ [**Examples**](#examples)
6
+
7
+ [![CircleCI](https://circleci.com/gh/ml-explore/mlx.svg?style=svg)](https://circleci.com/gh/ml-explore/mlx)
8
+
9
+ MLX is an array framework for machine learning on Apple silicon,
10
+ brought to you by Apple machine learning research.
11
+
12
+ Some key features of MLX include:
13
+
14
+ - **Familiar APIs**: MLX has a Python API that closely follows NumPy. MLX
15
+ also has fully featured C++, [C](https://github.com/ml-explore/mlx-c), and
16
+ [Swift](https://github.com/ml-explore/mlx-swift/) APIs, which closely mirror
17
+ the Python API. MLX has higher-level packages like `mlx.nn` and
18
+ `mlx.optimizers` with APIs that closely follow PyTorch to simplify building
19
+ more complex models.
20
+
21
+ - **Composable function transformations**: MLX supports composable function
22
+ transformations for automatic differentiation, automatic vectorization,
23
+ and computation graph optimization.
24
+
25
+ - **Lazy computation**: Computations in MLX are lazy. Arrays are only
26
+ materialized when needed.
27
+
28
+ - **Dynamic graph construction**: Computation graphs in MLX are constructed
29
+ dynamically. Changing the shapes of function arguments does not trigger
30
+ slow compilations, and debugging is simple and intuitive.
31
+
32
+ - **Multi-device**: Operations can run on any of the supported devices
33
+ (currently the CPU and the GPU).
34
+
35
+ - **Unified memory**: A notable difference from MLX and other frameworks
36
+ is the *unified memory model*. Arrays in MLX live in shared memory.
37
+ Operations on MLX arrays can be performed on any of the supported
38
+ device types without transferring data.
39
+
40
+ MLX is designed by machine learning researchers for machine learning
41
+ researchers. The framework is intended to be user-friendly, but still efficient
42
+ to train and deploy models. The design of the framework itself is also
43
+ conceptually simple. We intend to make it easy for researchers to extend and
44
+ improve MLX with the goal of quickly exploring new ideas.
45
+
46
+ The design of MLX is inspired by frameworks like
47
+ [NumPy](https://numpy.org/doc/stable/index.html),
48
+ [PyTorch](https://pytorch.org/), [Jax](https://github.com/google/jax), and
49
+ [ArrayFire](https://arrayfire.org/).
50
+
51
+ ## Examples
52
+
53
+ The [MLX examples repo](https://github.com/ml-explore/mlx-examples) has a
54
+ variety of examples, including:
55
+
56
+ - [Transformer language model](https://github.com/ml-explore/mlx-examples/tree/main/transformer_lm) training.
57
+ - Large-scale text generation with
58
+ [LLaMA](https://github.com/ml-explore/mlx-examples/tree/main/llms/llama) and
59
+ finetuning with [LoRA](https://github.com/ml-explore/mlx-examples/tree/main/lora).
60
+ - Generating images with [Stable Diffusion](https://github.com/ml-explore/mlx-examples/tree/main/stable_diffusion).
61
+ - Speech recognition with [OpenAI's Whisper](https://github.com/ml-explore/mlx-examples/tree/main/whisper).
62
+
63
+ ## Quickstart
64
+
65
+ See the [quick start
66
+ guide](https://ml-explore.github.io/mlx/build/html/usage/quick_start.html)
67
+ in the documentation.
68
+
69
+ ## Installation
70
+
71
+ MLX is available on [PyPI](https://pypi.org/project/mlx/). To install MLX on
72
+ macOS, run:
73
+
74
+ ```bash
75
+ pip install mlx
76
+ ```
77
+
78
+ To install the CUDA backend on Linux, run:
79
+
80
+ ```bash
81
+ pip install mlx[cuda]
82
+ ```
83
+
84
+ To install a CPU-only Linux package, run:
85
+
86
+ ```bash
87
+ pip install mlx[cpu]
88
+ ```
89
+
90
+ Checkout the
91
+ [documentation](https://ml-explore.github.io/mlx/build/html/install.html#)
92
+ for more information on building the C++ and Python APIs from source.
93
+
94
+ ## Contributing
95
+
96
+ Check out the [contribution guidelines](https://github.com/ml-explore/mlx/tree/main/CONTRIBUTING.md) for more information
97
+ on contributing to MLX. See the
98
+ [docs](https://ml-explore.github.io/mlx/build/html/install.html) for more
99
+ information on building from source, and running tests.
100
+
101
+ We are grateful for all of [our
102
+ contributors](https://github.com/ml-explore/mlx/tree/main/ACKNOWLEDGMENTS.md#Individual-Contributors). If you contribute
103
+ to MLX and wish to be acknowledged, please add your name to the list in your
104
+ pull request.
105
+
106
+ ## Citing MLX
107
+
108
+ The MLX software suite was initially developed with equal contribution by Awni
109
+ Hannun, Jagrit Digani, Angelos Katharopoulos, and Ronan Collobert. If you find
110
+ MLX useful in your research and wish to cite it, please use the following
111
+ BibTex entry:
112
+
113
+ ```
114
+ @software{mlx2023,
115
+ author = {Awni Hannun and Jagrit Digani and Angelos Katharopoulos and Ronan Collobert},
116
+ title = {{MLX}: Efficient and flexible machine learning on Apple silicon},
117
+ url = {https://github.com/ml-explore},
118
+ version = {0.0},
119
+ year = {2023},
120
+ }
121
+ ```
ml-stable-diffusion/mlx/benchmarks/cpp/CMakeLists.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ function(build_benchmark SRCFILE)
2
+ get_filename_component(src_name ${SRCFILE} NAME_WE)
3
+ set(target "${src_name}")
4
+ add_executable(${target} ${SRCFILE})
5
+ target_link_libraries(${target} PRIVATE mlx)
6
+ endfunction(build_benchmark)
7
+
8
+ build_benchmark(single_ops.cpp)
9
+ build_benchmark(irregular_strides.cpp)
10
+ build_benchmark(compare_devices.cpp)
11
+ build_benchmark(autograd.cpp)
ml-stable-diffusion/mlx/benchmarks/cpp/autograd.cpp ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #include <iostream>
4
+
5
+ #include "mlx/mlx.h"
6
+ #include "time_utils.h"
7
+
8
+ namespace mx = mlx::core;
9
+
10
+ void time_value_and_grad() {
11
+ auto x = mx::ones({200, 1000});
12
+ mx::eval(x);
13
+ auto fn = [](mx::array x) {
14
+ for (int i = 0; i < 20; ++i) {
15
+ x = mx::log(mx::exp(x));
16
+ }
17
+ return mx::sum(x);
18
+ };
19
+
20
+ auto grad_fn = mx::grad(fn);
21
+ auto independent_value_and_grad = [&]() {
22
+ auto value = fn(x);
23
+ auto dfdx = grad_fn(x);
24
+ return std::vector<mx::array>{value, dfdx};
25
+ };
26
+ TIME(independent_value_and_grad);
27
+
28
+ auto value_and_grad_fn = mx::value_and_grad(fn);
29
+ auto combined_value_and_grad = [&]() {
30
+ auto [value, dfdx] = value_and_grad_fn(x);
31
+ return std::vector<mx::array>{value, dfdx};
32
+ };
33
+ TIME(combined_value_and_grad);
34
+ }
35
+
36
+ int main() {
37
+ std::cout << "Benchmarks for " << mx::default_device() << std::endl;
38
+ time_value_and_grad();
39
+ }
ml-stable-diffusion/mlx/benchmarks/cpp/compare_devices.cpp ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #include <iostream>
4
+ #include "mlx/mlx.h"
5
+ #include "time_utils.h"
6
+
7
+ namespace mx = mlx::core;
8
+
9
+ void time_add_op() {
10
+ std::vector<int> sizes(1, 1);
11
+ for (int i = 0; i < 9; ++i) {
12
+ sizes.push_back(10 * sizes.back());
13
+ }
14
+ set_default_device(mx::Device::cpu);
15
+ for (auto size : sizes) {
16
+ auto a = mx::random::uniform({size});
17
+ auto b = mx::random::uniform({size});
18
+ mx::eval(a, b);
19
+ std::cout << "Size " << size << std::endl;
20
+ TIMEM("cpu", mx::add, a, b, mx::Device::cpu);
21
+ TIMEM("gpu", mx::add, a, b, mx::Device::gpu);
22
+ }
23
+ }
24
+
25
+ int main() {
26
+ time_add_op();
27
+ }
ml-stable-diffusion/mlx/benchmarks/cpp/irregular_strides.cpp ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #include <cstring>
4
+ #include <iostream>
5
+ #include <sstream>
6
+
7
+ #include "mlx/mlx.h"
8
+ #include "time_utils.h"
9
+
10
+ namespace mx = mlx::core;
11
+
12
+ void time_irregular_binary_ops_1D() {
13
+ auto device = mx::default_device();
14
+ int size = 1000000;
15
+ int step = 2;
16
+ auto a = mx::random::uniform({size});
17
+ auto b = mx::random::uniform({size});
18
+ mx::eval(a, b);
19
+ a = slice(a, {0}, {size}, {step});
20
+ b = slice(b, {0}, {size}, {step});
21
+ TIMEM("1D strided", mx::add, a, b, device);
22
+ }
23
+
24
+ void time_irregular_binary_ops_2D() {
25
+ auto device = mx::default_device();
26
+ int size = 2048;
27
+ auto a = mx::random::uniform({size, size});
28
+ auto b = mx::random::uniform({size, size});
29
+ mx::eval(a, b);
30
+ TIMEM("2D regular", mx::add, a, b, device);
31
+
32
+ b = mx::transpose(b);
33
+ mx::eval(b);
34
+ TIMEM("2D mx::transpose", mx::add, a, b, device);
35
+
36
+ b = mx::random::uniform({size});
37
+ mx::eval(b);
38
+ TIMEM("2D broadcast dim 0", mx::add, a, b, device);
39
+
40
+ b = mx::reshape(b, {size, 1});
41
+ mx::eval(b);
42
+ TIMEM("2D broadcast dim 1", mx::add, a, b, device);
43
+ }
44
+
45
+ void time_irregular_binary_ops_3D() {
46
+ auto device = mx::default_device();
47
+ int d0 = 32;
48
+ int d1 = 512;
49
+ int d2 = 512;
50
+ auto a = mx::random::uniform({d0, d1, d2});
51
+ auto b = mx::random::uniform({d0, d1, d2});
52
+ TIMEM("3D regular", mx::add, a, b, device);
53
+
54
+ b = mx::transpose(b, {0, 2, 1});
55
+ TIMEM("3D mx::transpose", mx::add, a, b, device);
56
+
57
+ b = mx::random::uniform({d1, d2});
58
+ TIMEM("3D broadcast dim 0", mx::add, a, b, device);
59
+
60
+ b = mx::random::uniform({d0, 1, d2});
61
+ TIMEM("3D broadcast dim 1", mx::add, a, b, device);
62
+
63
+ b = mx::random::uniform({d0, d1, 1});
64
+ TIMEM("3D broadcast dim 2", mx::add, a, b, device);
65
+
66
+ b = mx::random::uniform({d2});
67
+ TIMEM("3D broadcast dims 0, 1", mx::add, a, b, device);
68
+
69
+ b = mx::random::uniform({d1, 1});
70
+ TIMEM("3D broadcast dims 0, 2", mx::add, a, b, device);
71
+
72
+ b = mx::random::uniform({d0, 1, 1});
73
+ TIMEM("3D broadcast dims 1, 2", mx::add, a, b, device);
74
+ }
75
+
76
+ void time_irregular_binary_ops_4D() {
77
+ auto device = mx::default_device();
78
+ std::vector<int> shape = {8, 8, 512, 512};
79
+ auto a = mx::random::uniform(shape);
80
+ auto b = mx::random::uniform(shape);
81
+
82
+ TIMEM("4D regular", mx::add, a, b, device);
83
+
84
+ b = mx::transpose(b, {0, 1, 3, 2});
85
+ TIMEM("4D mx::transpose", mx::add, a, b, device);
86
+
87
+ std::string om = "4D broadcast dims ";
88
+ for (int i = 0; i < shape.size(); ++i) {
89
+ shape[i] = 1;
90
+ b = mx::random::uniform(shape);
91
+ std::ostringstream msg;
92
+ msg << om << i;
93
+ TIMEM(msg.str(), mx::add, a, b, device);
94
+
95
+ for (int j = i + 1; j < shape.size(); ++j) {
96
+ shape[j] = 1;
97
+ std::ostringstream msg;
98
+ msg << om << i << ", " << j;
99
+ b = mx::random::uniform(shape);
100
+ TIMEM(msg.str(), mx::add, a, b, device);
101
+ shape[j] = a.shape(j);
102
+
103
+ for (int k = j + 1; k < shape.size(); ++k) {
104
+ shape[k] = 1;
105
+ std::ostringstream msg;
106
+ msg << om << i << ", " << j << ", " << k;
107
+ b = mx::random::uniform(shape);
108
+ TIMEM(msg.str(), mx::add, a, b, device);
109
+ shape[k] = a.shape(k);
110
+ }
111
+ }
112
+ shape[i] = a.shape(i);
113
+ }
114
+ }
115
+
116
+ void time_irregular_reshape() {
117
+ auto device = mx::default_device();
118
+ std::vector<int> shape;
119
+ auto reshape_fn = [&shape, device](const mx::array& a) {
120
+ return mx::reshape(a, shape, device);
121
+ };
122
+
123
+ int size = 64;
124
+ int d = 2 * size;
125
+
126
+ auto a = mx::random::uniform({d, d, d});
127
+
128
+ shape = {8 * size, size, size};
129
+ TIMEM("3D contiguous", reshape_fn, a);
130
+
131
+ a = mx::transpose(a);
132
+ shape = {8 * size, size, size};
133
+ TIMEM("3D mx::transpose", reshape_fn, a);
134
+
135
+ a = mx::transpose(a, {1, 2, 0});
136
+ shape = {8 * size, size, size};
137
+ TIMEM("3D mx::transpose dims 1 2", reshape_fn, a);
138
+
139
+ a = mx::broadcast_to(mx::random::uniform({d, d}), {d, d, d});
140
+ TIMEM("3D broadcast dim 0", reshape_fn, a);
141
+
142
+ a = mx::broadcast_to(mx::random::uniform({d, 1, d}), {d, d, d});
143
+ TIMEM("3D broadcast dim 1", reshape_fn, a);
144
+
145
+ a = mx::broadcast_to(mx::random::uniform({d, d, 1}), {d, d, d});
146
+ TIMEM("3D broadcast dim 2", reshape_fn, a);
147
+
148
+ a = mx::broadcast_to(mx::random::uniform({d}), {d, d, d});
149
+ TIMEM("3D broadcast dims 0, 1", reshape_fn, a);
150
+
151
+ a = mx::broadcast_to(mx::random::uniform({d, 1}), {d, d, d});
152
+ TIMEM("3D broadcast dims 0, 2", reshape_fn, a);
153
+
154
+ a = mx::broadcast_to(mx::random::uniform({d, 1, 1}), {d, d, d});
155
+ TIMEM("3D broadcast dims 1, 2", reshape_fn, a);
156
+
157
+ a = mx::broadcast_to(mx::random::uniform({1, 1, 1}), {d, d, d});
158
+ TIMEM("3D broadcast dims 1, 2, 3", reshape_fn, a);
159
+ }
160
+
161
+ void time_irregular_astype_1D() {
162
+ auto device = mx::default_device();
163
+ int size = 1000000;
164
+ int step = 2;
165
+ auto a = mx::random::uniform({size});
166
+ a = slice(a, {0}, {size}, {step});
167
+ TIMEM("1D strided", mx::astype, a, mx::int32, device);
168
+ }
169
+
170
+ void time_irregular_astype_2D() {
171
+ auto device = mx::default_device();
172
+ int size = 2048;
173
+ std::vector<int> shape = {size, size};
174
+
175
+ auto a = mx::random::uniform(shape);
176
+ TIMEM("2D regular", mx::astype, a, mx::int32, device);
177
+
178
+ a = mx::transpose(a);
179
+ TIMEM("2D mx::transpose", mx::astype, a, mx::int32, device);
180
+
181
+ a = mx::broadcast_to(mx::random::uniform({size}), shape);
182
+ TIMEM("2D broadcast dim 0", mx::astype, a, mx::int32, device);
183
+
184
+ a = mx::broadcast_to(mx::random::uniform({size, 1}), shape);
185
+ TIMEM("2D broadcast dim 1", mx::astype, a, mx::int32, device);
186
+ }
187
+
188
+ int main(int argc, char** argv) {
189
+ if (argc > 1) {
190
+ bool use_gpu = !strcmp(argv[1], "gpu");
191
+ set_default_device(use_gpu ? mx::Device::gpu : mx::Device::cpu);
192
+ }
193
+ std::cout << "Benchmarks for " << mx::default_device() << std::endl;
194
+ time_irregular_binary_ops_1D();
195
+ time_irregular_binary_ops_2D();
196
+ time_irregular_binary_ops_3D();
197
+ time_irregular_binary_ops_4D();
198
+ time_irregular_reshape();
199
+ time_irregular_astype_1D();
200
+ time_irregular_astype_2D();
201
+ }
ml-stable-diffusion/mlx/benchmarks/cpp/single_ops.cpp ADDED
@@ -0,0 +1,288 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #include "mlx/mlx.h"
4
+ #include "time_utils.h"
5
+
6
+ namespace mx = mlx::core;
7
+
8
+ void time_creation_ops() {
9
+ int M = 2000;
10
+ int N = 500;
11
+ auto shape = {M, N};
12
+ auto full_fp32 = [&]() { return mx::full(shape, 3.3f); };
13
+ TIME(full_fp32);
14
+ auto zeros_fp32 = [&]() { return mx::zeros(shape, mx::float32); };
15
+ TIME(zeros_fp32);
16
+ auto ones_fp32 = [&]() { return mx::ones(shape, mx::float32); };
17
+ TIME(ones_fp32);
18
+
19
+ auto arange_fp32 = [&]() { return mx::arange(0.0, 10.0, 1e-4); };
20
+ TIME(arange_fp32);
21
+ }
22
+
23
+ void time_type_conversions() {
24
+ int M = 2000;
25
+ int N = 500;
26
+ auto shape = {M, N};
27
+ auto device = mx::default_device();
28
+
29
+ auto a = mx::zeros(shape, mx::float32);
30
+ mx::eval(a);
31
+ TIMEM("mx::float32 to mx::int32", mx::astype, a, mx::int32, device);
32
+ TIMEM("mx::float32 to mx::uint32", mx::astype, a, mx::uint32, device);
33
+
34
+ a = mx::zeros(shape, mx::int32);
35
+ mx::eval(a);
36
+ TIMEM("mx::int32 to mx::float32", mx::astype, a, mx::float32, device);
37
+
38
+ a = mx::zeros(shape, mx::bool_);
39
+ mx::eval(a);
40
+ TIMEM("bool to mx::float32", mx::astype, a, mx::float32, device);
41
+ TIMEM("bool to mx::int32", mx::astype, a, mx::int32, device);
42
+ TIMEM("bool to mx::uint32", mx::astype, a, mx::uint32, device);
43
+ }
44
+
45
+ void time_random_generation() {
46
+ int M = 2000;
47
+ int N = 500;
48
+
49
+ auto uniform = [&]() { return mx::random::uniform({M, N}, mx::float32); };
50
+ TIME(uniform);
51
+ auto normal = [&]() { return mx::random::normal({M, N}, mx::float32); };
52
+ TIME(normal);
53
+ }
54
+
55
+ void time_unary_ops() {
56
+ int M = 2000;
57
+ int N = 500;
58
+ auto device = mx::default_device();
59
+
60
+ auto a = mx::random::normal({M, N});
61
+ mx::eval(a);
62
+ TIME(mlx::core::abs, a, device);
63
+ TIME(mx::negative, a, device);
64
+ TIME(mx::sign, a, device);
65
+ TIME(mx::square, a, device);
66
+ TIME(mlx::core::sqrt, a, device);
67
+ TIME(mx::rsqrt, a, device);
68
+ TIME(mlx::core::exp, a, device);
69
+
70
+ a = mx::random::uniform({M, N});
71
+ TIME(mlx::core::log, a, device);
72
+ }
73
+
74
+ void time_binary_ops() {
75
+ int M = 1000, N = 100, K = 10;
76
+ auto condition = mx::random::randint(0, 2, {M, N, K});
77
+ auto a = mx::random::uniform({M, N, K});
78
+ auto b = mx::random::uniform({M, N, K});
79
+ auto device = mx::default_device();
80
+ mx::eval(a, b);
81
+
82
+ TIME(mx::add, a, b, device);
83
+ TIME(mx::subtract, a, b, device);
84
+ TIME(mx::multiply, a, b, device);
85
+ TIME(mx::divide, a, b, device);
86
+ TIME(mx::maximum, a, b, device);
87
+ TIME(mx::minimum, a, b, device);
88
+ TIME(mx::where, condition, a, b, device);
89
+
90
+ condition = mx::array({true});
91
+ b = mx::random::uniform({1});
92
+ mx::eval(b);
93
+ TIMEM("scalar", mx::add, a, b, device);
94
+ TIMEM("vector-scalar", mx::subtract, a, b, device);
95
+ TIMEM("scalar-vector", mx::subtract, b, a, device);
96
+ TIMEM("scalar", mx::multiply, a, b, device);
97
+ TIMEM("vector-scalar", mx::divide, a, b, device);
98
+ TIMEM("scalar-vector", mx::divide, b, a, device);
99
+ TIMEM("scalar-vector", mx::where, condition, a, b, device);
100
+
101
+ condition = mx::broadcast_to(mx::array({true}), {1000, 100});
102
+ a = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
103
+ b = mx::broadcast_to(mx::random::uniform({1}), {1000, 100});
104
+ mx::eval(a, b);
105
+ TIMEM("scalar-scalar broadcast", mx::add, a, b, device);
106
+ TIMEM("scalar-scalar broadcast", mx::subtract, a, b, device);
107
+ TIMEM("scalar-scalar broadcast", mx::multiply, a, b, device);
108
+ TIMEM("scalar-scalar broadcast", mx::divide, a, b, device);
109
+ TIMEM("scalar-scalar broadcast", mx::where, condition, a, b, device);
110
+ }
111
+
112
+ void time_strided_ops() {
113
+ int M = 50, N = 50, O = 50, P = 50;
114
+ auto a = mx::random::uniform({M, N, O, P});
115
+ auto b = mx::random::uniform({M, N, O, P});
116
+ auto device = mx::default_device();
117
+ mx::eval(a, b);
118
+ TIMEM("non-strided", mx::add, a, b, device);
119
+ a = mx::transpose(a, {1, 0, 2, 3});
120
+ b = mx::transpose(b, {3, 2, 0, 1});
121
+ mx::eval(a, b);
122
+ TIMEM("strided", mx::add, a, b, device);
123
+ }
124
+
125
+ void time_comparisons() {
126
+ int M = 1000, N = 100, K = 10;
127
+ auto a = mx::random::uniform({M, N, K});
128
+ auto b = mx::random::uniform({M, N, K});
129
+ auto device = mx::default_device();
130
+ mx::eval(a, b);
131
+ TIME(mx::equal, a, b, device);
132
+ TIME(mx::greater, a, b, device);
133
+ TIME(mx::greater_equal, a, b, device);
134
+ TIME(mx::less, a, b, device);
135
+ TIME(mx::less_equal, a, b, device);
136
+ }
137
+
138
+ void time_matvec() {
139
+ int M = 2000, N = 200;
140
+ auto a = mx::random::uniform({M, N});
141
+ auto b = mx::random::uniform({N});
142
+ auto c = mx::random::uniform({M});
143
+ mx::eval(a, b, c);
144
+ auto matvec = [&]() { return mx::matmul(a, b); };
145
+ TIME(matvec);
146
+
147
+ auto matvec_transpose = [&]() { return mx::matmul(mx::transpose(a), c); };
148
+ TIME(matvec_transpose);
149
+ }
150
+
151
+ void time_matmul() {
152
+ int M = 1000, N = 1000, K = 1000;
153
+ auto a = mx::random::uniform({M, K});
154
+ auto b = mx::random::uniform({K, N});
155
+ auto device = mx::default_device();
156
+ mx::eval(a, b);
157
+ TIME(mx::matmul, a, b, device);
158
+
159
+ auto transpose_matmul = [&]() { return mx::matmul(mx::transpose(a), b); };
160
+ TIME(transpose_matmul);
161
+ }
162
+
163
+ void time_reductions() {
164
+ auto a = mx::random::normal({10000, 1000});
165
+ mx::eval(a);
166
+ auto sum_all = [&a]() { return mx::sum(a, false); };
167
+ TIME(sum_all);
168
+
169
+ auto sum_along_0 = [&a]() { return mx::sum(a, 0, false); };
170
+ TIME(sum_along_0);
171
+
172
+ auto sum_along_1 = [&a]() { return mx::sum(a, 1, false); };
173
+ TIME(sum_along_1);
174
+
175
+ auto prod_all = [&a]() { return mx::prod(a, false); };
176
+ TIME(prod_all);
177
+
178
+ auto all_true = [&a]() { return mx::all(a, false); };
179
+ TIME(all_true);
180
+
181
+ auto all_along_0 = [&a]() { return mx::all(a, 0, false); };
182
+ TIME(all_along_0);
183
+
184
+ auto all_along_1 = [&a]() { return mx::all(a, 1, false); };
185
+ TIME(all_along_1);
186
+
187
+ auto any_true = [&a]() { return mx::any(a, false); };
188
+ TIME(any_true);
189
+
190
+ auto argmin_along_0 = [&a]() { return mx::argmin(a, 0, false); };
191
+ TIME(argmin_along_0);
192
+
193
+ auto argmin_along_1 = [&a]() { return mx::argmin(a, 1, false); };
194
+ TIME(argmin_along_1);
195
+
196
+ auto indices = mx::array({1});
197
+ auto updates = mx::reshape(mx::array({NAN}), {1, 1, 1});
198
+ std::vector<int> axes{0};
199
+ auto b = scatter(a, {indices}, updates, axes);
200
+ mx::eval(b);
201
+
202
+ auto max_along_0 = [&b]() { return mx::max(b, 0, false); };
203
+ TIME(max_along_0);
204
+ auto max_along_1 = [&b]() { return mx::max(b, 1, false); };
205
+ TIME(max_along_1);
206
+
207
+ auto min_along_0 = [&b]() { return mx::min(b, 0, false); };
208
+ TIME(min_along_0);
209
+ auto min_along_1 = [&b]() { return mx::min(b, 1, false); };
210
+ TIME(min_along_1);
211
+ }
212
+
213
+ void time_gather_scatter() {
214
+ auto a = mx::random::normal({1000, 768});
215
+ mx::eval(a);
216
+ auto indices = mx::random::randint(0, 1000, {256});
217
+ mx::eval(indices);
218
+
219
+ auto embedding_lookup = [&a, &indices]() { return mx::take(a, indices, 0); };
220
+ TIME(embedding_lookup);
221
+
222
+ indices = mx::random::randint(0, 768 * 1000, {256 * 768});
223
+ mx::eval(indices);
224
+
225
+ auto single_element_lookup = [&a, &indices]() {
226
+ return mx::take(a, indices);
227
+ };
228
+ TIME(single_element_lookup);
229
+
230
+ indices = mx::random::randint(0, 1000, {256});
231
+ auto updates = mx::random::normal({256, 1, 768});
232
+ mx::eval(indices, updates);
233
+
234
+ auto embedding_update = [&a, &indices, &updates]() {
235
+ return scatter(a, indices, updates, 0);
236
+ };
237
+ TIME(embedding_update);
238
+
239
+ auto embedding_add = [&a, &indices, &updates]() {
240
+ return scatter_add(a, indices, updates, 0);
241
+ };
242
+ TIME(embedding_add);
243
+
244
+ a = mx::reshape(a, {-1});
245
+ indices = mx::random::randint(0, 768 * 1000, {768 * 256});
246
+ updates = mx::random::normal({256 * 768, 1});
247
+ mx::eval(a, indices, updates);
248
+
249
+ auto single_element_update = [&a, &indices, &updates]() {
250
+ return scatter(a, indices, updates, 0);
251
+ };
252
+ TIME(single_element_update);
253
+
254
+ auto single_element_add = [&a, &indices, &updates]() {
255
+ return scatter_add(a, indices, updates, 0);
256
+ };
257
+ TIME(single_element_add);
258
+ }
259
+
260
+ void time_divmod() {
261
+ auto a = mx::random::normal({1000});
262
+ auto b = mx::random::normal({1000});
263
+ mx::eval({a, b});
264
+
265
+ auto divmod_fused = [&a, &b]() { return mx::divmod(a, b); };
266
+ TIME(divmod_fused);
267
+
268
+ auto divmod_separate = [&a, &b]() {
269
+ return std::vector<mx::array>{mx::floor_divide(a, b), mx::remainder(a, b)};
270
+ };
271
+ TIME(divmod_separate);
272
+ }
273
+
274
+ int main() {
275
+ std::cout << "Benchmarks for " << mx::default_device() << std::endl;
276
+ time_creation_ops();
277
+ time_type_conversions();
278
+ time_unary_ops();
279
+ time_binary_ops();
280
+ time_strided_ops();
281
+ time_random_generation();
282
+ time_comparisons();
283
+ time_matvec();
284
+ time_matmul();
285
+ time_reductions();
286
+ time_gather_scatter();
287
+ time_divmod();
288
+ }
ml-stable-diffusion/mlx/benchmarks/cpp/time_utils.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Copyright © 2023 Apple Inc.
2
+
3
+ #pragma once
4
+
5
+ #include <chrono>
6
+ #include <iomanip>
7
+ #include <iostream>
8
+
9
+ #include "mlx/mlx.h"
10
+
11
+ #define milliseconds(x) \
12
+ (std::chrono::duration_cast<std::chrono::nanoseconds>(x).count() / 1e6)
13
+ #define time_now() std::chrono::high_resolution_clock::now()
14
+
15
+ #define TIME(FUNC, ...) \
16
+ std::cout << "Timing " << #FUNC << " ... " << std::flush \
17
+ << std::setprecision(5) << time_fn(FUNC, ##__VA_ARGS__) << " msec" \
18
+ << std::endl;
19
+
20
+ #define TIMEM(MSG, FUNC, ...) \
21
+ std::cout << "Timing " << "(" << MSG << ") " << #FUNC << " ... " \
22
+ << std::flush << std::setprecision(5) \
23
+ << time_fn(FUNC, ##__VA_ARGS__) << " msec" << std::endl;
24
+
25
+ template <typename F, typename... Args>
26
+ double time_fn(F fn, Args&&... args) {
27
+ // warmup
28
+ for (int i = 0; i < 5; ++i) {
29
+ eval(fn(std::forward<Args>(args)...));
30
+ }
31
+
32
+ int num_iters = 100;
33
+ auto start = time_now();
34
+ for (int i = 0; i < num_iters; i++) {
35
+ eval(fn(std::forward<Args>(args)...));
36
+ }
37
+ auto end = time_now();
38
+ return milliseconds(end - start) / static_cast<double>(num_iters);
39
+ }
ml-stable-diffusion/mlx/benchmarks/numpy/single_ops.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import numpy as np
4
+ from time_utils import time_fn
5
+
6
+
7
+ def time_add():
8
+ a = np.ones((100, 100, 10), dtype=np.float32)
9
+ b = np.ones((100, 100, 10), dtype=np.float32)
10
+ time_fn(np.add, a, b)
11
+
12
+
13
+ def time_matmul():
14
+ a = np.random.rand(1000, 500).astype(np.float32)
15
+ b = np.random.rand(500, 1000).astype(np.float32)
16
+ time_fn(np.matmul, a, b)
17
+
18
+
19
+ def time_exp():
20
+ a = np.random.randn(1000, 100).astype(np.float32)
21
+ time_fn(np.exp, a)
22
+
23
+
24
+ def time_take():
25
+ a = np.random.rand(10000, 500)
26
+ ids = np.random.randint(0, 10000, (20, 10))
27
+ ids = [idx.reshape(-1) for idx in np.split(ids, 20)]
28
+
29
+ def random_take():
30
+ return [np.take(a, idx, 0) for idx in ids]
31
+
32
+ time_fn(random_take)
33
+
34
+
35
+ if __name__ == "__main__":
36
+ time_add()
37
+ time_matmul()
38
+ time_exp()
39
+ time_take()
ml-stable-diffusion/mlx/benchmarks/numpy/time_utils.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import time
4
+
5
+
6
+ def time_fn(fn, *args):
7
+ print(f"Timing {fn.__name__} ...", end=" ")
8
+
9
+ # warmup
10
+ for _ in range(5):
11
+ fn(*args)
12
+
13
+ num_iters = 100
14
+ tic = time.perf_counter()
15
+ for _ in range(num_iters):
16
+ x = fn(*args)
17
+ toc = time.perf_counter()
18
+
19
+ msec = 1e3 * (toc - tic) / num_iters
20
+ print(f"{msec:.5f} msec")
ml-stable-diffusion/mlx/benchmarks/python/batch_matmul_bench.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import argparse
4
+
5
+ import mlx.core as mx
6
+ from time_utils import time_fn
7
+
8
+ B = 8
9
+ T = 1024
10
+ D = 512
11
+
12
+
13
+ def time_batch_matmul():
14
+ mx.random.seed(3)
15
+ a = mx.random.uniform(shape=(B, T, D))
16
+ b = mx.random.uniform(shape=(D, D))
17
+ c = mx.random.uniform(shape=(B, T, D))
18
+ mx.eval(a, b, c)
19
+
20
+ time_fn(mx.matmul, a, b)
21
+
22
+ def batch_vjp_first():
23
+ return mx.vjp(mx.matmul, [a, b], [c])[1][0]
24
+
25
+ time_fn(batch_vjp_first)
26
+
27
+ def batch_vjp_second():
28
+ return mx.vjp(mx.matmul, [a, b], [c])[1][1]
29
+
30
+ time_fn(batch_vjp_second)
31
+
32
+
33
+ def time_unbatch_matmul():
34
+ mx.random.seed(3)
35
+ a = mx.random.uniform(shape=(B * T, D))
36
+ b = mx.random.uniform(shape=(D, D))
37
+ c = mx.random.uniform(shape=(B * T, D))
38
+ mx.eval(a, b, c)
39
+ time_fn(mx.matmul, a, b)
40
+
41
+ def unbatch_vjp_first():
42
+ return mx.matmul(c, mx.transpose(b))
43
+
44
+ time_fn(unbatch_vjp_first)
45
+
46
+ def unbatch_vjp_second():
47
+ return mx.matmul(mx.transpose(a), c)
48
+
49
+ time_fn(unbatch_vjp_second)
50
+
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser("MLX benchmarks.")
54
+ parser.add_argument("--gpu", action="store_true", help="Use the Metal back-end.")
55
+ args = parser.parse_args()
56
+ if args.gpu:
57
+ mx.set_default_device(mx.gpu)
58
+ else:
59
+ mx.set_default_device(mx.cpu)
60
+
61
+ time_batch_matmul()
62
+ time_unbatch_matmul()
ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemm.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import argparse
4
+ import math
5
+ import os
6
+ import subprocess
7
+ import time
8
+
9
+ import mlx.core as mx
10
+ import numpy as np
11
+ import torch
12
+
13
+ device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
14
+ device_name = device_name.decode("utf-8").strip("\n")
15
+
16
+ N_warmup = 8
17
+ N_iter_bench = 80
18
+ N_iter_func = 5
19
+
20
+
21
+ def bench(f, a, b):
22
+ for i in range(N_warmup):
23
+ f(a, b)
24
+ torch.mps.synchronize()
25
+
26
+ s = time.perf_counter_ns()
27
+ for i in range(N_iter_bench):
28
+ f(a, b)
29
+ e = time.perf_counter_ns()
30
+ return (e - s) * 1e-9
31
+
32
+
33
+ def gemm_nn_mlx(a, b):
34
+ ys = []
35
+ for i in range(N_iter_func):
36
+ y = a @ b
37
+ ys.append(y)
38
+ mx.eval(ys)
39
+ return ys
40
+
41
+
42
+ def gemm_nt_mlx(a, b):
43
+ ys = []
44
+ for i in range(N_iter_func):
45
+ y = a @ b.transpose((0, 2, 1))
46
+ ys.append(y)
47
+ mx.eval(ys)
48
+ return ys
49
+
50
+
51
+ def gemm_tn_mlx(a, b):
52
+ ys = []
53
+ for i in range(N_iter_func):
54
+ y = a.transpose((0, 2, 1)) @ b
55
+ ys.append(y)
56
+ mx.eval(ys)
57
+ return ys
58
+
59
+
60
+ def gemm_tt_mlx(a, b):
61
+ ys = []
62
+ for i in range(N_iter_func):
63
+ y = a.transpose((0, 2, 1)) @ b.transpose((0, 2, 1))
64
+ ys.append(y)
65
+ mx.eval(ys)
66
+ return ys
67
+
68
+
69
+ @torch.no_grad()
70
+ def gemm_nn_torch(a, b):
71
+ ys = []
72
+ for i in range(N_iter_func):
73
+ y = a @ b
74
+ ys.append(y)
75
+ torch.mps.synchronize()
76
+ return ys
77
+
78
+
79
+ @torch.no_grad()
80
+ def gemm_nt_torch(a, b):
81
+ ys = []
82
+ for i in range(N_iter_func):
83
+ y = a @ b.transpose(-1, -2)
84
+ ys.append(y)
85
+ torch.mps.synchronize()
86
+ return ys
87
+
88
+
89
+ @torch.no_grad()
90
+ def gemm_tn_torch(a, b):
91
+ ys = []
92
+ for i in range(N_iter_func):
93
+ y = a.transpose(-1, -2) @ b
94
+ ys.append(y)
95
+ torch.mps.synchronize()
96
+ return ys
97
+
98
+
99
+ @torch.no_grad()
100
+ def gemm_tt_torch(a, b):
101
+ ys = []
102
+ for i in range(N_iter_func):
103
+ y = a.transpose(-1, -2) @ b.transpose(-1, -2)
104
+ ys.append(y)
105
+ torch.mps.synchronize()
106
+ return ys
107
+
108
+
109
+ def bench_shape(B, M, N, K, np_dtype, transpose="nn"):
110
+ shape_a = (B, M, K) if transpose[0] == "n" else (B, K, M)
111
+ shape_b = (B, K, N) if transpose[1] == "n" else (B, N, K)
112
+
113
+ a_np = np.random.normal(0.0, 1.0 / math.sqrt(M + K), shape_a).astype(np_dtype)
114
+ b_np = np.random.normal(0.0, 1.0 / math.sqrt(N + K), shape_b).astype(np_dtype)
115
+
116
+ a_mx = mx.array(a_np)
117
+ b_mx = mx.array(b_np)
118
+
119
+ a_pt = torch.from_numpy(a_np).to("mps")
120
+ b_pt = torch.from_numpy(b_np).to("mps")
121
+
122
+ torch.mps.synchronize()
123
+
124
+ f_mx = {
125
+ "nn": gemm_nn_mlx,
126
+ "nt": gemm_nt_mlx,
127
+ "tn": gemm_tn_mlx,
128
+ "tt": gemm_tt_mlx,
129
+ }[transpose]
130
+
131
+ f_pt = {
132
+ "nn": gemm_nn_torch,
133
+ "nt": gemm_nt_torch,
134
+ "tn": gemm_tn_torch,
135
+ "tt": gemm_tt_torch,
136
+ }[transpose]
137
+
138
+ time_torch = bench(f_pt, a_pt, b_pt)
139
+ time_mlx = bench(f_mx, a_mx, b_mx)
140
+
141
+ t_a = (0, 1, 2) if transpose[0] == "n" else (0, 2, 1)
142
+ t_b = (0, 1, 2) if transpose[1] == "n" else (0, 2, 1)
143
+
144
+ c_mlx = a_mx.transpose(t_a) @ b_mx.transpose(t_b)
145
+ c_npy = a_np.transpose(t_a).astype(np_dtype) @ b_np.transpose(t_b).astype(np_dtype)
146
+
147
+ atol = 1e-5 if np_dtype == np.float32 else 1e-4
148
+
149
+ if not np.allclose(c_mlx, c_npy.astype(np_dtype), atol=atol):
150
+ print(
151
+ f"Failed at {(B, M, N, K)} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
152
+ )
153
+
154
+ return time_mlx, time_torch
155
+
156
+
157
+ def get_gflop_count(B, M, N, K):
158
+ return float(2.0 * N_iter_bench * N_iter_func * B * M * N * K) / float(1024.0**3)
159
+
160
+
161
+ if __name__ == "__main__":
162
+ parser = argparse.ArgumentParser(description="Run gemm benchmarks")
163
+
164
+ dtypes = ("float32", "float16", "complex64")
165
+ transposes = ("nn", "nt", "tn")
166
+ shapes = (
167
+ (16, 234, 768, 3072),
168
+ (1, 64, 64, 25344),
169
+ (16, 1024, 1024, 1024),
170
+ (1, 1024, 1024, 2048),
171
+ (4, 1024, 1024, 4096),
172
+ (4, 1024, 4096, 1024),
173
+ (1, 4096, 4096, 4096),
174
+ )
175
+
176
+ for dtype in dtypes:
177
+ for transpose in transposes:
178
+ for B, M, N, K in shapes:
179
+ np_dtype = getattr(np, dtype)
180
+ time_mlx, time_torch = bench_shape(B, M, N, K, np_dtype, transpose)
181
+
182
+ gflop_count = get_gflop_count(B, M, N, K)
183
+ gflops_mx = gflop_count / (time_mlx)
184
+ gflops_pt = gflop_count / (time_torch)
185
+ diff = gflops_mx / gflops_pt - 1.0
186
+
187
+ print(
188
+ f"{B:3d}, {M:4d}, {N:4d}, {K:4d}, {dtype}, {transpose}, {gflops_pt:05.3f}, {gflops_mx:05.3f}, {100.0 * diff:+5.2f}%"
189
+ )
190
+ if gflops_pt >= 2.0 * gflops_mx:
191
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/blas/bench_gemv.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import argparse
4
+ import os
5
+ import subprocess
6
+ import time
7
+
8
+ import matplotlib.pyplot as plt
9
+ import mlx.core as mx
10
+ import numpy as np
11
+ import torch
12
+
13
+ results_dir = "./results"
14
+
15
+ if not os.path.isdir(results_dir):
16
+ os.mkdir(results_dir)
17
+
18
+ device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
19
+ device_name = device_name.decode("utf-8").strip("\n")
20
+
21
+ N_warmup = 5
22
+ N_iter_bench = 50
23
+ N_iter_func = 20
24
+
25
+ out_vec_sizes = [128, 512, 2048, 4096]
26
+ in_vec_sizes = [128, 512, 2048, 4096]
27
+
28
+ benchmark_vector_lens = []
29
+ benchmark_vector_lens += [(i + 1) * 4096 for i in range(8)][::2]
30
+ benchmark_vector_lens += [(i + 1) * 4095 for i in range(8)][::2]
31
+ benchmark_vector_lens += [(i + 1) * 4097 for i in range(8)][::2]
32
+ benchmark_vector_lens += [64, 128, 512, 1024, 2048, 11008, 32000]
33
+
34
+ benchmark_vector_lens.sort()
35
+
36
+
37
+ def bench(f, m, v):
38
+ for i in range(N_warmup):
39
+ f(m, v)
40
+ torch.mps.synchronize()
41
+
42
+ s = time.perf_counter_ns()
43
+ for i in range(N_iter_bench):
44
+ f(m, v)
45
+ e = time.perf_counter_ns()
46
+ return (e - s) * 1e-9
47
+
48
+
49
+ def gemv_mlx(m, v):
50
+ ys = []
51
+ for i in range(N_iter_func):
52
+ y = m @ v
53
+ ys.append(y)
54
+ mx.eval(ys)
55
+ return ys
56
+
57
+
58
+ def gemv_t_mlx(m, v):
59
+ ys = []
60
+ for i in range(N_iter_func):
61
+ y = v @ m
62
+ ys.append(y)
63
+ mx.eval(ys)
64
+ return ys
65
+
66
+
67
+ @torch.no_grad()
68
+ def gemv_torch(m, v):
69
+ ys = []
70
+ for i in range(N_iter_func):
71
+ y = m @ v
72
+ ys.append(y)
73
+ torch.mps.synchronize()
74
+ return ys
75
+
76
+
77
+ @torch.no_grad()
78
+ def gemv_t_torch(m, v):
79
+ ys = []
80
+ for i in range(N_iter_func):
81
+ y = v @ m
82
+ ys.append(y)
83
+ torch.mps.synchronize()
84
+ return ys
85
+
86
+
87
+ def bench_lens(in_vec_len, out_vec_len, np_dtype, transpose=False):
88
+ shape_mat = (in_vec_len, out_vec_len) if transpose else (out_vec_len, in_vec_len)
89
+ shape_vec = (1, in_vec_len) if transpose else (in_vec_len, 1)
90
+
91
+ mat_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_mat).astype(np_dtype)
92
+ vec_npy = np.random.normal(0.0, 2.0 / in_vec_len, shape_vec).astype(np_dtype)
93
+ mat_mlx = mx.array(mat_npy)
94
+ vec_mlx = mx.array(vec_npy)
95
+ mat_trc = torch.from_numpy(mat_npy).to("mps")
96
+ vec_trc = torch.from_numpy(vec_npy).to("mps")
97
+
98
+ torch.mps.synchronize()
99
+
100
+ time_torch = (
101
+ bench(gemv_t_torch, mat_trc, vec_trc)
102
+ if transpose
103
+ else bench(gemv_torch, mat_trc, vec_trc)
104
+ )
105
+ time_mlx = (
106
+ bench(gemv_t_mlx, mat_mlx, vec_mlx)
107
+ if transpose
108
+ else bench(gemv_mlx, mat_mlx, vec_mlx)
109
+ )
110
+
111
+ c_mlx = (
112
+ np.asarray(vec_mlx @ mat_mlx) if transpose else np.asarray(mat_mlx @ vec_mlx)
113
+ )
114
+ c_npy = (vec_npy @ mat_npy) if transpose else (mat_npy @ vec_npy)
115
+
116
+ if not np.allclose(c_mlx, c_npy, atol=2e-5):
117
+ print(
118
+ f"Failed at {shape_mat} [transpose = {transpose}] with max(|a - b|) = {np.max(np.abs(c_npy - c_mlx))}"
119
+ )
120
+
121
+ return time_mlx, time_torch
122
+
123
+
124
+ def get_gflop_count(in_vec_len, out_vec_len):
125
+ return float(2.0 * N_iter_bench * N_iter_func * in_vec_len * out_vec_len) / float(
126
+ 1024**3
127
+ )
128
+
129
+
130
+ def get_gbyte_size(in_vec_len, out_vec_len, np_dtype):
131
+ n_elem = in_vec_len * out_vec_len + in_vec_len + out_vec_len
132
+ item_size = 4 if np_dtype == np.float32 else 2
133
+ return float(N_iter_bench * N_iter_func * n_elem * item_size) / float(1024**3)
134
+
135
+
136
+ def bench_with_in_len(ax, in_vec_len, out_vector_lens, dtype, transpose):
137
+ np_dtype = getattr(np, dtype)
138
+ mlx_gb_s = []
139
+ mlx_gflops = []
140
+ pyt_gb_s = []
141
+ pyt_gflops = []
142
+
143
+ for out_vec_len in out_vector_lens:
144
+ gflop_count = get_gflop_count(in_vec_len, out_vec_len)
145
+ gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
146
+
147
+ time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
148
+
149
+ mlx_gb_s.append(gbyte_size / time_mlx)
150
+ pyt_gb_s.append(gbyte_size / time_torch)
151
+
152
+ mlx_gflops.append(gflop_count / time_mlx)
153
+ pyt_gflops.append(gflop_count / time_torch)
154
+
155
+ if transpose:
156
+ title = f"gemv_t ([1, {in_vec_len}] [{in_vec_len}, out_vec_len]) | {dtype}"
157
+ else:
158
+ title = f"gemv ([out_vec_len, {in_vec_len}] X [{in_vec_len}, 1] ) | {dtype}"
159
+
160
+ ax.plot(out_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
161
+ ax.plot(out_vector_lens, pyt_gb_s, "tab:red", label="Torch")
162
+ ax.set_title(title)
163
+ ax.set(xlabel="out_vector_len", ylabel="Performance (GB/s)")
164
+ ax.legend()
165
+
166
+
167
+ def bench_with_out_len(ax, out_vec_len, in_vector_lens, dtype, transpose):
168
+ np_dtype = getattr(np, dtype)
169
+ mlx_gb_s = []
170
+ mlx_gflops = []
171
+ pyt_gb_s = []
172
+ pyt_gflops = []
173
+
174
+ for in_vec_len in in_vector_lens:
175
+ gflop_count = get_gflop_count(in_vec_len, out_vec_len)
176
+ gbyte_size = get_gbyte_size(in_vec_len, out_vec_len, np_dtype)
177
+
178
+ time_mlx, time_torch = bench_lens(in_vec_len, out_vec_len, np_dtype, transpose)
179
+
180
+ mlx_gb_s.append(gbyte_size / time_mlx)
181
+ pyt_gb_s.append(gbyte_size / time_torch)
182
+
183
+ mlx_gflops.append(gflop_count / time_mlx)
184
+ pyt_gflops.append(gflop_count / time_torch)
185
+
186
+ if transpose:
187
+ title = f"([1, in_vec_len] [in_vec_len, {out_vec_len}])"
188
+ else:
189
+ title = f"([{out_vec_len}, in_vec_len] X [in_vec_len, 1] )"
190
+
191
+ ax.plot(in_vector_lens, mlx_gb_s, "tab:blue", label="MLX")
192
+ ax.plot(in_vector_lens, pyt_gb_s, "tab:red", label="Torch")
193
+ ax.set_title(title)
194
+ ax.set(xlabel="in_vector_len", ylabel="Performance (GB/s)")
195
+ ax.legend()
196
+
197
+
198
+ for transpose in (False, True):
199
+ for dtype in ("float32", "float16", "complex64"):
200
+ fig, axs = plt.subplots(
201
+ len(in_vec_sizes), 2, figsize=(8.5, 11), layout="constrained"
202
+ )
203
+
204
+ for i, in_vec_len in enumerate(in_vec_sizes):
205
+ bench_with_in_len(
206
+ axs[i][0], in_vec_len, benchmark_vector_lens, dtype, transpose
207
+ )
208
+
209
+ for i, out_vec_len in enumerate(out_vec_sizes):
210
+ bench_with_out_len(
211
+ axs[i][1], out_vec_len, benchmark_vector_lens, dtype, transpose
212
+ )
213
+
214
+ op_name = "gemv_t" if transpose else "gemv"
215
+ fig.suptitle(f"{device_name}: {dtype} {op_name}")
216
+ fig.savefig(
217
+ os.path.join(
218
+ results_dir, f"{device_name.replace(' ', '_')}_{dtype}_{op_name}.pdf"
219
+ )
220
+ )
221
+ plt.close(fig)
ml-stable-diffusion/mlx/benchmarks/python/comparative/README.md ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Microbenchmarks comparing MLX to PyTorch
2
+ ========================================
3
+
4
+ Implement the same microbenchmarks in MLX and PyTorch to compare and make a
5
+ list of the biggest possible performance improvements and/or regressions.
6
+
7
+ Run with `python bench_mlx.py sum_axis --size 8x1024x128 --axis 2 --cpu` for
8
+ instance to measure the times it takes to sum across the 3rd axis of the above
9
+ tensor on the cpu.
10
+
11
+ `compare.py` runs several benchmarks and compares the speed-up or lack thereof
12
+ in comparison to PyTorch.
13
+
14
+ Each bench script can be run with `--print-pid` to print the PID and wait for a
15
+ key in order to ease attaching a debugger.
ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_mlx.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import argparse
4
+ import math
5
+ import os
6
+ import time
7
+ from functools import partial
8
+
9
+ import mlx.core as mx
10
+ import mlx.nn as nn
11
+
12
+
13
+ def int_or_list(x):
14
+ try:
15
+ return int(x)
16
+ except ValueError:
17
+ return [int(xi) for xi in x.split(",")]
18
+
19
+
20
+ def none_or_list(x):
21
+ if x == "":
22
+ return None
23
+ else:
24
+ return [int(xi) for xi in x.split(",")]
25
+
26
+
27
+ def dtype_from_str(x):
28
+ if x == "":
29
+ return mx.float32
30
+ else:
31
+ dt = getattr(mx, x)
32
+ if not isinstance(dt, mx.Dtype):
33
+ raise ValueError(f"{x} is not an mlx dtype")
34
+ return dt
35
+
36
+
37
+ def bench(f, *args):
38
+ for i in range(10):
39
+ f(*args)
40
+
41
+ s = time.time()
42
+ for i in range(100):
43
+ f(*args)
44
+ e = time.time()
45
+ return e - s
46
+
47
+
48
+ def matmul_square(x):
49
+ y = x
50
+ for i in range(10):
51
+ y = y @ x
52
+ mx.eval(y)
53
+ return y
54
+
55
+
56
+ def matmul(x, y):
57
+ ys = []
58
+ for i in range(10):
59
+ ys.append(x @ y)
60
+ mx.eval(ys)
61
+
62
+
63
+ def _quant_matmul(x, w, s, b, transpose, group_size, bits):
64
+ ys = []
65
+ for i in range(10):
66
+ ys.append(
67
+ mx.quantized_matmul(
68
+ x, w, s, b, transpose=transpose, group_size=group_size, bits=bits
69
+ )
70
+ )
71
+ mx.eval(ys)
72
+
73
+
74
+ quant_matmul = {
75
+ "quant_matmul_32_2": partial(_quant_matmul, transpose=False, group_size=32, bits=2),
76
+ "quant_matmul_32_4": partial(_quant_matmul, transpose=False, group_size=32, bits=4),
77
+ "quant_matmul_32_8": partial(_quant_matmul, transpose=False, group_size=32, bits=8),
78
+ "quant_matmul_64_2": partial(_quant_matmul, transpose=False, group_size=64, bits=2),
79
+ "quant_matmul_64_4": partial(_quant_matmul, transpose=False, group_size=64, bits=4),
80
+ "quant_matmul_64_8": partial(_quant_matmul, transpose=False, group_size=64, bits=8),
81
+ "quant_matmul_128_2": partial(
82
+ _quant_matmul, transpose=False, group_size=128, bits=2
83
+ ),
84
+ "quant_matmul_128_4": partial(
85
+ _quant_matmul, transpose=False, group_size=128, bits=4
86
+ ),
87
+ "quant_matmul_128_8": partial(
88
+ _quant_matmul, transpose=False, group_size=128, bits=8
89
+ ),
90
+ "quant_matmul_t_32_2": partial(
91
+ _quant_matmul, transpose=True, group_size=32, bits=2
92
+ ),
93
+ "quant_matmul_t_32_4": partial(
94
+ _quant_matmul, transpose=True, group_size=32, bits=4
95
+ ),
96
+ "quant_matmul_t_32_8": partial(
97
+ _quant_matmul, transpose=True, group_size=32, bits=8
98
+ ),
99
+ "quant_matmul_t_64_2": partial(
100
+ _quant_matmul, transpose=True, group_size=64, bits=2
101
+ ),
102
+ "quant_matmul_t_64_4": partial(
103
+ _quant_matmul, transpose=True, group_size=64, bits=4
104
+ ),
105
+ "quant_matmul_t_64_8": partial(
106
+ _quant_matmul, transpose=True, group_size=64, bits=8
107
+ ),
108
+ "quant_matmul_t_128_2": partial(
109
+ _quant_matmul, transpose=True, group_size=128, bits=2
110
+ ),
111
+ "quant_matmul_t_128_4": partial(
112
+ _quant_matmul, transpose=True, group_size=128, bits=4
113
+ ),
114
+ "quant_matmul_t_128_8": partial(
115
+ _quant_matmul, transpose=True, group_size=128, bits=8
116
+ ),
117
+ }
118
+
119
+
120
+ def conv1d(x, y):
121
+ ys = []
122
+ for i in range(10):
123
+ ys.append(mx.conv1d(x, y))
124
+ mx.eval(ys)
125
+
126
+
127
+ def conv2d(x, y):
128
+ ys = []
129
+ for i in range(10):
130
+ ys.append(mx.conv2d(x, y))
131
+ mx.eval(ys)
132
+
133
+
134
+ def binary(op, x, y):
135
+ for i in range(100):
136
+ y = getattr(mx, op)(x, y)
137
+ mx.eval(y)
138
+
139
+
140
+ def reduction(op, axis, x):
141
+ ys = []
142
+ for i in range(100):
143
+ ys.append(getattr(mx, op)(x, axis=axis))
144
+ mx.eval(ys)
145
+
146
+
147
+ def sum_and_add(axis, x, y):
148
+ z = x.sum(axis=axis, keepdims=True)
149
+ for i in range(50):
150
+ z = (z + y).sum(axis=axis, keepdims=True)
151
+ mx.eval(z)
152
+
153
+
154
+ def softmax(axis, x):
155
+ ys = []
156
+ for i in range(100):
157
+ ex = mx.exp(x - mx.max(x, axis=axis, keepdims=True))
158
+ y = ex / mx.sum(ex, axis=axis, keepdims=True)
159
+ ys.append(y)
160
+ mx.eval(ys)
161
+
162
+
163
+ def softmax_fused(axis, x):
164
+ ys = []
165
+ for i in range(100):
166
+ y = mx.softmax(x, axis=axis)
167
+ ys.append(y)
168
+ mx.eval(ys)
169
+
170
+
171
+ def relu(x):
172
+ y = x
173
+ for i in range(100):
174
+ y = nn.relu(y)
175
+ mx.eval(y)
176
+
177
+
178
+ def leaky_relu(x: mx.array):
179
+ y = x
180
+ for i in range(100):
181
+ y = nn.leaky_relu(y)
182
+ mx.eval(y)
183
+
184
+
185
+ def prelu(x: mx.array):
186
+ y = x
187
+ for i in range(100):
188
+ y = nn.prelu(y, mx.ones(1))
189
+ mx.eval(y)
190
+
191
+
192
+ def softplus(x: mx.array):
193
+ y = x
194
+ for i in range(100):
195
+ y = nn.softplus(y)
196
+ mx.eval(y)
197
+
198
+
199
+ def mish(x: mx.array):
200
+ y = x
201
+ for i in range(100):
202
+ y = nn.mish(y)
203
+ mx.eval(y)
204
+
205
+
206
+ def leaky_relu(x):
207
+ y = x
208
+ for i in range(100):
209
+ y = nn.leaky_relu(y)
210
+ mx.eval(y)
211
+
212
+
213
+ def elu(x):
214
+ y = x
215
+ for i in range(100):
216
+ y = nn.elu(y)
217
+ mx.eval(y)
218
+
219
+
220
+ def relu6(x):
221
+ y = x
222
+ for i in range(100):
223
+ y = nn.relu6(y)
224
+ mx.eval(y)
225
+
226
+
227
+ def softplus(x):
228
+ y = x
229
+ for i in range(100):
230
+ y = nn.softplus(y)
231
+ mx.eval(y)
232
+
233
+
234
+ def celu(x):
235
+ y = x
236
+ for i in range(100):
237
+ y = nn.celu(y)
238
+ mx.eval(y)
239
+
240
+
241
+ def log_sigmoid(x):
242
+ y = x
243
+ for i in range(100):
244
+ y = nn.log_sigmoid(y)
245
+ mx.eval(y)
246
+
247
+
248
+ def scalar_mult(x):
249
+ y = x
250
+ for i in range(100):
251
+ y = y * (1.0 / (1 + i))
252
+ mx.eval(y)
253
+
254
+
255
+ def cross_entropy(targets, x):
256
+ ys = []
257
+ for i in range(100):
258
+ y = mx.logsumexp(x, axis=-1, keepdims=True) - mx.take_along_axis(
259
+ x, mx.reshape(targets, (-1, 1)), axis=-1
260
+ )
261
+ ys.append(mx.mean(y))
262
+ mx.eval(ys)
263
+
264
+
265
+ def logsumexp(axis, x):
266
+ ys = []
267
+ for i in range(100):
268
+ ys.append(mx.logsumexp(x, axis=axis))
269
+ mx.eval(ys)
270
+
271
+
272
+ def linear(w, b, x):
273
+ ys = []
274
+ for i in range(10):
275
+ ys.append(x @ mx.transpose(w, (1, 0)) + b)
276
+ mx.eval(ys)
277
+
278
+
279
+ def linear_fused(w, b, x):
280
+ ys = []
281
+ for i in range(10):
282
+ ys.append(mx.addmm(b, x, mx.transpose(w, (1, 0))))
283
+ mx.eval(ys)
284
+
285
+
286
+ def rope(x):
287
+ *_, N, D = x.shape
288
+ ys = []
289
+ for i in range(10):
290
+ shape = x.shape
291
+ x = mx.reshape(x, (-1, N, D))
292
+ positions = mx.arange(N)
293
+ freqs = mx.exp(mx.arange(0.0, D // 2) / math.log(10000 / (D // 2 - 1)))
294
+ theta = mx.reshape(positions, (-1, 1)) * mx.reshape(freqs, (1, -1))
295
+ costheta = mx.cos(theta)
296
+ sintheta = mx.sin(theta)
297
+ x1 = x[..., ::2]
298
+ x2 = x[..., 1::2]
299
+ rx1 = x1 * costheta - x2 * sintheta
300
+ rx2 = x1 * sintheta + x2 * costheta
301
+ y = mx.concatenate([rx1[..., None], rx2[..., None]], axis=-1)
302
+ y = mx.reshape(y, (-1, N, D))
303
+ ys.append(y)
304
+ mx.eval(ys)
305
+
306
+
307
+ def concatenate(axis, x, y):
308
+ ys = []
309
+ for i in range(10):
310
+ ys.append(mx.concatenate([x, y], axis=axis))
311
+ mx.eval(ys)
312
+
313
+
314
+ def cumsum(axis, x):
315
+ ys = []
316
+ for i in range(10):
317
+ ys.append(mx.cumsum(x, axis))
318
+ mx.eval(ys)
319
+
320
+
321
+ def sort(axis, x):
322
+ ys = []
323
+ for i in range(10):
324
+ ys.append(mx.sort(x, axis))
325
+ mx.eval(ys)
326
+
327
+
328
+ def topk(axis, x):
329
+ k = x.shape[axis] // 3
330
+ ys = []
331
+ for i in range(10):
332
+ ys.append(mx.topk(x, k, axis))
333
+ mx.eval(ys)
334
+
335
+
336
+ def step_function(x):
337
+ y = x
338
+ for i in range(100):
339
+ y = nn.step(x)
340
+ mx.eval(y)
341
+
342
+
343
+ def selu(x):
344
+ y = x
345
+ for i in range(100):
346
+ y = nn.selu(x)
347
+ mx.eval(y)
348
+
349
+
350
+ if __name__ == "__main__":
351
+ parser = argparse.ArgumentParser()
352
+ parser.add_argument("benchmark", help="Choose the benchmark to run")
353
+ parser.add_argument(
354
+ "--size",
355
+ default=[(1024, 1024)],
356
+ type=lambda x: list(map(int, x.split("x"))),
357
+ help="Set the matrix size",
358
+ action="append",
359
+ )
360
+ parser.add_argument(
361
+ "--axis",
362
+ default=[1],
363
+ type=int_or_list,
364
+ help="Set a reduction axis",
365
+ action="append",
366
+ )
367
+ parser.add_argument(
368
+ "--transpose",
369
+ type=none_or_list,
370
+ default=[],
371
+ help="Permute the matrix",
372
+ action="append",
373
+ )
374
+ parser.add_argument(
375
+ "--print-pid", action="store_true", help="Print the PID and pause"
376
+ )
377
+ parser.add_argument("--cpu", action="store_true", help="Use the CPU")
378
+ parser.add_argument(
379
+ "--fused", action="store_true", help="Use fused functions where possible"
380
+ )
381
+ parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
382
+
383
+ args = parser.parse_args()
384
+
385
+ if len(args.size) > 1:
386
+ args.size.pop(0)
387
+ if len(args.axis) > 1:
388
+ args.axis.pop(0)
389
+
390
+ if args.cpu:
391
+ mx.set_default_device(mx.cpu)
392
+ else:
393
+ mx.set_default_device(mx.gpu)
394
+
395
+ types = args.dtype
396
+ if not types:
397
+ types = [mx.float32]
398
+ if len(types) < len(args.size):
399
+ types = types + [types[0]] * (len(args.size) - len(types))
400
+
401
+ xs = []
402
+ for size, dtype in zip(args.size, types):
403
+ xs.append(mx.random.normal(size).astype(dtype))
404
+ for i, t in enumerate(args.transpose):
405
+ if t is None:
406
+ continue
407
+ xs[i] = mx.transpose(xs[i], t)
408
+ mx.eval(xs)
409
+ x = xs[0]
410
+ axis = args.axis[0]
411
+
412
+ if args.print_pid:
413
+ print(os.getpid())
414
+ input("Press enter to run")
415
+
416
+ if args.benchmark == "matmul_square":
417
+ print(bench(matmul_square, x))
418
+
419
+ elif args.benchmark == "matmul":
420
+ print(bench(matmul, *xs))
421
+
422
+ elif args.benchmark.startswith("quant_matmul"):
423
+ print(bench(quant_matmul[args.benchmark], *xs))
424
+
425
+ elif args.benchmark == "linear":
426
+ if args.fused:
427
+ print(bench(linear_fused, *xs))
428
+ else:
429
+ print(bench(linear, *xs))
430
+
431
+ elif args.benchmark == "sum_axis":
432
+ print(bench(reduction, "sum", axis, x))
433
+
434
+ elif args.benchmark == "sum_all":
435
+ print(bench(reduction, "sum", None, x))
436
+
437
+ elif args.benchmark == "argmax":
438
+ print(bench(reduction, "argmax", axis, x))
439
+
440
+ elif args.benchmark == "add":
441
+ print(bench(binary, "add", *xs))
442
+
443
+ elif args.benchmark == "mul":
444
+ print(bench(binary, "multiply", *xs))
445
+
446
+ elif args.benchmark == "softmax":
447
+ if args.fused:
448
+ print(bench(softmax_fused, axis, x))
449
+ else:
450
+ print(bench(softmax, axis, x))
451
+
452
+ elif args.benchmark == "relu":
453
+ print(bench(relu, x))
454
+
455
+ elif args.benchmark == "elu":
456
+ print(bench(elu, x))
457
+
458
+ elif args.benchmark == "relu6":
459
+ print(bench(relu6, x))
460
+
461
+ elif args.benchmark == "celu":
462
+ print(bench(celu, x))
463
+
464
+ elif args.benchmark == "log_sigmoid":
465
+ print(bench(log_sigmoid, x))
466
+
467
+ elif args.benchmark == "leaky_relu":
468
+ print(bench(leaky_relu, x))
469
+ elif args.benchmark == "prelu":
470
+ print(bench(prelu, x))
471
+ elif args.benchmark == "softplus":
472
+ print(bench(softplus, x))
473
+ elif args.benchmark == "mish":
474
+ print(bench(mish, x))
475
+ elif args.benchmark == "scalar_mul":
476
+ print(bench(scalar_mult, x))
477
+
478
+ elif args.benchmark == "cross_entropy":
479
+ if len(size) != 2:
480
+ raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
481
+
482
+ targets = mx.zeros((len(x),), dtype=mx.uint32)
483
+ print(bench(cross_entropy, targets, x))
484
+
485
+ elif args.benchmark == "logsumexp":
486
+ print(bench(logsumexp, axis, x))
487
+
488
+ elif args.benchmark == "rope":
489
+ print(bench(rope, x))
490
+
491
+ elif args.benchmark == "concatenate":
492
+ print(bench(concatenate, axis, *xs))
493
+
494
+ elif args.benchmark == "cumsum":
495
+ print(bench(cumsum, axis, *xs))
496
+
497
+ elif args.benchmark == "conv1d":
498
+ print(bench(conv1d, *xs))
499
+
500
+ elif args.benchmark == "conv2d":
501
+ print(bench(conv2d, *xs))
502
+
503
+ elif args.benchmark == "sort":
504
+ print(bench(sort, axis, x))
505
+
506
+ elif args.benchmark == "topk":
507
+ print(bench(topk, axis, x))
508
+
509
+ elif args.benchmark == "step":
510
+ print(bench(step_function, x))
511
+
512
+ elif args.benchmark == "selu":
513
+ print(bench(selu, x))
514
+
515
+ elif args.benchmark == "sum_and_add":
516
+ print(bench(sum_and_add, axis, *xs))
517
+
518
+ else:
519
+ raise ValueError("Unknown benchmark")
ml-stable-diffusion/mlx/benchmarks/python/comparative/bench_torch.py ADDED
@@ -0,0 +1,482 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ import argparse
4
+ import os
5
+ import time
6
+
7
+ import torch
8
+ import torch.cuda
9
+ import torch.mps
10
+
11
+
12
+ def int_or_list(x):
13
+ try:
14
+ return int(x)
15
+ except ValueError:
16
+ return [int(xi) for xi in x.split(",")]
17
+
18
+
19
+ def none_or_list(x):
20
+ if x == "":
21
+ return None
22
+ else:
23
+ return [int(xi) for xi in x.split(",")]
24
+
25
+
26
+ def dtype_from_str(x):
27
+ if x == "":
28
+ return torch.float32
29
+ else:
30
+ dt = getattr(torch, x)
31
+ if not isinstance(dt, torch.dtype):
32
+ raise ValueError(f"{x} is not a torch dtype")
33
+ return dt
34
+
35
+
36
+ def bench(f, *args):
37
+ for i in range(10):
38
+ f(*args)
39
+
40
+ s = time.time()
41
+ for i in range(100):
42
+ f(*args)
43
+ e = time.time()
44
+ return e - s
45
+
46
+
47
+ def sync_if_needed(x):
48
+ if x.device == torch.device("mps"):
49
+ torch.mps.synchronize()
50
+ elif x.device == torch.device("cuda"):
51
+ torch.cuda.synchronize()
52
+
53
+
54
+ @torch.no_grad()
55
+ def matmul_square(x):
56
+ y = x
57
+ for i in range(10):
58
+ y = y @ x
59
+ sync_if_needed(x)
60
+
61
+
62
+ @torch.no_grad()
63
+ def matmul(x, y):
64
+ ys = []
65
+ for i in range(10):
66
+ ys.append(x @ y)
67
+ sync_if_needed(x)
68
+
69
+
70
+ @torch.no_grad()
71
+ def conv1d(x, y):
72
+ x = torch.transpose(x, -1, -2)
73
+ y = torch.transpose(y, -1, -2)
74
+ ys = []
75
+ for i in range(10):
76
+ ys.append(torch.nn.functional.conv1d(x, y))
77
+ sync_if_needed(x)
78
+
79
+
80
+ @torch.no_grad()
81
+ def conv2d(x, y):
82
+ x = torch.permute(x, (0, 3, 1, 2))
83
+ y = torch.permute(y, (0, 3, 1, 2))
84
+ ys = []
85
+ for i in range(10):
86
+ ys.append(torch.nn.functional.conv2d(x, y))
87
+ sync_if_needed(x)
88
+
89
+
90
+ @torch.no_grad()
91
+ def binary(op, x, y):
92
+ for i in range(100):
93
+ y = getattr(torch, op)(x, y)
94
+ sync_if_needed(x)
95
+
96
+
97
+ @torch.no_grad()
98
+ def reduction(op, axis, x):
99
+ ys = []
100
+ for i in range(100):
101
+ ys.append(getattr(x, op)(axis))
102
+ sync_if_needed(x)
103
+
104
+
105
+ @torch.no_grad()
106
+ def sum_and_add(axis, x, y):
107
+ z = x.sum(axis=axis, keepdims=True)
108
+ for i in range(50):
109
+ z = (z + y).sum(axis=axis, keepdims=True)
110
+ sync_if_needed(x)
111
+
112
+
113
+ @torch.no_grad()
114
+ def softmax(axis, x):
115
+ ys = []
116
+ for i in range(100):
117
+ ex = torch.exp(x - torch.max(x, dim=axis, keepdims=True).values)
118
+ y = ex / torch.sum(ex, dim=axis, keepdims=True)
119
+ ys.append(y)
120
+ sync_if_needed(x)
121
+
122
+
123
+ @torch.no_grad()
124
+ def softmax_fused(axis, x):
125
+ ys = []
126
+ for i in range(100):
127
+ ys.append(torch.nn.functional.softmax(x, dim=axis))
128
+ sync_if_needed(x)
129
+
130
+
131
+ @torch.no_grad()
132
+ def relu(x):
133
+ y = x
134
+ for i in range(100):
135
+ y = torch.nn.functional.relu(y)
136
+ sync_if_needed(x)
137
+
138
+
139
+ @torch.no_grad()
140
+ def leaky_relu(x):
141
+ y = x
142
+ for i in range(100):
143
+ y = torch.nn.functional.leaky_relu(y)
144
+ sync_if_needed(x)
145
+
146
+
147
+ @torch.no_grad()
148
+ def elu(x):
149
+ y = x
150
+ for i in range(100):
151
+ y = torch.nn.functional.elu(y)
152
+ sync_if_needed(x)
153
+
154
+
155
+ @torch.no_grad()
156
+ def celu(x):
157
+ y = x
158
+ for i in range(100):
159
+ y = torch.nn.functional.celu(y)
160
+ sync_if_needed(x)
161
+
162
+
163
+ @torch.no_grad()
164
+ def relu6(x):
165
+ y = x
166
+ for i in range(100):
167
+ y = torch.nn.functional.relu6(y)
168
+ sync_if_needed(x)
169
+
170
+
171
+ @torch.no_grad()
172
+ def softplus(x):
173
+ y = x
174
+ for i in range(100):
175
+ y = torch.nn.functional.softplus(y)
176
+ sync_if_needed(x)
177
+
178
+
179
+ @torch.no_grad()
180
+ def log_sigmoid(x):
181
+ y = x
182
+ for i in range(100):
183
+ y = torch.nn.functional.logsigmoid(y)
184
+ sync_if_needed(x)
185
+
186
+
187
+ @torch.no_grad()
188
+ def prelu(x: torch.Tensor) -> torch.Tensor:
189
+ y = x
190
+ for _ in range(100):
191
+ y = torch.nn.functional.prelu(y, torch.ones(1).to(y.device))
192
+ sync_if_needed(x)
193
+
194
+
195
+ @torch.no_grad()
196
+ def mish(x: torch.Tensor) -> torch.Tensor:
197
+ y = x
198
+ for _ in range(100):
199
+ y = torch.nn.functional.mish(y)
200
+ sync_if_needed(x)
201
+
202
+
203
+ @torch.no_grad()
204
+ def scalar_mult(x):
205
+ y = x
206
+ for i in range(100):
207
+ y = y * (1.0 / (1 + i))
208
+ sync_if_needed(x)
209
+
210
+
211
+ @torch.no_grad()
212
+ def cross_entropy(targets, x):
213
+ ys = []
214
+ for i in range(100):
215
+ ys.append(torch.nn.functional.cross_entropy(x, targets))
216
+ sync_if_needed(x)
217
+
218
+
219
+ @torch.no_grad()
220
+ def logsumexp(axis, x):
221
+ ys = []
222
+ for i in range(100):
223
+ ys.append(torch.logsumexp(x, dim=axis))
224
+ sync_if_needed(x)
225
+
226
+
227
+ @torch.no_grad()
228
+ def linear_fused(w, b, x):
229
+ ys = []
230
+ for i in range(10):
231
+ ys.append(torch.nn.functional.linear(x, w, b))
232
+ sync_if_needed(x)
233
+
234
+
235
+ @torch.no_grad()
236
+ def linear(w, b, x):
237
+ ys = []
238
+ for i in range(10):
239
+ ys.append((x @ torch.transpose(w, -2, -1)) + b)
240
+ sync_if_needed(x)
241
+
242
+
243
+ @torch.no_grad()
244
+ def rope(x):
245
+ *_, N, D = x.shape
246
+ ys = []
247
+ for i in range(10):
248
+ x = x.view(-1, N, D)
249
+ positions = torch.arange(N, device=x.device)
250
+ freqs = 10000 ** torch.linspace(0, 1, D // 2, device=x.device)
251
+ theta = positions[:, None] * freqs[None]
252
+ costheta = torch.cos(theta)
253
+ sintheta = torch.sin(theta)
254
+ x1 = x[..., ::2]
255
+ x2 = x[..., 1::2]
256
+ rx1 = x1 * costheta - x2 * sintheta
257
+ rx2 = x1 * sintheta + x2 * costheta
258
+ y = torch.cat([rx1[..., None], rx2[..., None]], dim=-1)
259
+ y = y.reshape(-1, N, D)
260
+ ys.append(y)
261
+ sync_if_needed(x)
262
+
263
+
264
+ @torch.no_grad()
265
+ def concatenate(axis, x, y):
266
+ ys = []
267
+ for i in range(10):
268
+ ys.append(torch.cat([x, y], dim=axis))
269
+ sync_if_needed(x)
270
+
271
+
272
+ @torch.no_grad()
273
+ def cumsum(axis, x):
274
+ ys = []
275
+ for i in range(10):
276
+ ys.append(x.cumsum(axis))
277
+ sync_if_needed(x)
278
+
279
+
280
+ @torch.no_grad()
281
+ def sort(axis, x):
282
+ ys = []
283
+ for i in range(10):
284
+ ys.append(torch.sort(x, dim=axis)[0])
285
+ sync_if_needed(x)
286
+
287
+
288
+ @torch.no_grad()
289
+ def topk(axis, x):
290
+ k = x.shape[axis] // 3
291
+ ys = []
292
+ for i in range(10):
293
+ ys.append(torch.topk(x, k, dim=axis)[0])
294
+ sync_if_needed(x)
295
+
296
+
297
+ @torch.no_grad()
298
+ def step_function(x):
299
+ y = x
300
+ for i in range(100):
301
+ y = torch.where(y < 0, 0, 1)
302
+ sync_if_needed(x)
303
+
304
+
305
+ @torch.no_grad()
306
+ def selu(x):
307
+ y = x
308
+ for i in range(100):
309
+ y = torch.nn.functional.selu(y)
310
+ sync_if_needed(x)
311
+
312
+
313
+ if __name__ == "__main__":
314
+ parser = argparse.ArgumentParser()
315
+ parser.add_argument("benchmark", help="Choose the benchmark to run")
316
+ parser.add_argument(
317
+ "--size",
318
+ default=[(1024, 1024)],
319
+ type=lambda x: list(map(int, x.split("x"))),
320
+ help="Set the matrix size",
321
+ action="append",
322
+ )
323
+ parser.add_argument(
324
+ "--axis",
325
+ default=[1],
326
+ type=int_or_list,
327
+ help="Set a reduction axis",
328
+ action="append",
329
+ )
330
+ parser.add_argument(
331
+ "--transpose",
332
+ type=none_or_list,
333
+ default=[],
334
+ help="Permute the matrix",
335
+ action="append",
336
+ )
337
+ parser.add_argument(
338
+ "--print-pid", action="store_true", help="Print the PID and pause"
339
+ )
340
+ parser.add_argument("--cpu", action="store_true", help="Use the CPU")
341
+ parser.add_argument(
342
+ "--fused", action="store_true", help="Use fused functions where possible"
343
+ )
344
+ parser.add_argument("--dtype", type=dtype_from_str, default=[], action="append")
345
+
346
+ args = parser.parse_args()
347
+
348
+ if len(args.size) > 1:
349
+ args.size.pop(0)
350
+ if len(args.axis) > 1:
351
+ args.axis.pop(0)
352
+
353
+ torch.set_num_threads(1)
354
+ device = "mps"
355
+ if torch.cuda.is_available():
356
+ device = "cuda"
357
+ if args.cpu:
358
+ device = "cpu"
359
+
360
+ types = args.dtype
361
+ if not types:
362
+ types = [torch.float32]
363
+ if len(types) < len(args.size):
364
+ types = types + [types[0]] * (len(args.size) - len(types))
365
+
366
+ xs = []
367
+ for size, dtype in zip(args.size, types):
368
+ xs.append(torch.randn(*size).to(device).to(dtype))
369
+ for i, t in enumerate(args.transpose):
370
+ if t is None:
371
+ continue
372
+ xs[i] = xs[i].permute(*t)
373
+ x = xs[0]
374
+ axis = args.axis[0]
375
+
376
+ if args.print_pid:
377
+ print(os.getpid())
378
+ input("Press enter to run")
379
+
380
+ if args.benchmark == "matmul_square":
381
+ print(bench(matmul_square, x))
382
+
383
+ elif args.benchmark == "matmul":
384
+ print(bench(matmul, *xs))
385
+
386
+ elif args.benchmark == "linear":
387
+ if args.fused:
388
+ print(bench(linear_fused, *xs))
389
+ else:
390
+ print(bench(linear, *xs))
391
+
392
+ elif args.benchmark == "sum_axis":
393
+ print(bench(reduction, "sum", axis, x))
394
+
395
+ elif args.benchmark == "sum_all":
396
+ print(bench(reduction, "sum", None, x))
397
+
398
+ elif args.benchmark == "argmax":
399
+ print(bench(reduction, "argmax", axis, x))
400
+
401
+ elif args.benchmark == "add":
402
+ print(bench(binary, "add", *xs))
403
+
404
+ elif args.benchmark == "mul":
405
+ print(bench(binary, "mul", *xs))
406
+
407
+ elif args.benchmark == "softmax":
408
+ if args.fused:
409
+ print(bench(softmax_fused, axis, x))
410
+ else:
411
+ print(bench(softmax, axis, x))
412
+
413
+ elif args.benchmark == "relu":
414
+ print(bench(relu, x))
415
+
416
+ elif args.benchmark == "leaky_relu":
417
+ print(bench(leaky_relu, x))
418
+
419
+ elif args.benchmark == "elu":
420
+ print(bench(elu, x))
421
+
422
+ elif args.benchmark == "relu6":
423
+ print(bench(relu6, x))
424
+
425
+ elif args.benchmark == "softplus":
426
+ print(bench(softplus, x))
427
+
428
+ elif args.benchmark == "celu":
429
+ print(bench(celu, x))
430
+
431
+ elif args.benchmark == "log_sigmoid":
432
+ print(bench(log_sigmoid, x))
433
+
434
+ elif args.benchmark == "prelu":
435
+ print(bench(prelu, x))
436
+ elif args.benchmark == "mish":
437
+ print(bench(mish, x))
438
+ elif args.benchmark == "scalar_mul":
439
+ print(bench(scalar_mult, x))
440
+
441
+ elif args.benchmark == "cross_entropy":
442
+ if len(size) != 2:
443
+ raise ValueError("Error: [cross_entropy] benchmark requires a 2 dim size")
444
+
445
+ targets = torch.zeros(len(x), dtype=torch.long).to(x.device)
446
+ print(bench(cross_entropy, targets, x))
447
+
448
+ elif args.benchmark == "logsumexp":
449
+ print(bench(logsumexp, axis, x))
450
+
451
+ elif args.benchmark == "rope":
452
+ print(bench(rope, x))
453
+
454
+ elif args.benchmark == "concatenate":
455
+ print(bench(concatenate, axis, *xs))
456
+
457
+ elif args.benchmark == "cumsum":
458
+ print(bench(cumsum, axis, *xs))
459
+
460
+ elif args.benchmark == "conv1d":
461
+ print(bench(conv1d, *xs))
462
+
463
+ elif args.benchmark == "conv2d":
464
+ print(bench(conv2d, *xs))
465
+
466
+ elif args.benchmark == "sort":
467
+ print(bench(sort, axis, x))
468
+
469
+ elif args.benchmark == "topk":
470
+ print(bench(topk, axis, x))
471
+
472
+ elif args.benchmark == "step":
473
+ print(bench(step_function, x))
474
+
475
+ elif args.benchmark == "selu":
476
+ print(bench(selu, x))
477
+
478
+ elif args.benchmark == "sum_and_add":
479
+ print(bench(sum_and_add, axis, *xs))
480
+
481
+ else:
482
+ raise ValueError(f"Unknown benchmark `{args.benchmark}`.")
ml-stable-diffusion/mlx/benchmarks/python/comparative/compare.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023 Apple Inc.
2
+
3
+ #!/usr/bin/env python
4
+
5
+ import argparse
6
+ import re
7
+ from pathlib import Path
8
+ from subprocess import run
9
+
10
+ BENCH_MLX = Path(__file__).parent / "bench_mlx.py"
11
+ BENCH_TORCH = Path(__file__).parent / "bench_torch.py"
12
+
13
+
14
+ def run_or_raise(*args, **kwargs):
15
+ try:
16
+ result = run(*args, capture_output=True, **kwargs)
17
+ return float(result.stdout)
18
+ except ValueError:
19
+ raise ValueError(
20
+ f"stdout: {result.stdout.decode()}\nstderr: {result.stderr.decode()}"
21
+ )
22
+
23
+
24
+ def compare(args):
25
+ t_mlx = run_or_raise(["python", BENCH_MLX] + args)
26
+ t_torch = run_or_raise(["python", BENCH_TORCH] + args)
27
+
28
+ print((t_torch - t_mlx) / t_torch, " ".join(args), sep="\t")
29
+
30
+
31
+ def compare_mlx_dtypes(args, dt1, dt2):
32
+ t_mlx_dt1 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt1])
33
+ t_mlx_dt2 = run_or_raise(["python", BENCH_MLX] + args + ["--dtype", dt2])
34
+
35
+ print((t_mlx_dt2 - t_mlx_dt1) / t_mlx_dt2, " ".join(args), sep="\t")
36
+
37
+
38
+ def make_regex_search(regexes):
39
+ compiled_regexes = list(map(re.compile, regexes))
40
+
41
+ def search(x):
42
+ return (c.search(x) is not None for c in compiled_regexes)
43
+
44
+ return search
45
+
46
+
47
+ def make_predicate(positive_filter, negative_filter):
48
+ if positive_filter is not None:
49
+ positive_filter_search = make_regex_search(positive_filter)
50
+ positive_filter = lambda x: all(positive_filter_search(x))
51
+ else:
52
+ positive_filter = lambda x: True
53
+
54
+ if negative_filter is not None:
55
+ negative_filter_search = make_regex_search(negative_filter)
56
+ negative_filter = lambda x: not any(negative_filter_search(x))
57
+ else:
58
+ negative_filter = lambda x: True
59
+
60
+ def predicate(x):
61
+ return positive_filter(x) and negative_filter(x)
62
+
63
+ return predicate
64
+
65
+
66
+ if __name__ == "__main__":
67
+ parser = argparse.ArgumentParser(description="Run comparisons against PyTorch")
68
+ parser.add_argument(
69
+ "--filter", "-f", help="Regex filter to select benchmarks", nargs="+"
70
+ )
71
+ parser.add_argument(
72
+ "--negative_filter", "-n", help="Regex filter to remove benchmarks", nargs="+"
73
+ )
74
+ parser.add_argument(
75
+ "--mlx_dtypes",
76
+ "-d",
77
+ help="Compare mlx benchmarks between the 2 provided data types",
78
+ nargs=2,
79
+ )
80
+ args, rest = parser.parse_known_args()
81
+
82
+ _filter = make_predicate(args.filter, args.negative_filter)
83
+
84
+ if args.mlx_dtypes:
85
+ compare_filtered = lambda x: (
86
+ compare_mlx_dtypes(x.split() + rest, args.mlx_dtypes[0], args.mlx_dtypes[1])
87
+ if _filter(x)
88
+ else None
89
+ )
90
+ else:
91
+ compare_filtered = lambda x: compare(x.split() + rest) if _filter(x) else None
92
+
93
+ # Binary ops
94
+ compare_filtered("add --size 10x1024x128 --size 1x1024x128 --cpu")
95
+ compare_filtered("add --size 10x1024x128 --size 1x1024x128")
96
+ compare_filtered("add --size 1024x128 --size 1x128 --cpu")
97
+ compare_filtered("add --size 1024x128 --size 1x128")
98
+ compare_filtered("add --size 1024x4096 --size 1x4096 --cpu")
99
+ compare_filtered("add --size 1024x4096 --size 1x4096")
100
+ compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0 --cpu")
101
+ compare_filtered("add --size 1024x4096 --size 1x1024 --transpose 1,0")
102
+ compare_filtered("add --size 1024x1024 --size 1024x1024 --cpu")
103
+ compare_filtered("add --size 1024x1024 --size 1024x1024")
104
+ compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0 --cpu")
105
+ compare_filtered("add --size 1024x1024 --size 1024x1024 --transpose 1,0")
106
+ compare_filtered(
107
+ "add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0 --cpu"
108
+ )
109
+ compare_filtered(
110
+ "add --size 1024x1024 --size 1024x1024 --transpose 1,0 --transpose 1,0"
111
+ )
112
+
113
+ # Reduction ops
114
+ compare_filtered("sum_all --size 10x1024x128 --cpu")
115
+ compare_filtered("sum_all --size 10x1024x128")
116
+ compare_filtered("sum_axis --size 16x1024x128 --axis 2 --cpu")
117
+ compare_filtered("sum_axis --size 16x1024x128 --axis 2")
118
+ compare_filtered("sum_axis --size 16x128x1024 --axis 2 --cpu")
119
+ compare_filtered("sum_axis --size 16x128x1024 --axis 2")
120
+ compare_filtered("sum_axis --size 1024x1024 --axis 1 --cpu")
121
+ compare_filtered("sum_axis --size 1024x1024 --axis 1")
122
+ compare_filtered("sum_axis --size 1024x1024 --axis 0 --cpu")
123
+ compare_filtered("sum_axis --size 1024x1024 --axis 0")
124
+ compare_filtered("sum_axis --size 16x128x1024 --axis 1 --cpu")
125
+ compare_filtered("sum_axis --size 16x128x1024 --axis 1")
126
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0 --cpu")
127
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0")
128
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --cpu")
129
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,1")
130
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --cpu")
131
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,2")
132
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1 --cpu")
133
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,1 --transpose 0,2,1")
134
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1 --cpu")
135
+ compare_filtered("sum_axis --size 16x128x1024 --axis 0,2 --transpose 0,2,1")
136
+ compare_filtered("argmax --size 10x1024x128 --axis 1 --cpu")
137
+ compare_filtered("argmax --size 10x1024x128 --axis 1")
138
+ compare_filtered("argmax --size 10x1024x128 --axis 2 --cpu")
139
+ compare_filtered("argmax --size 10x1024x128 --axis 2")
140
+ compare_filtered("argmax --size 1024x1024 --axis 1 --cpu")
141
+ compare_filtered("argmax --size 1024x1024 --axis 1")
142
+
143
+ # Matmul ops
144
+ compare_filtered("matmul_square --size 1024x1024")
145
+ compare_filtered("matmul_square --size 1024x1024 --cpu")
146
+ compare_filtered("matmul_square --size 16x1024x1024")
147
+ compare_filtered("matmul_square --size 16x1024x1024 --cpu")
148
+ compare_filtered(
149
+ "matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1"
150
+ )
151
+ compare_filtered(
152
+ "matmul --size 16x768x768 --size 16x768x768 --transpose= --transpose 0,2,1 --cpu"
153
+ )
154
+ compare_filtered(
155
+ "matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1"
156
+ )
157
+ compare_filtered(
158
+ "matmul --size 16x768x128 --size 16x768x128 --transpose= --transpose 0,2,1 --cpu"
159
+ )
160
+ compare_filtered("matmul --size 512x8192 --size 8192x512")
161
+ compare_filtered("matmul --size 512x8192 --size 8192x512 --cpu")
162
+ # compare_filtered("matmul --size 512x131072 --size 131072x512")
163
+ # compare_filtered("matmul --size 512x131072 --size 131072x512 --cpu")
164
+ compare_filtered("matmul --size 8192x512 --size 512x8192")
165
+ compare_filtered("matmul --size 8192x512 --size 512x8192 --cpu")
166
+ # compare_filtered("matmul --size 131072x512 --size 512x512")
167
+ # compare_filtered("matmul --size 131072x512 --size 512x512 --cpu")
168
+ compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024")
169
+ compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --cpu")
170
+ compare_filtered("linear --size 1024x1024 --size 1024 --size 128x1024 --fused")
171
+ compare_filtered(
172
+ "linear --size 1024x1024 --size 1024 --size 128x1024 --fused --cpu"
173
+ )
174
+
175
+ # Matvec ops
176
+ compare_filtered("matmul --size 1x1x4096 --size 4096x4096 --cpu")
177
+ compare_filtered("matmul --size 1x1x4096 --size 4096x4096")
178
+ compare_filtered(
179
+ "matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0 --cpu"
180
+ )
181
+ compare_filtered(
182
+ "matmul --size 1x1x4096 --size 4096x4096 --transpose= --transpose 1,0"
183
+ )
184
+ compare_filtered("matmul --size 32x1x1000 --size 32x1000x128 --cpu")
185
+ compare_filtered("matmul --size 32x1x1000 --size 32x1000x128")
186
+ compare_filtered(
187
+ "matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1 --cpu"
188
+ )
189
+ compare_filtered(
190
+ "matmul --size 32x1x1000 --size 32x128x1000 --transpose= --transpose 0,2,1"
191
+ )
192
+
193
+ # Various ops
194
+ compare_filtered("softmax --size 32x16x1024 --axis 2")
195
+ compare_filtered("softmax --size 32x16x1024 --axis 2 --cpu")
196
+ compare_filtered("softmax --size 32x16x1024 --axis 2 --fused")
197
+ compare_filtered("softmax --size 32x16x1024 --axis 2 --fused --cpu")
198
+ compare_filtered("softmax --size 2x1024x1024 --axis 1")
199
+ compare_filtered("softmax --size 2x1024x1024 --axis 1 --cpu")
200
+ compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused")
201
+ compare_filtered("softmax --size 2x1024x1024 --axis 1 --fused --cpu")
202
+ compare_filtered("relu --size 32x16x1024")
203
+ compare_filtered("relu --size 32x16x1024 --cpu")
204
+ compare_filtered("leaky_relu --size 32x16x1024")
205
+ compare_filtered("leaky_relu --size 32x16x1024 --cpu")
206
+ compare_filtered("elu --size 32x16x1024")
207
+ compare_filtered("elu --size 32x16x1024 --cpu")
208
+ compare_filtered("relu6 --size 32x16x1024")
209
+ compare_filtered("relu6 --size 32x16x1024 --cpu")
210
+ compare_filtered("softplus --size 32x16x1024")
211
+ compare_filtered("softplus --size 32x16x1024 --cpu")
212
+ compare_filtered("celu --size 32x16x1024")
213
+ compare_filtered("celu --size 32x16x1024 --cpu")
214
+ compare_filtered("log_sigmoid --size 32x16x1024")
215
+ compare_filtered("log_sigmoid --size 32x16x1024 --cpu")
216
+ compare_filtered("step --size 32x16x1024")
217
+ compare_filtered("step --size 32x16x1024 --cpu")
218
+ compare_filtered("selu --size 32x16x1024")
219
+ compare_filtered("selu --size 32x16x1024 --cpu")
220
+ # compare_filtered("mish --size 32x16x1024") NOTE: Torch does not implement Mish in MPS atm
221
+ compare_filtered("mish --size 32x16x1024 --cpu")
222
+ compare_filtered("prelu --size 32x16x1024")
223
+ compare_filtered("prelu --size 32x16x1024 --cpu")
224
+
225
+ compare_filtered("scalar_mul --size 32x16x1024")
226
+ compare_filtered("scalar_mul --size 32x16x1024 --cpu")
227
+ compare_filtered("cross_entropy --size 256x1024")
228
+ compare_filtered("cross_entropy --size 256x1024 --cpu")
229
+ compare_filtered("logsumexp --size 1024x1024 --axis 1")
230
+ compare_filtered("logsumexp --size 1024x1024 --axis 1 --cpu")
231
+ compare_filtered("logsumexp --size 1024x1024 --axis 0")
232
+ compare_filtered("logsumexp --size 1024x1024 --axis 0 --cpu")
233
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2")
234
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 2 --cpu")
235
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1")
236
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 1 --cpu")
237
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0")
238
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1024x128 --axis 0 --cpu")
239
+ compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1")
240
+ compare_filtered("concatenate --size 32x1024x128 --size 32x16x128 --axis 1 --cpu")
241
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1")
242
+ compare_filtered("concatenate --size 32x1024x128 --size 32x1x128 --axis 1 --cpu")
243
+ compare_filtered("concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2")
244
+ compare_filtered(
245
+ "concatenate --size 1x32x1024x128 --size 1x32x1x128 --axis 2 --cpu"
246
+ )
247
+ compare_filtered("conv1d --size 1x1000x80 --size 128x11x80")
248
+ compare_filtered("conv1d --size 1x1000x80 --size 128x11x80 --cpu")
249
+ compare_filtered("conv1d --size 16x1000x80 --size 128x11x80")
250
+ compare_filtered("conv1d --size 4x1000x80 --size 128x11x80 --cpu")
251
+ compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3")
252
+ compare_filtered("conv2d --size 1x256x256x3 --size 8x3x3x3 --cpu")
253
+ compare_filtered("conv2d --size 16x256x256x3 --size 8x3x3x3")
254
+ compare_filtered("conv2d --size 4x256x256x3 --size 8x3x3x3 --cpu")
255
+ compare_filtered("cumsum --size 1024x1024 --axis 1 --cpu")
256
+ compare_filtered("cumsum --size 1024x1024 --axis 0 --cpu")
257
+ compare_filtered("cumsum --size 1024x1024 --axis 1")
258
+ compare_filtered("cumsum --size 1024x1024 --axis 0")
259
+ compare_filtered("cumsum --size 128x1024 --axis 1")
260
+ compare_filtered("cumsum --size 128x1024 --axis 0")
261
+ compare_filtered("cumsum --size 1024x4096 --axis 1")
262
+ compare_filtered("cumsum --size 1024x4096 --axis 0")
263
+ compare_filtered("cumsum --size 128x4096 --axis 1")
264
+ compare_filtered("cumsum --size 128x4096 --axis 0")
265
+ compare_filtered("cumsum --size 1024x7777 --axis 1")
266
+ compare_filtered("cumsum --size 1024x7777 --axis 0")
267
+ compare_filtered("cumsum --size 128x7777 --axis 1")
268
+ compare_filtered("cumsum --size 128x7777 --axis 0")
269
+ compare_filtered("cumsum --size 32768x128 --axis 1")
270
+ compare_filtered("cumsum --size 32768x128 --axis 0")
271
+
272
+ compare_filtered("sort --size 1024x1024 --axis 0")
273
+ compare_filtered("sort --size 1024x1024 --axis 1")
274
+ compare_filtered("sort --size 32768x128 --axis 0")
275
+ compare_filtered("sort --size 32768x128 --axis 1")
276
+ compare_filtered("sort --size 128x128 --axis 0 --cpu")
277
+ compare_filtered("sort --size 128x128 --axis 1 --cpu")
278
+
279
+ compare_filtered("topk --size 1024x1024 --axis 0")
280
+ compare_filtered("topk --size 1024x1024 --axis 1")
281
+ compare_filtered("topk --size 32768x128 --axis 0")
282
+ compare_filtered("topk --size 32768x128 --axis 1")
283
+ compare_filtered("topk --size 128x128 --axis 0 --cpu")
284
+ compare_filtered("topk --size 128x128 --axis 1 --cpu")
ml-stable-diffusion/mlx/benchmarks/python/compile_bench.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023-2024 Apple Inc.
2
+
3
+ import argparse
4
+ import math
5
+ import random
6
+
7
+ import mlx.core as mx
8
+ from time_utils import time_fn
9
+
10
+
11
+ def bench_gelu():
12
+ def gelu(x):
13
+ return x * (1 + mx.erf(x / math.sqrt(2))) / 2
14
+
15
+ x = mx.random.uniform(shape=(1000, 1024))
16
+
17
+ def gen_fun(fun):
18
+ def bench_fun(x):
19
+ for _ in range(10):
20
+ x = fun(x)
21
+ return x
22
+
23
+ return bench_fun
24
+
25
+ time_fn(gen_fun(gelu), x, msg="fixed gelu")
26
+ time_fn(gen_fun(mx.compile(gelu)), x, msg="compiled fixed gelu")
27
+
28
+ def randint():
29
+ return random.randint(1, x.shape[0])
30
+
31
+ def gen_fun(fun):
32
+ def bench_fun(x, y):
33
+ x = x[: randint()]
34
+ for _ in range(10):
35
+ x = fun(x)
36
+ y = fun(y)
37
+ return x, y
38
+
39
+ return bench_fun
40
+
41
+ y = mx.random.uniform(shape=(1000, 1024))
42
+ time_fn(gen_fun(gelu), x, y, msg="variable gelu")
43
+ time_fn(gen_fun(mx.compile(gelu)), x, y, msg="compiled variable gelu")
44
+ time_fn(
45
+ gen_fun(mx.compile(gelu, shapeless=True)),
46
+ x,
47
+ y,
48
+ msg="shapeless variable gelu",
49
+ )
50
+
51
+
52
+ def bench_layernorm():
53
+ weight = mx.random.uniform(shape=(4096,)).astype(mx.float16)
54
+ bias = mx.random.uniform(shape=(4096,)).astype(mx.float16)
55
+ mx.eval(weight, bias)
56
+
57
+ def layernorm(x):
58
+ x = x.astype(mx.float32)
59
+ means = mx.mean(x, axis=-1, keepdims=True)
60
+ var = mx.var(x, axis=-1, keepdims=True)
61
+ x = (x - means) * mx.rsqrt(var + 1e-4)
62
+ x = x.astype(mx.float16)
63
+ return weight * x + bias
64
+
65
+ x = mx.random.uniform(shape=(1000, 4096)).astype(mx.float16)
66
+
67
+ def gen_fun(fun):
68
+ def bench_fun(x):
69
+ for _ in range(10):
70
+ x = fun(x)
71
+ return x
72
+
73
+ return bench_fun
74
+
75
+ time_fn(gen_fun(layernorm), x, msg="fixed layernorm")
76
+ time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled fixed layernorm")
77
+
78
+ def randint():
79
+ return random.randint(1, x.shape[0])
80
+
81
+ def gen_fun(fun):
82
+ def bench_fun(x):
83
+ x = x[: randint()]
84
+ for _ in range(10):
85
+ x = fun(x)
86
+ return x
87
+
88
+ return bench_fun
89
+
90
+ random.seed(0)
91
+ time_fn(gen_fun(layernorm), x, msg="variable layernorm")
92
+ random.seed(0)
93
+ time_fn(gen_fun(mx.compile(layernorm)), x, msg="compiled variable layernorm")
94
+ random.seed(0)
95
+ time_fn(
96
+ gen_fun(mx.compile(layernorm, shapeless=True)),
97
+ x,
98
+ msg="shapeless variable layernorm",
99
+ )
100
+
101
+
102
+ if __name__ == "__main__":
103
+ parser = argparse.ArgumentParser("Compile benchmarks.")
104
+ args = parser.parse_args()
105
+
106
+ bench_gelu()
107
+ bench_layernorm()
ml-stable-diffusion/mlx/benchmarks/python/conv1d_bench.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import subprocess
5
+ import time
6
+
7
+ import mlx.core as mx
8
+ import numpy as np
9
+ import torch
10
+
11
+ device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
12
+ device_name = device_name.decode("utf-8").strip("\n")
13
+
14
+ N_warmup = 10
15
+ N_iter_bench = 100
16
+ N_iter_func = 5
17
+
18
+
19
+ def bench(f, a, b):
20
+ for i in range(N_warmup):
21
+ f(a, b)
22
+ torch.mps.synchronize()
23
+
24
+ s = time.perf_counter_ns()
25
+ for i in range(N_iter_bench):
26
+ f(a, b)
27
+ e = time.perf_counter_ns()
28
+ return (e - s) * 1e-9
29
+
30
+
31
+ def make_mx_conv_1D(strides=1, padding=0, groups=1):
32
+ def mx_conv_1D(a, b):
33
+ ys = []
34
+ for _ in range(N_iter_func):
35
+ y = mx.conv1d(a, b, stride=strides, padding=padding, groups=groups)
36
+ ys.append(y)
37
+ mx.eval(ys)
38
+ return ys
39
+
40
+ return mx_conv_1D
41
+
42
+
43
+ def make_pt_conv_1D(strides=1, padding=0, groups=1):
44
+ @torch.no_grad()
45
+ def pt_conv_1D(a, b):
46
+ ys = []
47
+ for _ in range(N_iter_func):
48
+ y = torch.conv1d(a, b, stride=strides, padding=padding, groups=groups)
49
+ ys.append(y)
50
+ torch.mps.synchronize()
51
+ return ys
52
+
53
+ return pt_conv_1D
54
+
55
+
56
+ def bench_shape(N, iH, C, wH, O, strides, padding, np_dtype, groups):
57
+ scale = 1.0 / math.sqrt(wH * C)
58
+ a_np = np.random.uniform(0, 0.5, (N, iH, C)).astype(np_dtype)
59
+ b_np = np.random.uniform(-scale, scale, (O, wH, int(C / groups))).astype(np_dtype)
60
+
61
+ a_mx = mx.array(a_np)
62
+ b_mx = mx.array(b_np)
63
+
64
+ a_pt = torch.from_numpy(a_np.transpose((0, 2, 1))).to("mps")
65
+ b_pt = torch.from_numpy(b_np.transpose((0, 2, 1))).to("mps")
66
+
67
+ torch.mps.synchronize()
68
+
69
+ f_mx = make_mx_conv_1D(strides, padding, groups)
70
+ f_pt = make_pt_conv_1D(strides, padding, groups)
71
+
72
+ time_torch = bench(f_pt, a_pt, b_pt)
73
+ time_mlx = bench(f_mx, a_mx, b_mx)
74
+
75
+ out_mx = mx.conv1d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
76
+ out_pt = torch.conv1d(
77
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
78
+ )
79
+ out_pt = torch.permute(out_pt, (0, 2, 1))
80
+ out_pt = out_pt.numpy(force=True)
81
+
82
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
83
+
84
+ if not np.allclose(out_pt, out_mx, atol=atol):
85
+ print(
86
+ f"Failed at {(N, iH, C)}, {(O, wH, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
87
+ )
88
+
89
+ return time_mlx, time_torch
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser(description="Run conv benchmarks")
94
+
95
+ dtypes = ("float32",)
96
+ shapes = (
97
+ (4, 32, 32, 5, 32, 1, 2, 1),
98
+ (4, 32, 32, 5, 32, 1, 2, 2),
99
+ (4, 32, 32, 5, 32, 1, 2, 4),
100
+ (4, 32, 32, 5, 32, 1, 2, 8),
101
+ (4, 32, 32, 5, 32, 1, 2, 8),
102
+ (4, 32, 32, 5, 32, 1, 2, 16),
103
+ (4, 32, 32, 5, 32, 1, 2, 32),
104
+ (4, 32, 256, 5, 512, 1, 2, 2),
105
+ (4, 32, 256, 5, 512, 1, 2, 128),
106
+ (4, 32, 256, 5, 512, 1, 2, 256),
107
+ )
108
+
109
+ for dtype in dtypes:
110
+ print("(N, iH, C), (O, wH, C), dtype, stride, pads, groups, diff%")
111
+ for N, iH, C, wH, O, strides, padding, groups in shapes:
112
+ np_dtype = getattr(np, dtype)
113
+ time_mlx, time_torch = bench_shape(
114
+ N, iH, C, wH, O, strides, padding, np_dtype, groups
115
+ )
116
+ diff = time_torch / time_mlx - 1.0
117
+
118
+ print(
119
+ f"({N}, {iH:3d}, {C:3d}), ({O:3d}, {wH:2d}, {C:3d}), {dtype}, {strides:5d}, {padding:4d}, {groups:6d}, {100. * diff:+5.2f}%"
120
+ )
121
+
122
+ if time_mlx >= 2.0 * time_torch:
123
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/conv2d_bench_cpu.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import time
4
+
5
+ import mlx.core as mx
6
+ import numpy as np
7
+ import torch
8
+
9
+ N_warmup = 1
10
+ N_iter_bench = 10
11
+ N_iter_func = 5
12
+ mx.set_default_device(mx.cpu)
13
+
14
+
15
+ def bench(f, a, b):
16
+ for i in range(N_warmup):
17
+ f(a, b)
18
+
19
+ s = time.perf_counter_ns()
20
+ for i in range(N_iter_bench):
21
+ f(a, b)
22
+ e = time.perf_counter_ns()
23
+ return (e - s) * 1e-9
24
+
25
+
26
+ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
27
+ def mx_conv_2D(a, b):
28
+ ys = []
29
+ for i in range(N_iter_func):
30
+ y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
31
+ ys.append(y)
32
+ mx.eval(ys)
33
+ return ys
34
+
35
+ return mx_conv_2D
36
+
37
+
38
+ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
39
+ @torch.no_grad()
40
+ def pt_conv_2D(a, b):
41
+ ys = []
42
+ for i in range(N_iter_func):
43
+ y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
44
+ ys.append(y)
45
+ return ys
46
+
47
+ return pt_conv_2D
48
+
49
+
50
+ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
51
+ scale = 1.0 / math.sqrt(kH * kH * C)
52
+ a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
53
+ b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
54
+ np_dtype
55
+ )
56
+
57
+ a_mx = mx.array(a_np)
58
+ b_mx = mx.array(b_np)
59
+
60
+ a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
61
+ b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("cpu")
62
+
63
+ f_mx = make_mx_conv_2D(strides, padding, groups)
64
+ f_pt = make_pt_conv_2D(strides, padding, groups)
65
+
66
+ time_torch = bench(f_pt, a_pt, b_pt)
67
+ time_mlx = bench(f_mx, a_mx, b_mx)
68
+
69
+ out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
70
+ out_pt = torch.conv2d(
71
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
72
+ )
73
+ out_pt = torch.permute(out_pt, (0, 2, 3, 1))
74
+ out_pt = out_pt.numpy(force=True)
75
+
76
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
77
+
78
+ if not np.allclose(out_pt, out_mx, atol=atol):
79
+ print(
80
+ f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
81
+ )
82
+
83
+ return time_mlx, time_torch
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser(description="Run conv benchmarks")
88
+
89
+ dtypes = ("float32",)
90
+ shapes = (
91
+ (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
92
+ (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
93
+ (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
94
+ (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
95
+ (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
96
+ (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
97
+ (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
98
+ (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
99
+ (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
100
+ # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
101
+ # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
102
+ # (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
103
+ (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
104
+ (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
105
+ (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
106
+ (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
107
+ (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
108
+ (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
109
+ (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
110
+ )
111
+
112
+ for dtype in dtypes:
113
+ print(
114
+ "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
115
+ )
116
+ for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
117
+ np_dtype = getattr(np, dtype)
118
+ time_mlx, time_torch = bench_shape(
119
+ N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
120
+ )
121
+ diff = time_torch / time_mlx - 1.0
122
+
123
+ print(
124
+ f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
125
+ )
126
+ if time_mlx >= 2.0 * time_torch:
127
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/conv2d_train_bench_cpu.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn
5
+ import mlx.optimizers as opt
6
+ import torch
7
+
8
+
9
+ def bench_mlx(steps: int = 20) -> float:
10
+ mx.set_default_device(mx.cpu)
11
+
12
+ class BenchNetMLX(mlx.nn.Module):
13
+ # simple encoder-decoder net
14
+
15
+ def __init__(self, in_channels, hidden_channels=32):
16
+ super().__init__()
17
+
18
+ self.net = mlx.nn.Sequential(
19
+ mlx.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
20
+ mlx.nn.ReLU(),
21
+ mlx.nn.Conv2d(
22
+ hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
23
+ ),
24
+ mlx.nn.ReLU(),
25
+ mlx.nn.ConvTranspose2d(
26
+ 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
27
+ ),
28
+ mlx.nn.ReLU(),
29
+ mlx.nn.ConvTranspose2d(
30
+ hidden_channels, in_channels, kernel_size=3, padding=1
31
+ ),
32
+ )
33
+
34
+ def __call__(self, input):
35
+ return self.net(input)
36
+
37
+ benchNet = BenchNetMLX(3)
38
+ mx.eval(benchNet.parameters())
39
+ optim = opt.Adam(learning_rate=1e-3)
40
+
41
+ inputs = mx.random.normal([10, 256, 256, 3])
42
+
43
+ params = benchNet.parameters()
44
+ optim.init(params)
45
+
46
+ state = [benchNet.state, optim.state]
47
+
48
+ def loss_fn(params, image):
49
+ benchNet.update(params)
50
+ pred_image = benchNet(image)
51
+ return (pred_image - image).abs().mean()
52
+
53
+ def step(params, image):
54
+ loss, grads = mx.value_and_grad(loss_fn)(params, image)
55
+ optim.update(benchNet, grads)
56
+ return loss
57
+
58
+ total_time = 0.0
59
+ print("MLX:")
60
+ for i in range(steps):
61
+ start_time = time.perf_counter()
62
+
63
+ step(benchNet.parameters(), inputs)
64
+ mx.eval(state)
65
+ end_time = time.perf_counter()
66
+
67
+ print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
68
+ total_time += (end_time - start_time) * 1000
69
+
70
+ return total_time
71
+
72
+
73
+ def bench_torch(steps: int = 20) -> float:
74
+ device = torch.device("cpu")
75
+
76
+ class BenchNetTorch(torch.nn.Module):
77
+ # simple encoder-decoder net
78
+
79
+ def __init__(self, in_channels, hidden_channels=32):
80
+ super().__init__()
81
+
82
+ self.net = torch.nn.Sequential(
83
+ torch.nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1),
84
+ torch.nn.ReLU(),
85
+ torch.nn.Conv2d(
86
+ hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
87
+ ),
88
+ torch.nn.ReLU(),
89
+ torch.nn.ConvTranspose2d(
90
+ 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
91
+ ),
92
+ torch.nn.ReLU(),
93
+ torch.nn.ConvTranspose2d(
94
+ hidden_channels, in_channels, kernel_size=3, padding=1
95
+ ),
96
+ )
97
+
98
+ def forward(self, input):
99
+ return self.net(input)
100
+
101
+ benchNet = BenchNetTorch(3).to(device)
102
+ optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
103
+
104
+ inputs = torch.randn(10, 3, 256, 256, device=device)
105
+
106
+ def loss_fn(pred_image, image):
107
+ return (pred_image - image).abs().mean()
108
+
109
+ total_time = 0.0
110
+ print("PyTorch:")
111
+ for i in range(steps):
112
+ start_time = time.perf_counter()
113
+
114
+ optim.zero_grad()
115
+ pred_image = benchNet(inputs)
116
+ loss = loss_fn(pred_image, inputs)
117
+ loss.backward()
118
+ optim.step()
119
+
120
+ end_time = time.perf_counter()
121
+
122
+ print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
123
+ total_time += (end_time - start_time) * 1000
124
+
125
+ return total_time
126
+
127
+
128
+ def main():
129
+ steps = 20
130
+ time_mlx = bench_mlx(steps)
131
+ time_torch = bench_torch(steps)
132
+
133
+ print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
134
+ print(f"total time of MLX: {time_mlx:9.2f} ms")
135
+ print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
136
+ print(f"total time of PyTorch: {time_torch:9.2f} ms")
137
+
138
+ diff = time_torch / time_mlx - 1.0
139
+ print(f"torch/mlx diff: {100. * diff:+5.2f}%")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()
ml-stable-diffusion/mlx/benchmarks/python/conv2d_transpose_bench_cpu.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import time
4
+
5
+ import mlx.core as mx
6
+ import numpy as np
7
+ import torch
8
+
9
+ N_warmup = 1
10
+ N_iter_bench = 10
11
+ N_iter_func = 5
12
+
13
+
14
+ def bench(f, a, b):
15
+ for i in range(N_warmup):
16
+ f(a, b)
17
+
18
+ s = time.perf_counter_ns()
19
+ for i in range(N_iter_bench):
20
+ f(a, b)
21
+ e = time.perf_counter_ns()
22
+ return (e - s) * 1e-9
23
+
24
+
25
+ def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
26
+ def mx_conv_transpose_2D(a, b):
27
+ ys = []
28
+ for i in range(N_iter_func):
29
+ y = mx.conv_transpose2d(
30
+ a, b, stride=strides, padding=padding, groups=groups, stream=mx.cpu
31
+ )
32
+ ys.append(y)
33
+ mx.eval(ys)
34
+ return ys
35
+
36
+ return mx_conv_transpose_2D
37
+
38
+
39
+ def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
40
+ @torch.no_grad()
41
+ def pt_conv_transpose_2D(a, b):
42
+ ys = []
43
+ for i in range(N_iter_func):
44
+ y = torch.conv_transpose2d(
45
+ a, b, stride=strides, padding=padding, groups=groups
46
+ )
47
+ ys.append(y)
48
+ return ys
49
+
50
+ return pt_conv_transpose_2D
51
+
52
+
53
+ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
54
+ scale = 1.0 / math.sqrt(kH * kH * C)
55
+ a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
56
+ b_np = np.random.uniform(-scale, scale, (int(O / groups), kH, kW, C)).astype(
57
+ np_dtype
58
+ )
59
+
60
+ a_mx = mx.array(a_np)
61
+ b_mx = mx.array(b_np)
62
+
63
+ a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("cpu")
64
+ b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("cpu")
65
+
66
+ f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
67
+ f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
68
+
69
+ time_torch = bench(f_pt, a_pt, b_pt)
70
+ time_mlx = bench(f_mx, a_mx, b_mx)
71
+
72
+ out_mx = mx.conv_transpose2d(
73
+ a_mx, b_mx, stride=strides, padding=padding, groups=groups, stream=mx.cpu
74
+ )
75
+ out_pt = torch.conv_transpose2d(
76
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
77
+ )
78
+ out_pt = torch.permute(out_pt, (0, 2, 3, 1))
79
+ out_pt = out_pt.numpy(force=True)
80
+
81
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
82
+
83
+ if not np.allclose(out_pt, out_mx, atol=atol):
84
+ print(
85
+ f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
86
+ )
87
+
88
+ return time_mlx, time_torch
89
+
90
+
91
+ if __name__ == "__main__":
92
+ parser = argparse.ArgumentParser(description="Run conv benchmarks")
93
+
94
+ dtypes = ("float32",)
95
+ shapes = (
96
+ (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
97
+ (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
98
+ (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
99
+ (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
100
+ (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
101
+ (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
102
+ (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
103
+ (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
104
+ (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
105
+ (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
106
+ (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
107
+ (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
108
+ (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
109
+ (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
110
+ (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
111
+ (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
112
+ )
113
+
114
+ for dtype in dtypes:
115
+ print(
116
+ "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
117
+ )
118
+ for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
119
+ np_dtype = getattr(np, dtype)
120
+ time_mlx, time_torch = bench_shape(
121
+ N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
122
+ )
123
+ diff = time_torch / time_mlx - 1.0
124
+
125
+ print(
126
+ f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
127
+ )
128
+ if time_mlx >= 2.0 * time_torch:
129
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/conv3d_bench_cpu.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import time
4
+
5
+ import mlx.core as mx
6
+ import numpy as np
7
+ import torch
8
+
9
+ N_warmup = 1
10
+ N_iter_bench = 10
11
+ N_iter_func = 5
12
+ mx.set_default_device(mx.cpu)
13
+
14
+
15
+ def bench(f, a, b):
16
+ for i in range(N_warmup):
17
+ f(a, b)
18
+
19
+ s = time.perf_counter_ns()
20
+ for i in range(N_iter_bench):
21
+ f(a, b)
22
+ e = time.perf_counter_ns()
23
+ return (e - s) * 1e-9
24
+
25
+
26
+ def make_mx_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
27
+ def mx_conv_3D(a, b):
28
+ ys = []
29
+ for i in range(N_iter_func):
30
+ y = mx.conv3d(a, b, stride=strides, padding=padding, groups=groups)
31
+ ys.append(y)
32
+ mx.eval(ys)
33
+ return ys
34
+
35
+ return mx_conv_3D
36
+
37
+
38
+ def make_pt_conv_3D(strides=(1, 1), padding=(0, 0), groups=1):
39
+ @torch.no_grad()
40
+ def pt_conv_3D(a, b):
41
+ ys = []
42
+ for i in range(N_iter_func):
43
+ y = torch.conv3d(a, b, stride=strides, padding=padding, groups=groups)
44
+ ys.append(y)
45
+ return ys
46
+
47
+ return pt_conv_3D
48
+
49
+
50
+ def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
51
+ scale = 1.0 / math.sqrt(kD * kH * kW * C)
52
+ a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
53
+ b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
54
+ np_dtype
55
+ )
56
+
57
+ a_mx = mx.array(a_np)
58
+ b_mx = mx.array(b_np)
59
+
60
+ a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
61
+ b_pt = torch.from_numpy(b_np.transpose((0, 4, 1, 2, 3))).to("cpu")
62
+
63
+ f_mx = make_mx_conv_3D(strides, padding, groups)
64
+ f_pt = make_pt_conv_3D(strides, padding, groups)
65
+
66
+ time_torch = bench(f_pt, a_pt, b_pt)
67
+ time_mlx = bench(f_mx, a_mx, b_mx)
68
+
69
+ out_mx = mx.conv3d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
70
+ out_pt = torch.conv3d(
71
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
72
+ )
73
+ out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
74
+ out_pt = out_pt.numpy(force=True)
75
+
76
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
77
+
78
+ if not np.allclose(out_pt, out_mx, atol=atol):
79
+ print(
80
+ f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
81
+ )
82
+
83
+ return time_mlx, time_torch
84
+
85
+
86
+ if __name__ == "__main__":
87
+ parser = argparse.ArgumentParser(description="Run conv benchmarks")
88
+
89
+ dtypes = ("float32",)
90
+ shapes = (
91
+ (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
92
+ (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
93
+ )
94
+
95
+ for dtype in dtypes:
96
+ print(
97
+ "(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
98
+ )
99
+ for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
100
+ np_dtype = getattr(np, dtype)
101
+ time_mlx, time_torch = bench_shape(
102
+ N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
103
+ )
104
+ diff = time_torch / time_mlx - 1.0
105
+
106
+ print(
107
+ f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
108
+ )
109
+ if time_mlx >= 2.0 * time_torch:
110
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/conv3d_train_bench_cpu.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import mlx.core as mx
4
+ import mlx.nn
5
+ import mlx.optimizers as opt
6
+ import torch
7
+
8
+
9
+ def bench_mlx(steps: int = 20, shape=(10, 32, 32, 32, 3)) -> float:
10
+ mx.set_default_device(mx.cpu)
11
+
12
+ class BenchNetMLX(mlx.nn.Module):
13
+ # simple encoder-decoder net
14
+
15
+ def __init__(self, in_channels, hidden_channels=16):
16
+ super().__init__()
17
+
18
+ self.net = mlx.nn.Sequential(
19
+ mlx.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
20
+ mlx.nn.ReLU(),
21
+ mlx.nn.Conv3d(
22
+ hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
23
+ ),
24
+ mlx.nn.ReLU(),
25
+ mlx.nn.ConvTranspose3d(
26
+ 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
27
+ ),
28
+ mlx.nn.ReLU(),
29
+ mlx.nn.ConvTranspose3d(
30
+ hidden_channels, in_channels, kernel_size=3, padding=1
31
+ ),
32
+ )
33
+
34
+ def __call__(self, input):
35
+ return self.net(input)
36
+
37
+ benchNet = BenchNetMLX(3)
38
+ mx.eval(benchNet.parameters())
39
+ optim = opt.Adam(learning_rate=1e-3)
40
+
41
+ inputs = mx.random.normal(shape)
42
+
43
+ params = benchNet.parameters()
44
+ optim.init(params)
45
+
46
+ state = [benchNet.state, optim.state]
47
+
48
+ def loss_fn(params, image):
49
+ benchNet.update(params)
50
+ pred_image = benchNet(image)
51
+ return (pred_image - image).abs().mean()
52
+
53
+ def step(params, image):
54
+ loss, grads = mx.value_and_grad(loss_fn)(params, image)
55
+ optim.update(benchNet, grads)
56
+ return loss
57
+
58
+ total_time = 0.0
59
+ print("MLX:")
60
+ for i in range(steps):
61
+ start_time = time.perf_counter()
62
+
63
+ step(benchNet.parameters(), inputs)
64
+ mx.eval(state)
65
+ end_time = time.perf_counter()
66
+
67
+ print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
68
+ total_time += (end_time - start_time) * 1000
69
+
70
+ return total_time
71
+
72
+
73
+ def bench_torch(steps: int = 20, shape=(10, 3, 32, 32, 32)) -> float:
74
+ device = torch.device("cpu")
75
+
76
+ class BenchNetTorch(torch.nn.Module):
77
+ # simple encoder-decoder net
78
+
79
+ def __init__(self, in_channels, hidden_channels=16):
80
+ super().__init__()
81
+
82
+ self.net = torch.nn.Sequential(
83
+ torch.nn.Conv3d(in_channels, hidden_channels, kernel_size=3, padding=1),
84
+ torch.nn.ReLU(),
85
+ torch.nn.Conv3d(
86
+ hidden_channels, 2 * hidden_channels, kernel_size=3, padding=1
87
+ ),
88
+ torch.nn.ReLU(),
89
+ torch.nn.ConvTranspose3d(
90
+ 2 * hidden_channels, hidden_channels, kernel_size=3, padding=1
91
+ ),
92
+ torch.nn.ReLU(),
93
+ torch.nn.ConvTranspose3d(
94
+ hidden_channels, in_channels, kernel_size=3, padding=1
95
+ ),
96
+ )
97
+
98
+ def forward(self, input):
99
+ return self.net(input)
100
+
101
+ benchNet = BenchNetTorch(3).to(device)
102
+ optim = torch.optim.Adam(benchNet.parameters(), lr=1e-3)
103
+
104
+ inputs = torch.randn(*shape, device=device)
105
+
106
+ def loss_fn(pred_image, image):
107
+ return (pred_image - image).abs().mean()
108
+
109
+ total_time = 0.0
110
+ print("PyTorch:")
111
+ for i in range(steps):
112
+ start_time = time.perf_counter()
113
+
114
+ optim.zero_grad()
115
+ pred_image = benchNet(inputs)
116
+ loss = loss_fn(pred_image, inputs)
117
+ loss.backward()
118
+ optim.step()
119
+
120
+ end_time = time.perf_counter()
121
+
122
+ print(f"{i:3d}, time={(end_time-start_time) * 1000:7.2f} ms")
123
+ total_time += (end_time - start_time) * 1000
124
+
125
+ return total_time
126
+
127
+
128
+ def main():
129
+ steps = 10
130
+ time_mlx = bench_mlx(steps)
131
+ time_torch = bench_torch(steps)
132
+
133
+ print(f"average time of MLX: {time_mlx/steps:9.2f} ms")
134
+ print(f"total time of MLX: {time_mlx:9.2f} ms")
135
+ print(f"average time of PyTorch: {time_torch/steps:9.2f} ms")
136
+ print(f"total time of PyTorch: {time_torch:9.2f} ms")
137
+
138
+ diff = time_torch / time_mlx - 1.0
139
+ print(f"torch/mlx diff: {100. * diff:+5.2f}%")
140
+
141
+
142
+ if __name__ == "__main__":
143
+ main()
ml-stable-diffusion/mlx/benchmarks/python/conv3d_transpose_bench_cpu.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import time
4
+
5
+ import mlx.core as mx
6
+ import numpy as np
7
+ import torch
8
+
9
+ N_warmup = 1
10
+ N_iter_bench = 10
11
+ N_iter_func = 5
12
+ mx.set_default_device(mx.cpu)
13
+
14
+
15
+ def bench(f, a, b):
16
+ for i in range(N_warmup):
17
+ f(a, b)
18
+
19
+ s = time.perf_counter_ns()
20
+ for i in range(N_iter_bench):
21
+ f(a, b)
22
+ e = time.perf_counter_ns()
23
+ return (e - s) * 1e-9
24
+
25
+
26
+ def make_mx_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
27
+ def mx_conv_3D(a, b):
28
+ ys = []
29
+ for i in range(N_iter_func):
30
+ y = mx.conv_transpose3d(
31
+ a, b, stride=strides, padding=padding, groups=groups
32
+ )
33
+ ys.append(y)
34
+ mx.eval(ys)
35
+ return ys
36
+
37
+ return mx_conv_3D
38
+
39
+
40
+ def make_pt_conv_3D(strides=(1, 1, 1), padding=(0, 0, 0), groups=1):
41
+ @torch.no_grad()
42
+ def pt_conv_3D(a, b):
43
+ ys = []
44
+ for i in range(N_iter_func):
45
+ y = torch.conv_transpose3d(
46
+ a, b, stride=strides, padding=padding, groups=groups
47
+ )
48
+ ys.append(y)
49
+ return ys
50
+
51
+ return pt_conv_3D
52
+
53
+
54
+ def bench_shape(N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype):
55
+ scale = 1.0 / math.sqrt(kD * kH * kW * C)
56
+ a_np = np.random.uniform(0, 0.5, (N, D, H, W, C)).astype(np_dtype)
57
+ b_np = np.random.uniform(-scale, scale, (O, kD, kH, kW, int(C / groups))).astype(
58
+ np_dtype
59
+ )
60
+
61
+ a_mx = mx.array(a_np)
62
+ b_mx = mx.array(b_np)
63
+
64
+ a_pt = torch.from_numpy(a_np.transpose((0, 4, 1, 2, 3))).to("cpu")
65
+ b_pt = torch.from_numpy(b_np.transpose((4, 0, 1, 2, 3))).to("cpu")
66
+
67
+ f_mx = make_mx_conv_3D(strides, padding, groups)
68
+ f_pt = make_pt_conv_3D(strides, padding, groups)
69
+
70
+ time_torch = bench(f_pt, a_pt, b_pt)
71
+ time_mlx = bench(f_mx, a_mx, b_mx)
72
+
73
+ out_mx = mx.conv_transpose3d(
74
+ a_mx, b_mx, stride=strides, padding=padding, groups=groups
75
+ )
76
+ out_pt = torch.conv_transpose3d(
77
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
78
+ )
79
+ out_pt = torch.permute(out_pt, (0, 2, 3, 4, 1))
80
+ out_pt = out_pt.numpy(force=True)
81
+
82
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
83
+
84
+ if not np.allclose(out_pt, out_mx, atol=atol):
85
+ print(
86
+ f"Failed at {(N, D, H, W, C)}, {(O, kD, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
87
+ )
88
+
89
+ return time_mlx, time_torch
90
+
91
+
92
+ if __name__ == "__main__":
93
+ parser = argparse.ArgumentParser(description="Run conv benchmarks")
94
+
95
+ dtypes = ("float32",)
96
+ shapes = (
97
+ (4, 16, 16, 16, 16, 5, 5, 5, 16, (1, 1, 1), (2, 2, 2), 1),
98
+ (4, 16, 16, 16, 32, 5, 5, 5, 32, (1, 1, 1), (2, 2, 2), 1),
99
+ )
100
+
101
+ for dtype in dtypes:
102
+ print(
103
+ "(N, D, H, W, C), ( O, kD, kH, kW, C), dtype, stride, pads, groups, diff%"
104
+ )
105
+ for N, D, H, W, C, kD, kH, kW, O, strides, padding, groups in shapes:
106
+ np_dtype = getattr(np, dtype)
107
+ time_mlx, time_torch = bench_shape(
108
+ N, D, H, W, C, kD, kH, kW, O, strides, padding, groups, np_dtype
109
+ )
110
+ diff = time_torch / time_mlx - 1.0
111
+
112
+ print(
113
+ f"({N}, {D:3d}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kD:2d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
114
+ )
115
+ if time_mlx >= 2.0 * time_torch:
116
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/conv_bench.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import subprocess
5
+ import time
6
+
7
+ import mlx.core as mx
8
+ import numpy as np
9
+ import torch
10
+
11
+ device_name = subprocess.check_output(["sysctl", "-n", "machdep.cpu.brand_string"])
12
+ device_name = device_name.decode("utf-8").strip("\n")
13
+
14
+ N_warmup = 10
15
+ N_iter_bench = 100
16
+ N_iter_func = 5
17
+
18
+
19
+ def bench(f, a, b):
20
+ for i in range(N_warmup):
21
+ f(a, b)
22
+ torch.mps.synchronize()
23
+
24
+ s = time.perf_counter_ns()
25
+ for i in range(N_iter_bench):
26
+ f(a, b)
27
+ e = time.perf_counter_ns()
28
+ return (e - s) * 1e-9
29
+
30
+
31
+ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
32
+ def mx_conv_2D(a, b):
33
+ ys = []
34
+ for i in range(N_iter_func):
35
+ y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
36
+ ys.append(y)
37
+ mx.eval(ys)
38
+ return ys
39
+
40
+ return mx_conv_2D
41
+
42
+
43
+ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
44
+ @torch.no_grad()
45
+ def pt_conv_2D(a, b):
46
+ ys = []
47
+ for i in range(N_iter_func):
48
+ y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
49
+ ys.append(y)
50
+ torch.mps.synchronize()
51
+ return ys
52
+
53
+ return pt_conv_2D
54
+
55
+
56
+ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
57
+ scale = 1.0 / math.sqrt(kH * kH * C)
58
+ a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
59
+ b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
60
+ np_dtype
61
+ )
62
+
63
+ a_mx = mx.array(a_np)
64
+ b_mx = mx.array(b_np)
65
+
66
+ a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
67
+ b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
68
+
69
+ torch.mps.synchronize()
70
+
71
+ f_mx = make_mx_conv_2D(strides, padding, groups)
72
+ f_pt = make_pt_conv_2D(strides, padding, groups)
73
+
74
+ time_torch = bench(f_pt, a_pt, b_pt)
75
+ time_mlx = bench(f_mx, a_mx, b_mx)
76
+
77
+ out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
78
+ out_pt = torch.conv2d(
79
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
80
+ )
81
+ out_pt = torch.permute(out_pt, (0, 2, 3, 1))
82
+ out_pt = out_pt.numpy(force=True)
83
+
84
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
85
+
86
+ if not np.allclose(out_pt, out_mx, atol=atol):
87
+ print(
88
+ f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
89
+ )
90
+
91
+ return time_mlx, time_torch
92
+
93
+
94
+ if __name__ == "__main__":
95
+ parser = argparse.ArgumentParser(description="Run conv benchmarks")
96
+
97
+ dtypes = ("float32",)
98
+ shapes = (
99
+ (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
100
+ (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
101
+ (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
102
+ (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
103
+ (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
104
+ (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
105
+ (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
106
+ (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
107
+ (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
108
+ (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 2),
109
+ (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 16),
110
+ (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 64),
111
+ (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
112
+ (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
113
+ (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
114
+ (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
115
+ (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
116
+ (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
117
+ (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
118
+ )
119
+
120
+ for dtype in dtypes:
121
+ print(
122
+ "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
123
+ )
124
+ for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
125
+ np_dtype = getattr(np, dtype)
126
+ time_mlx, time_torch = bench_shape(
127
+ N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
128
+ )
129
+ diff = time_torch / time_mlx - 1.0
130
+
131
+ print(
132
+ f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
133
+ )
134
+ if time_mlx >= 2.0 * time_torch:
135
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/conv_transpose_bench.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import math
3
+ import os
4
+ import subprocess
5
+ import time
6
+
7
+ import mlx.core as mx
8
+ import numpy as np
9
+ import torch
10
+
11
+ N_warmup = 10
12
+ N_iter_bench = 100
13
+ N_iter_func = 5
14
+
15
+
16
+ def bench(f, a, b):
17
+ for i in range(N_warmup):
18
+ f(a, b)
19
+ torch.mps.synchronize()
20
+
21
+ s = time.perf_counter_ns()
22
+ for i in range(N_iter_bench):
23
+ f(a, b)
24
+ e = time.perf_counter_ns()
25
+ return (e - s) * 1e-9
26
+
27
+
28
+ def make_mx_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
29
+ def mx_conv_transpose_2D(a, b):
30
+ ys = []
31
+ for i in range(N_iter_func):
32
+ y = mx.conv_transpose2d(
33
+ a, b, stride=strides, padding=padding, groups=groups
34
+ )
35
+ ys.append(y)
36
+ mx.eval(ys)
37
+ return ys
38
+
39
+ return mx_conv_transpose_2D
40
+
41
+
42
+ def make_pt_conv_transpose_2D(strides=(1, 1), padding=(0, 0), groups=1):
43
+ @torch.no_grad()
44
+ def pt_conv_transpose_2D(a, b):
45
+ ys = []
46
+ for i in range(N_iter_func):
47
+ y = torch.conv_transpose2d(
48
+ a, b, stride=strides, padding=padding, groups=groups
49
+ )
50
+ ys.append(y)
51
+ torch.mps.synchronize()
52
+ return ys
53
+
54
+ return pt_conv_transpose_2D
55
+
56
+
57
+ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
58
+ scale = 1.0 / math.sqrt(kH * kH * C)
59
+ a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
60
+ b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
61
+ np_dtype
62
+ )
63
+
64
+ a_mx = mx.array(a_np)
65
+ b_mx = mx.array(b_np)
66
+
67
+ a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
68
+ b_pt = torch.from_numpy(b_np.transpose((3, 0, 1, 2))).to("mps")
69
+
70
+ torch.mps.synchronize()
71
+
72
+ f_mx = make_mx_conv_transpose_2D(strides, padding, groups)
73
+ f_pt = make_pt_conv_transpose_2D(strides, padding, groups)
74
+
75
+ time_torch = bench(f_pt, a_pt, b_pt)
76
+ time_mlx = bench(f_mx, a_mx, b_mx)
77
+
78
+ out_mx = mx.conv_transpose2d(
79
+ a_mx, b_mx, stride=strides, padding=padding, groups=groups
80
+ )
81
+ out_pt = torch.conv_transpose2d(
82
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
83
+ )
84
+ out_pt = torch.permute(out_pt, (0, 2, 3, 1))
85
+ out_pt = out_pt.numpy(force=True)
86
+
87
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
88
+
89
+ if not np.allclose(out_pt, out_mx, atol=atol):
90
+ print(
91
+ f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
92
+ )
93
+
94
+ return time_mlx, time_torch
95
+
96
+
97
+ if __name__ == "__main__":
98
+ parser = argparse.ArgumentParser(description="Run conv benchmarks")
99
+
100
+ dtypes = ("float32",)
101
+ shapes = (
102
+ (4, 32, 32, 32, 5, 5, 32, (1, 1), (2, 2), 1),
103
+ (4, 32, 32, 64, 5, 5, 64, (1, 1), (2, 2), 1),
104
+ (4, 32, 32, 128, 5, 5, 128, (1, 1), (2, 2), 1),
105
+ (4, 32, 32, 256, 5, 5, 256, (1, 1), (2, 2), 1),
106
+ (4, 32, 32, 512, 5, 5, 512, (1, 1), (2, 2), 1),
107
+ (4, 64, 64, 32, 5, 5, 32, (1, 1), (2, 2), 1),
108
+ (4, 64, 64, 64, 5, 5, 64, (1, 1), (2, 2), 1),
109
+ (4, 64, 64, 128, 5, 5, 128, (1, 1), (2, 2), 1),
110
+ (4, 64, 64, 256, 5, 5, 256, (1, 1), (2, 2), 1),
111
+ (4, 128, 128, 32, 5, 5, 32, (1, 1), (2, 2), 1),
112
+ (4, 128, 128, 64, 5, 5, 64, (1, 1), (2, 2), 1),
113
+ (4, 128, 128, 128, 5, 5, 128, (1, 1), (2, 2), 1),
114
+ (4, 256, 256, 32, 5, 5, 3, (1, 1), (2, 2), 1),
115
+ (4, 256, 256, 3, 5, 5, 32, (1, 1), (2, 2), 1),
116
+ (4, 128, 128, 64, 5, 5, 3, (1, 1), (2, 2), 1),
117
+ (4, 128, 128, 3, 5, 5, 64, (1, 1), (2, 2), 1),
118
+ )
119
+
120
+ for dtype in dtypes:
121
+ print(
122
+ "(N, H, W, C), ( O, kH, kW, C), dtype, stride, pads, groups, diff%"
123
+ )
124
+ for N, H, W, C, kH, kW, O, strides, padding, groups in shapes:
125
+ np_dtype = getattr(np, dtype)
126
+ time_mlx, time_torch = bench_shape(
127
+ N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype
128
+ )
129
+ diff = time_torch / time_mlx - 1.0
130
+
131
+ print(
132
+ f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kH:2d}, {kW:2d}, {C:3d}), {dtype}, {strides}, {padding}, {groups:7d}, {100. * diff:+5.2f}%"
133
+ )
134
+ if time_mlx >= 2.0 * time_torch:
135
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/conv_unaligned_bench.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import time
3
+
4
+ import mlx.core as mx
5
+ import numpy as np
6
+ import torch
7
+
8
+ N_warmup = 10
9
+ N_iter_bench = 100
10
+ N_iter_func = 5
11
+
12
+
13
+ def bench(f, a, b):
14
+ for i in range(N_warmup):
15
+ f(a, b)
16
+ torch.mps.synchronize()
17
+
18
+ s = time.perf_counter_ns()
19
+ for i in range(N_iter_bench):
20
+ f(a, b)
21
+ e = time.perf_counter_ns()
22
+ return (e - s) * 1e-9
23
+
24
+
25
+ def make_mx_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
26
+ def mx_conv_2D(a, b):
27
+ ys = []
28
+ for i in range(N_iter_func):
29
+ y = mx.conv2d(a, b, stride=strides, padding=padding, groups=groups)
30
+ ys.append(y)
31
+ mx.eval(ys)
32
+ return ys
33
+
34
+ return mx_conv_2D
35
+
36
+
37
+ def make_pt_conv_2D(strides=(1, 1), padding=(0, 0), groups=1):
38
+ @torch.no_grad()
39
+ def pt_conv_2D(a, b):
40
+ ys = []
41
+ for i in range(N_iter_func):
42
+ y = torch.conv2d(a, b, stride=strides, padding=padding, groups=groups)
43
+ ys.append(y)
44
+ torch.mps.synchronize()
45
+ return ys
46
+
47
+ return pt_conv_2D
48
+
49
+
50
+ def bench_shape(N, H, W, C, kH, kW, O, strides, padding, groups, np_dtype):
51
+ scale = 1.0 / math.sqrt(kH * kH * C)
52
+ a_np = np.random.uniform(0, 0.5, (N, H, W, C)).astype(np_dtype)
53
+ b_np = np.random.uniform(-scale, scale, (O, kH, kW, int(C / groups))).astype(
54
+ np_dtype
55
+ )
56
+
57
+ a_mx = mx.array(a_np)
58
+ b_mx = mx.array(b_np)
59
+
60
+ a_pt = torch.from_numpy(a_np.transpose((0, 3, 1, 2))).to("mps")
61
+ b_pt = torch.from_numpy(b_np.transpose((0, 3, 1, 2))).to("mps")
62
+
63
+ torch.mps.synchronize()
64
+
65
+ f_mx = make_mx_conv_2D(strides, padding, groups)
66
+ f_pt = make_pt_conv_2D(strides, padding, groups)
67
+
68
+ time_torch = bench(f_pt, a_pt, b_pt)
69
+ time_mlx = bench(f_mx, a_mx, b_mx)
70
+
71
+ out_mx = mx.conv2d(a_mx, b_mx, stride=strides, padding=padding, groups=groups)
72
+ out_pt = torch.conv2d(
73
+ a_pt.to("cpu"), b_pt.to("cpu"), stride=strides, padding=padding, groups=groups
74
+ )
75
+ out_pt = torch.permute(out_pt, (0, 2, 3, 1))
76
+ out_pt = out_pt.numpy(force=True)
77
+
78
+ atol = 2e-5 if np_dtype == np.float32 else 1e-4
79
+
80
+ if not np.allclose(out_pt, out_mx, atol=atol):
81
+ print(
82
+ f"Failed at {(N, H, W, C)}, {(O, kH, kW, C)} [strides = {strides}, padding = {padding}, groups = {groups}] with max(|a - b|) = {np.max(np.abs(out_pt - out_mx))}"
83
+ )
84
+
85
+ return time_mlx, time_torch
86
+
87
+
88
+ if __name__ == "__main__":
89
+ dtype = "float32"
90
+ shapes = (
91
+ (4, 32, 32, 21, 3, 3, 128),
92
+ (4, 32, 32, 21, 3, 3, 37),
93
+ (4, 32, 32, 370, 3, 3, 370),
94
+ (4, 32, 32, 370, 7, 7, 128),
95
+ (2, 320, 640, 21, 7, 7, 21),
96
+ )
97
+ for N, H, W, C, kh, kw, O in shapes:
98
+ time_mlx, time_torch = bench_shape(
99
+ N, H, W, C, kh, kw, O, (1, 1), (0, 0), 1, dtype
100
+ )
101
+ diff = time_torch / time_mlx - 1.0
102
+
103
+ print(
104
+ f"({N}, {H:3d}, {W:3d}, {C:3d}), ({O:3d}, {kh:2d}, {kw:2d}, {C:3d}), {dtype}, {100. * diff:+5.2f}%"
105
+ )
106
+ if time_mlx >= 2.0 * time_torch:
107
+ print("ATTENTION ^^^^^^^")
ml-stable-diffusion/mlx/benchmarks/python/distributed_bench.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ """
4
+ Run with:
5
+ mpirun -n 2 python /path/to/distributed_bench.py
6
+ """
7
+
8
+ import time
9
+
10
+ import mlx.core as mx
11
+
12
+
13
+ def time_fn(fn, *args, **kwargs):
14
+ msg = kwargs.pop("msg", None)
15
+ world = mx.distributed.init()
16
+ if world.rank() == 0:
17
+ if msg:
18
+ print(f"Timing {msg} ...", end=" ")
19
+ else:
20
+ print(f"Timing {fn.__name__} ...", end=" ")
21
+
22
+ # warmup
23
+ for _ in range(5):
24
+ mx.eval(fn(*args, **kwargs))
25
+
26
+ num_iters = 100
27
+ tic = time.perf_counter()
28
+ for _ in range(num_iters):
29
+ x = mx.eval(fn(*args, **kwargs))
30
+ toc = time.perf_counter()
31
+
32
+ msec = 1e3 * (toc - tic) / num_iters
33
+ if world.rank() == 0:
34
+ print(f"{msec:.5f} msec")
35
+
36
+
37
+ def time_all_sum():
38
+ shape = (4096,)
39
+ x = mx.random.uniform(shape=shape)
40
+ mx.eval(x)
41
+
42
+ def sine(x):
43
+ for _ in range(20):
44
+ x = mx.sin(x)
45
+ return x
46
+
47
+ time_fn(sine, x)
48
+
49
+ def all_sum_plain(x):
50
+ for _ in range(20):
51
+ x = mx.distributed.all_sum(x)
52
+ return x
53
+
54
+ time_fn(all_sum_plain, x)
55
+
56
+ def all_sum_with_sine(x):
57
+ for _ in range(20):
58
+ x = mx.sin(x)
59
+ x = mx.distributed.all_sum(x)
60
+ return x
61
+
62
+ time_fn(all_sum_with_sine, x)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ time_all_sum()
ml-stable-diffusion/mlx/benchmarks/python/einsum_bench.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import time
4
+
5
+ import mlx.core as mx
6
+ import numpy as np
7
+
8
+
9
+ def timeit(fn, its=100, args=[]):
10
+ for _ in range(5):
11
+ fn(*args)
12
+ tic = time.perf_counter()
13
+ for _ in range(its):
14
+ fn(*args)
15
+ toc = time.perf_counter()
16
+ return 1e3 * (toc - tic) / its
17
+
18
+
19
+ def time_little_einsum_path():
20
+ subscripts = "ik,kj->ij"
21
+ x = mx.ones((32, 32))
22
+ y = mx.ones((32, 32))
23
+ mx_time = timeit(mx.einsum_path, args=(subscripts, x, y))
24
+
25
+ x = np.array(x)
26
+ y = np.array(y)
27
+ np_time = timeit(np.einsum_path, args=(subscripts, x, y))
28
+ print("Timing little einsum path...")
29
+ print(f"MLX ... {mx_time:.3f} ms")
30
+ print(f"NumPy... {np_time:.3f} ms")
31
+
32
+
33
+ def time_big_einsum_path():
34
+ chars = list("abcdefgh")
35
+ char_to_dim = {c: v for v, c in enumerate(chars)}
36
+
37
+ num_inputs = 10
38
+ inputs = []
39
+ subscripts = []
40
+ for _ in range(num_inputs):
41
+ subscript = np.random.choice(chars, size=5, replace=False).tolist()
42
+ subscripts.append("".join(subscript))
43
+ inputs.append(np.ones(list(char_to_dim[c] for c in subscript)))
44
+ subscripts = ",".join(subscripts)
45
+
46
+ np_time = timeit(np.einsum_path, args=(subscripts, *inputs))
47
+
48
+ inputs = [mx.array(x) for x in inputs]
49
+ mx_time = timeit(mx.einsum_path, args=(subscripts, *inputs))
50
+ print("Timing big einsum path...")
51
+ print(f"MLX ... {mx_time:.3f} ms")
52
+ print(f"NumPy... {np_time:.3f} ms")
53
+
54
+
55
+ def time_attention():
56
+ def regular_attention(x):
57
+ # shape [batch, sequence, num_heads, head_dim]
58
+ queries, keys, values = x, x, x
59
+ scores = queries.transpose(0, 2, 1, 3) @ keys.transpose(0, 2, 3, 1)
60
+ scores = mx.softmax(scores, axis=-1)
61
+ output = (scores @ values.transpose(0, 2, 1, 3)).swapaxes(1, 2)
62
+ mx.eval(output)
63
+
64
+ def einsum_attention(x):
65
+ # shape [batch, sequence, num_heads, head_dim]
66
+ queries, keys, values = x, x, x
67
+ scores = mx.einsum("itjk,iujk->ijtu", queries, keys)
68
+ scores = mx.softmax(scores, axis=-1)
69
+ output = mx.einsum("ijtu,iujk->itjk", scores, values)
70
+ mx.eval(output)
71
+
72
+ x = mx.random.uniform(shape=(8, 512, 32, 128))
73
+
74
+ regular_time = timeit(regular_attention, args=(x,))
75
+ ein_time = timeit(einsum_attention, args=(x,))
76
+ print("Timing einsum attention...")
77
+ print(f"Regular ... {regular_time:.3f} ms")
78
+ print(f"Einsum ... {ein_time:.3f} ms")
79
+
80
+
81
+ if __name__ == "__main__":
82
+ time_little_einsum_path()
83
+ time_big_einsum_path()
84
+ time_attention()
ml-stable-diffusion/mlx/benchmarks/python/fft_bench.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2024 Apple Inc.
2
+
3
+ import matplotlib
4
+ import mlx.core as mx
5
+ import numpy as np
6
+ import sympy
7
+ import torch
8
+ from time_utils import measure_runtime
9
+
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pyplot as plt
12
+
13
+
14
+ def bandwidth_gb(runtime_ms, system_size):
15
+ bytes_per_fft = np.dtype(np.complex64).itemsize * 2
16
+ bytes_per_gb = 1e9
17
+ ms_per_s = 1e3
18
+ return system_size * bytes_per_fft / runtime_ms * ms_per_s / bytes_per_gb
19
+
20
+
21
+ def run_bench(system_size, fft_sizes, backend="mlx", dim=1):
22
+ def fft_mlx(x):
23
+ if dim == 1:
24
+ out = mx.fft.fft(x)
25
+ elif dim == 2:
26
+ out = mx.fft.fft2(x)
27
+ mx.eval(out)
28
+ return out
29
+
30
+ def fft_mps(x):
31
+ if dim == 1:
32
+ out = torch.fft.fft(x)
33
+ elif dim == 2:
34
+ out = torch.fft.fft2(x)
35
+ torch.mps.synchronize()
36
+ return out
37
+
38
+ bandwidths = []
39
+ for n in fft_sizes:
40
+ batch_size = system_size // n**dim
41
+ shape = [batch_size] + [n for _ in range(dim)]
42
+ if backend == "mlx":
43
+ x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
44
+ x = mx.array(x_np)
45
+ mx.eval(x)
46
+ fft = fft_mlx
47
+ elif backend == "mps":
48
+ x_np = np.random.uniform(size=(system_size // n, n)).astype(np.complex64)
49
+ x = torch.tensor(x_np, device="mps")
50
+ torch.mps.synchronize()
51
+ fft = fft_mps
52
+ else:
53
+ raise NotImplementedError()
54
+ runtime_ms = measure_runtime(fft, x=x)
55
+ bandwidth = bandwidth_gb(runtime_ms, np.prod(shape))
56
+ print(n, bandwidth)
57
+ bandwidths.append(bandwidth)
58
+
59
+ return np.array(bandwidths)
60
+
61
+
62
+ def time_fft():
63
+ x = np.array(range(2, 512))
64
+ system_size = int(2**26)
65
+
66
+ print("MLX GPU")
67
+ with mx.stream(mx.gpu):
68
+ gpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
69
+
70
+ print("MPS GPU")
71
+ mps_bandwidths = run_bench(system_size=system_size, fft_sizes=x, backend="mps")
72
+
73
+ print("CPU")
74
+ system_size = int(2**20)
75
+ with mx.stream(mx.cpu):
76
+ cpu_bandwidths = run_bench(system_size=system_size, fft_sizes=x)
77
+
78
+ x = np.array(x)
79
+
80
+ all_indices = x - x[0]
81
+ radix_2to13 = (
82
+ np.array([i for i in x if all(p <= 13 for p in sympy.primefactors(i))]) - x[0]
83
+ )
84
+ bluesteins = (
85
+ np.array([i for i in x if any(p > 13 for p in sympy.primefactors(i))]) - x[0]
86
+ )
87
+
88
+ for indices, name in [
89
+ (all_indices, "All"),
90
+ (radix_2to13, "Radix 2-13"),
91
+ (bluesteins, "Bluestein's"),
92
+ ]:
93
+ # plot bandwidths
94
+ print(name)
95
+ plt.scatter(x[indices], gpu_bandwidths[indices], color="green", label="GPU")
96
+ plt.scatter(x[indices], mps_bandwidths[indices], color="blue", label="MPS")
97
+ plt.scatter(x[indices], cpu_bandwidths[indices], color="red", label="CPU")
98
+ plt.title(f"MLX FFT Benchmark -- {name}")
99
+ plt.xlabel("N")
100
+ plt.ylabel("Bandwidth (GB/s)")
101
+ plt.legend()
102
+ plt.savefig(f"{name}.png")
103
+ plt.clf()
104
+
105
+ av_gpu_bandwidth = np.mean(gpu_bandwidths)
106
+ av_mps_bandwidth = np.mean(mps_bandwidths)
107
+ av_cpu_bandwidth = np.mean(cpu_bandwidths)
108
+ print("Average bandwidths:")
109
+ print("GPU:", av_gpu_bandwidth)
110
+ print("MPS:", av_mps_bandwidth)
111
+ print("CPU:", av_cpu_bandwidth)
112
+
113
+ portion_faster = len(np.where(gpu_bandwidths > mps_bandwidths)[0]) / len(x)
114
+ print("Percent MLX faster than MPS: ", portion_faster * 100)
115
+
116
+
117
+ if __name__ == "__main__":
118
+ time_fft()
ml-stable-diffusion/mlx/benchmarks/python/gather_bench.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023-2024 Apple Inc.
2
+
3
+ import argparse
4
+
5
+ import mlx.core as mx
6
+ import torch
7
+ from time_utils import measure_runtime
8
+
9
+
10
+ def benchmark_gather_mlx(x_shape, idx_shape):
11
+ def gather(x, idx):
12
+ mx.eval(x[idx])
13
+
14
+ idx = mx.random.randint(0, x_shape[0] - 1, idx_shape)
15
+ x = mx.random.normal(x_shape).astype(mx.float32)
16
+
17
+ runtime = measure_runtime(gather, x=x, idx=idx)
18
+ print(f"MLX: {runtime:.3f}ms")
19
+
20
+
21
+ def benchmark_gather_torch(x_shape, idx_shape, device):
22
+ def gather(x, idx, device):
23
+ _ = x[idx]
24
+ if device == torch.device("mps"):
25
+ torch.mps.synchronize()
26
+
27
+ idx = torch.randint(0, x_shape[0] - 1, idx_shape).to(device)
28
+ x = torch.randn(x_shape, dtype=torch.float32).to(device)
29
+
30
+ runtime = measure_runtime(gather, x=x, idx=idx, device=device)
31
+ print(f"PyTorch: {runtime:.3f}ms")
32
+
33
+
34
+ if __name__ == "__main__":
35
+ parser = argparse.ArgumentParser("Gather benchmarks.")
36
+ parser.add_argument("--cpu", action="store_true", help="Use the CPU.")
37
+ args = parser.parse_args()
38
+
39
+ if args.cpu:
40
+ mx.set_default_device(mx.cpu)
41
+ device = torch.device("cpu")
42
+ else:
43
+ device = torch.device("mps")
44
+
45
+ idx_shapes = [(1_000_000,), (100_000,), ()]
46
+ x_shapes = [(100, 64), (100, 1024), (4, 1_000_000)]
47
+
48
+ for x_shape, idx_shape in zip(x_shapes, idx_shapes):
49
+ print("=" * 20)
50
+ print(f"X {x_shape}, Indices {idx_shape}")
51
+ benchmark_gather_mlx(x_shape, idx_shape)
52
+ benchmark_gather_torch(x_shape, idx_shape, device=device)
ml-stable-diffusion/mlx/benchmarks/python/gather_mm_bench.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2025 Apple Inc.
2
+
3
+ import mlx.core as mx
4
+ from time_utils import time_fn
5
+
6
+ N = 1024
7
+ D = 1024
8
+ M = 1024
9
+ E = 32
10
+ I = 4
11
+
12
+
13
+ def gather_sort(x, indices):
14
+ N, M = indices.shape
15
+ indices = indices.flatten()
16
+ order = mx.argsort(indices)
17
+ inv_order = mx.argsort(order)
18
+ return x.flatten(0, -3)[order // M], indices[order], inv_order
19
+
20
+
21
+ def scatter_unsort(x, inv_order, shape=None):
22
+ x = x[inv_order]
23
+ if shape is not None:
24
+ x = mx.unflatten(x, 0, shape)
25
+ return x
26
+
27
+
28
+ def gather_mm_simulate(x, w, indices):
29
+ x, idx, inv_order = gather_sort(x, indices)
30
+ for i in range(2):
31
+ y = mx.concatenate([x[i] @ w[j].T for i, j in enumerate(idx.tolist())], axis=0)
32
+ x = y[:, None]
33
+ x = scatter_unsort(x, inv_order, indices.shape)
34
+ return x
35
+
36
+
37
+ def time_gather_mm():
38
+ x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
39
+ w1 = mx.random.normal((E, M, D)) / 1024**0.5
40
+ w2 = mx.random.normal((E, D, M)) / 1024**0.5
41
+ indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
42
+ sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
43
+ mx.eval(x, w1, w2, indices, sorted_indices)
44
+
45
+ def gather_mm(x, w1, w2, indices, sort):
46
+ idx = indices
47
+ inv_order = None
48
+ if sort:
49
+ x, idx, inv_order = gather_sort(x, indices)
50
+ x = mx.gather_mm(x, w1.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
51
+ x = mx.gather_mm(x, w2.swapaxes(-1, -2), rhs_indices=idx, sorted_indices=sort)
52
+ if sort:
53
+ x = scatter_unsort(x, inv_order, indices.shape)
54
+ return x
55
+
56
+ time_fn(gather_mm, x, w1, w2, indices, False)
57
+ time_fn(gather_mm, x, w1, w2, sorted_indices, False)
58
+ time_fn(gather_mm, x, w1, w2, indices, True)
59
+
60
+ x = mx.random.normal((N * I, D)) / 1024**0.5
61
+ w1 = mx.random.normal((M, D)) / 1024**0.5
62
+ w2 = mx.random.normal((D, M)) / 1024**0.5
63
+ mx.eval(x, w1, w2)
64
+
65
+ def equivalent_matmul(x, w1, w2):
66
+ x = x @ w1.T
67
+ x = x @ w2.T
68
+ return x
69
+
70
+ time_fn(equivalent_matmul, x, w1, w2)
71
+
72
+
73
+ if __name__ == "__main__":
74
+ time_gather_mm()
ml-stable-diffusion/mlx/benchmarks/python/gather_qmm_bench.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2025 Apple Inc.
2
+
3
+ import mlx.core as mx
4
+ from time_utils import time_fn
5
+
6
+ N = 1024
7
+ D = 1024
8
+ M = 1024
9
+ E = 32
10
+ I = 4
11
+
12
+
13
+ def gather_sort(x, indices):
14
+ N, M = indices.shape
15
+ indices = indices.flatten()
16
+ order = mx.argsort(indices)
17
+ inv_order = mx.argsort(order)
18
+ return x.flatten(0, -3)[order // M], indices[order], inv_order
19
+
20
+
21
+ def scatter_unsort(x, inv_order, shape=None):
22
+ x = x[inv_order]
23
+ if shape is not None:
24
+ x = mx.unflatten(x, 0, shape)
25
+ return x
26
+
27
+
28
+ def gather_mm_simulate(x, w, indices):
29
+ x, idx, inv_order = gather_sort(x, indices)
30
+ for i in range(2):
31
+ y = mx.concatenate(
32
+ [
33
+ mx.quantized_matmul(x[i], w[0][j], w[1][j], w[2][j], transpose=True)
34
+ for i, j in enumerate(idx.tolist())
35
+ ],
36
+ axis=0,
37
+ )
38
+ x = y[:, None]
39
+ x = scatter_unsort(x, inv_order, indices.shape)
40
+ return x
41
+
42
+
43
+ def time_gather_qmm():
44
+ x = mx.random.normal((N, 1, 1, D)) / 1024**0.5
45
+ w1 = mx.random.normal((E, M, D)) / 1024**0.5
46
+ w2 = mx.random.normal((E, D, M)) / 1024**0.5
47
+ w1 = mx.quantize(w1)
48
+ w2 = mx.quantize(w2)
49
+ indices = (mx.random.uniform(shape=(N, I)) * E).astype(mx.uint32)
50
+ sorted_indices = mx.sort(indices.flatten()).reshape(N, I)
51
+ mx.eval(x, w1, w2, indices, sorted_indices)
52
+
53
+ def gather_mm(x, w1, w2, indices, sort):
54
+ idx = indices
55
+ inv_order = None
56
+ if sort:
57
+ x, idx, inv_order = gather_sort(x, indices)
58
+ x = mx.gather_qmm(x, *w1, transpose=True, rhs_indices=idx, sorted_indices=sort)
59
+ x = mx.gather_qmm(x, *w2, transpose=True, rhs_indices=idx, sorted_indices=sort)
60
+ if sort:
61
+ x = scatter_unsort(x, inv_order, indices.shape)
62
+ return x
63
+
64
+ time_fn(gather_mm, x, w1, w2, indices, False)
65
+ time_fn(gather_mm, x, w1, w2, sorted_indices, False)
66
+ time_fn(gather_mm, x, w1, w2, indices, True)
67
+
68
+ x = mx.random.normal((N * I, D)) / 1024**0.5
69
+ w1 = mx.random.normal((M, D)) / 1024**0.5
70
+ w2 = mx.random.normal((D, M)) / 1024**0.5
71
+ w1 = mx.quantize(w1)
72
+ w2 = mx.quantize(w2)
73
+ mx.eval(x, w1, w2)
74
+
75
+ def equivalent_matmul(x, w1, w2):
76
+ x = mx.quantized_matmul(x, *w1, transpose=True)
77
+ x = mx.quantized_matmul(x, *w2, transpose=True)
78
+ return x
79
+
80
+ time_fn(equivalent_matmul, x, w1, w2)
81
+
82
+
83
+ if __name__ == "__main__":
84
+ time_gather_qmm()
ml-stable-diffusion/mlx/benchmarks/python/hadamard_bench.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+ import matplotlib
4
+ import mlx.core as mx
5
+ import numpy as np
6
+ from time_utils import measure_runtime
7
+
8
+ matplotlib.use("Agg")
9
+ import matplotlib.pyplot as plt
10
+
11
+
12
+ def had(x):
13
+ y = mx.hadamard_transform(x)
14
+ mx.eval(y)
15
+
16
+
17
+ def copy(x):
18
+ y = x + 1.0
19
+ mx.eval(y)
20
+
21
+
22
+ def run(dtype):
23
+ system_size = 2**26
24
+ outputs = {}
25
+ for test_fn in (had, copy):
26
+ for m in [1, 12, 20, 28]:
27
+ if test_fn == copy:
28
+ key = "copy"
29
+ elif m == 1:
30
+ key = "had_2^k"
31
+ else:
32
+ key = "had_m*2^k"
33
+ outputs.setdefault(key, {})
34
+ for k in range(7, 14):
35
+ n = m * 2**k
36
+ if n > 2**15:
37
+ continue
38
+ x_np = np.random.normal(size=(system_size // n, n)).astype(dtype)
39
+ x = mx.array(x_np)
40
+ runtime_ms = measure_runtime(test_fn, x=x)
41
+ bytes_per_gb = 1e9
42
+ ms_per_s = 1e3
43
+ bytes_per_had = np.dtype(x_np.dtype).itemsize * 2
44
+ bandwidth_gb = (
45
+ system_size * bytes_per_had / runtime_ms * ms_per_s / bytes_per_gb
46
+ )
47
+ print(n, bandwidth_gb)
48
+ outputs[key][n] = bandwidth_gb
49
+
50
+ colors = {
51
+ "copy": "black",
52
+ "had_2^k": "steelblue",
53
+ "had_m*2^k": "skyblue",
54
+ }
55
+ for key, output in outputs.items():
56
+ plt.scatter(output.keys(), output.values(), color=colors[key], label=key)
57
+ plt.title(f"MLX Hadamard Benchmark -- {dtype.__name__}")
58
+ plt.xlabel("N")
59
+ plt.ylabel("Bandwidth (GB/s)")
60
+ plt.legend()
61
+ plt.savefig(f"bench_{dtype.__name__}.png")
62
+ plt.clf()
63
+
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument("--fp16", action="store_true")
68
+ args = parser.parse_args()
69
+ dtype = np.float16 if args.fp16 else np.float32
70
+ run(dtype)
ml-stable-diffusion/mlx/benchmarks/python/layer_norm_bench.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright © 2023-2024 Apple Inc.
2
+
3
+ from functools import partial
4
+
5
+ import mlx.core as mx
6
+ import mlx.nn as nn
7
+ from time_utils import time_fn
8
+
9
+
10
+ def layer_norm(x, w, b, eps):
11
+ ot = x.dtype
12
+ x = x.astype(mx.float32)
13
+ mu = mx.mean(x, -1, keepdims=True)
14
+ v = mx.var(x, -1, keepdims=True)
15
+ y = (x - mu) * mx.rsqrt(v + eps)
16
+ if w is not None:
17
+ y = y * w
18
+ if b is not None:
19
+ y = y + b
20
+ return y
21
+
22
+
23
+ def time_layer_norm(N, dt):
24
+ L = 1024
25
+ f1 = lambda x, w, b, y: (layer_norm(x, w, b, 1e-5) * y).sum()
26
+ f2 = lambda x, w, b, y: (mx.fast.layer_norm(x, w, b, 1e-5) * y).sum()
27
+ g1 = mx.grad(f1, argnums=(0, 1, 2))
28
+ g2 = mx.grad(f2, argnums=(0, 1, 2))
29
+
30
+ x = mx.random.uniform(shape=(8, L, N)).astype(dt)
31
+ w = mx.random.uniform(shape=(N,)).astype(dt)
32
+ b = mx.random.uniform(shape=(N,)).astype(dt)
33
+ y = mx.random.uniform(shape=(8, L, N)).astype(dt)
34
+ mx.eval(x, w, b, y)
35
+
36
+ def layer_norm_loop(f, x, w, b):
37
+ for _ in range(32):
38
+ x = f(x, w, b)
39
+ return x
40
+
41
+ time_fn(layer_norm_loop, partial(layer_norm, eps=1e-5), x, w, b)
42
+ time_fn(layer_norm_loop, partial(mx.fast.layer_norm, eps=1e-5), x, w, b)
43
+
44
+ def layer_norm_grad_loop(g, x, w, b):
45
+ gx, gw, gb = x, w, b
46
+ for _ in range(32):
47
+ gx, gw, gb = g(gx, gw, gb, y)
48
+ return gx, gw, gb
49
+
50
+ time_fn(layer_norm_grad_loop, g1, x, w, b)
51
+ time_fn(layer_norm_grad_loop, g2, x, w, b)
52
+ time_fn(layer_norm_grad_loop, mx.compile(g1), x, w, b)
53
+ time_fn(layer_norm_grad_loop, mx.compile(g2), x, w, b)
54
+
55
+ f1 = lambda x, y: (layer_norm(x, None, None, 1e-5) * y).sum()
56
+ f2 = lambda x, y: (mx.fast.layer_norm(x, None, None, 1e-5) * y).sum()
57
+ g1 = mx.grad(f1, argnums=(0,))
58
+ g2 = mx.grad(f2, argnums=(0,))
59
+
60
+ x = mx.random.uniform(shape=(8, L, N)).astype(dt)
61
+ w = mx.random.uniform(shape=(N,)).astype(dt)
62
+ b = mx.random.uniform(shape=(N,)).astype(dt)
63
+ y = mx.random.uniform(shape=(8, L, N)).astype(dt)
64
+ mx.eval(x, w, b, y)
65
+
66
+ def layer_norm_grad_x_loop(g, x):
67
+ gx = x
68
+ for _ in range(32):
69
+ gx = g(gx, y)
70
+ return gx
71
+
72
+ time_fn(layer_norm_grad_x_loop, g1, x)
73
+ time_fn(layer_norm_grad_x_loop, g2, x)
74
+ time_fn(layer_norm_grad_x_loop, mx.compile(g1), x)
75
+ time_fn(layer_norm_grad_x_loop, mx.compile(g2), x)
76
+
77
+
78
+ if __name__ == "__main__":
79
+ for dt in [mx.float32, mx.float16, mx.bfloat16]:
80
+ for n in [1024, 2048, 4096, 8192, 8192 + 1024]:
81
+ print(dt, n)
82
+ time_layer_norm(n, dt)