Add files using upload-large-folder tool
Browse files- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/__init__.py +687 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/fragmented_array.py +661 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/profiler.py +289 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/utils.py +699 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/wgmma.py +518 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/__init__.py +48 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/gpu.py +18 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/ops/__init__.py +19 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/ops/gpu/attention.py +573 -0
- external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/tpu.py +53 -0
- external/alphageometry/README.md +447 -0
- external/alphageometry/lm_inference_test.py +89 -0
- external/alphageometry/models.py +178 -0
- external/alphageometry/numericals.py +1921 -0
- external/alphageometry/numericals_test.py +313 -0
- external/alphageometry/pretty.py +216 -0
- external/alphageometry/problem.py +1133 -0
- external/alphageometry/problem_test.py +61 -0
- external/alphageometry/requirements.in +17 -0
- external/alphageometry/requirements.txt +0 -0
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/__init__.py
ADDED
|
@@ -0,0 +1,687 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections.abc import Callable
|
| 2 |
+
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
# ==============================================================================
|
| 16 |
+
|
| 17 |
+
import contextlib
|
| 18 |
+
import ctypes
|
| 19 |
+
import dataclasses
|
| 20 |
+
import functools
|
| 21 |
+
import itertools
|
| 22 |
+
import os
|
| 23 |
+
import pathlib
|
| 24 |
+
import subprocess
|
| 25 |
+
import tempfile
|
| 26 |
+
import time
|
| 27 |
+
from typing import Any, Generic, Sequence, TypeVar
|
| 28 |
+
|
| 29 |
+
import jax
|
| 30 |
+
from jax._src import config
|
| 31 |
+
from jax._src import core as jax_core
|
| 32 |
+
from jax._src.interpreters import mlir
|
| 33 |
+
from jax._src.lib import xla_client
|
| 34 |
+
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
| 35 |
+
from jaxlib.mlir import ir
|
| 36 |
+
from jaxlib.mlir.dialects import arith
|
| 37 |
+
from jaxlib.mlir.dialects import builtin
|
| 38 |
+
from jaxlib.mlir.dialects import func
|
| 39 |
+
from jaxlib.mlir.dialects import gpu
|
| 40 |
+
from jaxlib.mlir.dialects import llvm
|
| 41 |
+
from jaxlib.mlir.dialects import memref
|
| 42 |
+
from jaxlib.mlir.dialects import nvgpu
|
| 43 |
+
from jaxlib.mlir.dialects import nvvm
|
| 44 |
+
from jaxlib.mlir.passmanager import PassManager
|
| 45 |
+
import numpy as np
|
| 46 |
+
|
| 47 |
+
from . import dsl as mgpu
|
| 48 |
+
from . import profiler
|
| 49 |
+
from . import utils
|
| 50 |
+
|
| 51 |
+
# mypy: ignore-errors
|
| 52 |
+
|
| 53 |
+
# MLIR can't find libdevice unless we point it to the CUDA path
|
| 54 |
+
# TODO(apaszke): Unify with jax._src.lib.cuda_path
|
| 55 |
+
CUDA_ROOT = "/usr/local/cuda"
|
| 56 |
+
if os.environ.get("CUDA_ROOT") is None:
|
| 57 |
+
os.environ["CUDA_ROOT"] = CUDA_ROOT
|
| 58 |
+
else:
|
| 59 |
+
CUDA_ROOT = os.environ["CUDA_ROOT"]
|
| 60 |
+
|
| 61 |
+
PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
|
| 62 |
+
NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")
|
| 63 |
+
|
| 64 |
+
TMA_DESCRIPTOR_BYTES = 128
|
| 65 |
+
TMA_DESCRIPTOR_ALIGNMENT = 64
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
c = mgpu.c # This is too common to fully qualify.
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
RUNTIME_PATH = pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent / "libmosaic_gpu_runtime.so"
|
| 72 |
+
if RUNTIME_PATH.exists():
|
| 73 |
+
# Set this so that the custom call can find it
|
| 74 |
+
os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
|
| 78 |
+
mosaic_gpu_p.multiple_results = True
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@mosaic_gpu_p.def_abstract_eval
|
| 82 |
+
def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes):
|
| 83 |
+
del module, gmem_scratch_bytes # Unused.
|
| 84 |
+
return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]
|
| 85 |
+
|
| 86 |
+
# TODO(apaszke): Implement a proper system for managing kernel lifetimes
|
| 87 |
+
kernel_idx = itertools.count()
|
| 88 |
+
|
| 89 |
+
def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes):
|
| 90 |
+
del out_types # Unused.
|
| 91 |
+
idx_bytes = next(kernel_idx).to_bytes(8, byteorder="little")
|
| 92 |
+
op = mlir.custom_call(
|
| 93 |
+
"mosaic_gpu",
|
| 94 |
+
result_types=[
|
| 95 |
+
*(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out),
|
| 96 |
+
mlir.aval_to_ir_type(
|
| 97 |
+
jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8)
|
| 98 |
+
),
|
| 99 |
+
],
|
| 100 |
+
operands=args,
|
| 101 |
+
backend_config=idx_bytes
|
| 102 |
+
+ module.operation.get_asm(binary=True, enable_debug_info=True),
|
| 103 |
+
)
|
| 104 |
+
return op.results[:-1] # Skip the scratch space.
|
| 105 |
+
|
| 106 |
+
mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclasses.dataclass(frozen=True)
|
| 110 |
+
class MemRefTransform:
|
| 111 |
+
def apply(self, ref: ir.Value) -> ir.Value:
|
| 112 |
+
raise NotImplementedError("Subclasses should override this method")
|
| 113 |
+
|
| 114 |
+
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
| 115 |
+
raise NotImplementedError("Subclasses should override this method")
|
| 116 |
+
|
| 117 |
+
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
| 118 |
+
raise NotImplementedError("Subclasses should override this method")
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
@dataclasses.dataclass(frozen=True)
|
| 122 |
+
class TileTransform(MemRefTransform):
|
| 123 |
+
"""Tiles a suffix of memref dimensions.
|
| 124 |
+
|
| 125 |
+
For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
|
| 126 |
+
the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with
|
| 127 |
+
the tile shape, and the size of tiled dimensions is divided by the tile size.
|
| 128 |
+
This is especially useful for swizzled WGMMA, which expect tiled layouts in
|
| 129 |
+
shared memory.
|
| 130 |
+
"""
|
| 131 |
+
tiling: tuple[int, ...]
|
| 132 |
+
|
| 133 |
+
def apply(self, ref: ir.Value) -> ir.Value:
|
| 134 |
+
untiled_rank = ir.MemRefType(ref.type).rank
|
| 135 |
+
tiling_rank = len(self.tiling)
|
| 136 |
+
tiled_rank = untiled_rank + tiling_rank
|
| 137 |
+
for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]):
|
| 138 |
+
ref = mgpu.memref_unfold(ref, d, (None, t))
|
| 139 |
+
permutation = (
|
| 140 |
+
*range(untiled_rank - tiling_rank),
|
| 141 |
+
*range(untiled_rank - tiling_rank, tiled_rank, 2),
|
| 142 |
+
*range(untiled_rank - tiling_rank + 1, tiled_rank, 2),
|
| 143 |
+
)
|
| 144 |
+
return mgpu.memref_transpose(ref, permutation)
|
| 145 |
+
|
| 146 |
+
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
| 147 |
+
index = ir.IndexType.get()
|
| 148 |
+
tiling_rank = len(self.tiling)
|
| 149 |
+
return (
|
| 150 |
+
*idx[:-tiling_rank],
|
| 151 |
+
*(
|
| 152 |
+
arith.divui(i, c(t, index))
|
| 153 |
+
for i, t in zip(idx[-tiling_rank:], self.tiling)
|
| 154 |
+
),
|
| 155 |
+
*(
|
| 156 |
+
arith.remui(i, c(t, index))
|
| 157 |
+
for i, t in zip(idx[-tiling_rank:], self.tiling)
|
| 158 |
+
),
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
| 162 |
+
# Note that this also checks that tiled dims are not squeezed. Their slice
|
| 163 |
+
# size would be 1 if so.
|
| 164 |
+
tiling_rank = len(self.tiling)
|
| 165 |
+
for size, tile_size in zip(shape[-tiling_rank:], self.tiling):
|
| 166 |
+
if size % tile_size:
|
| 167 |
+
raise ValueError(
|
| 168 |
+
f"Expected GMEM slice shape {shape} suffix to be a multiple"
|
| 169 |
+
f" of tiling {self.tiling}"
|
| 170 |
+
)
|
| 171 |
+
return (
|
| 172 |
+
*shape[:-tiling_rank],
|
| 173 |
+
*(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)),
|
| 174 |
+
*self.tiling,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
@dataclasses.dataclass(frozen=True)
|
| 179 |
+
class TransposeTransform(MemRefTransform):
|
| 180 |
+
"""Transposes memref dimensions."""
|
| 181 |
+
permutation: tuple[int, ...]
|
| 182 |
+
|
| 183 |
+
def __post_init__(self):
|
| 184 |
+
if len(self.permutation) != len(set(self.permutation)):
|
| 185 |
+
raise ValueError("Permutation must be a permutation")
|
| 186 |
+
|
| 187 |
+
def apply(self, ref: ir.Value) -> ir.Value:
|
| 188 |
+
return mgpu.memref_transpose(ref, self.permutation)
|
| 189 |
+
|
| 190 |
+
def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
|
| 191 |
+
return tuple(idx[p] for p in self.permutation)
|
| 192 |
+
|
| 193 |
+
def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
|
| 194 |
+
return tuple(shape[p] for p in self.permutation)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
OnDeviceProfiler = profiler.OnDeviceProfiler
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
@dataclasses.dataclass()
|
| 201 |
+
class LaunchContext:
|
| 202 |
+
launch_op: gpu.LaunchOp
|
| 203 |
+
gmem_scratch_ptr: ir.Value
|
| 204 |
+
profiler: OnDeviceProfiler | None = None
|
| 205 |
+
next_scratch_offset: int = 0
|
| 206 |
+
host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
|
| 207 |
+
default_factory=list, init=False
|
| 208 |
+
)
|
| 209 |
+
tma_descriptors: dict[
|
| 210 |
+
tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
|
| 211 |
+
ir.Value,
|
| 212 |
+
] = dataclasses.field(default_factory=dict, init=False)
|
| 213 |
+
|
| 214 |
+
@contextlib.contextmanager
|
| 215 |
+
def named_region(self, *args, **kwargs):
|
| 216 |
+
if self.profiler is not None:
|
| 217 |
+
with self.profiler.record(*args, **kwargs):
|
| 218 |
+
yield
|
| 219 |
+
else:
|
| 220 |
+
yield
|
| 221 |
+
|
| 222 |
+
def _alloc_scratch(
|
| 223 |
+
self,
|
| 224 |
+
size: int,
|
| 225 |
+
alignment: int | None = None,
|
| 226 |
+
host_init: Callable[[ir.Value], None] = lambda _: None,
|
| 227 |
+
device_init: Callable[[ir.Value], Any] = lambda x: x,
|
| 228 |
+
) -> ir.Value:
|
| 229 |
+
"""Allocates a GMEM scratch buffer.
|
| 230 |
+
|
| 231 |
+
The buffer is initialized on the host and then copied to GMEM before the
|
| 232 |
+
kernel launch.
|
| 233 |
+
"""
|
| 234 |
+
i8 = ir.IntegerType.get_signless(8)
|
| 235 |
+
ptr_ty = ir.Type.parse("!llvm.ptr")
|
| 236 |
+
if alignment is None:
|
| 237 |
+
alignment = size
|
| 238 |
+
if self.next_scratch_offset % alignment:
|
| 239 |
+
raise NotImplementedError # TODO(apaszke): Pad to match alignment
|
| 240 |
+
alloc_base = self.next_scratch_offset
|
| 241 |
+
self.next_scratch_offset += size
|
| 242 |
+
def host_init_wrapped(host_ptr):
|
| 243 |
+
with ir.InsertionPoint(self.launch_op):
|
| 244 |
+
host_init(
|
| 245 |
+
llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
|
| 246 |
+
)
|
| 247 |
+
self.host_scratch_init.append(host_init_wrapped)
|
| 248 |
+
with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
|
| 249 |
+
return device_init(llvm.getelementptr(
|
| 250 |
+
ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8
|
| 251 |
+
))
|
| 252 |
+
|
| 253 |
+
def _get_tma_desc(
|
| 254 |
+
self,
|
| 255 |
+
ref,
|
| 256 |
+
gmem_transform: tuple[MemRefTransform, ...],
|
| 257 |
+
transformed_slice_shape: tuple[int, ...],
|
| 258 |
+
swizzle: int | None,
|
| 259 |
+
):
|
| 260 |
+
index = ir.IndexType.get()
|
| 261 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 262 |
+
tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform)
|
| 263 |
+
if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
|
| 264 |
+
swizzle_str = f"swizzle_{swizzle}b" if swizzle is not None else "none"
|
| 265 |
+
default_tensor_map_attrs = dict(
|
| 266 |
+
swizzle=swizzle_str, l2promo="none", oob="zero", interleave="none"
|
| 267 |
+
)
|
| 268 |
+
tensor_map_ty = utils.get_tensormap_descriptor(
|
| 269 |
+
tensor=(
|
| 270 |
+
f"memref<{'x'.join(map(str, transformed_slice_shape))}x{ref_ty.element_type}, 3>"
|
| 271 |
+
),
|
| 272 |
+
**default_tensor_map_attrs,
|
| 273 |
+
)
|
| 274 |
+
with ir.InsertionPoint(self.launch_op):
|
| 275 |
+
for t in gmem_transform:
|
| 276 |
+
ref = t.apply(ref)
|
| 277 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 278 |
+
|
| 279 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 280 |
+
ptr_ty = ir.Type.parse("!llvm.ptr")
|
| 281 |
+
def init_tma_desc(host_ptr):
|
| 282 |
+
_, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
|
| 283 |
+
aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
|
| 284 |
+
as_i64 = lambda i: arith.index_cast(i64, i)
|
| 285 |
+
alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx))
|
| 286 |
+
llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings...
|
| 287 |
+
base_ptr = llvm.getelementptr(
|
| 288 |
+
ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type,
|
| 289 |
+
)
|
| 290 |
+
rank = ref_ty.rank
|
| 291 |
+
assert rank * 2 == len(sizes_and_strides)
|
| 292 |
+
args = [
|
| 293 |
+
host_ptr,
|
| 294 |
+
base_ptr,
|
| 295 |
+
c(utils.bytewidth(ref_ty.element_type), i64),
|
| 296 |
+
c(rank, i64),
|
| 297 |
+
utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
|
| 298 |
+
utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
|
| 299 |
+
c(0 if swizzle is None else swizzle, i64),
|
| 300 |
+
utils.pack_array([c(v, i64) for v in transformed_slice_shape]),
|
| 301 |
+
]
|
| 302 |
+
func.call([], "mosaic_gpu_init_tma_desc", args)
|
| 303 |
+
def cast_tma_desc(device_ptr):
|
| 304 |
+
# TODO(apaszke): Investigate why prefetching can cause launch failures
|
| 305 |
+
# nvvm.prefetch_tensormap(device_ptr)
|
| 306 |
+
return builtin.unrealized_conversion_cast(
|
| 307 |
+
[tensor_map_ty], [device_ptr]
|
| 308 |
+
)
|
| 309 |
+
tma_desc = self._alloc_scratch(
|
| 310 |
+
TMA_DESCRIPTOR_BYTES,
|
| 311 |
+
alignment=TMA_DESCRIPTOR_ALIGNMENT,
|
| 312 |
+
host_init=init_tma_desc,
|
| 313 |
+
device_init=cast_tma_desc,
|
| 314 |
+
)
|
| 315 |
+
self.tma_descriptors[tma_desc_key] = tma_desc
|
| 316 |
+
return tma_desc
|
| 317 |
+
|
| 318 |
+
def async_copy(
|
| 319 |
+
self,
|
| 320 |
+
*,
|
| 321 |
+
src_ref,
|
| 322 |
+
dst_ref,
|
| 323 |
+
gmem_slice: Any = (),
|
| 324 |
+
gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (),
|
| 325 |
+
barrier: mgpu.Barrier | None = None,
|
| 326 |
+
swizzle: int | None = None,
|
| 327 |
+
arrive: bool | None = None,
|
| 328 |
+
uniform: bool = True,
|
| 329 |
+
):
|
| 330 |
+
index = ir.IndexType.get()
|
| 331 |
+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
| 332 |
+
src_ref_ty = ir.MemRefType(src_ref.type)
|
| 333 |
+
dst_ref_ty = ir.MemRefType(dst_ref.type)
|
| 334 |
+
element_type = src_ref_ty.element_type
|
| 335 |
+
if element_type != dst_ref_ty.element_type:
|
| 336 |
+
raise ValueError(
|
| 337 |
+
f"Expected same element type, got {element_type} and"
|
| 338 |
+
f" {dst_ref_ty.element_type}"
|
| 339 |
+
)
|
| 340 |
+
if not isinstance(gmem_transform, tuple):
|
| 341 |
+
gmem_transform = (gmem_transform,)
|
| 342 |
+
|
| 343 |
+
if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem:
|
| 344 |
+
gmem_ref, smem_ref = src_ref, dst_ref
|
| 345 |
+
if barrier is None:
|
| 346 |
+
raise ValueError("Barriers are required for GMEM -> SMEM copies")
|
| 347 |
+
if arrive is None:
|
| 348 |
+
arrive = True # Arrive by default
|
| 349 |
+
elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None:
|
| 350 |
+
gmem_ref, smem_ref = dst_ref, src_ref
|
| 351 |
+
if barrier is not None:
|
| 352 |
+
raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")
|
| 353 |
+
if arrive is not None:
|
| 354 |
+
raise ValueError("arrive is unsupported for SMEM -> GMEM copies")
|
| 355 |
+
else:
|
| 356 |
+
raise ValueError("Only SMEM <-> GMEM copies supported")
|
| 357 |
+
# TODO(apaszke): This is a very approximate check. Improve it!
|
| 358 |
+
expected_name = "builtin.unrealized_conversion_cast"
|
| 359 |
+
if (
|
| 360 |
+
gmem_ref.owner is None
|
| 361 |
+
or gmem_ref.owner.opview.OPERATION_NAME != expected_name
|
| 362 |
+
):
|
| 363 |
+
raise ValueError("GMEM reference in async_copy must be a kernel argument")
|
| 364 |
+
|
| 365 |
+
base_indices, slice_shape, is_squeezed = utils.parse_indices(
|
| 366 |
+
gmem_slice, ir.MemRefType(gmem_ref.type).shape
|
| 367 |
+
)
|
| 368 |
+
dyn_base_indices = tuple(
|
| 369 |
+
c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
|
| 370 |
+
)
|
| 371 |
+
slice_shape = tuple(slice_shape)
|
| 372 |
+
for t in gmem_transform:
|
| 373 |
+
dyn_base_indices = t.transform_index(dyn_base_indices)
|
| 374 |
+
slice_shape = t.transform_shape(slice_shape)
|
| 375 |
+
for dim, squeezed in enumerate(is_squeezed):
|
| 376 |
+
if squeezed:
|
| 377 |
+
smem_ref = mgpu.memref_unsqueeze(smem_ref, dim)
|
| 378 |
+
smem_ref_ty = ir.MemRefType(smem_ref.type)
|
| 379 |
+
|
| 380 |
+
if slice_shape != tuple(smem_ref_ty.shape):
|
| 381 |
+
raise ValueError(
|
| 382 |
+
"Expected the SMEM reference to have the same shape as the tiled"
|
| 383 |
+
f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
|
| 384 |
+
)
|
| 385 |
+
tma_desc = self._get_tma_desc(
|
| 386 |
+
gmem_ref, gmem_transform, slice_shape, swizzle,
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
# nvgpu TMA instructions expect reversed indices...
|
| 390 |
+
rev_dyn_based_indices = reversed(dyn_base_indices)
|
| 391 |
+
|
| 392 |
+
uniform_ctx = mgpu.single_thread if uniform else contextlib.nullcontext
|
| 393 |
+
|
| 394 |
+
if gmem_ref is src_ref:
|
| 395 |
+
with uniform_ctx():
|
| 396 |
+
assert barrier is not None # for pytype
|
| 397 |
+
barrier_group = barrier.barrier_array.value
|
| 398 |
+
barrier_idx = barrier.offset
|
| 399 |
+
if arrive:
|
| 400 |
+
slice_bytes = c(
|
| 401 |
+
np.prod(slice_shape) * mgpu.bytewidth(element_type), index
|
| 402 |
+
)
|
| 403 |
+
nvgpu.mbarrier_arrive_expect_tx(
|
| 404 |
+
barrier_group, slice_bytes, barrier_idx
|
| 405 |
+
)
|
| 406 |
+
nvgpu.tma_async_load(
|
| 407 |
+
smem_ref, barrier_group, tma_desc, rev_dyn_based_indices, barrier_idx
|
| 408 |
+
)
|
| 409 |
+
else:
|
| 410 |
+
with uniform_ctx():
|
| 411 |
+
nvgpu.tma_async_store(smem_ref, tma_desc, rev_dyn_based_indices)
|
| 412 |
+
nvvm.cp_async_bulk_commit_group()
|
| 413 |
+
|
| 414 |
+
def await_async_copy(
|
| 415 |
+
self, allow_groups: int, await_read_only: bool = False
|
| 416 |
+
):
|
| 417 |
+
nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
|
| 418 |
+
gpu.barrier() # Groups are supposedly tracked per-thread
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
# ShapeTrees currently can not contain unions.
|
| 422 |
+
ShapeTree = Any
|
| 423 |
+
RefTree = Any
|
| 424 |
+
T = TypeVar('T')
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
@dataclasses.dataclass(frozen=True)
|
| 428 |
+
class Union(Generic[T]):
|
| 429 |
+
members: Sequence[T]
|
| 430 |
+
|
| 431 |
+
|
| 432 |
+
def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
|
| 433 |
+
return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _construct_smem_reftree(
|
| 437 |
+
dynamic_smem: ir.Value, smem_buffers: ShapeTree) -> RefTree:
|
| 438 |
+
index = ir.IndexType.get()
|
| 439 |
+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
| 440 |
+
flat_ref_tys, smem_buffer_tree = jax.tree.flatten(smem_buffers)
|
| 441 |
+
smem_refs = []
|
| 442 |
+
dynamic_smem_offset = 0
|
| 443 |
+
for ref_ty in flat_ref_tys:
|
| 444 |
+
mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
|
| 445 |
+
tile_smem = memref.view(
|
| 446 |
+
ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
|
| 447 |
+
dynamic_smem, c(dynamic_smem_offset, index), [],
|
| 448 |
+
)
|
| 449 |
+
dynamic_smem_offset += _count_buffer_bytes(ref_ty)
|
| 450 |
+
smem_refs.append(tile_smem)
|
| 451 |
+
return jax.tree.unflatten(smem_buffer_tree, smem_refs)
|
| 452 |
+
|
| 453 |
+
|
| 454 |
+
# TODO(apaszke): Inline this
|
| 455 |
+
@contextlib.contextmanager
|
| 456 |
+
def _launch(
|
| 457 |
+
token,
|
| 458 |
+
grid,
|
| 459 |
+
block,
|
| 460 |
+
gmem_scratch_ptr,
|
| 461 |
+
smem_buffers: ShapeTree | Union[ShapeTree],
|
| 462 |
+
profiler_spec: profiler.ProfilerSpec | None = None,
|
| 463 |
+
maybe_prof_buffer: ir.Value | None = None,
|
| 464 |
+
):
|
| 465 |
+
if (profiler_spec is None) != (maybe_prof_buffer is None):
|
| 466 |
+
raise ValueError
|
| 467 |
+
index = ir.IndexType.get()
|
| 468 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 469 |
+
i8 = ir.IntegerType.get_signless(8)
|
| 470 |
+
grid_vals = [c(i, index) for i in grid]
|
| 471 |
+
block_vals = [c(i, index) for i in block]
|
| 472 |
+
|
| 473 |
+
if isinstance(smem_buffers, Union):
|
| 474 |
+
smem_disjoint_live_buffers_collections = smem_buffers.members
|
| 475 |
+
compute_smem_bytes = max(
|
| 476 |
+
sum(_count_buffer_bytes(l) for l in jax.tree.leaves(s))
|
| 477 |
+
for s in smem_buffers.members)
|
| 478 |
+
else:
|
| 479 |
+
smem_disjoint_live_buffers_collections = [smem_buffers]
|
| 480 |
+
compute_smem_bytes = sum(
|
| 481 |
+
_count_buffer_bytes(l) for l in jax.tree.leaves(smem_buffers))
|
| 482 |
+
|
| 483 |
+
smem_bytes = compute_smem_bytes
|
| 484 |
+
if profiler_spec is not None:
|
| 485 |
+
smem_bytes += profiler_spec.smem_bytes(block=block)
|
| 486 |
+
|
| 487 |
+
# TODO(cperivol): Query the shared memory size programmatically.
|
| 488 |
+
if smem_bytes > 228 * 1024:
|
| 489 |
+
raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000")
|
| 490 |
+
launch_op = gpu.LaunchOp(
|
| 491 |
+
token.type, [token], *grid_vals, *block_vals,
|
| 492 |
+
dynamicSharedMemorySize=c(smem_bytes, i32))
|
| 493 |
+
launch_op.body.blocks.append(*([index] * 12)) # Append an empty block
|
| 494 |
+
smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
|
| 495 |
+
with ir.InsertionPoint(launch_op.body.blocks[0]):
|
| 496 |
+
dynamic_smem = gpu.dynamic_shared_memory(
|
| 497 |
+
ir.MemRefType.get(
|
| 498 |
+
(ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem
|
| 499 |
+
)
|
| 500 |
+
)
|
| 501 |
+
smem_ref_trees = []
|
| 502 |
+
|
| 503 |
+
for smem_live_buffers_collection in smem_disjoint_live_buffers_collections:
|
| 504 |
+
smem_ref_tree = _construct_smem_reftree(
|
| 505 |
+
dynamic_smem, smem_live_buffers_collection)
|
| 506 |
+
smem_ref_trees.append(smem_ref_tree)
|
| 507 |
+
|
| 508 |
+
if profiler_spec:
|
| 509 |
+
prof_smem = memref.view(
|
| 510 |
+
ir.MemRefType.get(
|
| 511 |
+
(profiler_spec.smem_i32_elements(block=block),),
|
| 512 |
+
i32, memory_space=smem,
|
| 513 |
+
),
|
| 514 |
+
dynamic_smem, c(compute_smem_bytes, index), [],
|
| 515 |
+
)
|
| 516 |
+
prof = profiler.OnDeviceProfiler(
|
| 517 |
+
profiler_spec, prof_smem, maybe_prof_buffer
|
| 518 |
+
)
|
| 519 |
+
else:
|
| 520 |
+
prof = None
|
| 521 |
+
|
| 522 |
+
if isinstance(smem_buffers, Union):
|
| 523 |
+
smem_ref_tree: Union[RefTree] = Union(smem_ref_trees)
|
| 524 |
+
else:
|
| 525 |
+
smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else []
|
| 526 |
+
|
| 527 |
+
yield LaunchContext(launch_op, gmem_scratch_ptr, prof), smem_ref_tree
|
| 528 |
+
if prof is not None:
|
| 529 |
+
prof.finalize(grid=grid, block=block)
|
| 530 |
+
gpu.terminator()
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
def _lower_as_gpu_kernel(
|
| 534 |
+
body,
|
| 535 |
+
grid: tuple[int, ...],
|
| 536 |
+
block: tuple[int, ...],
|
| 537 |
+
in_shapes: tuple[Any, ...],
|
| 538 |
+
out_shape,
|
| 539 |
+
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
| 540 |
+
prof_spec: profiler.ProfilerSpec | None = None,
|
| 541 |
+
):
|
| 542 |
+
ptr_ty = ir.Type.parse("!llvm.ptr")
|
| 543 |
+
token_ty = ir.Type.parse("!gpu.async.token")
|
| 544 |
+
i8 = ir.IntegerType.get_signless(8)
|
| 545 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 546 |
+
|
| 547 |
+
def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
|
| 548 |
+
return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
|
| 549 |
+
|
| 550 |
+
in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes]
|
| 551 |
+
|
| 552 |
+
unwrap_output_tuple = False
|
| 553 |
+
if isinstance(out_shape, list):
|
| 554 |
+
out_shape = tuple(out_shape)
|
| 555 |
+
elif not isinstance(out_shape, tuple):
|
| 556 |
+
out_shape = (out_shape,)
|
| 557 |
+
unwrap_output_tuple = True
|
| 558 |
+
out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape]
|
| 559 |
+
if prof_spec is not None:
|
| 560 |
+
out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block))
|
| 561 |
+
out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block))
|
| 562 |
+
|
| 563 |
+
module = ir.Module.create()
|
| 564 |
+
with ir.InsertionPoint(module.body):
|
| 565 |
+
_declare_runtime_functions()
|
| 566 |
+
gmem_scratch_bytes = 0
|
| 567 |
+
@func.FuncOp.from_py_func(ptr_ty, ptr_ty)
|
| 568 |
+
def main(token_ptr, buffers):
|
| 569 |
+
nonlocal gmem_scratch_bytes
|
| 570 |
+
token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
|
| 571 |
+
arg_refs = []
|
| 572 |
+
i = -1
|
| 573 |
+
for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
|
| 574 |
+
ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
|
| 575 |
+
arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
|
| 576 |
+
gmem_scratch_ptr = llvm.LoadOp(
|
| 577 |
+
ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i + 1], ptr_ty)
|
| 578 |
+
)
|
| 579 |
+
in_refs = arg_refs[:len(in_ref_tys)]
|
| 580 |
+
out_refs = arg_refs[len(in_ref_tys):]
|
| 581 |
+
prof_buffer = out_refs.pop() if prof_spec is not None else None
|
| 582 |
+
with _launch(
|
| 583 |
+
token, grid, block, gmem_scratch_ptr, smem_scratch_shape,
|
| 584 |
+
prof_spec, prof_buffer
|
| 585 |
+
) as (launch_ctx, smem_refs):
|
| 586 |
+
body(launch_ctx, *in_refs, *out_refs, smem_refs)
|
| 587 |
+
gmem_scratch_bytes = launch_ctx.next_scratch_offset
|
| 588 |
+
# Allocate and initialize the host buffer right before the launch.
|
| 589 |
+
# Note that we couldn't do that before, because we had to run the body
|
| 590 |
+
# to learn what the scratch contains.
|
| 591 |
+
with ir.InsertionPoint(launch_ctx.launch_op):
|
| 592 |
+
host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
|
| 593 |
+
for init_callback in launch_ctx.host_scratch_init:
|
| 594 |
+
init_callback(host_scratch_ptr)
|
| 595 |
+
func.call(
|
| 596 |
+
[],
|
| 597 |
+
"mosaic_gpu_memcpy_async_h2d",
|
| 598 |
+
[
|
| 599 |
+
gmem_scratch_ptr,
|
| 600 |
+
host_scratch_ptr,
|
| 601 |
+
c(gmem_scratch_bytes, i64),
|
| 602 |
+
token_ptr,
|
| 603 |
+
],
|
| 604 |
+
)
|
| 605 |
+
main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
|
| 606 |
+
module.operation.verify()
|
| 607 |
+
|
| 608 |
+
return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple
|
| 609 |
+
|
| 610 |
+
|
| 611 |
+
def as_gpu_kernel(
|
| 612 |
+
body,
|
| 613 |
+
grid: tuple[int, ...],
|
| 614 |
+
block: tuple[int, ...],
|
| 615 |
+
in_shape,
|
| 616 |
+
out_shape,
|
| 617 |
+
smem_scratch_shape: ShapeTree | Union[ShapeTree],
|
| 618 |
+
prof_spec: profiler.ProfilerSpec | None = None,
|
| 619 |
+
):
|
| 620 |
+
if isinstance(in_shape, list):
|
| 621 |
+
in_shape = tuple(in_shape)
|
| 622 |
+
elif not isinstance(in_shape, tuple):
|
| 623 |
+
in_shape = (in_shape,)
|
| 624 |
+
|
| 625 |
+
module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
|
| 626 |
+
_lower_as_gpu_kernel(
|
| 627 |
+
body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec
|
| 628 |
+
)
|
| 629 |
+
)
|
| 630 |
+
|
| 631 |
+
expected_arg_treedef = jax.tree.structure(in_shape)
|
| 632 |
+
def _check_args(*args):
|
| 633 |
+
arg_treedef = jax.tree.structure(args)
|
| 634 |
+
if arg_treedef != expected_arg_treedef:
|
| 635 |
+
raise ValueError(
|
| 636 |
+
f"Invalid argument structure: expected {expected_arg_treedef}, got"
|
| 637 |
+
f" {arg_treedef}, ({args=})"
|
| 638 |
+
)
|
| 639 |
+
|
| 640 |
+
def bind(*args):
|
| 641 |
+
return mosaic_gpu_p.bind(
|
| 642 |
+
*args,
|
| 643 |
+
out_types=out_shape,
|
| 644 |
+
module=module,
|
| 645 |
+
gmem_scratch_bytes=gmem_scratch_bytes,
|
| 646 |
+
)
|
| 647 |
+
|
| 648 |
+
if prof_spec is not None:
|
| 649 |
+
@jax.jit
|
| 650 |
+
def prof_kernel(*args):
|
| 651 |
+
_check_args(*args)
|
| 652 |
+
*results, prof_buffer = bind(*args)
|
| 653 |
+
def dump_profile(prof_buffer):
|
| 654 |
+
out_file = os.path.join(
|
| 655 |
+
os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
|
| 656 |
+
f"{time.time_ns()}-trace.json",
|
| 657 |
+
)
|
| 658 |
+
try:
|
| 659 |
+
with open(out_file, "x") as f:
|
| 660 |
+
prof_spec.dump(prof_buffer, f, grid=grid, block=block)
|
| 661 |
+
except FileExistsError:
|
| 662 |
+
pass # TODO: Retry
|
| 663 |
+
jax.debug.callback(dump_profile, prof_buffer)
|
| 664 |
+
return results[0] if unwrap_output_tuple else results
|
| 665 |
+
return prof_kernel
|
| 666 |
+
else:
|
| 667 |
+
@jax.jit
|
| 668 |
+
def kernel(*args):
|
| 669 |
+
_check_args(*args)
|
| 670 |
+
results = bind(*args)
|
| 671 |
+
return results[0] if unwrap_output_tuple else results
|
| 672 |
+
return kernel
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
def _declare_runtime_functions():
|
| 676 |
+
"""Declares the runtime functions that can be used by the generated code."""
|
| 677 |
+
ptr_ty = ir.Type.parse("!llvm.ptr")
|
| 678 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 679 |
+
arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
|
| 680 |
+
init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
|
| 681 |
+
func.FuncOp(
|
| 682 |
+
"mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
|
| 683 |
+
)
|
| 684 |
+
memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
|
| 685 |
+
func.FuncOp(
|
| 686 |
+
"mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
|
| 687 |
+
)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/fragmented_array.py
ADDED
|
@@ -0,0 +1,661 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utilities for code generator."""
|
| 16 |
+
|
| 17 |
+
import dataclasses
|
| 18 |
+
|
| 19 |
+
import jax
|
| 20 |
+
from jaxlib.mlir import ir
|
| 21 |
+
from jaxlib.mlir.dialects import arith
|
| 22 |
+
from jaxlib.mlir.dialects import gpu
|
| 23 |
+
from jaxlib.mlir.dialects import llvm
|
| 24 |
+
from jaxlib.mlir.dialects import math as mlir_math
|
| 25 |
+
from jaxlib.mlir.dialects import memref
|
| 26 |
+
from jaxlib.mlir.dialects import nvvm
|
| 27 |
+
from jaxlib.mlir.dialects import vector
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
from . import dsl as mgpu
|
| 31 |
+
from . import utils
|
| 32 |
+
|
| 33 |
+
# mypy: ignore-errors
|
| 34 |
+
|
| 35 |
+
WARPGROUP_SIZE = utils.WARPGROUP_SIZE
|
| 36 |
+
c = utils.c
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
@dataclasses.dataclass(frozen=True)
|
| 40 |
+
class WGSplatFragLayout:
|
| 41 |
+
"""A fragmented array where all the values are equal represented as a register per thread.
|
| 42 |
+
|
| 43 |
+
FragmentedArrays in this layout can be are always the result of a
|
| 44 |
+
splat, each thread in the warpgroup has a single copy of the value,
|
| 45 |
+
while the FragmentedArray pretends it has whatever shape the user
|
| 46 |
+
wants. This means we can trivially broadcast, reshape and do
|
| 47 |
+
elementwise operations with all other layouts.
|
| 48 |
+
|
| 49 |
+
Example:
|
| 50 |
+
|
| 51 |
+
To load a value in
|
| 52 |
+
```
|
| 53 |
+
FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2))
|
| 54 |
+
```
|
| 55 |
+
|
| 56 |
+
A shape is always provided for sanity check reasons.
|
| 57 |
+
|
| 58 |
+
"""
|
| 59 |
+
|
| 60 |
+
shape: tuple[int, ...] = ()
|
| 61 |
+
|
| 62 |
+
def can_broadcast_to(self, shape) -> bool:
|
| 63 |
+
"""Check that the shape can be broadcast.
|
| 64 |
+
|
| 65 |
+
Only dimensions of size 1 can be broadcast. All other dimensions
|
| 66 |
+
must be the same as the argument shape.
|
| 67 |
+
"""
|
| 68 |
+
return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1]))
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@dataclasses.dataclass(frozen=True)
|
| 72 |
+
class WGMMAFragLayout:
|
| 73 |
+
"""[m, n] matrix, where m % 64 == 0 == n % 8."""
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@dataclasses.dataclass(frozen=True)
|
| 77 |
+
class WGMMARowFragLayout:
|
| 78 |
+
"""[m] matrix, where m % 64 == 0."""
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@dataclasses.dataclass(frozen=True)
|
| 82 |
+
class WGStridedFragLayout:
|
| 83 |
+
"""Convert the array to 1D and then shard across threads."""
|
| 84 |
+
|
| 85 |
+
shape: tuple[int, ...]
|
| 86 |
+
vec_size: int
|
| 87 |
+
|
| 88 |
+
def __post_init__(self):
|
| 89 |
+
if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0:
|
| 90 |
+
raise ValueError((self, WARPGROUP_SIZE))
|
| 91 |
+
|
| 92 |
+
@classmethod
|
| 93 |
+
def from_memref_type(cls, memref_ty: ir.Type):
|
| 94 |
+
if not ir.MemRefType.isinstance(memref_ty):
|
| 95 |
+
raise TypeError(memref_ty)
|
| 96 |
+
|
| 97 |
+
memref_type = ir.MemRefType(memref_ty)
|
| 98 |
+
bw = mgpu.bytewidth(memref_type.element_type)
|
| 99 |
+
assert 8 % bw == 0 and 8 // bw != 0, bw
|
| 100 |
+
if np.prod(memref_type.shape) % WARPGROUP_SIZE != 0:
|
| 101 |
+
raise ValueError(
|
| 102 |
+
"Ref must have a number of elements that is a multiple of"
|
| 103 |
+
f" {WARPGROUP_SIZE}"
|
| 104 |
+
)
|
| 105 |
+
max_vec_size = np.prod(memref_type.shape) // WARPGROUP_SIZE
|
| 106 |
+
return cls(
|
| 107 |
+
shape=tuple(memref_type.shape), vec_size=min(8 // bw, max_vec_size)
|
| 108 |
+
)
|
| 109 |
+
|
| 110 |
+
def thread_vec_idxs(self):
|
| 111 |
+
"""The indexes to be used for vector load/store WGStridedFragLayout.
|
| 112 |
+
|
| 113 |
+
Yields:
|
| 114 |
+
The indices of the vector that correspond to the current thread.
|
| 115 |
+
"""
|
| 116 |
+
index = ir.IndexType.get()
|
| 117 |
+
cardinality = np.prod(self.shape)
|
| 118 |
+
assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0
|
| 119 |
+
reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size)
|
| 120 |
+
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index))
|
| 121 |
+
off = arith.muli(tidx, c(self.vec_size, tidx.type))
|
| 122 |
+
for i in range(reg_num):
|
| 123 |
+
yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))]
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
WGMMA_LAYOUT = WGMMAFragLayout()
|
| 130 |
+
WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
@jax.tree_util.register_pytree_node_class
|
| 134 |
+
class FragmentedArray:
|
| 135 |
+
registers: np.ndarray # of ir.Value, see checks in init for shapes.
|
| 136 |
+
layout: FragmentedLayout
|
| 137 |
+
|
| 138 |
+
def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout):
|
| 139 |
+
self.registers = _registers
|
| 140 |
+
self.layout = _layout
|
| 141 |
+
|
| 142 |
+
match self.layout:
|
| 143 |
+
# Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout
|
| 144 |
+
# Each element is a vector<2xdtype>
|
| 145 |
+
case WGMMAFragLayout():
|
| 146 |
+
if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1):
|
| 147 |
+
raise ValueError("Invalid register array shape")
|
| 148 |
+
|
| 149 |
+
# Registers are [m_tiles, 2 rows] in WGMMA_ROW layout
|
| 150 |
+
# Each element is a dtype scalar
|
| 151 |
+
case WGMMARowFragLayout():
|
| 152 |
+
if self.registers.ndim != 2 or self.registers.shape[-1] != 2:
|
| 153 |
+
raise ValueError("Invalid register array shape")
|
| 154 |
+
|
| 155 |
+
# Registers are flat
|
| 156 |
+
case WGStridedFragLayout(shape):
|
| 157 |
+
(reg_size,) = ir.VectorType(_registers.flat[0].type).shape
|
| 158 |
+
if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size:
|
| 159 |
+
raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type)
|
| 160 |
+
|
| 161 |
+
# Just a single register
|
| 162 |
+
case WGSplatFragLayout():
|
| 163 |
+
if _registers.size != 1:
|
| 164 |
+
raise ValueError(f"WGStridedFragLayout requires a single value {_registers.shape} ({_registers.size})")
|
| 165 |
+
|
| 166 |
+
case _:
|
| 167 |
+
raise NotImplementedError
|
| 168 |
+
|
| 169 |
+
@classmethod
|
| 170 |
+
def load_strided(cls, ref: ir.Value):
|
| 171 |
+
if not ir.MemRefType.isinstance(ref.type):
|
| 172 |
+
raise TypeError(ref.type)
|
| 173 |
+
|
| 174 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 175 |
+
ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
|
| 176 |
+
layout = WGStridedFragLayout.from_memref_type(ref_ty)
|
| 177 |
+
vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
|
| 178 |
+
vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()]
|
| 179 |
+
return cls(_registers=np.array(vecs), _layout=layout)
|
| 180 |
+
|
| 181 |
+
@classmethod
|
| 182 |
+
def splat(cls, value, shape, layout=None):
|
| 183 |
+
layout = layout or WGSplatFragLayout(shape)
|
| 184 |
+
match layout:
|
| 185 |
+
case WGMMARowFragLayout():
|
| 186 |
+
if len(shape) != 1:
|
| 187 |
+
raise ValueError
|
| 188 |
+
if shape[0] % 64:
|
| 189 |
+
raise ValueError
|
| 190 |
+
reg_shape = (shape[0] // 64, 2)
|
| 191 |
+
case WGMMAFragLayout():
|
| 192 |
+
if len(shape) != 2:
|
| 193 |
+
raise ValueError
|
| 194 |
+
if shape[0] % 64 or shape[1] % 8:
|
| 195 |
+
raise ValueError
|
| 196 |
+
reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1)
|
| 197 |
+
value = vector.splat(ir.VectorType.get((2,), value.type), value)
|
| 198 |
+
case WGStridedFragLayout(vec_size=vec_size):
|
| 199 |
+
assert shape == layout.shape
|
| 200 |
+
elems = np.prod(shape)
|
| 201 |
+
reg_shape = (elems // (WARPGROUP_SIZE * vec_size),)
|
| 202 |
+
value = vector.splat(ir.VectorType.get((vec_size,), value.type), value)
|
| 203 |
+
case WGSplatFragLayout():
|
| 204 |
+
assert shape == layout.shape
|
| 205 |
+
reg_shape = ()
|
| 206 |
+
case _:
|
| 207 |
+
raise NotImplementedError(layout)
|
| 208 |
+
|
| 209 |
+
return cls(
|
| 210 |
+
_registers=np.full(reg_shape, value, dtype=object),
|
| 211 |
+
_layout=layout,
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
@property
|
| 215 |
+
def shape(self):
|
| 216 |
+
match self.layout:
|
| 217 |
+
case WGMMAFragLayout():
|
| 218 |
+
row_tiles, col_tiles = self.registers.shape[:2]
|
| 219 |
+
return (row_tiles * 64, col_tiles * 8)
|
| 220 |
+
case WGMMARowFragLayout():
|
| 221 |
+
row_tiles = self.registers.shape[0]
|
| 222 |
+
return (row_tiles * 64,)
|
| 223 |
+
case WGStridedFragLayout(shape):
|
| 224 |
+
return shape
|
| 225 |
+
case WGSplatFragLayout(shape=shape):
|
| 226 |
+
return shape
|
| 227 |
+
|
| 228 |
+
@property
|
| 229 |
+
def mlir_dtype(self):
|
| 230 |
+
reg_ty = self.registers.flat[0].type
|
| 231 |
+
match self.layout:
|
| 232 |
+
case WGMMAFragLayout() | WGStridedFragLayout():
|
| 233 |
+
return ir.VectorType(reg_ty).element_type
|
| 234 |
+
case WGMMARowFragLayout() | WGSplatFragLayout():
|
| 235 |
+
return reg_ty
|
| 236 |
+
|
| 237 |
+
def _pointwise(self, op, *other):
|
| 238 |
+
other_arrs = []
|
| 239 |
+
for o in other:
|
| 240 |
+
if not isinstance(o, FragmentedArray):
|
| 241 |
+
if not isinstance(o, ir.Value):
|
| 242 |
+
raise NotImplementedError(o)
|
| 243 |
+
|
| 244 |
+
o = FragmentedArray.splat(o, shape=self.shape, layout=self.layout)
|
| 245 |
+
|
| 246 |
+
if isinstance(o.layout, WGSplatFragLayout):
|
| 247 |
+
if not o.layout.can_broadcast_to(self.shape):
|
| 248 |
+
raise ValueError("Can't broadcast shape.")
|
| 249 |
+
o = FragmentedArray.splat(o.registers.flat[0], shape=self.shape, layout=self.layout)
|
| 250 |
+
else:
|
| 251 |
+
if self.layout != o.layout:
|
| 252 |
+
raise ValueError("Incompatible FragmentedArray layouts")
|
| 253 |
+
if self.registers.shape != o.registers.shape:
|
| 254 |
+
raise ValueError("Incompatible FragmentedArray shapes")
|
| 255 |
+
|
| 256 |
+
other_arrs.append(o)
|
| 257 |
+
new_regs = np.empty_like(self.registers)
|
| 258 |
+
|
| 259 |
+
for idx, reg in np.ndenumerate(self.registers):
|
| 260 |
+
new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs))
|
| 261 |
+
return FragmentedArray(_registers=new_regs, _layout=self.layout)
|
| 262 |
+
|
| 263 |
+
def __add__(self, other):
|
| 264 |
+
if ir.FloatType.isinstance(self.mlir_dtype):
|
| 265 |
+
return self._pointwise(arith.addf, other)
|
| 266 |
+
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
| 267 |
+
return self._pointwise(arith.addi, other)
|
| 268 |
+
else:
|
| 269 |
+
raise NotImplementedError(self.mlir_dtype)
|
| 270 |
+
|
| 271 |
+
def __mul__(self, other):
|
| 272 |
+
if ir.FloatType.isinstance(self.mlir_dtype):
|
| 273 |
+
return self._pointwise(arith.mulf, other)
|
| 274 |
+
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
| 275 |
+
return self._pointwise(arith.muli, other)
|
| 276 |
+
else:
|
| 277 |
+
raise NotImplementedError(self.mlir_dtype)
|
| 278 |
+
|
| 279 |
+
def __sub__(self, other):
|
| 280 |
+
if not ir.FloatType.isinstance(self.mlir_dtype):
|
| 281 |
+
raise NotImplementedError
|
| 282 |
+
return self._pointwise(arith.subf, other)
|
| 283 |
+
|
| 284 |
+
def __truediv__(self, other):
|
| 285 |
+
if not ir.FloatType.isinstance(self.mlir_dtype):
|
| 286 |
+
raise NotImplementedError
|
| 287 |
+
return self._pointwise(arith.divf, other)
|
| 288 |
+
|
| 289 |
+
def max(self, other):
|
| 290 |
+
if not ir.FloatType.isinstance(self.mlir_dtype):
|
| 291 |
+
raise NotImplementedError
|
| 292 |
+
return self._pointwise(arith.maximumf, other)
|
| 293 |
+
|
| 294 |
+
def exp(self, approx: bool = False):
|
| 295 |
+
if not ir.FloatType.isinstance(self.mlir_dtype):
|
| 296 |
+
raise NotImplementedError
|
| 297 |
+
def fast_exp(x):
|
| 298 |
+
f32 = ir.F32Type.get()
|
| 299 |
+
if self.mlir_dtype != f32:
|
| 300 |
+
raise NotImplementedError
|
| 301 |
+
log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
|
| 302 |
+
if x.type == f32:
|
| 303 |
+
scaled = arith.mulf(x, log2e)
|
| 304 |
+
return llvm.inline_asm(
|
| 305 |
+
f32, [scaled], "ex2.approx.f32 $0,$1;", "=f,f", asm_dialect=0
|
| 306 |
+
)
|
| 307 |
+
elif ir.VectorType.isinstance(x.type):
|
| 308 |
+
index = ir.IndexType.get()
|
| 309 |
+
result = llvm.mlir_undef(x.type)
|
| 310 |
+
for i in range(2):
|
| 311 |
+
v = vector.extractelement(x, position=c(i, index))
|
| 312 |
+
vr = fast_exp(v)
|
| 313 |
+
result = vector.insertelement(vr, result, position=c(i, index))
|
| 314 |
+
return result
|
| 315 |
+
else:
|
| 316 |
+
raise NotImplementedError(x.type)
|
| 317 |
+
return self._pointwise(fast_exp if approx else mlir_math.exp)
|
| 318 |
+
|
| 319 |
+
def rsqrt(self):
|
| 320 |
+
return self._pointwise(mlir_math.rsqrt)
|
| 321 |
+
|
| 322 |
+
def __and__(self, other):
|
| 323 |
+
if not ir.IntegerType.isinstance(self.mlir_dtype):
|
| 324 |
+
raise ValueError(
|
| 325 |
+
"Bitwise operations only defined for integer types, not"
|
| 326 |
+
f" {self.mlir_dtype}"
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
return self._pointwise(arith.andi, other)
|
| 330 |
+
|
| 331 |
+
def bitcast(self, elt: ir.Type):
|
| 332 |
+
reg_type = self.registers.flat[0].type
|
| 333 |
+
if ir.VectorType.isinstance(reg_type):
|
| 334 |
+
reg_shape = ir.VectorType(reg_type).shape
|
| 335 |
+
ty = ir.VectorType.get(reg_shape, elt)
|
| 336 |
+
else:
|
| 337 |
+
ty = elt
|
| 338 |
+
|
| 339 |
+
return self._pointwise(lambda x: arith.bitcast(ty, x))
|
| 340 |
+
|
| 341 |
+
def __getitem__(self, idx):
|
| 342 |
+
if self.layout != WGMMA_LAYOUT:
|
| 343 |
+
raise NotImplementedError("Only WGMMA layouts support slicing")
|
| 344 |
+
base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
|
| 345 |
+
if any(is_squeezed):
|
| 346 |
+
raise NotImplementedError("Only slicing implemented")
|
| 347 |
+
if (
|
| 348 |
+
base_idx[0] % 64
|
| 349 |
+
or slice_shape[0] % 64
|
| 350 |
+
or base_idx[1] % 8
|
| 351 |
+
or slice_shape[1] % 8
|
| 352 |
+
):
|
| 353 |
+
raise NotImplementedError("Only tile aligned slicing supported")
|
| 354 |
+
base_idx[0] //= 64
|
| 355 |
+
slice_shape[0] //= 64
|
| 356 |
+
base_idx[1] //= 8
|
| 357 |
+
slice_shape[1] //= 8
|
| 358 |
+
new_regs = self.registers[
|
| 359 |
+
base_idx[0] : base_idx[0] + slice_shape[0],
|
| 360 |
+
base_idx[1] : base_idx[1] + slice_shape[1],
|
| 361 |
+
]
|
| 362 |
+
return FragmentedArray(_registers=new_regs, _layout=self.layout)
|
| 363 |
+
|
| 364 |
+
# TODO(apaszke): Support JAX dtypes here as well?
|
| 365 |
+
def astype(self, new_dtype: ir.Type):
|
| 366 |
+
cur_dtype = self.mlir_dtype
|
| 367 |
+
if cur_dtype == new_dtype:
|
| 368 |
+
return self
|
| 369 |
+
from_float = ir.FloatType.isinstance(cur_dtype)
|
| 370 |
+
to_float = ir.FloatType.isinstance(new_dtype)
|
| 371 |
+
from_integer = ir.IntegerType.isinstance(cur_dtype)
|
| 372 |
+
to_integer = ir.IntegerType.isinstance(new_dtype)
|
| 373 |
+
if from_float and to_float:
|
| 374 |
+
if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
|
| 375 |
+
convert = arith.truncf
|
| 376 |
+
else:
|
| 377 |
+
convert = arith.extf
|
| 378 |
+
elif from_integer and to_integer:
|
| 379 |
+
if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
|
| 380 |
+
convert = arith.trunci
|
| 381 |
+
else:
|
| 382 |
+
convert = arith.extsi
|
| 383 |
+
elif from_integer and to_float:
|
| 384 |
+
convert = arith.sitofp
|
| 385 |
+
elif from_float and to_integer:
|
| 386 |
+
convert = arith.fptosi
|
| 387 |
+
new_registers = np.empty_like(self.registers)
|
| 388 |
+
match self.layout:
|
| 389 |
+
case WGMMAFragLayout():
|
| 390 |
+
new_reg_ty = ir.VectorType.get((2,), new_dtype)
|
| 391 |
+
case WGStridedFragLayout(vec_size=vec_size):
|
| 392 |
+
new_reg_ty = ir.VectorType.get((vec_size,), new_dtype)
|
| 393 |
+
case WGMMARowFragLayout() | WGSplatFragLayout():
|
| 394 |
+
new_reg_ty = new_dtype
|
| 395 |
+
case _:
|
| 396 |
+
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
| 397 |
+
for idx, reg in np.ndenumerate(self.registers):
|
| 398 |
+
new_registers[idx] = convert(new_reg_ty, reg)
|
| 399 |
+
return FragmentedArray(_registers=new_registers, _layout=self.layout)
|
| 400 |
+
|
| 401 |
+
def reduce_sum(self, scratch) -> ir.Value:
|
| 402 |
+
index = ir.IndexType.get()
|
| 403 |
+
if not isinstance(self.layout, WGStridedFragLayout):
|
| 404 |
+
raise NotImplementedError(f"Unsupported layout {self.layout}")
|
| 405 |
+
result = c(0, self.mlir_dtype)
|
| 406 |
+
for reg in self.registers:
|
| 407 |
+
result = arith.addf(
|
| 408 |
+
result,
|
| 409 |
+
vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
|
| 410 |
+
)
|
| 411 |
+
scratch_ty = ir.MemRefType(scratch.type)
|
| 412 |
+
if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
|
| 413 |
+
raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
|
| 414 |
+
|
| 415 |
+
if ir.FloatType.isinstance(self.mlir_dtype):
|
| 416 |
+
op = arith.addf
|
| 417 |
+
elif ir.IntegerType.isinstance(self.mlir_dtype):
|
| 418 |
+
op = arith.addi
|
| 419 |
+
else:
|
| 420 |
+
raise NotImplementedError(self.mlir_dtype)
|
| 421 |
+
|
| 422 |
+
warp_result = utils.warp_tree_reduce(result, op, 32)
|
| 423 |
+
warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
|
| 424 |
+
memref.store(warp_result, scratch, [warp_id])
|
| 425 |
+
utils.commit_shared()
|
| 426 |
+
zero_index = c(0, index)
|
| 427 |
+
with mgpu.single_thread():
|
| 428 |
+
scratch_vec = vector.load(
|
| 429 |
+
ir.VectorType.get((4,), self.mlir_dtype),
|
| 430 |
+
scratch,
|
| 431 |
+
[zero_index],
|
| 432 |
+
)
|
| 433 |
+
scratch_sum = vector.reduction(
|
| 434 |
+
self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec
|
| 435 |
+
)
|
| 436 |
+
memref.store(scratch_sum, scratch, [zero_index])
|
| 437 |
+
utils.commit_shared()
|
| 438 |
+
return memref.load(scratch, [zero_index])
|
| 439 |
+
|
| 440 |
+
def reduce(self, op, axis):
|
| 441 |
+
if self.layout != WGMMA_LAYOUT:
|
| 442 |
+
raise NotImplementedError(self.layout)
|
| 443 |
+
if axis != 1:
|
| 444 |
+
raise NotImplementedError
|
| 445 |
+
index = ir.IndexType.get()
|
| 446 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 447 |
+
new_regs = np.empty(self.registers.shape[::2], dtype=object)
|
| 448 |
+
assert self.registers.shape[-1] == 1
|
| 449 |
+
for row_tile, row_subtile in np.ndindex(new_regs.shape):
|
| 450 |
+
# Reduce the registers owned by the current thread over n tiles
|
| 451 |
+
thread_result_vec = self.registers[row_tile, 0, row_subtile, 0]
|
| 452 |
+
for n_tile in range(1, self.registers.shape[1]):
|
| 453 |
+
thread_result_vec = op(
|
| 454 |
+
thread_result_vec, self.registers[row_tile, n_tile, row_subtile, 0]
|
| 455 |
+
)
|
| 456 |
+
thread_result = op(
|
| 457 |
+
vector.extractelement(thread_result_vec, position=c(0, index)),
|
| 458 |
+
vector.extractelement(thread_result_vec, position=c(1, index)),
|
| 459 |
+
)
|
| 460 |
+
# Do a shuffle to reduce in groups of 4 consecutive threads.
|
| 461 |
+
result = thread_result
|
| 462 |
+
for i in (1, 2):
|
| 463 |
+
other_result = nvvm.shfl_sync(
|
| 464 |
+
result.type,
|
| 465 |
+
c(0xFFFFFFFF, i32),
|
| 466 |
+
result,
|
| 467 |
+
c(i, i32),
|
| 468 |
+
c(0x1F, i32),
|
| 469 |
+
nvvm.ShflKind.bfly,
|
| 470 |
+
)
|
| 471 |
+
result = op(result, other_result)
|
| 472 |
+
new_regs[row_tile, row_subtile] = result
|
| 473 |
+
return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT)
|
| 474 |
+
|
| 475 |
+
def broadcast(self, shape):
|
| 476 |
+
if not isinstance(self.layout, WGSplatFragLayout):
|
| 477 |
+
raise NotImplementedError(self.layout)
|
| 478 |
+
|
| 479 |
+
if self.shape == shape:
|
| 480 |
+
return self
|
| 481 |
+
|
| 482 |
+
if not self.layout.can_broadcast_to(shape):
|
| 483 |
+
raise ValueError(f"Can't broadcast {self.shape} to {shape}")
|
| 484 |
+
|
| 485 |
+
return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape))
|
| 486 |
+
|
| 487 |
+
def reshape(self, shape):
|
| 488 |
+
if self.shape == shape:
|
| 489 |
+
return self
|
| 490 |
+
|
| 491 |
+
if not isinstance(self.layout, WGSplatFragLayout):
|
| 492 |
+
raise NotImplementedError(self.layout)
|
| 493 |
+
|
| 494 |
+
if np.prod(shape) != np.prod(self.shape):
|
| 495 |
+
raise ValueError(f"Can't reshape {self.shape} to {shape}")
|
| 496 |
+
|
| 497 |
+
return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape))
|
| 498 |
+
|
| 499 |
+
def broadcast_minor(self, n):
|
| 500 |
+
if self.layout != WGMMA_ROW_LAYOUT:
|
| 501 |
+
raise NotImplementedError
|
| 502 |
+
num_row_tiles = self.registers.shape[0]
|
| 503 |
+
num_col_tiles, rem = divmod(n, 8)
|
| 504 |
+
if rem:
|
| 505 |
+
raise ValueError("Number of columns must be divisible by 8")
|
| 506 |
+
new_regs = np.empty((num_row_tiles, num_col_tiles, 2, 1), dtype=object)
|
| 507 |
+
dtype = self.mlir_dtype
|
| 508 |
+
for (row_tile, row_subtile), reg in np.ndenumerate(self.registers):
|
| 509 |
+
new_regs[row_tile, :, row_subtile, :] = vector.splat(
|
| 510 |
+
ir.VectorType.get((2,), dtype), reg
|
| 511 |
+
)
|
| 512 |
+
return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT)
|
| 513 |
+
|
| 514 |
+
def store_untiled(self, ref: ir.Value):
|
| 515 |
+
if not ir.MemRefType.isinstance(ref.type):
|
| 516 |
+
raise ValueError(ref)
|
| 517 |
+
|
| 518 |
+
match self.layout:
|
| 519 |
+
case WGMMAFragLayout():
|
| 520 |
+
self._store_untiled_wgmma(ref)
|
| 521 |
+
case WGStridedFragLayout():
|
| 522 |
+
self._store_untiled_wg_strided(ref)
|
| 523 |
+
case _:
|
| 524 |
+
raise NotImplementedError(self.layout)
|
| 525 |
+
|
| 526 |
+
def _store_untiled_wg_strided(self, ref: ir.Value):
|
| 527 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 528 |
+
ref_shape = tuple(ref_ty.shape)
|
| 529 |
+
if ref_shape != self.shape:
|
| 530 |
+
raise ValueError((ref_shape, self.shape))
|
| 531 |
+
smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
|
| 532 |
+
for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat):
|
| 533 |
+
vector.store(reg, smem_1d, idx)
|
| 534 |
+
|
| 535 |
+
def _store_untiled_wgmma(self, ref: ir.Value):
|
| 536 |
+
"""Stores accumulator to a 2D memref. Not optimized at the moment."""
|
| 537 |
+
assert self.layout == WGMMA_LAYOUT
|
| 538 |
+
index = ir.IndexType.get()
|
| 539 |
+
m, n = self.shape
|
| 540 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 541 |
+
if ref_ty.shape != [m, n]:
|
| 542 |
+
raise ValueError(ref.type, (m, n))
|
| 543 |
+
|
| 544 |
+
def c(x):
|
| 545 |
+
return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
|
| 546 |
+
|
| 547 |
+
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
|
| 548 |
+
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
|
| 549 |
+
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
|
| 550 |
+
row_base = arith.addi(
|
| 551 |
+
arith.divui(lane_id, c(4)), arith.muli(warp_id, c(16))
|
| 552 |
+
)
|
| 553 |
+
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
|
| 554 |
+
it = np.ndenumerate(self.registers)
|
| 555 |
+
for (row_tile, col_tile, row_idx, col_zero), elem in it:
|
| 556 |
+
del col_zero
|
| 557 |
+
row = arith.addi(row_base, c(row_tile * 64 + row_idx * 8))
|
| 558 |
+
for col_idx in range(2):
|
| 559 |
+
value = vector.extractelement(elem, position=c(col_idx))
|
| 560 |
+
col = arith.addi(col_base, c(col_tile * 8 + col_idx))
|
| 561 |
+
memref.store(value, ref, [row, col])
|
| 562 |
+
|
| 563 |
+
def store_tiled(self, ref, swizzle: int | None):
|
| 564 |
+
if self.layout != WGMMA_LAYOUT:
|
| 565 |
+
raise NotImplementedError
|
| 566 |
+
dtype = self.mlir_dtype
|
| 567 |
+
bw = mgpu.bytewidth(dtype)
|
| 568 |
+
m, n = self.shape
|
| 569 |
+
assert m % 64 == 0 # This is implied by the layout.
|
| 570 |
+
cols_per_tile = 128 // bw
|
| 571 |
+
expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
|
| 572 |
+
if ir.MemRefType(ref.type).shape != expected_shape:
|
| 573 |
+
raise ValueError(ref.type, (m, n))
|
| 574 |
+
for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle):
|
| 575 |
+
vector.store(get(self.registers), ref, idxs)
|
| 576 |
+
|
| 577 |
+
@classmethod
|
| 578 |
+
def load_tiled(cls, ref, swizzle: int | None):
|
| 579 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 580 |
+
dtype = ref_ty.element_type
|
| 581 |
+
bw = mgpu.bytewidth(dtype)
|
| 582 |
+
m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape
|
| 583 |
+
if m_tile_size != 64 or n_tile_size != (128 // bw):
|
| 584 |
+
raise ValueError
|
| 585 |
+
m, n = m_tiles * m_tile_size, n_tiles * n_tile_size
|
| 586 |
+
assert m % 64 == 0 # This is implied by the layout.
|
| 587 |
+
registers = np.full(
|
| 588 |
+
(m_tiles, n // 8, 2, 1),
|
| 589 |
+
vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)),
|
| 590 |
+
dtype=object,
|
| 591 |
+
)
|
| 592 |
+
for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle):
|
| 593 |
+
update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs))
|
| 594 |
+
return cls(_registers=registers, _layout=WGMMA_LAYOUT)
|
| 595 |
+
|
| 596 |
+
@staticmethod
|
| 597 |
+
def transfer_tiled(shape, dtype, swizzle: int | None):
|
| 598 |
+
bw = mgpu.bytewidth(dtype)
|
| 599 |
+
m, n = shape
|
| 600 |
+
if n % 32 != 0:
|
| 601 |
+
raise NotImplementedError
|
| 602 |
+
cols_per_tile = 128 // bw
|
| 603 |
+
if swizzle != 128:
|
| 604 |
+
raise NotImplementedError("Only 128B swizzle supported")
|
| 605 |
+
|
| 606 |
+
c = arith.ConstantOp.create_index
|
| 607 |
+
tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
|
| 608 |
+
lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
|
| 609 |
+
warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
|
| 610 |
+
sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7}
|
| 611 |
+
if bw > 2: # Stagger is only necessary for values larger than 16bit.
|
| 612 |
+
is_even_row = arith.cmpi(
|
| 613 |
+
arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0)
|
| 614 |
+
)
|
| 615 |
+
else:
|
| 616 |
+
# We rely on canonicalization to clean up the selects.
|
| 617 |
+
i1 = ir.IntegerType.get_signless(1)
|
| 618 |
+
is_even_row = arith.constant(i1, ir.BoolAttr.get(True))
|
| 619 |
+
row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
|
| 620 |
+
col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
|
| 621 |
+
# The swizzle pattern is constant for a given thread.
|
| 622 |
+
col_swizzle_bits = arith.muli(sub_row_base, c(16 // bw))
|
| 623 |
+
for row_group in range(m // 64):
|
| 624 |
+
for col_group in range(n // cols_per_tile):
|
| 625 |
+
for row_subidx in range(2):
|
| 626 |
+
row = arith.addi(row_base, c(row_subidx * 8))
|
| 627 |
+
for col_subidx in range(cols_per_tile // 8):
|
| 628 |
+
# We stagger the even and odd rows a little to avoid bank conflicts.
|
| 629 |
+
# It seems that the STS.64 is 2x faster (and the hardware reports no
|
| 630 |
+
# conflicts) when the conflicts are split between half-warps, as
|
| 631 |
+
# opposed to having them within the half-warp. This requires a
|
| 632 |
+
# little more work for the selects, but is ultimately worth it.
|
| 633 |
+
col_subidx_even = col_subidx
|
| 634 |
+
col_subidx_odd = col_subidx ^ 2
|
| 635 |
+
col_off = arith.select(
|
| 636 |
+
is_even_row, c(col_subidx_even * 8), c(col_subidx_odd * 8)
|
| 637 |
+
)
|
| 638 |
+
col = arith.addi(col_base, col_off)
|
| 639 |
+
col = arith.xori(col, col_swizzle_bits)
|
| 640 |
+
reg_idx_even = col_subidx_even + col_group * (cols_per_tile // 8)
|
| 641 |
+
reg_idx_odd = col_subidx_odd + col_group * (cols_per_tile // 8)
|
| 642 |
+
even_idx = row_group, reg_idx_even, row_subidx, 0
|
| 643 |
+
odd_idx = row_group, reg_idx_odd, row_subidx, 0
|
| 644 |
+
idx = c(row_group), c(col_group), row, col
|
| 645 |
+
def get_register(regs, even_idx=even_idx, odd_idx=odd_idx):
|
| 646 |
+
value_even = regs[even_idx]
|
| 647 |
+
value_odd = regs[odd_idx]
|
| 648 |
+
return arith.select(is_even_row, value_even, value_odd)
|
| 649 |
+
def update_registers(regs, new, even_idx=even_idx, odd_idx=odd_idx):
|
| 650 |
+
regs[even_idx] = arith.select(is_even_row, new, regs[even_idx])
|
| 651 |
+
regs[odd_idx] = arith.select(is_even_row, regs[odd_idx], new)
|
| 652 |
+
yield get_register, update_registers, idx
|
| 653 |
+
|
| 654 |
+
def tree_flatten(self):
|
| 655 |
+
return list(self.registers.flat), (self.layout, self.registers.shape)
|
| 656 |
+
|
| 657 |
+
@classmethod
|
| 658 |
+
def tree_unflatten(cls, aux, flat_registers):
|
| 659 |
+
layout, reg_shape = aux
|
| 660 |
+
registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape)
|
| 661 |
+
return cls(_registers=registers, _layout=layout)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/profiler.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
import contextlib
|
| 17 |
+
import ctypes
|
| 18 |
+
import functools
|
| 19 |
+
import json
|
| 20 |
+
import math
|
| 21 |
+
|
| 22 |
+
import jax
|
| 23 |
+
from jax._src.interpreters import mlir
|
| 24 |
+
from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
|
| 25 |
+
from jax._src.lib import xla_client
|
| 26 |
+
import jax.numpy as jnp
|
| 27 |
+
from jaxlib.mlir import ir
|
| 28 |
+
from jaxlib.mlir.dialects import arith
|
| 29 |
+
from jaxlib.mlir.dialects import gpu
|
| 30 |
+
from jaxlib.mlir.dialects import memref
|
| 31 |
+
from jaxlib.mlir.dialects import scf
|
| 32 |
+
import numpy as np
|
| 33 |
+
|
| 34 |
+
from .utils import * # noqa: F403
|
| 35 |
+
|
| 36 |
+
# ruff: noqa: F405
|
| 37 |
+
# mypy: ignore-errors
|
| 38 |
+
|
| 39 |
+
xla_client.register_custom_call_target(
|
| 40 |
+
"mosaic_gpu_record_event",
|
| 41 |
+
mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(),
|
| 42 |
+
platform="CUDA",
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
record_event_p = jax.core.Primitive("record_event")
|
| 46 |
+
record_event_p.multiple_results = True
|
| 47 |
+
|
| 48 |
+
@record_event_p.def_abstract_eval
|
| 49 |
+
def _record_event_abstract_eval(*args, event):
|
| 50 |
+
del event # Unused.
|
| 51 |
+
return args
|
| 52 |
+
|
| 53 |
+
@functools.partial(mlir.register_lowering, record_event_p, platform="cuda")
|
| 54 |
+
def _record_event_lowering_rule(ctx, *args, event):
|
| 55 |
+
ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes(
|
| 56 |
+
8, byteorder="little"
|
| 57 |
+
) # pytype: disable=attribute-error
|
| 58 |
+
op = mlir.custom_call(
|
| 59 |
+
"mosaic_gpu_record_event",
|
| 60 |
+
result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
|
| 61 |
+
operands=args,
|
| 62 |
+
backend_config=ptr_bytes,
|
| 63 |
+
operand_output_aliases={i: i for i in range(len(args))},
|
| 64 |
+
)
|
| 65 |
+
return op.results
|
| 66 |
+
|
| 67 |
+
def _record_event(args, event):
|
| 68 |
+
flat_args, treedef = jax.tree.flatten(args)
|
| 69 |
+
return jax.tree.unflatten(
|
| 70 |
+
treedef, record_event_p.bind(*flat_args, event=event)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
def measure(f, *args):
|
| 74 |
+
# TODO(apaszke): Raise if this is called under jit.
|
| 75 |
+
start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
|
| 76 |
+
end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
|
| 77 |
+
try:
|
| 78 |
+
@jax.jit
|
| 79 |
+
def run(*args):
|
| 80 |
+
return _record_event(f(*_record_event(args, start_event)), end_event)
|
| 81 |
+
jax.block_until_ready(run(*args)) # Warmup.
|
| 82 |
+
results = jax.block_until_ready(run(*args))
|
| 83 |
+
elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed(
|
| 84 |
+
start_event, end_event
|
| 85 |
+
)
|
| 86 |
+
finally:
|
| 87 |
+
mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event)
|
| 88 |
+
mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event)
|
| 89 |
+
return results, elapsed
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
class ProfilerSpec:
|
| 93 |
+
ENTER = 0
|
| 94 |
+
EXIT = 1 << 31
|
| 95 |
+
|
| 96 |
+
def __init__(self, entries_per_warpgroup: int):
|
| 97 |
+
self.entries_per_warpgroup = entries_per_warpgroup
|
| 98 |
+
self.interned_names = {}
|
| 99 |
+
|
| 100 |
+
def _num_warpgroups(
|
| 101 |
+
self, grid: tuple[int, ...], block: tuple[int, ...]
|
| 102 |
+
) -> int:
|
| 103 |
+
if math.prod(block) % WARPGROUP_SIZE:
|
| 104 |
+
raise ValueError("Block size is not a multiple of warpgroup size")
|
| 105 |
+
return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE
|
| 106 |
+
|
| 107 |
+
def mlir_buffer_type(
|
| 108 |
+
self, grid: tuple[int, ...], block: tuple[int, ...]
|
| 109 |
+
) -> ir.Type:
|
| 110 |
+
return ir.MemRefType.get(
|
| 111 |
+
(self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
|
| 112 |
+
ir.IntegerType.get_signless(32),
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
def jax_buffer_type(
|
| 116 |
+
self, grid: tuple[int, ...], block: tuple[int, ...]
|
| 117 |
+
) -> ir.Type:
|
| 118 |
+
return jax.ShapeDtypeStruct(
|
| 119 |
+
(self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
|
| 120 |
+
jnp.uint32,
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
def smem_i32_elements(self, block: tuple[int, ...]):
|
| 124 |
+
num_warpgroups = self._num_warpgroups((), block)
|
| 125 |
+
return int(num_warpgroups * self.entries_per_warpgroup)
|
| 126 |
+
|
| 127 |
+
def smem_bytes(self, block: tuple[int, ...]):
|
| 128 |
+
bytes_per_entry = 4
|
| 129 |
+
return self.smem_i32_elements(block) * bytes_per_entry
|
| 130 |
+
|
| 131 |
+
def intern_name(self, name: str) -> int:
|
| 132 |
+
if name_id := self.interned_names.get(name, None):
|
| 133 |
+
return name_id
|
| 134 |
+
name_id = self.interned_names[name] = len(self.interned_names)
|
| 135 |
+
if name_id & self.EXIT:
|
| 136 |
+
raise RuntimeError("Allocated too many names")
|
| 137 |
+
return name_id
|
| 138 |
+
|
| 139 |
+
def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]):
|
| 140 |
+
buffer = np.asarray(buffer)
|
| 141 |
+
num_blocks = math.prod(grid)
|
| 142 |
+
warpgroups_per_block = self._num_warpgroups((), block)
|
| 143 |
+
entries = buffer.reshape(
|
| 144 |
+
num_blocks, warpgroups_per_block, self.entries_per_warpgroup
|
| 145 |
+
)
|
| 146 |
+
start_times = entries[..., :2].astype(np.int64)
|
| 147 |
+
start_times = (start_times[..., 0] << 32) + start_times[..., 1]
|
| 148 |
+
start_times -= start_times.min() # Normalize
|
| 149 |
+
entries_used = entries[..., 2]
|
| 150 |
+
if np.any(entries_used > self.entries_per_warpgroup - 2):
|
| 151 |
+
raise RuntimeError("Insufficient space to capture a full trace")
|
| 152 |
+
traces = entries[..., 3:]
|
| 153 |
+
unintern = {v: k for k, v in self.interned_names.items()}
|
| 154 |
+
events = []
|
| 155 |
+
for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block):
|
| 156 |
+
valid_entries = entries_used[block_idx, wg_idx] - 3
|
| 157 |
+
local_clock_offset = None
|
| 158 |
+
assert valid_entries % 2 == 0, valid_entries
|
| 159 |
+
start_time = start_times[block_idx, wg_idx]
|
| 160 |
+
block_events = []
|
| 161 |
+
for i in range(0, valid_entries, 2):
|
| 162 |
+
tag = traces[block_idx, wg_idx, i]
|
| 163 |
+
time = traces[block_idx, wg_idx, i + 1]
|
| 164 |
+
if local_clock_offset is None:
|
| 165 |
+
local_clock_offset = time
|
| 166 |
+
time -= local_clock_offset
|
| 167 |
+
time -= i * 6 # Account for the overhead of profiling.
|
| 168 |
+
if time < 0:
|
| 169 |
+
break # Detect a timer wraparound
|
| 170 |
+
name_id = tag
|
| 171 |
+
begin = True
|
| 172 |
+
if name_id & ProfilerSpec.EXIT:
|
| 173 |
+
name_id = name_id ^ ProfilerSpec.EXIT
|
| 174 |
+
begin = False
|
| 175 |
+
name = unintern[name_id]
|
| 176 |
+
block_events.append({
|
| 177 |
+
"name": name,
|
| 178 |
+
"ph": "B" if begin else "E",
|
| 179 |
+
"ts": float(start_time + time) / 1e3,
|
| 180 |
+
"pid": 1 + block_idx,
|
| 181 |
+
"tid": 1 + wg_idx,
|
| 182 |
+
})
|
| 183 |
+
else: # If we didn't break
|
| 184 |
+
events.extend(block_events)
|
| 185 |
+
return json.dump({"displayTimeUnit": "ns", "traceEvents": events}, f)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class OnDeviceProfiler:
|
| 189 |
+
|
| 190 |
+
def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value):
|
| 191 |
+
self.spec = spec
|
| 192 |
+
# self.should_store = gpu.thread_id(gpu.Dimension.x)
|
| 193 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 194 |
+
index = ir.IndexType.get()
|
| 195 |
+
self.entries_per_wg = spec.entries_per_warpgroup
|
| 196 |
+
wg_idx = warpgroup_idx(sync=False)
|
| 197 |
+
self.smem_buffer = memref_slice(
|
| 198 |
+
smem_buffer,
|
| 199 |
+
ds(
|
| 200 |
+
arith.index_cast(
|
| 201 |
+
index, arith.muli(wg_idx, c(self.entries_per_wg, i32))
|
| 202 |
+
),
|
| 203 |
+
self.entries_per_wg,
|
| 204 |
+
),
|
| 205 |
+
)
|
| 206 |
+
self.gmem_buffer = gmem_buffer
|
| 207 |
+
# Hopefully mem2reg will remove the allocation.
|
| 208 |
+
self.offset = memref.alloca(ir.MemRefType.get((), i32), [], [])
|
| 209 |
+
memref.store(c(0, i32), self.offset, [])
|
| 210 |
+
|
| 211 |
+
@contextlib.contextmanager
|
| 212 |
+
def record(self, name: str):
|
| 213 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 214 |
+
index = ir.IndexType.get()
|
| 215 |
+
name_id = self.spec.intern_name(name)
|
| 216 |
+
def store(modifier):
|
| 217 |
+
cur = arith.index_cast(index, memref.load(self.offset, []))
|
| 218 |
+
# TODO(apaszke): Clamp indices
|
| 219 |
+
# bound = arith.subi(self.entries_per_block, c(2, index))
|
| 220 |
+
# cur = arith.select(
|
| 221 |
+
# arith.cmpi(arith.CmpIPredicate.ult, cur, bound), cur, bound
|
| 222 |
+
# )
|
| 223 |
+
memref.store(c(modifier | name_id, i32), self.smem_buffer, [cur])
|
| 224 |
+
memref.store(
|
| 225 |
+
clock(), self.smem_buffer, [arith.addi(cur, c(1, cur.type))]
|
| 226 |
+
)
|
| 227 |
+
memref.store(
|
| 228 |
+
arith.index_cast(i32, arith.addi(cur, c(2, cur.type))),
|
| 229 |
+
self.offset,
|
| 230 |
+
[],
|
| 231 |
+
)
|
| 232 |
+
store(ProfilerSpec.ENTER)
|
| 233 |
+
yield
|
| 234 |
+
store(ProfilerSpec.EXIT)
|
| 235 |
+
|
| 236 |
+
def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]):
|
| 237 |
+
index = ir.IndexType.get()
|
| 238 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 239 |
+
|
| 240 |
+
gpu.barrier() # Make sure all warpgroups are done.
|
| 241 |
+
|
| 242 |
+
block_idx = c(0, index)
|
| 243 |
+
for dim in gpu.Dimension: # pytype: disable=wrong-arg-types
|
| 244 |
+
block_idx = arith.addi(
|
| 245 |
+
arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim)
|
| 246 |
+
)
|
| 247 |
+
wg_idx = warpgroup_idx(sync=False)
|
| 248 |
+
wg_per_block = math.prod(block) // WARPGROUP_SIZE
|
| 249 |
+
global_wg_idx = arith.addi(
|
| 250 |
+
arith.muli(block_idx, c(wg_per_block, index)),
|
| 251 |
+
arith.index_cast(index, wg_idx),
|
| 252 |
+
)
|
| 253 |
+
start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index))
|
| 254 |
+
wg_gmem_buffer = memref.subview(
|
| 255 |
+
self.gmem_buffer, [start_offset], [self.entries_per_wg], [1],
|
| 256 |
+
result_type=ir.Type.parse(
|
| 257 |
+
f"memref<{self.entries_per_wg}xi32, strided<[1], offset: ?>>"
|
| 258 |
+
),
|
| 259 |
+
)
|
| 260 |
+
thread_in_wg = arith.remui(thread_idx(), c(128, i32))
|
| 261 |
+
if_first = scf.IfOp(
|
| 262 |
+
arith.cmpi(arith.CmpIPredicate.eq, thread_in_wg, c(0, i32))
|
| 263 |
+
)
|
| 264 |
+
with ir.InsertionPoint(if_first.then_block):
|
| 265 |
+
# TODO(apaszke): Either use globaltimer or delete
|
| 266 |
+
# memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)])
|
| 267 |
+
# memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)])
|
| 268 |
+
memref.store(c(0, i32), wg_gmem_buffer, [c(0, index)])
|
| 269 |
+
memref.store(c(0, i32), wg_gmem_buffer, [c(1, index)])
|
| 270 |
+
memref.store(
|
| 271 |
+
arith.addi(memref.load(self.offset, []), c(3, i32)),
|
| 272 |
+
wg_gmem_buffer,
|
| 273 |
+
[c(2, index)],
|
| 274 |
+
)
|
| 275 |
+
|
| 276 |
+
for_op = scf.ForOp(
|
| 277 |
+
c(0, index),
|
| 278 |
+
c(self.entries_per_wg - 3, index),
|
| 279 |
+
c(1, index),
|
| 280 |
+
)
|
| 281 |
+
with ir.InsertionPoint(for_op.body):
|
| 282 |
+
x = memref.load(self.smem_buffer, [for_op.induction_variable])
|
| 283 |
+
memref.store(
|
| 284 |
+
x,
|
| 285 |
+
wg_gmem_buffer,
|
| 286 |
+
[arith.addi(for_op.induction_variable, c(3, index))],
|
| 287 |
+
)
|
| 288 |
+
scf.yield_([])
|
| 289 |
+
scf.yield_([])
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/utils.py
ADDED
|
@@ -0,0 +1,699 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
"""Utilities for code generator."""
|
| 16 |
+
|
| 17 |
+
from collections.abc import Iterator
|
| 18 |
+
import contextlib
|
| 19 |
+
import dataclasses
|
| 20 |
+
from typing import Any, Literal, Sequence
|
| 21 |
+
|
| 22 |
+
import jax
|
| 23 |
+
from jaxlib.mlir import ir
|
| 24 |
+
from jaxlib.mlir.dialects import arith
|
| 25 |
+
from jaxlib.mlir.dialects import builtin
|
| 26 |
+
from jaxlib.mlir.dialects import gpu
|
| 27 |
+
from jaxlib.mlir.dialects import llvm
|
| 28 |
+
from jaxlib.mlir.dialects import memref
|
| 29 |
+
from jaxlib.mlir.dialects import nvgpu
|
| 30 |
+
from jaxlib.mlir.dialects import nvvm
|
| 31 |
+
from jaxlib.mlir.dialects import scf
|
| 32 |
+
from jaxlib.mlir.dialects import vector
|
| 33 |
+
import numpy as np
|
| 34 |
+
|
| 35 |
+
# mypy: ignore-errors
|
| 36 |
+
|
| 37 |
+
WARPGROUP_SIZE: int = 128
|
| 38 |
+
DYNAMIC = -9223372036854775808
|
| 39 |
+
|
| 40 |
+
# pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
|
| 44 |
+
if len(memref_ty.shape) == 0:
|
| 45 |
+
raise NotImplementedError
|
| 46 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 47 |
+
rank = len(memref_ty.shape)
|
| 48 |
+
desc_ty = ir.Type.parse(
|
| 49 |
+
f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
|
| 50 |
+
)
|
| 51 |
+
desc = llvm.UndefOp(desc_ty)
|
| 52 |
+
desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
|
| 53 |
+
desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
|
| 54 |
+
desc = llvm.InsertValueOp(
|
| 55 |
+
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2]
|
| 56 |
+
)
|
| 57 |
+
for i, s in enumerate(memref_ty.shape):
|
| 58 |
+
desc = llvm.InsertValueOp(
|
| 59 |
+
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i]
|
| 60 |
+
)
|
| 61 |
+
for i, s in enumerate(get_contiguous_strides(memref_ty.shape)):
|
| 62 |
+
desc = llvm.InsertValueOp(
|
| 63 |
+
desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i]
|
| 64 |
+
)
|
| 65 |
+
return builtin.unrealized_conversion_cast([memref_ty], [desc])
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def pack_array(values):
|
| 69 |
+
if not values:
|
| 70 |
+
raise ValueError("Empty array")
|
| 71 |
+
elem_ty = values[0].type
|
| 72 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 73 |
+
ptr_ty = ir.Type.parse("!llvm.ptr")
|
| 74 |
+
arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty)
|
| 75 |
+
for i, v in enumerate(values):
|
| 76 |
+
elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty)
|
| 77 |
+
llvm.store(v, elem_ptr)
|
| 78 |
+
return arr_ptr
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_contiguous_strides(xs):
|
| 82 |
+
strides_ret = []
|
| 83 |
+
stride = 1
|
| 84 |
+
for x in xs[::-1]:
|
| 85 |
+
strides_ret.append(stride)
|
| 86 |
+
stride *= x
|
| 87 |
+
return strides_ret[::-1]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def c(val: int | float, ty):
|
| 91 |
+
if ir.IntegerType.isinstance(ty) or ir.IndexType.isinstance(ty):
|
| 92 |
+
if not isinstance(val, (int, np.integer)):
|
| 93 |
+
raise TypeError(type(val))
|
| 94 |
+
attr = ir.IntegerAttr.get(ty, val)
|
| 95 |
+
elif ir.FloatType.isinstance(ty):
|
| 96 |
+
attr = ir.FloatAttr.get(ty, val)
|
| 97 |
+
elif ir.VectorType.isinstance(ty):
|
| 98 |
+
return vector.splat(ty, c(val, ir.VectorType(ty).element_type))
|
| 99 |
+
else:
|
| 100 |
+
raise NotImplementedError(ty)
|
| 101 |
+
return arith.constant(ty, attr)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def get_tensormap_descriptor(**attrs):
|
| 105 |
+
return ir.Type.parse(
|
| 106 |
+
f"!nvgpu.tensormap.descriptor<{', '.join(k + '=' + v for k, v in attrs.items())}>"
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def debug_print(fmt, *args, uniform=True):
|
| 111 |
+
type_formats = []
|
| 112 |
+
new_args = []
|
| 113 |
+
for arg in args:
|
| 114 |
+
ty_format = None
|
| 115 |
+
if ir.IndexType.isinstance(arg.type):
|
| 116 |
+
ty_format = "%llu"
|
| 117 |
+
if ir.IntegerType.isinstance(arg.type):
|
| 118 |
+
width = ir.IntegerType(arg.type).width
|
| 119 |
+
if width == 64:
|
| 120 |
+
ty_format = "%llu"
|
| 121 |
+
elif width == 1:
|
| 122 |
+
ty_format = "%llu"
|
| 123 |
+
arg = arith.extui(ir.IntegerType.get_signless(64), arg)
|
| 124 |
+
if ir.F32Type.isinstance(arg.type):
|
| 125 |
+
ty_format = "%f"
|
| 126 |
+
if ir.F16Type.isinstance(arg.type):
|
| 127 |
+
ty_format = "%f"
|
| 128 |
+
arg = arith.extf(ir.F32Type.get(), arg)
|
| 129 |
+
if ty_format is None:
|
| 130 |
+
raise NotImplementedError(arg.type)
|
| 131 |
+
type_formats.append(ty_format)
|
| 132 |
+
new_args.append(arg)
|
| 133 |
+
ctx = single_thread if uniform else contextlib.nullcontext
|
| 134 |
+
with ctx():
|
| 135 |
+
gpu.printf(fmt.format(*type_formats) + "\n", new_args)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclasses.dataclass(frozen=True)
|
| 139 |
+
class ForResult:
|
| 140 |
+
op: scf.ForOp
|
| 141 |
+
results: tuple[Any, ...]
|
| 142 |
+
|
| 143 |
+
@property
|
| 144 |
+
def result(self):
|
| 145 |
+
if len(self.results) != 1:
|
| 146 |
+
raise ValueError
|
| 147 |
+
return self.results[0]
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def fori(bound, carrys):
|
| 151 |
+
unwrap = False
|
| 152 |
+
if not isinstance(carrys, (list, tuple)):
|
| 153 |
+
carrys = [carrys]
|
| 154 |
+
unwrap = True
|
| 155 |
+
flat_carrys, carry_treedef = jax.tree.flatten(carrys)
|
| 156 |
+
|
| 157 |
+
def wrapper(f):
|
| 158 |
+
index = ir.IndexType.get()
|
| 159 |
+
c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0))
|
| 160 |
+
c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1))
|
| 161 |
+
for_op = scf.ForOp(c0, bound, c1, flat_carrys)
|
| 162 |
+
with ir.InsertionPoint(for_op.body):
|
| 163 |
+
i = for_op.induction_variable
|
| 164 |
+
inner_carrys = jax.tree.unflatten(carry_treedef, for_op.inner_iter_args)
|
| 165 |
+
if unwrap:
|
| 166 |
+
[inner_carrys] = inner_carrys
|
| 167 |
+
new_carrys = f(i, inner_carrys)
|
| 168 |
+
if unwrap:
|
| 169 |
+
new_carrys = [new_carrys]
|
| 170 |
+
new_flat_carrys, new_carry_treedef = jax.tree.flatten(new_carrys)
|
| 171 |
+
if new_carry_treedef != carry_treedef:
|
| 172 |
+
raise ValueError(new_carry_treedef, carry_treedef)
|
| 173 |
+
scf.YieldOp(new_flat_carrys)
|
| 174 |
+
final_flat_carrys = for_op.results
|
| 175 |
+
return ForResult(
|
| 176 |
+
for_op, jax.tree.unflatten(carry_treedef, final_flat_carrys)
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
return wrapper
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def thread_idx():
|
| 183 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 184 |
+
as_i32 = lambda x: arith.index_cast(i32, x)
|
| 185 |
+
tidx = as_i32(gpu.thread_id(gpu.Dimension.x))
|
| 186 |
+
stride = as_i32(gpu.block_dim(gpu.Dimension.x))
|
| 187 |
+
for dim in (gpu.Dimension.y, gpu.Dimension.z):
|
| 188 |
+
tidx = arith.addi(tidx, arith.muli(as_i32(gpu.thread_id(dim)), stride))
|
| 189 |
+
stride = arith.muli(stride, as_i32(gpu.block_dim(dim)))
|
| 190 |
+
return tidx
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def _warp_bcast(val, lane_idx=0):
|
| 194 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 195 |
+
mask = c(0xFFFFFFFF, i32)
|
| 196 |
+
return nvvm.shfl_sync(
|
| 197 |
+
val.type, mask, val, c(lane_idx, i32), c(0x1F, i32), nvvm.ShflKind.idx
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
def warp_idx(sync=True):
|
| 202 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 203 |
+
warp_idx = arith.shrui(thread_idx(), c(5, i32))
|
| 204 |
+
# Performing a warp broadcast improves performance as compiler understands
|
| 205 |
+
# that the value is uniform across the warp.
|
| 206 |
+
return _warp_bcast(warp_idx) if sync else warp_idx
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
def warpgroup_idx(sync=True):
|
| 210 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 211 |
+
wg_idx = arith.shrui(thread_idx(), c(7, i32))
|
| 212 |
+
# Performing a warp broadcast improves performance as compiler understands
|
| 213 |
+
# that the value is uniform across the warp.
|
| 214 |
+
return _warp_bcast(wg_idx) if sync else wg_idx
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
# True withon `once()` contexts.
|
| 218 |
+
_ONCE_REGION_ACTIVE = False
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
@contextlib.contextmanager
|
| 222 |
+
def single_thread():
|
| 223 |
+
"""Runs the context only from a single thread."""
|
| 224 |
+
global _ONCE_REGION_ACTIVE
|
| 225 |
+
|
| 226 |
+
if _ONCE_REGION_ACTIVE:
|
| 227 |
+
yield
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
warp = warp_idx()
|
| 231 |
+
first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type))
|
| 232 |
+
elected = nvvm.elect_sync(ir.IntegerType.get_signless(1))
|
| 233 |
+
should_run = arith.andi(first_warp, elected)
|
| 234 |
+
if_op = scf.IfOp(should_run)
|
| 235 |
+
_ONCE_REGION_ACTIVE = True
|
| 236 |
+
try:
|
| 237 |
+
with ir.InsertionPoint(if_op.then_block):
|
| 238 |
+
yield
|
| 239 |
+
scf.YieldOp([])
|
| 240 |
+
finally:
|
| 241 |
+
_ONCE_REGION_ACTIVE = False
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
def clock():
|
| 245 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 246 |
+
return llvm.inline_asm(
|
| 247 |
+
i32, [], "mov.u32 $0,%clock;", "=r", asm_dialect=0, has_side_effects=True
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def globaltimer(kind: Literal["low", "high"] | None = None):
|
| 252 |
+
if kind is None:
|
| 253 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 254 |
+
return llvm.inline_asm(
|
| 255 |
+
i64, [], "mov.u32 $0,%globaltimer;",
|
| 256 |
+
"=l", asm_dialect=0, has_side_effects=True,
|
| 257 |
+
)
|
| 258 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 259 |
+
return llvm.inline_asm(
|
| 260 |
+
i32, [], f"mov.u32 $0,%globaltimer_{kind[:2]};",
|
| 261 |
+
"=r", asm_dialect=0, has_side_effects=True,
|
| 262 |
+
)
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def bytewidth(ty: ir.Type):
|
| 266 |
+
if ir.IntegerType.isinstance(ty):
|
| 267 |
+
return ir.IntegerType(ty).width // 8
|
| 268 |
+
if ir.FloatType.isinstance(ty):
|
| 269 |
+
return ir.FloatType(ty).width // 8
|
| 270 |
+
raise NotImplementedError(ty)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@dataclasses.dataclass(frozen=True)
|
| 274 |
+
class DynamicSlice:
|
| 275 |
+
base: ir.Value | int
|
| 276 |
+
length: int
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
ds = DynamicSlice
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def memref_slice(ref: ir.Value, index) -> ir.Value:
|
| 283 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 284 |
+
base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape)
|
| 285 |
+
|
| 286 |
+
memref_strides, offset = ref_ty.get_strides_and_offset()
|
| 287 |
+
new_offset = offset
|
| 288 |
+
for idx, stride in zip(base_indices, memref_strides):
|
| 289 |
+
if isinstance(idx, int):
|
| 290 |
+
new_offset += idx * stride
|
| 291 |
+
else:
|
| 292 |
+
new_offset = ir.ShapedType.get_dynamic_stride_or_offset()
|
| 293 |
+
break
|
| 294 |
+
new_strides = [
|
| 295 |
+
s for s, squeeze in zip(memref_strides, is_squeezed) if not squeeze
|
| 296 |
+
]
|
| 297 |
+
new_shape = [s for s, squeeze in zip(slice_shape, is_squeezed) if not squeeze]
|
| 298 |
+
new_layout = ir.StridedLayoutAttr.get(new_offset, new_strides)
|
| 299 |
+
|
| 300 |
+
ref_slice = memref.subview(
|
| 301 |
+
ref, base_indices, slice_shape, [1] * len(ref_ty.shape),
|
| 302 |
+
result_type=ir.MemRefType.get(
|
| 303 |
+
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
| 304 |
+
),
|
| 305 |
+
)
|
| 306 |
+
return ref_slice
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _is_contiguous_shape_slice(
|
| 310 |
+
ref_ty: ir.MemRefType, dim_slice: slice | None = slice(None)
|
| 311 |
+
):
|
| 312 |
+
# If it's not a strided layout then we are definitely contiguous.
|
| 313 |
+
if not ir.StridedLayoutAttr.isinstance(ref_ty.layout):
|
| 314 |
+
return True
|
| 315 |
+
|
| 316 |
+
strides = ir.StridedLayoutAttr(ref_ty.layout).strides[dim_slice]
|
| 317 |
+
shape = ref_ty.shape[dim_slice]
|
| 318 |
+
|
| 319 |
+
# Check that each dimension fits exactly it the immediately larger stride.
|
| 320 |
+
ss = sorted(zip(strides, shape), key=lambda x: x[0], reverse=True)
|
| 321 |
+
for (prev_stride, _), (stride, shape) in zip(ss, ss[1:]):
|
| 322 |
+
if stride * shape != prev_stride:
|
| 323 |
+
return False
|
| 324 |
+
|
| 325 |
+
return True
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
|
| 329 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 330 |
+
new_shape = list(ref_ty.shape)
|
| 331 |
+
new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])]
|
| 332 |
+
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
|
| 333 |
+
contig_strided_1d = ir.Attribute.parse("strided<[1]>")
|
| 334 |
+
# Not sure why but MLIR expects the strided 1D layout to disappear in this op.
|
| 335 |
+
if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d:
|
| 336 |
+
new_layout = ir.AffineMapAttr.get(
|
| 337 |
+
ir.AffineMap.get_identity(ref_ty.rank - fold_rank + 1)
|
| 338 |
+
)
|
| 339 |
+
elif _is_contiguous_shape_slice(ref_ty, slice(dim, dim + fold_rank)):
|
| 340 |
+
new_strides, offset = ref_ty.get_strides_and_offset()
|
| 341 |
+
new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]]
|
| 342 |
+
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
| 343 |
+
else:
|
| 344 |
+
raise NotImplementedError(
|
| 345 |
+
f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=},"
|
| 346 |
+
f" {dim=}, {fold_rank=}"
|
| 347 |
+
)
|
| 348 |
+
|
| 349 |
+
new_ty = ir.MemRefType.get(
|
| 350 |
+
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
| 351 |
+
)
|
| 352 |
+
assoc = [[d] for d in range(dim)]
|
| 353 |
+
assoc.append([dim + i for i in range(fold_rank)])
|
| 354 |
+
assoc.extend([d] for d in range(dim + fold_rank, ref_ty.rank))
|
| 355 |
+
assert len(assoc) == new_ty.rank
|
| 356 |
+
return memref.collapse_shape(new_ty, ref, assoc)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value:
|
| 360 |
+
"""Unfolds dim into two dimensions, the size of leading one given be major_factor."""
|
| 361 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 362 |
+
new_shape = list(ref_ty.shape)
|
| 363 |
+
if sum(f is None for f in factors) > 1:
|
| 364 |
+
raise ValueError("Can only infer one dimension")
|
| 365 |
+
known_factor_prod = np.prod([f for f in factors if f is not None])
|
| 366 |
+
if new_shape[dim] % known_factor_prod:
|
| 367 |
+
raise ValueError("Non-divisible unfold:", new_shape[dim], factors)
|
| 368 |
+
factors = tuple(
|
| 369 |
+
new_shape[dim] // known_factor_prod if f is None else f for f in factors
|
| 370 |
+
)
|
| 371 |
+
new_shape[dim : dim + 1] = factors
|
| 372 |
+
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
|
| 373 |
+
if ref_ty.layout == identity:
|
| 374 |
+
new_layout = ir.AffineMapAttr.get(
|
| 375 |
+
ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1)
|
| 376 |
+
)
|
| 377 |
+
else:
|
| 378 |
+
new_strides, offset = ref_ty.get_strides_and_offset()
|
| 379 |
+
prev_stride = new_strides[dim]
|
| 380 |
+
inserted_strides = []
|
| 381 |
+
for f in reversed(factors):
|
| 382 |
+
inserted_strides.append(prev_stride)
|
| 383 |
+
prev_stride *= f
|
| 384 |
+
new_strides[dim : dim + 1] = reversed(inserted_strides)
|
| 385 |
+
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
| 386 |
+
new_ty = ir.MemRefType.get(
|
| 387 |
+
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
| 388 |
+
)
|
| 389 |
+
if dim == ref_ty.rank:
|
| 390 |
+
assoc = [[d] for d in range(ref_ty.rank)]
|
| 391 |
+
assoc[-1].extend(range(ref_ty.rank, ref_ty.rank + len(factors) - 1))
|
| 392 |
+
else:
|
| 393 |
+
assoc = [[d] for d in range(dim)]
|
| 394 |
+
assoc.append(list(range(dim, dim + len(factors))))
|
| 395 |
+
assoc.extend([d + len(factors) - 1] for d in range(dim + 1, ref_ty.rank))
|
| 396 |
+
assert len(assoc) == ref_ty.rank
|
| 397 |
+
return memref.expand_shape(new_ty, ref, assoc, [], new_ty.shape)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def memref_unsqueeze(ref: ir.Value, dim) -> ir.Value:
|
| 401 |
+
"""Inserts a singleton dimension."""
|
| 402 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 403 |
+
if dim == ref_ty.rank:
|
| 404 |
+
new_shape = list(ref_ty.shape)
|
| 405 |
+
new_shape.append(1)
|
| 406 |
+
identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
|
| 407 |
+
if ref_ty.layout == identity:
|
| 408 |
+
new_layout = ir.AffineMapAttr.get(
|
| 409 |
+
ir.AffineMap.get_identity(ref_ty.rank + 1)
|
| 410 |
+
)
|
| 411 |
+
else:
|
| 412 |
+
new_strides, offset = ref_ty.get_strides_and_offset()
|
| 413 |
+
new_strides.append(1)
|
| 414 |
+
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
| 415 |
+
new_ty = ir.MemRefType.get(
|
| 416 |
+
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
| 417 |
+
)
|
| 418 |
+
assoc = [[d] for d in range(ref_ty.rank)]
|
| 419 |
+
assoc[-1].append(ref_ty.rank)
|
| 420 |
+
return memref.expand_shape(new_ty, ref, assoc, [], new_ty.shape)
|
| 421 |
+
else:
|
| 422 |
+
return memref_unfold(ref, dim, (1, None))
|
| 423 |
+
|
| 424 |
+
|
| 425 |
+
def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value:
|
| 426 |
+
ref_ty = ir.MemRefType(ref.type)
|
| 427 |
+
strides, offset = ref_ty.get_strides_and_offset()
|
| 428 |
+
new_strides = [strides[p] for p in permutation]
|
| 429 |
+
new_shape = [ref_ty.shape[p] for p in permutation]
|
| 430 |
+
new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
|
| 431 |
+
new_ty = ir.MemRefType.get(
|
| 432 |
+
new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
|
| 433 |
+
)
|
| 434 |
+
return memref.transpose(
|
| 435 |
+
new_ty, ref, ir.AffineMap.get_permutation(permutation)
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def parse_indices(
|
| 440 |
+
index, shape: tuple[int, ...]
|
| 441 |
+
) -> tuple[list[ir.Value | int], list[int], list[bool]]:
|
| 442 |
+
if not isinstance(index, tuple):
|
| 443 |
+
index = (index,)
|
| 444 |
+
if trailing_dims := len(shape) - len(index):
|
| 445 |
+
index += (slice(None),) * trailing_dims
|
| 446 |
+
base_indices = []
|
| 447 |
+
slice_shape = []
|
| 448 |
+
is_squeezed = []
|
| 449 |
+
for idx, bound in zip(index, shape):
|
| 450 |
+
if isinstance(idx, (ir.Operation, ir.OpView)):
|
| 451 |
+
idx = idx.result
|
| 452 |
+
if isinstance(idx, int):
|
| 453 |
+
base_indices.append(idx)
|
| 454 |
+
slice_shape.append(1)
|
| 455 |
+
is_squeezed.append(True)
|
| 456 |
+
elif isinstance(idx, slice):
|
| 457 |
+
if idx.step is not None:
|
| 458 |
+
raise NotImplementedError("Strided slices not implemented")
|
| 459 |
+
base_indices.append(idx.start or 0)
|
| 460 |
+
slice_shape.append((idx.stop or bound) - (idx.start or 0))
|
| 461 |
+
is_squeezed.append(False)
|
| 462 |
+
elif isinstance(idx, DynamicSlice):
|
| 463 |
+
base_indices.append(idx.base)
|
| 464 |
+
slice_shape.append(idx.length)
|
| 465 |
+
is_squeezed.append(False)
|
| 466 |
+
elif isinstance(idx, ir.Value):
|
| 467 |
+
if not ir.IndexType.isinstance(idx.type):
|
| 468 |
+
raise ValueError("Expected an index-typed index")
|
| 469 |
+
base_indices.append(idx)
|
| 470 |
+
slice_shape.append(1)
|
| 471 |
+
is_squeezed.append(True)
|
| 472 |
+
else:
|
| 473 |
+
raise NotImplementedError(type(idx))
|
| 474 |
+
assert len(base_indices) == len(slice_shape) == len(is_squeezed) == len(shape)
|
| 475 |
+
return base_indices, slice_shape, is_squeezed
|
| 476 |
+
|
| 477 |
+
|
| 478 |
+
def commit_shared():
|
| 479 |
+
gpu.barrier()
|
| 480 |
+
nvvm.fence_proxy(
|
| 481 |
+
nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta
|
| 482 |
+
)
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class BarrierArray:
|
| 486 |
+
|
| 487 |
+
def __init__(self, num_barriers: int, arrival_count: int = 1):
|
| 488 |
+
barrier_group_ty = ir.Type.parse(
|
| 489 |
+
"!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>,"
|
| 490 |
+
f" num_barriers={num_barriers}>"
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
self.num_barriers = num_barriers
|
| 494 |
+
self.value = nvgpu.mbarrier_create(barrier_group_ty)
|
| 495 |
+
self.num_barriers = num_barriers
|
| 496 |
+
index = ir.IndexType.get()
|
| 497 |
+
if num_barriers > 32:
|
| 498 |
+
raise NotImplementedError("Only up to 32 barriers per group supported")
|
| 499 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 500 |
+
self.phases = memref.alloca(ir.MemRefType.get((), i32), [], [])
|
| 501 |
+
memref.store(c(0, i32), self.phases, [])
|
| 502 |
+
with single_thread():
|
| 503 |
+
for i in range(num_barriers):
|
| 504 |
+
nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index))
|
| 505 |
+
gpu.barrier()
|
| 506 |
+
|
| 507 |
+
def __iter__(self) -> Iterator["Barrier"]:
|
| 508 |
+
for offset in range(self.num_barriers):
|
| 509 |
+
yield self[offset]
|
| 510 |
+
|
| 511 |
+
def __getitem__(self, offset: ir.Value | int):
|
| 512 |
+
if isinstance(offset, int):
|
| 513 |
+
offset = c(offset, ir.IndexType.get())
|
| 514 |
+
return Barrier(self, offset)
|
| 515 |
+
|
| 516 |
+
|
| 517 |
+
@dataclasses.dataclass(frozen=True)
|
| 518 |
+
class Barrier:
|
| 519 |
+
barrier_array: BarrierArray
|
| 520 |
+
offset: ir.Value
|
| 521 |
+
|
| 522 |
+
def wait_parity(self, parity):
|
| 523 |
+
index = ir.IndexType.get()
|
| 524 |
+
nvgpu.mbarrier_try_wait_parity(
|
| 525 |
+
self.barrier_array.value, parity, c(10000000, index), self.offset,
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
def wait(self):
|
| 529 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 530 |
+
parities = memref.load(self.barrier_array.phases, [])
|
| 531 |
+
offset_i32 = arith.index_castui(i32, self.offset)
|
| 532 |
+
bitmask = arith.shli(c(1, i32), offset_i32)
|
| 533 |
+
parity = arith.cmpi(
|
| 534 |
+
arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32)
|
| 535 |
+
)
|
| 536 |
+
new_parities = arith.xori(parities, bitmask)
|
| 537 |
+
memref.store(new_parities, self.barrier_array.phases, [])
|
| 538 |
+
self.wait_parity(parity)
|
| 539 |
+
|
| 540 |
+
def arrive(self):
|
| 541 |
+
token_ty = ir.Type.parse("!nvgpu.mbarrier.token")
|
| 542 |
+
nvgpu.mbarrier_arrive(token_ty, self.barrier_array.value, self.offset)
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
class Partition:
|
| 546 |
+
source_bounds: tuple[int, ...]
|
| 547 |
+
target_bounds: tuple[int, ...]
|
| 548 |
+
partition: tuple[int | None, ...]
|
| 549 |
+
base_offset: tuple[ir.Value, ...] | None
|
| 550 |
+
|
| 551 |
+
def __init__(
|
| 552 |
+
self,
|
| 553 |
+
elements: tuple[int, ...],
|
| 554 |
+
*,
|
| 555 |
+
partition: tuple[int | None, ...],
|
| 556 |
+
base_offset: tuple[ir.Value, ...] | None = None,
|
| 557 |
+
num_chunks: tuple[int, ...] | None = None,
|
| 558 |
+
chunk_size: tuple[int, ...] | None = None,
|
| 559 |
+
):
|
| 560 |
+
self.target_bounds = elements
|
| 561 |
+
self.partition = partition
|
| 562 |
+
self.base_offset = base_offset
|
| 563 |
+
if len(self.target_bounds) != len(self.partition):
|
| 564 |
+
raise ValueError
|
| 565 |
+
if num_chunks is None == chunk_size is None:
|
| 566 |
+
raise ValueError(
|
| 567 |
+
"Exactly one of num_chunks and chunk_size must be specified"
|
| 568 |
+
)
|
| 569 |
+
if num_chunks is not None:
|
| 570 |
+
self.source_bounds = num_chunks
|
| 571 |
+
else:
|
| 572 |
+
if len(chunk_size) != len(self.target_bounds):
|
| 573 |
+
raise ValueError
|
| 574 |
+
source_bounds = []
|
| 575 |
+
for els, chunk in zip(elements, chunk_size):
|
| 576 |
+
if els % chunk:
|
| 577 |
+
raise ValueError("Non-divisible partition", elements, chunk_size)
|
| 578 |
+
source_bounds.append(els // chunk)
|
| 579 |
+
self.source_bounds = tuple(source_bounds)
|
| 580 |
+
|
| 581 |
+
seen_dims = set()
|
| 582 |
+
for p in self.partition:
|
| 583 |
+
if p is None:
|
| 584 |
+
continue
|
| 585 |
+
if not (0 <= p < len(self.source_bounds)):
|
| 586 |
+
raise ValueError
|
| 587 |
+
if p in seen_dims:
|
| 588 |
+
raise ValueError
|
| 589 |
+
seen_dims.add(p)
|
| 590 |
+
for tb, p in zip(self.target_bounds, self.partition):
|
| 591 |
+
if p is not None and tb % self.source_bounds[p]:
|
| 592 |
+
raise ValueError("Non-divisible partitioning")
|
| 593 |
+
|
| 594 |
+
@property
|
| 595 |
+
def num_chunks(self) -> tuple[int, ...]:
|
| 596 |
+
return self.source_bounds
|
| 597 |
+
|
| 598 |
+
@property
|
| 599 |
+
def target_block_shape(self):
|
| 600 |
+
return tuple(tb if p is None else tb // self.source_bounds[p]
|
| 601 |
+
for tb, p in zip(self.target_bounds, self.partition))
|
| 602 |
+
|
| 603 |
+
def get_base(self, *source_coords: ir.Value | int) -> list[ir.Value]:
|
| 604 |
+
coords = []
|
| 605 |
+
index = ir.IndexType.get()
|
| 606 |
+
for i, (tbs, p) in enumerate(zip(self.target_block_shape, self.partition)):
|
| 607 |
+
if p is None:
|
| 608 |
+
dim_base = c(0, index)
|
| 609 |
+
else:
|
| 610 |
+
dim_base = arith.muli(c(tbs, index), source_coords[p])
|
| 611 |
+
if self.base_offset is not None:
|
| 612 |
+
dim_base = arith.addi(self.base_offset[i], dim_base)
|
| 613 |
+
coords.append(dim_base)
|
| 614 |
+
return coords
|
| 615 |
+
|
| 616 |
+
|
| 617 |
+
class Partition1D:
|
| 618 |
+
partition: Partition
|
| 619 |
+
|
| 620 |
+
def __init__(
|
| 621 |
+
self,
|
| 622 |
+
elements: int,
|
| 623 |
+
*,
|
| 624 |
+
base_offset: ir.Value | None = None,
|
| 625 |
+
num_chunks: int | None = None,
|
| 626 |
+
chunk_size: int | None = None,
|
| 627 |
+
):
|
| 628 |
+
self.base_offset = base_offset
|
| 629 |
+
if num_chunks is None == chunk_size is None:
|
| 630 |
+
raise ValueError(
|
| 631 |
+
"Exactly one of num_chunks and chunk_size must be specified"
|
| 632 |
+
)
|
| 633 |
+
common_kwargs = dict(elements=(elements,), partition=(0,))
|
| 634 |
+
if base_offset is not None:
|
| 635 |
+
common_kwargs["base_offset"] = (base_offset,)
|
| 636 |
+
if num_chunks is not None:
|
| 637 |
+
self.partition = Partition(num_chunks=(num_chunks,), **common_kwargs)
|
| 638 |
+
else:
|
| 639 |
+
self.partition = Partition(chunk_size=(chunk_size,), **common_kwargs)
|
| 640 |
+
|
| 641 |
+
@property
|
| 642 |
+
def num_chunks(self) -> int:
|
| 643 |
+
return self.partition.source_bounds[0]
|
| 644 |
+
|
| 645 |
+
def get_base(self, source_coords: ir.Value) -> ir.Value:
|
| 646 |
+
return self.partition.get_base(source_coords)[0]
|
| 647 |
+
|
| 648 |
+
def refine(
|
| 649 |
+
self,
|
| 650 |
+
*,
|
| 651 |
+
chunk: ir.Value | None = None,
|
| 652 |
+
num_chunks: int | None = None,
|
| 653 |
+
chunk_size: int | None = None,
|
| 654 |
+
):
|
| 655 |
+
return Partition1D(
|
| 656 |
+
self.partition.target_block_shape[0],
|
| 657 |
+
num_chunks=num_chunks,
|
| 658 |
+
chunk_size=chunk_size,
|
| 659 |
+
base_offset=self.get_base(chunk) if chunk is not None else None,
|
| 660 |
+
)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
def tile_shape(shape, tiling):
|
| 664 |
+
if len(tiling) > len(shape):
|
| 665 |
+
raise ValueError
|
| 666 |
+
if not tiling:
|
| 667 |
+
return shape
|
| 668 |
+
tiling_rank = len(tiling)
|
| 669 |
+
for s, t in zip(shape[-tiling_rank:], tiling):
|
| 670 |
+
if s % t:
|
| 671 |
+
raise ValueError("Non-divisible tiling:", shape, tiling)
|
| 672 |
+
return (
|
| 673 |
+
*shape[:-tiling_rank],
|
| 674 |
+
*(s // t for s, t in zip(shape[-tiling_rank:], tiling)),
|
| 675 |
+
*tiling,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
|
| 679 |
+
def warp_tree_reduce(value, op, group_size):
|
| 680 |
+
"""Reduce a value across the warpgroup."""
|
| 681 |
+
assert 32 % group_size == 0 and group_size <= 32
|
| 682 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 683 |
+
result = value
|
| 684 |
+
iters = np.log2(group_size)
|
| 685 |
+
if not iters.is_integer():
|
| 686 |
+
raise ValueError(f"Warp reduction group size should be a power of 2 (got {group_size})")
|
| 687 |
+
iters = int(iters)
|
| 688 |
+
for i in range(iters):
|
| 689 |
+
other_result = nvvm.shfl_sync(
|
| 690 |
+
result.type,
|
| 691 |
+
c(0xFFFFFFFF, i32),
|
| 692 |
+
result,
|
| 693 |
+
c(1 << i, i32),
|
| 694 |
+
c(0x1F, i32),
|
| 695 |
+
nvvm.ShflKind.bfly,
|
| 696 |
+
)
|
| 697 |
+
result = op(result, other_result)
|
| 698 |
+
|
| 699 |
+
return result
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/wgmma.py
ADDED
|
@@ -0,0 +1,518 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2024 The JAX Authors. All Rights Reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
import dataclasses
|
| 17 |
+
import enum
|
| 18 |
+
import functools
|
| 19 |
+
import itertools
|
| 20 |
+
|
| 21 |
+
import jax
|
| 22 |
+
from jaxlib.mlir import ir
|
| 23 |
+
from jaxlib.mlir.dialects import arith
|
| 24 |
+
from jaxlib.mlir.dialects import builtin
|
| 25 |
+
from jaxlib.mlir.dialects import llvm
|
| 26 |
+
from jaxlib.mlir.dialects import nvvm
|
| 27 |
+
from jaxlib.mlir.dialects import vector
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
from . import dsl as mgpu
|
| 31 |
+
|
| 32 |
+
# mypy: ignore-errors
|
| 33 |
+
|
| 34 |
+
c = mgpu.c
|
| 35 |
+
bytewidth = mgpu.bytewidth
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@jax.tree_util.register_pytree_node_class
|
| 39 |
+
@dataclasses.dataclass
|
| 40 |
+
class WGMMAAccumulator:
|
| 41 |
+
"""A FragmentedArray that has is synchronized with the async proxy.
|
| 42 |
+
|
| 43 |
+
This implies that it requires no additional synchronization when passed in
|
| 44 |
+
as a WGMMA accumulator. In particular, when created from a
|
| 45 |
+
FragmentedArray, the necessary synchronization is inserted at construction.
|
| 46 |
+
"""
|
| 47 |
+
value: mgpu.FragmentedArray
|
| 48 |
+
|
| 49 |
+
def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True):
|
| 50 |
+
if _value.layout != mgpu.WGMMA_LAYOUT:
|
| 51 |
+
raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
|
| 52 |
+
self.value = _value
|
| 53 |
+
if _sync:
|
| 54 |
+
self._value = wgmma_fence(_value)
|
| 55 |
+
|
| 56 |
+
@classmethod
|
| 57 |
+
def zero(cls, m, n, dtype=None):
|
| 58 |
+
if m % 64 or n % 8:
|
| 59 |
+
raise ValueError
|
| 60 |
+
f32 = ir.F32Type.get()
|
| 61 |
+
if dtype is None:
|
| 62 |
+
dtype = f32
|
| 63 |
+
zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
|
| 64 |
+
return cls(
|
| 65 |
+
_value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT)
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
@classmethod
|
| 69 |
+
def from_registers(cls, registers):
|
| 70 |
+
return cls(_value=registers)
|
| 71 |
+
|
| 72 |
+
def tree_flatten(self):
|
| 73 |
+
return (self.value,), ()
|
| 74 |
+
|
| 75 |
+
@classmethod
|
| 76 |
+
def tree_unflatten(cls, aux, value):
|
| 77 |
+
del aux
|
| 78 |
+
return cls(_value=value[0], _sync=False)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def wgmma_encode(x: int):
|
| 82 |
+
result = (x & 0x3FFFF) >> 4
|
| 83 |
+
if result << 4 != x:
|
| 84 |
+
raise ValueError("Cannot encode value in a WGMMA descriptor")
|
| 85 |
+
return result
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def llvm_mul(x, y):
|
| 89 |
+
return llvm.mul(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def llvm_add(x, y):
|
| 93 |
+
return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def get_memref_base(memref_arg, memory_space=None):
|
| 97 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 98 |
+
memref_ty = ir.MemRefType(memref_arg.type)
|
| 99 |
+
if len(memref_ty.shape) == 0:
|
| 100 |
+
raise NotImplementedError
|
| 101 |
+
elem_bytewidth = bytewidth(memref_ty.element_type)
|
| 102 |
+
rank = len(memref_ty.shape)
|
| 103 |
+
# TODO: Read out memory space from memref
|
| 104 |
+
space = "" if memory_space is None else "<" + str(memory_space) + ">"
|
| 105 |
+
ptr_ty = ir.Type.parse("!llvm.ptr" + space)
|
| 106 |
+
desc_ty = ir.Type.parse(
|
| 107 |
+
f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
|
| 108 |
+
f" array<{rank} x i64>)>"
|
| 109 |
+
)
|
| 110 |
+
desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
|
| 111 |
+
aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
|
| 112 |
+
offset_elems = llvm.extractvalue(i64, desc, [2])
|
| 113 |
+
offset_bytes = llvm_mul(offset_elems, c(elem_bytewidth, i64))
|
| 114 |
+
return llvm.inttoptr(
|
| 115 |
+
ptr_ty, llvm_add(llvm.ptrtoint(i64, aligned_ptr), offset_bytes)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
def create_descriptor(
|
| 120 |
+
memref_arg,
|
| 121 |
+
leading_byte_offset: int,
|
| 122 |
+
stride_byte_offset: int,
|
| 123 |
+
swizzle: int | None,
|
| 124 |
+
memory_space: int | None = None,
|
| 125 |
+
nvgpu_type=None,
|
| 126 |
+
):
|
| 127 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 128 |
+
ptr_val = llvm.ptrtoint(i64, get_memref_base(memref_arg, memory_space))
|
| 129 |
+
if swizzle is None:
|
| 130 |
+
swizzle_encoding = 0
|
| 131 |
+
elif swizzle == 128:
|
| 132 |
+
swizzle_encoding = 1
|
| 133 |
+
else:
|
| 134 |
+
raise NotImplementedError(swizzle)
|
| 135 |
+
encoded_base_addr = llvm.LShrOp(
|
| 136 |
+
llvm.AndOp(ptr_val, c(0x3FFFF, i64)), c(4, i64)
|
| 137 |
+
)
|
| 138 |
+
desc_const = (
|
| 139 |
+
(wgmma_encode(leading_byte_offset) << 16)
|
| 140 |
+
| (wgmma_encode(stride_byte_offset) << 32)
|
| 141 |
+
|
|
| 142 |
+
# We ignore the offset
|
| 143 |
+
(swizzle_encoding << 62)
|
| 144 |
+
)
|
| 145 |
+
desc = llvm.OrOp(encoded_base_addr, c(desc_const, i64))
|
| 146 |
+
if nvgpu_type is not None:
|
| 147 |
+
desc = builtin.UnrealizedConversionCastOp([nvgpu_type], [desc])
|
| 148 |
+
return desc.result
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _unpack_i32(vec_ty, r):
|
| 152 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 153 |
+
return vector.bitcast(
|
| 154 |
+
vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _supported_wgmma_types(dtype, abtype) -> bool:
|
| 159 |
+
input_types_are = lambda ty: ty.isinstance(abtype)
|
| 160 |
+
if ir.F32Type.isinstance(dtype):
|
| 161 |
+
return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type))
|
| 162 |
+
elif ir.F16Type.isinstance(dtype):
|
| 163 |
+
return input_types_are(ir.F16Type)
|
| 164 |
+
else:
|
| 165 |
+
return False
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def wgmma_m64k128B(
|
| 169 |
+
acc: np.ndarray, # of register Values
|
| 170 |
+
a,
|
| 171 |
+
b_descriptor: ir.Value,
|
| 172 |
+
a_transpose: bool | None,
|
| 173 |
+
b_transpose: bool,
|
| 174 |
+
a_k_stride: int | None,
|
| 175 |
+
b_k_stride: int,
|
| 176 |
+
n: int,
|
| 177 |
+
element_type: ir.Type,
|
| 178 |
+
):
|
| 179 |
+
out_ty = ir.VectorType(acc.flat[0].type).element_type
|
| 180 |
+
if not _supported_wgmma_types(out_ty, element_type):
|
| 181 |
+
raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}")
|
| 182 |
+
|
| 183 |
+
f16 = ir.F16Type.get()
|
| 184 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 185 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 186 |
+
index = ir.IndexType.get()
|
| 187 |
+
if b_k_stride % 16:
|
| 188 |
+
raise ValueError
|
| 189 |
+
if n % (128 // bytewidth(element_type)):
|
| 190 |
+
raise ValueError
|
| 191 |
+
# Only 16-bit types support transposes
|
| 192 |
+
supports_transpose = bytewidth(element_type) == 2
|
| 193 |
+
if not supports_transpose and (a_transpose or b_transpose):
|
| 194 |
+
raise ValueError("Only f16 WGMMA supports transposes")
|
| 195 |
+
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
|
| 196 |
+
if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
|
| 197 |
+
raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
|
| 198 |
+
if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64):
|
| 199 |
+
raise ValueError("Unsupported A register array layout")
|
| 200 |
+
if a_k_stride is not None or a_transpose is not None:
|
| 201 |
+
raise ValueError("Unsupported WGMMA features with A in registers")
|
| 202 |
+
else:
|
| 203 |
+
if a_k_stride is None or a_k_stride % 16:
|
| 204 |
+
raise ValueError
|
| 205 |
+
if a_transpose is None:
|
| 206 |
+
raise ValueError
|
| 207 |
+
|
| 208 |
+
if ir.F32Type.isinstance(out_ty):
|
| 209 |
+
num_acc_regs = n // 2
|
| 210 |
+
out_ty_field = out_ty
|
| 211 |
+
acc_regs = [ # pylint: disable=g-complex-comprehension
|
| 212 |
+
vector.extractelement(reg, position=c(pos, index))
|
| 213 |
+
for reg in acc.flat
|
| 214 |
+
for pos in range(2)
|
| 215 |
+
]
|
| 216 |
+
to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape)
|
| 217 |
+
acc_constraint = "f"
|
| 218 |
+
elif ir.F16Type.isinstance(out_ty):
|
| 219 |
+
num_acc_regs = n // 4
|
| 220 |
+
out_ty_field = i32
|
| 221 |
+
acc_regs = [_as_i32_reg(reg) for reg in acc.flat]
|
| 222 |
+
vec_ty = ir.VectorType(acc.flat[0].type)
|
| 223 |
+
to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape)
|
| 224 |
+
acc_constraint = "r"
|
| 225 |
+
else:
|
| 226 |
+
raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})")
|
| 227 |
+
|
| 228 |
+
num_imm_regs = 4 if supports_transpose else 2
|
| 229 |
+
|
| 230 |
+
if a_in_regs:
|
| 231 |
+
a_reg_constraints = ["r"] * 4 # 4x f16x2 registers
|
| 232 |
+
num_imm_regs -= 1 # transpose not supported for a in registers
|
| 233 |
+
else:
|
| 234 |
+
a_reg_constraints = ["l"] # descriptor
|
| 235 |
+
# Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
|
| 236 |
+
# Seems like it's not actually documented in LLVM IR docs.
|
| 237 |
+
reg_constraints_list = (
|
| 238 |
+
[f"={acc_constraint}"] * num_acc_regs # accumulator registers
|
| 239 |
+
+ [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too.
|
| 240 |
+
+ a_reg_constraints # a descriptor / registers
|
| 241 |
+
+ ["l"] * 1 # b descriptor
|
| 242 |
+
+ ["n"] * (1 + num_imm_regs) # literal constants
|
| 243 |
+
)
|
| 244 |
+
reg_constraints = ",".join(reg_constraints_list)
|
| 245 |
+
|
| 246 |
+
reg_count = itertools.count()
|
| 247 |
+
|
| 248 |
+
def take_regs(n):
|
| 249 |
+
return (f"${i}" for i in itertools.islice(reg_count, n))
|
| 250 |
+
|
| 251 |
+
acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
|
| 252 |
+
for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing.
|
| 253 |
+
pass
|
| 254 |
+
if a_in_regs:
|
| 255 |
+
a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
|
| 256 |
+
else:
|
| 257 |
+
a_regs, = take_regs(1)
|
| 258 |
+
b_desc_reg, use_out_reg = take_regs(2)
|
| 259 |
+
imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...).
|
| 260 |
+
assert next(reg_count) == len(reg_constraints_list)
|
| 261 |
+
el_ty = element_type
|
| 262 |
+
k_instr = 32 // bytewidth(element_type)
|
| 263 |
+
wgmma_instr = (
|
| 264 |
+
f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} "
|
| 265 |
+
f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};"
|
| 266 |
+
)
|
| 267 |
+
ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"
|
| 268 |
+
|
| 269 |
+
def lc(x):
|
| 270 |
+
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
| 271 |
+
|
| 272 |
+
use_out = scale_a = scale_b = lc(1)
|
| 273 |
+
imms = [use_out, scale_a, scale_b]
|
| 274 |
+
if supports_transpose and a_transpose is not None:
|
| 275 |
+
imms += [lc(int(a_transpose)), lc(int(b_transpose))]
|
| 276 |
+
elif supports_transpose:
|
| 277 |
+
imms += [lc(int(b_transpose))]
|
| 278 |
+
if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1):
|
| 279 |
+
raise ValueError(acc.shape)
|
| 280 |
+
acc_struct_type = ir.Type.parse(
|
| 281 |
+
f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>"
|
| 282 |
+
)
|
| 283 |
+
for i in range(4):
|
| 284 |
+
# Slice out the relevant part of A or advance the A descriptor.
|
| 285 |
+
if a_in_regs:
|
| 286 |
+
a_slice = a[:, (i * 16) : ((i + 1) * 16)]
|
| 287 |
+
a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
|
| 288 |
+
else:
|
| 289 |
+
if i > 0:
|
| 290 |
+
a = llvm_add(
|
| 291 |
+
a,
|
| 292 |
+
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
|
| 293 |
+
)
|
| 294 |
+
a_args = [a]
|
| 295 |
+
# Advance the B descriptor.
|
| 296 |
+
if i > 0:
|
| 297 |
+
b_descriptor = llvm_add(
|
| 298 |
+
b_descriptor,
|
| 299 |
+
llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
|
| 300 |
+
)
|
| 301 |
+
assert len(a_args) == len(a_reg_constraints)
|
| 302 |
+
acc_struct = llvm.inline_asm(
|
| 303 |
+
acc_struct_type,
|
| 304 |
+
[*acc_regs, *a_args, b_descriptor, *imms],
|
| 305 |
+
ptx,
|
| 306 |
+
reg_constraints,
|
| 307 |
+
asm_dialect=0,
|
| 308 |
+
has_side_effects=True,
|
| 309 |
+
)
|
| 310 |
+
acc_regs = [
|
| 311 |
+
llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs))
|
| 312 |
+
]
|
| 313 |
+
return to_acc_vec_regs(acc_regs)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class WGMMALayout(enum.Enum):
|
| 317 |
+
ROW_MAJOR = enum.auto()
|
| 318 |
+
COL_MAJOR = enum.auto()
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
|
| 322 |
+
# transpositions from memref strides.
|
| 323 |
+
def wgmma(
|
| 324 |
+
acc: WGMMAAccumulator,
|
| 325 |
+
a,
|
| 326 |
+
b,
|
| 327 |
+
*,
|
| 328 |
+
# Order only applies within each tile!
|
| 329 |
+
a_order: WGMMALayout | None = None,
|
| 330 |
+
b_order: WGMMALayout = WGMMALayout.ROW_MAJOR,
|
| 331 |
+
):
|
| 332 |
+
if a_in_regs := isinstance(a, mgpu.FragmentedArray):
|
| 333 |
+
a_element_type = a.mlir_dtype
|
| 334 |
+
a_shape = a.shape
|
| 335 |
+
else:
|
| 336 |
+
a_ty = ir.MemRefType(a.type)
|
| 337 |
+
a_element_type = a_ty.element_type
|
| 338 |
+
a_shape = a_ty.shape
|
| 339 |
+
b_ty = ir.MemRefType(b.type)
|
| 340 |
+
supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
|
| 341 |
+
if a_element_type not in supported_types:
|
| 342 |
+
raise ValueError(a_element_type)
|
| 343 |
+
if b_ty.element_type not in supported_types:
|
| 344 |
+
raise ValueError(b_ty.element_type)
|
| 345 |
+
if (element_type := a_element_type) != b_ty.element_type:
|
| 346 |
+
raise ValueError
|
| 347 |
+
element_bytewidth = bytewidth(element_type)
|
| 348 |
+
kn_tile = 128 // element_bytewidth
|
| 349 |
+
|
| 350 |
+
groups_k, groups_n = b_ty.shape[:2]
|
| 351 |
+
if b_ty.shape[2:] != [kn_tile, kn_tile]:
|
| 352 |
+
raise ValueError(b_ty.shape)
|
| 353 |
+
|
| 354 |
+
if a_in_regs:
|
| 355 |
+
if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get():
|
| 356 |
+
raise ValueError(a_element_type)
|
| 357 |
+
if a_shape[0] % 64 or a_shape[1] % kn_tile:
|
| 358 |
+
raise ValueError(a_shape)
|
| 359 |
+
if a_shape[1] // kn_tile != groups_k:
|
| 360 |
+
raise ValueError(a_shape[1] // kn_tile, groups_k)
|
| 361 |
+
groups_m = a_shape[0] // 64
|
| 362 |
+
if a_order is not None:
|
| 363 |
+
raise ValueError(
|
| 364 |
+
"a_order can only be specified when A is in shared memory"
|
| 365 |
+
)
|
| 366 |
+
else:
|
| 367 |
+
groups_m = a_shape[0]
|
| 368 |
+
if a_shape[1] != groups_k:
|
| 369 |
+
raise ValueError(a_shape[1], groups_k)
|
| 370 |
+
if a_shape[2:] != [64, kn_tile]:
|
| 371 |
+
raise ValueError(a_shape)
|
| 372 |
+
if a_order is None:
|
| 373 |
+
a_order = WGMMALayout.ROW_MAJOR
|
| 374 |
+
|
| 375 |
+
row_major = WGMMALayout.ROW_MAJOR
|
| 376 |
+
col_major = WGMMALayout.COL_MAJOR
|
| 377 |
+
a_desc_fields = dict(
|
| 378 |
+
leading_byte_offset=((1 if a_order == row_major else 512) << 4),
|
| 379 |
+
stride_byte_offset=(64 << 4),
|
| 380 |
+
swizzle=128,
|
| 381 |
+
memory_space=3,
|
| 382 |
+
)
|
| 383 |
+
b_desc_fields = dict(
|
| 384 |
+
leading_byte_offset=((512 if b_order == row_major else 1) << 4),
|
| 385 |
+
stride_byte_offset=(64 << 4),
|
| 386 |
+
swizzle=128,
|
| 387 |
+
memory_space=3,
|
| 388 |
+
)
|
| 389 |
+
wgmma_params = dict(
|
| 390 |
+
a_transpose=a_order == col_major,
|
| 391 |
+
b_transpose=b_order == row_major,
|
| 392 |
+
a_k_stride=(2 if a_order == row_major else 128) * 16,
|
| 393 |
+
b_k_stride=(128 if b_order == row_major else 2) * 16,
|
| 394 |
+
n=(groups_n * kn_tile),
|
| 395 |
+
element_type=ir.FloatTF32Type.get()
|
| 396 |
+
if ir.F32Type.isinstance(element_type)
|
| 397 |
+
else element_type,
|
| 398 |
+
)
|
| 399 |
+
if a_in_regs:
|
| 400 |
+
wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
|
| 401 |
+
|
| 402 |
+
if a_in_regs:
|
| 403 |
+
a = wgmma_fence(a) # Make sure the registers are ready.
|
| 404 |
+
a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
|
| 405 |
+
else:
|
| 406 |
+
a_desc_base = create_descriptor(a, **a_desc_fields)
|
| 407 |
+
a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
|
| 408 |
+
a_byte_strides = [s * element_bytewidth for s in a_strides]
|
| 409 |
+
a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2]
|
| 410 |
+
if a_byte_strides[2:] != [128, element_bytewidth]:
|
| 411 |
+
raise ValueError(a_byte_strides)
|
| 412 |
+
b_desc_base = create_descriptor(b, **b_desc_fields)
|
| 413 |
+
b_strides, _ = b_ty.get_strides_and_offset()
|
| 414 |
+
b_byte_strides = [s * element_bytewidth for s in b_strides]
|
| 415 |
+
b_k_byte_stride = b_byte_strides[0]
|
| 416 |
+
if b_byte_strides[1:] != [128 * kn_tile, 128, element_bytewidth]:
|
| 417 |
+
raise ValueError(b_byte_strides)
|
| 418 |
+
|
| 419 |
+
i64 = ir.IntegerType.get_signless(64)
|
| 420 |
+
new_acc_regs = acc.value.registers.copy()
|
| 421 |
+
for mi in range(groups_m):
|
| 422 |
+
for ki in range(groups_k):
|
| 423 |
+
if a_in_regs:
|
| 424 |
+
a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile]
|
| 425 |
+
else:
|
| 426 |
+
a_mk = llvm_add(
|
| 427 |
+
a_desc_base,
|
| 428 |
+
c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
|
| 429 |
+
)
|
| 430 |
+
b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
|
| 431 |
+
new_acc_regs[mi : mi + 1] = wgmma_m64k128B(
|
| 432 |
+
new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
|
| 433 |
+
)
|
| 434 |
+
return WGMMAAccumulator(
|
| 435 |
+
_value=mgpu.FragmentedArray(
|
| 436 |
+
_registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT
|
| 437 |
+
),
|
| 438 |
+
_sync=False,
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
def wgmma_fence(array: mgpu.FragmentedArray):
|
| 443 |
+
"""Fences the array construction from WGMMA instructions.
|
| 444 |
+
|
| 445 |
+
This is a little workaround to force LLVM to initialize the PTX registers
|
| 446 |
+
before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats
|
| 447 |
+
in-register computation as pure and can move it after the fence, which is
|
| 448 |
+
explicitly disallowed by the PTX programming model.
|
| 449 |
+
"""
|
| 450 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 451 |
+
index = ir.IndexType.get()
|
| 452 |
+
dtype = array.mlir_dtype
|
| 453 |
+
src_vec_ty = ir.VectorType(array.registers.flat[0].type)
|
| 454 |
+
assert src_vec_ty.shape == [2]
|
| 455 |
+
|
| 456 |
+
if dtype == ir.F32Type.get():
|
| 457 |
+
regs = [ # pylint: disable=g-complex-comprehension
|
| 458 |
+
vector.extractelement(reg, position=c(pos, index))
|
| 459 |
+
for reg in array.registers.flat
|
| 460 |
+
for pos in range(2)
|
| 461 |
+
]
|
| 462 |
+
reg_dtype = dtype
|
| 463 |
+
reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs)
|
| 464 |
+
ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
|
| 465 |
+
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
|
| 466 |
+
regs = [_as_i32_reg(reg) for reg in array.registers.flat]
|
| 467 |
+
reg_dtype = i32
|
| 468 |
+
reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs)
|
| 469 |
+
ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
|
| 470 |
+
else:
|
| 471 |
+
raise NotImplementedError(dtype)
|
| 472 |
+
reg_constraints = ",".join(reg_constraints_list)
|
| 473 |
+
# Copy over the registers. ptxas should be able to remove the moves.
|
| 474 |
+
ptx_lines.append("wgmma.fence.sync.aligned")
|
| 475 |
+
ptx = ";\n".join(ptx_lines) + ";\n"
|
| 476 |
+
dtype_str = str(reg_dtype)
|
| 477 |
+
struct_ty = ir.Type.parse(
|
| 478 |
+
f"!llvm.struct<({','.join(dtype_str for _ in regs)})>"
|
| 479 |
+
)
|
| 480 |
+
acc_struct = llvm.inline_asm(
|
| 481 |
+
struct_ty, regs, ptx, reg_constraints,
|
| 482 |
+
asm_dialect=0, has_side_effects=True,
|
| 483 |
+
)
|
| 484 |
+
regs = [
|
| 485 |
+
llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs))
|
| 486 |
+
]
|
| 487 |
+
if dtype == ir.F32Type.get():
|
| 488 |
+
registers = _as_fragmented_reg_ndarray(
|
| 489 |
+
regs, array.mlir_dtype, array.registers.shape
|
| 490 |
+
)
|
| 491 |
+
elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
|
| 492 |
+
regs = [_unpack_i32(src_vec_ty, r) for r in regs]
|
| 493 |
+
registers = np.asarray(regs, dtype=object).reshape(array.registers.shape)
|
| 494 |
+
else:
|
| 495 |
+
raise NotImplementedError(dtype)
|
| 496 |
+
return mgpu.FragmentedArray(_registers=registers, _layout=array.layout)
|
| 497 |
+
|
| 498 |
+
|
| 499 |
+
def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]):
|
| 500 |
+
vec_regs = []
|
| 501 |
+
for first, second in zip(flat_regs[::2], flat_regs[1::2]):
|
| 502 |
+
vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
|
| 503 |
+
vec = llvm.insertelement(vec, first, position=_lc(0))
|
| 504 |
+
vec = llvm.insertelement(vec, second, position=_lc(1))
|
| 505 |
+
vec_regs.append(vec)
|
| 506 |
+
return np.asarray(vec_regs, dtype=object).reshape(shape)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def _as_i32_reg(v):
|
| 510 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 511 |
+
return llvm.extractelement(
|
| 512 |
+
vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
|
| 513 |
+
)
|
| 514 |
+
|
| 515 |
+
|
| 516 |
+
def _lc(x):
|
| 517 |
+
i32 = ir.IntegerType.get_signless(32)
|
| 518 |
+
return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/__init__.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Module for pallas, a JAX extension for custom kernels."""
|
| 16 |
+
|
| 17 |
+
from jax._src import pallas
|
| 18 |
+
from jax._src.pallas.core import BlockSpec
|
| 19 |
+
from jax._src.pallas.core import no_block_spec
|
| 20 |
+
from jax._src.pallas.core import Unblocked
|
| 21 |
+
from jax._src.pallas.core import unblocked
|
| 22 |
+
from jax._src.pallas.pallas_call import pallas_call
|
| 23 |
+
from jax._src.pallas.pallas_call import pallas_call_p
|
| 24 |
+
from jax._src.pallas.primitives import atomic_add
|
| 25 |
+
from jax._src.pallas.primitives import atomic_and
|
| 26 |
+
from jax._src.pallas.primitives import atomic_cas
|
| 27 |
+
from jax._src.pallas.primitives import atomic_max
|
| 28 |
+
from jax._src.pallas.primitives import atomic_min
|
| 29 |
+
from jax._src.pallas.primitives import atomic_or
|
| 30 |
+
from jax._src.pallas.primitives import atomic_xchg
|
| 31 |
+
from jax._src.pallas.primitives import atomic_xor
|
| 32 |
+
from jax._src.pallas.primitives import debug_print
|
| 33 |
+
from jax._src.pallas.primitives import dot
|
| 34 |
+
from jax._src.pallas.primitives import load
|
| 35 |
+
from jax._src.pallas.primitives import max_contiguous
|
| 36 |
+
from jax._src.pallas.primitives import multiple_of
|
| 37 |
+
from jax._src.pallas.primitives import num_programs
|
| 38 |
+
from jax._src.pallas.primitives import program_id
|
| 39 |
+
from jax._src.pallas.primitives import store
|
| 40 |
+
from jax._src.pallas.primitives import swap
|
| 41 |
+
from jax._src.pallas.utils import cdiv
|
| 42 |
+
from jax._src.pallas.utils import next_power_of_2
|
| 43 |
+
from jax._src.pallas.utils import strides_from_shape
|
| 44 |
+
from jax._src.pallas.utils import when
|
| 45 |
+
from jax._src.state.indexing import ds
|
| 46 |
+
from jax._src.state.indexing import dslice
|
| 47 |
+
from jax._src.state.indexing import Slice
|
| 48 |
+
from jax._src.state.primitives import broadcast_to
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/gpu.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Triton-specific Pallas APIs."""
|
| 16 |
+
|
| 17 |
+
from jax._src.pallas.triton.primitives import approx_tanh
|
| 18 |
+
from jax._src.pallas.triton.primitives import elementwise_inline_asm
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/ops/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# All files within ops should be treated as user code.
|
| 16 |
+
import os
|
| 17 |
+
import jax._src.source_info_util
|
| 18 |
+
jax._src.source_info_util.register_inclusion(os.path.dirname(__file__))
|
| 19 |
+
del os, jax
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/ops/gpu/attention.py
ADDED
|
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Module containing fused attention forward and backward pass."""
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import functools
|
| 19 |
+
from typing import Any, Optional
|
| 20 |
+
|
| 21 |
+
import jax
|
| 22 |
+
from jax import lax
|
| 23 |
+
from jax.experimental import pallas as pl
|
| 24 |
+
import jax.numpy as jnp
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def mha_forward_kernel(
|
| 31 |
+
q_ref,
|
| 32 |
+
k_ref,
|
| 33 |
+
v_ref, # Input arrays
|
| 34 |
+
segment_ids_ref: jax.Array | None, # segment_id arrays
|
| 35 |
+
o_ref: Any, # Output
|
| 36 |
+
*residual_refs: Any, # Residual outputs
|
| 37 |
+
num_heads: int,
|
| 38 |
+
sm_scale: float,
|
| 39 |
+
causal: bool,
|
| 40 |
+
block_q: int,
|
| 41 |
+
block_d: int,
|
| 42 |
+
block_k: int,
|
| 43 |
+
):
|
| 44 |
+
seq_len = q_ref.shape[0]
|
| 45 |
+
start_q = pl.program_id(0)
|
| 46 |
+
|
| 47 |
+
# o is the buffer where we accumulate the output on sram.
|
| 48 |
+
# m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
|
| 49 |
+
m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf')
|
| 50 |
+
l_i = jnp.zeros(block_q, dtype=jnp.float32)
|
| 51 |
+
# acc is the buffer where we accumulate the output on sram.
|
| 52 |
+
o = jnp.zeros((block_q, block_d), dtype=jnp.float32)
|
| 53 |
+
|
| 54 |
+
# Load q: it will stay in L1 throughout. Indices form a matrix because we
|
| 55 |
+
# read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
|
| 56 |
+
# q tile has shape [block_q, block_d], block_d == head_dim.
|
| 57 |
+
curr_q_slice = pl.dslice(start_q * block_q, block_q)
|
| 58 |
+
q = pl.load(q_ref, (curr_q_slice, pl.dslice(None)))
|
| 59 |
+
q_segment_ids = (
|
| 60 |
+
None
|
| 61 |
+
if segment_ids_ref is None
|
| 62 |
+
else pl.load(segment_ids_ref, (curr_q_slice,))
|
| 63 |
+
)
|
| 64 |
+
# In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
|
| 65 |
+
# (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
|
| 66 |
+
# Here we only loop over blocks of kv to process entire seq_len, the loop over
|
| 67 |
+
# blocks of q is carried out by the grid.
|
| 68 |
+
def body(start_k, carry):
|
| 69 |
+
o_prev, m_prev, l_prev = carry
|
| 70 |
+
curr_k_slice = pl.dslice(start_k * block_k, block_k)
|
| 71 |
+
|
| 72 |
+
k = pl.load(k_ref, (curr_k_slice, slice(None)))
|
| 73 |
+
kv_segment_ids = (
|
| 74 |
+
None
|
| 75 |
+
if segment_ids_ref is None
|
| 76 |
+
else pl.load(segment_ids_ref, (curr_k_slice,))
|
| 77 |
+
)
|
| 78 |
+
qk = pl.dot(q, k.T) # [block_q, block_k]
|
| 79 |
+
if sm_scale != 1.:
|
| 80 |
+
qk *= sm_scale # [block_q, block_k]
|
| 81 |
+
|
| 82 |
+
# Avoids Triton crash.
|
| 83 |
+
# if num_heads > 2:
|
| 84 |
+
# qk = qk.astype(q_ref.dtype)
|
| 85 |
+
# qk = qk.astype(jnp.float32)
|
| 86 |
+
|
| 87 |
+
if causal or segment_ids_ref is not None:
|
| 88 |
+
mask = None
|
| 89 |
+
if segment_ids_ref is not None:
|
| 90 |
+
mask = segment_mask(q_segment_ids, kv_segment_ids)
|
| 91 |
+
if causal:
|
| 92 |
+
span_q = start_q * block_q + jnp.arange(block_q)
|
| 93 |
+
span_k = start_k * block_k + jnp.arange(block_k)
|
| 94 |
+
causal_mask = span_q[:, None] >= span_k[None, :]
|
| 95 |
+
mask = (
|
| 96 |
+
causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
|
| 97 |
+
)
|
| 98 |
+
# Apply mask to qk.
|
| 99 |
+
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
|
| 100 |
+
|
| 101 |
+
m_curr = qk.max(axis=-1)
|
| 102 |
+
m_next = jnp.maximum(m_prev, m_curr)
|
| 103 |
+
correction = jnp.exp(m_prev - m_next)
|
| 104 |
+
l_prev_corr = correction * l_prev
|
| 105 |
+
s_curr = jnp.exp(
|
| 106 |
+
qk - m_next[:, None]
|
| 107 |
+
) # Use m_next instead of m_curr to avoid a correction on l_curr
|
| 108 |
+
l_curr = s_curr.sum(axis=-1)
|
| 109 |
+
l_next = l_prev_corr + l_curr
|
| 110 |
+
l_next_rcp = 1. / l_next
|
| 111 |
+
s_curr = s_curr * l_next_rcp[:, None]
|
| 112 |
+
o_prev_corr = (l_prev_corr * l_next_rcp)[:, None] * o_prev
|
| 113 |
+
v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d)))
|
| 114 |
+
o_curr = pl.dot(s_curr.astype(v.dtype), v)
|
| 115 |
+
|
| 116 |
+
o_next = o_prev_corr + o_curr
|
| 117 |
+
return o_next, m_next, l_next
|
| 118 |
+
if causal:
|
| 119 |
+
# Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q)
|
| 120 |
+
upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k)
|
| 121 |
+
else:
|
| 122 |
+
upper_bound = pl.cdiv(seq_len, block_k) # type: ignore
|
| 123 |
+
o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))
|
| 124 |
+
|
| 125 |
+
if residual_refs:
|
| 126 |
+
l_ref, m_ref = residual_refs
|
| 127 |
+
pl.store(l_ref, (curr_q_slice,), l_i)
|
| 128 |
+
pl.store(m_ref, (curr_q_slice,), m_i)
|
| 129 |
+
# Write output to dram.
|
| 130 |
+
o = o.astype(o_ref.dtype)
|
| 131 |
+
pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def segment_mask(
|
| 135 |
+
q_segment_ids: jax.Array,
|
| 136 |
+
kv_segment_ids: jax.Array,
|
| 137 |
+
):
|
| 138 |
+
# [B, T, 1] or [T, 1]
|
| 139 |
+
q_segment_ids = jnp.expand_dims(q_segment_ids, axis=-1)
|
| 140 |
+
# [B, 1, S] or [1, S]
|
| 141 |
+
if kv_segment_ids.ndim == 1:
|
| 142 |
+
kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=0)
|
| 143 |
+
else:
|
| 144 |
+
kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=1)
|
| 145 |
+
return jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
@functools.partial(
|
| 149 |
+
jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
|
| 150 |
+
)
|
| 151 |
+
@functools.partial(
|
| 152 |
+
jax.jit,
|
| 153 |
+
static_argnames=[
|
| 154 |
+
"sm_scale",
|
| 155 |
+
"causal",
|
| 156 |
+
"block_q",
|
| 157 |
+
"block_k",
|
| 158 |
+
"backward_pass_impl",
|
| 159 |
+
"num_warps",
|
| 160 |
+
"num_stages",
|
| 161 |
+
"grid",
|
| 162 |
+
"interpret",
|
| 163 |
+
"debug",
|
| 164 |
+
],
|
| 165 |
+
)
|
| 166 |
+
def mha(
|
| 167 |
+
q,
|
| 168 |
+
k,
|
| 169 |
+
v,
|
| 170 |
+
segment_ids: jnp.ndarray | None,
|
| 171 |
+
sm_scale: float = 1.0,
|
| 172 |
+
causal: bool = False,
|
| 173 |
+
block_q: int = 128,
|
| 174 |
+
block_k: int = 128,
|
| 175 |
+
backward_pass_impl: str = "triton",
|
| 176 |
+
num_warps: int | None = None,
|
| 177 |
+
num_stages: int = 2,
|
| 178 |
+
grid: tuple[int, ...] | None = None,
|
| 179 |
+
interpret: bool = False,
|
| 180 |
+
debug: bool = False,
|
| 181 |
+
):
|
| 182 |
+
del backward_pass_impl
|
| 183 |
+
batch_size, seq_len, num_heads, head_dim = q.shape
|
| 184 |
+
block_q = min(block_q, seq_len)
|
| 185 |
+
block_k = min(block_k, seq_len)
|
| 186 |
+
# Heuristics.
|
| 187 |
+
grid_ = grid
|
| 188 |
+
if grid_ is None:
|
| 189 |
+
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
|
| 190 |
+
|
| 191 |
+
num_warps_ = num_warps
|
| 192 |
+
if num_warps_ is None:
|
| 193 |
+
num_warps_ = 4 if head_dim <= 64 else 8
|
| 194 |
+
kernel = functools.partial(mha_forward_kernel, num_heads=num_heads,
|
| 195 |
+
sm_scale=sm_scale, block_q=block_q,
|
| 196 |
+
block_k=block_k, block_d=head_dim,
|
| 197 |
+
causal=causal)
|
| 198 |
+
|
| 199 |
+
in_specs = [
|
| 200 |
+
pl.BlockSpec(
|
| 201 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 202 |
+
),
|
| 203 |
+
pl.BlockSpec(
|
| 204 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 205 |
+
),
|
| 206 |
+
pl.BlockSpec(
|
| 207 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 208 |
+
),
|
| 209 |
+
]
|
| 210 |
+
in_specs.append(
|
| 211 |
+
None # type: ignore[arg-type]
|
| 212 |
+
if segment_ids is None
|
| 213 |
+
else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len))
|
| 214 |
+
)
|
| 215 |
+
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
|
| 216 |
+
return pl.pallas_call(
|
| 217 |
+
kernel,
|
| 218 |
+
grid=grid_,
|
| 219 |
+
in_specs=in_specs,
|
| 220 |
+
out_specs=pl.BlockSpec(
|
| 221 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 222 |
+
),
|
| 223 |
+
compiler_params=dict(
|
| 224 |
+
triton=dict(num_warps=num_warps_, num_stages=num_stages)
|
| 225 |
+
),
|
| 226 |
+
out_shape=out_shape,
|
| 227 |
+
debug=debug,
|
| 228 |
+
interpret=interpret,
|
| 229 |
+
name="mha_forward",
|
| 230 |
+
)(q, k, v, segment_ids)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def _mha_forward(
|
| 234 |
+
q,
|
| 235 |
+
k,
|
| 236 |
+
v,
|
| 237 |
+
segment_ids: jax.Array | None,
|
| 238 |
+
sm_scale: float,
|
| 239 |
+
causal: bool,
|
| 240 |
+
block_q: int,
|
| 241 |
+
block_k: int,
|
| 242 |
+
backward_pass_impl: str,
|
| 243 |
+
num_warps: int | None,
|
| 244 |
+
num_stages: int,
|
| 245 |
+
grid: Any,
|
| 246 |
+
interpret: bool,
|
| 247 |
+
debug: bool,
|
| 248 |
+
):
|
| 249 |
+
del backward_pass_impl
|
| 250 |
+
batch_size, seq_len, num_heads, head_dim = q.shape
|
| 251 |
+
block_q = min(block_q, seq_len)
|
| 252 |
+
block_k = min(block_k, seq_len)
|
| 253 |
+
# Heuristics.
|
| 254 |
+
grid_ = grid
|
| 255 |
+
if grid_ is None:
|
| 256 |
+
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
|
| 257 |
+
|
| 258 |
+
num_warps_ = num_warps
|
| 259 |
+
if num_warps_ is None:
|
| 260 |
+
num_warps_ = 4 if head_dim <= 64 else 8
|
| 261 |
+
kernel = functools.partial(mha_forward_kernel, num_heads=num_heads,
|
| 262 |
+
sm_scale=sm_scale, causal=causal, block_q=block_q,
|
| 263 |
+
block_k=block_k, block_d=head_dim)
|
| 264 |
+
out_shape = [
|
| 265 |
+
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
|
| 266 |
+
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l
|
| 267 |
+
dtype=jnp.float32),
|
| 268 |
+
jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m
|
| 269 |
+
dtype=jnp.float32)
|
| 270 |
+
]
|
| 271 |
+
in_specs = [
|
| 272 |
+
pl.BlockSpec(
|
| 273 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 274 |
+
),
|
| 275 |
+
pl.BlockSpec(
|
| 276 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 277 |
+
),
|
| 278 |
+
pl.BlockSpec(
|
| 279 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 280 |
+
),
|
| 281 |
+
]
|
| 282 |
+
in_specs.append(
|
| 283 |
+
None # type: ignore[arg-type]
|
| 284 |
+
if segment_ids is None
|
| 285 |
+
else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len))
|
| 286 |
+
)
|
| 287 |
+
out, l, m = pl.pallas_call(
|
| 288 |
+
kernel,
|
| 289 |
+
grid=grid_,
|
| 290 |
+
in_specs=in_specs,
|
| 291 |
+
out_specs=[
|
| 292 |
+
pl.BlockSpec(
|
| 293 |
+
lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 294 |
+
),
|
| 295 |
+
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
| 296 |
+
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
| 297 |
+
],
|
| 298 |
+
compiler_params=dict(
|
| 299 |
+
triton=dict(num_warps=num_warps_, num_stages=num_stages)
|
| 300 |
+
),
|
| 301 |
+
out_shape=out_shape,
|
| 302 |
+
debug=debug,
|
| 303 |
+
interpret=interpret,
|
| 304 |
+
name="mha_forward",
|
| 305 |
+
)(q, k, v, segment_ids)
|
| 306 |
+
return out, (q, k, v, segment_ids, out, l, m)
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def _preprocess_backward_kernel(out_ref, dout_ref, l_ref,
|
| 310 |
+
new_dout_ref, delta_ref, *,
|
| 311 |
+
block_q: int):
|
| 312 |
+
pid_m = pl.program_id(0)
|
| 313 |
+
|
| 314 |
+
off_m = pl.ds(pid_m * block_q, block_q)
|
| 315 |
+
# load
|
| 316 |
+
o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32)
|
| 317 |
+
do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32)
|
| 318 |
+
denom = pl.load(l_ref, (off_m,)).astype(jnp.float32)
|
| 319 |
+
# compute
|
| 320 |
+
do = do / denom[:, None]
|
| 321 |
+
delta = jnp.sum(o * do, axis=1)
|
| 322 |
+
# write-back
|
| 323 |
+
pl.store(new_dout_ref, (off_m, slice(None)),
|
| 324 |
+
do.astype(new_dout_ref.dtype))
|
| 325 |
+
pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype))
|
| 326 |
+
|
| 327 |
+
@jax.named_scope("preprocess_backward")
|
| 328 |
+
def _preprocess_backward(out, do, l, block_q: int,
|
| 329 |
+
debug: bool, interpret: bool):
|
| 330 |
+
batch_size, seq_len, num_heads, head_dim = out.shape
|
| 331 |
+
out_shape = [
|
| 332 |
+
jax.ShapeDtypeStruct(do.shape, do.dtype),
|
| 333 |
+
jax.ShapeDtypeStruct(l.shape, l.dtype),
|
| 334 |
+
]
|
| 335 |
+
do_scaled, delta = pl.pallas_call(
|
| 336 |
+
functools.partial(_preprocess_backward_kernel, block_q=block_q),
|
| 337 |
+
grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
|
| 338 |
+
in_specs=[
|
| 339 |
+
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
| 340 |
+
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
| 341 |
+
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
| 342 |
+
],
|
| 343 |
+
out_specs=[
|
| 344 |
+
pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
|
| 345 |
+
pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
|
| 346 |
+
],
|
| 347 |
+
compiler_params=dict(
|
| 348 |
+
triton=dict(num_warps=4, num_stages=3)
|
| 349 |
+
),
|
| 350 |
+
out_shape=out_shape,
|
| 351 |
+
debug=debug,
|
| 352 |
+
interpret=interpret,
|
| 353 |
+
name="mha_preprocess_backward")(out, do, l)
|
| 354 |
+
return do_scaled, delta
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def mha_backward_kernel(
|
| 358 |
+
# Inputs
|
| 359 |
+
q_ref,
|
| 360 |
+
k_ref,
|
| 361 |
+
v_ref,
|
| 362 |
+
segment_ids_ref: jax.Array | None,
|
| 363 |
+
out_ref,
|
| 364 |
+
do_scaled_ref,
|
| 365 |
+
l_ref,
|
| 366 |
+
m_ref,
|
| 367 |
+
delta_ref,
|
| 368 |
+
_,
|
| 369 |
+
# Outputs
|
| 370 |
+
dq_ref,
|
| 371 |
+
dk_ref,
|
| 372 |
+
dv_ref,
|
| 373 |
+
*,
|
| 374 |
+
sm_scale: float,
|
| 375 |
+
causal: bool,
|
| 376 |
+
block_q: int,
|
| 377 |
+
block_d: int,
|
| 378 |
+
block_k: int,
|
| 379 |
+
):
|
| 380 |
+
del out_ref, l_ref # Not needed
|
| 381 |
+
seq_len = q_ref.shape[0]
|
| 382 |
+
|
| 383 |
+
def outer_loop(start_k, _):
|
| 384 |
+
|
| 385 |
+
dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
|
| 386 |
+
dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
|
| 387 |
+
k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
|
| 388 |
+
v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
|
| 389 |
+
span_k = start_k * block_k + jnp.arange(block_k)
|
| 390 |
+
kv_segment_ids = (
|
| 391 |
+
None
|
| 392 |
+
if segment_ids_ref is None
|
| 393 |
+
else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),))
|
| 394 |
+
)
|
| 395 |
+
|
| 396 |
+
def inner_loop(start_q, carry):
|
| 397 |
+
dv, dk = carry
|
| 398 |
+
q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
|
| 399 |
+
qk = pl.dot(q, k.T)
|
| 400 |
+
qk = qk.astype(q_ref.dtype)
|
| 401 |
+
qk = qk.astype(jnp.float32)
|
| 402 |
+
if sm_scale != 1.0:
|
| 403 |
+
qk *= sm_scale
|
| 404 |
+
|
| 405 |
+
q_segment_ids = (
|
| 406 |
+
None
|
| 407 |
+
if segment_ids_ref is None
|
| 408 |
+
else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),))
|
| 409 |
+
)
|
| 410 |
+
|
| 411 |
+
if causal or segment_ids_ref is not None:
|
| 412 |
+
mask = None
|
| 413 |
+
if segment_ids_ref is not None:
|
| 414 |
+
mask = segment_mask(q_segment_ids, kv_segment_ids)
|
| 415 |
+
|
| 416 |
+
if causal:
|
| 417 |
+
span_q = start_q * block_q + jnp.arange(block_q)
|
| 418 |
+
causal_mask = span_q[:, None] >= span_k[None, :]
|
| 419 |
+
mask = (
|
| 420 |
+
causal_mask
|
| 421 |
+
if mask is None
|
| 422 |
+
else jnp.logical_and(mask, causal_mask)
|
| 423 |
+
)
|
| 424 |
+
qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
|
| 425 |
+
|
| 426 |
+
m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
|
| 427 |
+
p = jnp.exp(qk - m[:, None])
|
| 428 |
+
do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
|
| 429 |
+
dv = dv + pl.dot(p.astype(do.dtype).T, do)
|
| 430 |
+
di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
|
| 431 |
+
dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
|
| 432 |
+
dp = dp + pl.dot(do, v.T)
|
| 433 |
+
ds = p * dp
|
| 434 |
+
if sm_scale != 1.0:
|
| 435 |
+
ds = ds * sm_scale
|
| 436 |
+
dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
|
| 437 |
+
dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
|
| 438 |
+
slice(None)), eviction_policy="evict_last")
|
| 439 |
+
dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
|
| 440 |
+
pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
|
| 441 |
+
slice(None)), dq, eviction_policy="evict_last")
|
| 442 |
+
return dv, dk
|
| 443 |
+
if causal:
|
| 444 |
+
lower_bound = lax.div(start_k * block_k, block_q)
|
| 445 |
+
else:
|
| 446 |
+
lower_bound = 0
|
| 447 |
+
dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop,
|
| 448 |
+
(dv, dk))
|
| 449 |
+
pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
|
| 450 |
+
slice(None)), dv.astype(dv_ref.dtype))
|
| 451 |
+
pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
|
| 452 |
+
slice(None)), dk.astype(dk_ref.dtype))
|
| 453 |
+
lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None)
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
|
| 457 |
+
backward_pass_impl: str, num_warps: int | None,
|
| 458 |
+
num_stages: int, grid: Any, interpret: bool,
|
| 459 |
+
debug: bool, res, do):
|
| 460 |
+
del num_warps, num_stages, grid
|
| 461 |
+
q, k, v, segment_ids, out, l, m = res
|
| 462 |
+
|
| 463 |
+
if backward_pass_impl == "xla":
|
| 464 |
+
return jax.vjp(
|
| 465 |
+
functools.partial(mha_reference, sm_scale=sm_scale, causal=causal),
|
| 466 |
+
q,
|
| 467 |
+
k,
|
| 468 |
+
v,
|
| 469 |
+
segment_ids,
|
| 470 |
+
)[1](do)
|
| 471 |
+
elif backward_pass_impl == "triton":
|
| 472 |
+
batch_size, seq_len, num_heads, head_dim = q.shape
|
| 473 |
+
block_q = min(block_q, seq_len)
|
| 474 |
+
block_k = min(block_k, seq_len)
|
| 475 |
+
do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
|
| 476 |
+
# We accumulate into dq so we need to initialize it to zeros.
|
| 477 |
+
dq = jnp.zeros(q.shape, jnp.float32)
|
| 478 |
+
out_shapes = [
|
| 479 |
+
jax.ShapeDtypeStruct(dq.shape, dq.dtype),
|
| 480 |
+
jax.ShapeDtypeStruct(k.shape, k.dtype),
|
| 481 |
+
jax.ShapeDtypeStruct(v.shape, v.dtype),
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
in_specs = [
|
| 485 |
+
pl.BlockSpec(
|
| 486 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 487 |
+
),
|
| 488 |
+
pl.BlockSpec(
|
| 489 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 490 |
+
),
|
| 491 |
+
pl.BlockSpec(
|
| 492 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 493 |
+
),
|
| 494 |
+
pl.BlockSpec(
|
| 495 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 496 |
+
),
|
| 497 |
+
pl.BlockSpec(
|
| 498 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 499 |
+
),
|
| 500 |
+
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
|
| 501 |
+
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
|
| 502 |
+
pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
|
| 503 |
+
pl.BlockSpec(
|
| 504 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 505 |
+
),
|
| 506 |
+
]
|
| 507 |
+
if segment_ids is None:
|
| 508 |
+
in_specs.insert(3, None) # type: ignore[arg-type]
|
| 509 |
+
input_output_aliases = {8: 0}
|
| 510 |
+
else:
|
| 511 |
+
in_specs.insert(3, pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len)))
|
| 512 |
+
input_output_aliases = {9: 0}
|
| 513 |
+
grid = (batch_size, num_heads)
|
| 514 |
+
# TODO(sharadmv): figure out why num_warps=8 doesn't work!
|
| 515 |
+
num_warps = 8
|
| 516 |
+
dq, dk, dv = pl.pallas_call(
|
| 517 |
+
functools.partial(
|
| 518 |
+
mha_backward_kernel,
|
| 519 |
+
block_q=block_q,
|
| 520 |
+
block_d=head_dim,
|
| 521 |
+
block_k=block_k,
|
| 522 |
+
sm_scale=sm_scale,
|
| 523 |
+
causal=causal,
|
| 524 |
+
),
|
| 525 |
+
grid=grid,
|
| 526 |
+
out_shape=out_shapes,
|
| 527 |
+
in_specs=in_specs,
|
| 528 |
+
out_specs=[
|
| 529 |
+
pl.BlockSpec(
|
| 530 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 531 |
+
),
|
| 532 |
+
pl.BlockSpec(
|
| 533 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 534 |
+
),
|
| 535 |
+
pl.BlockSpec(
|
| 536 |
+
lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
|
| 537 |
+
),
|
| 538 |
+
],
|
| 539 |
+
name="mha_backward",
|
| 540 |
+
debug=debug,
|
| 541 |
+
interpret=interpret,
|
| 542 |
+
compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
|
| 543 |
+
input_output_aliases=input_output_aliases,
|
| 544 |
+
)(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq)
|
| 545 |
+
else:
|
| 546 |
+
raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
|
| 547 |
+
return dq.astype(q.dtype), dk, dv, None
|
| 548 |
+
mha.defvjp(_mha_forward, _mha_backward)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
@functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
|
| 552 |
+
def mha_reference(
|
| 553 |
+
q,
|
| 554 |
+
k,
|
| 555 |
+
v,
|
| 556 |
+
segment_ids: jnp.ndarray | None,
|
| 557 |
+
sm_scale=1.0,
|
| 558 |
+
causal: bool = False,
|
| 559 |
+
):
|
| 560 |
+
q_seq_len = q.shape[1]
|
| 561 |
+
kv_seq_len = k.shape[1]
|
| 562 |
+
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
|
| 563 |
+
mask = None
|
| 564 |
+
if segment_ids is not None:
|
| 565 |
+
mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1)
|
| 566 |
+
mask = jnp.broadcast_to(mask, logits.shape)
|
| 567 |
+
if causal:
|
| 568 |
+
causal_mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool))
|
| 569 |
+
causal_mask = jnp.broadcast_to(causal_mask, logits.shape)
|
| 570 |
+
mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
|
| 571 |
+
logits = logits if mask is None else jnp.where(mask, logits, float("-inf"))
|
| 572 |
+
weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
|
| 573 |
+
return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
|
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/tpu.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 The JAX Authors.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""Mosaic-specific Pallas APIs."""
|
| 16 |
+
|
| 17 |
+
from jax._src.pallas.mosaic import core
|
| 18 |
+
from jax._src.pallas.mosaic.core import dma_semaphore
|
| 19 |
+
from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
|
| 20 |
+
from jax._src.pallas.mosaic.core import semaphore
|
| 21 |
+
from jax._src.pallas.mosaic.core import SemaphoreType
|
| 22 |
+
from jax._src.pallas.mosaic.core import TPUMemorySpace
|
| 23 |
+
from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
|
| 24 |
+
from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
|
| 25 |
+
from jax._src.pallas.mosaic.lowering import LoweringException
|
| 26 |
+
from jax._src.pallas.mosaic.pipeline import BufferedRef
|
| 27 |
+
from jax._src.pallas.mosaic.pipeline import emit_pipeline
|
| 28 |
+
from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
|
| 29 |
+
from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
|
| 30 |
+
from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
|
| 31 |
+
from jax._src.pallas.mosaic.primitives import async_copy
|
| 32 |
+
from jax._src.pallas.mosaic.primitives import async_remote_copy
|
| 33 |
+
from jax._src.pallas.mosaic.primitives import bitcast
|
| 34 |
+
from jax._src.pallas.mosaic.primitives import delay
|
| 35 |
+
from jax._src.pallas.mosaic.primitives import device_id
|
| 36 |
+
from jax._src.pallas.mosaic.primitives import DeviceIdType
|
| 37 |
+
from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
|
| 38 |
+
from jax._src.pallas.mosaic.primitives import make_async_copy
|
| 39 |
+
from jax._src.pallas.mosaic.primitives import make_async_remote_copy
|
| 40 |
+
from jax._src.pallas.mosaic.primitives import repeat
|
| 41 |
+
from jax._src.pallas.mosaic.primitives import roll
|
| 42 |
+
from jax._src.pallas.mosaic.primitives import run_scoped
|
| 43 |
+
from jax._src.pallas.mosaic.primitives import semaphore_read
|
| 44 |
+
from jax._src.pallas.mosaic.primitives import semaphore_signal
|
| 45 |
+
from jax._src.pallas.mosaic.primitives import semaphore_wait
|
| 46 |
+
from jax._src.pallas.mosaic.primitives import prng_seed
|
| 47 |
+
from jax._src.pallas.mosaic.primitives import prng_random_bits
|
| 48 |
+
from jax._src.tpu_custom_call import CostEstimate
|
| 49 |
+
|
| 50 |
+
ANY = TPUMemorySpace.ANY
|
| 51 |
+
CMEM = TPUMemorySpace.CMEM
|
| 52 |
+
SMEM = TPUMemorySpace.SMEM
|
| 53 |
+
VMEM = TPUMemorySpace.VMEM
|
external/alphageometry/README.md
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Solving Olympiad Geometry without Human Demonstrations
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
This repository contains the code necessary to
|
| 6 |
+
reproduce DDAR and AlphaGeometry,
|
| 7 |
+
the two geometry theorem provers
|
| 8 |
+
introduced in the [Nature 2024](https://www.nature.com/articles/s41586-023-06747-5) paper:
|
| 9 |
+
|
| 10 |
+
*<center>"Solving Olympiad Geometry without Human Demonstrations".</center>*
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
</br>
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
<center>
|
| 17 |
+
<img alt="fig1" width="800px" src="fig1.svg">
|
| 18 |
+
</center>
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
## Dependencies
|
| 22 |
+
|
| 23 |
+
For the instructions presented below,
|
| 24 |
+
we use Python 3.10.9, and dependencies with their exact
|
| 25 |
+
version numbers listed in `requirements.txt`.
|
| 26 |
+
|
| 27 |
+
Our code depends on `meliad`, which is
|
| 28 |
+
not a registered package with `pip`. See instructions below
|
| 29 |
+
for how to manually install `meliad`.
|
| 30 |
+
|
| 31 |
+
Note that one can still run the DDAR solver
|
| 32 |
+
without the `meliad` and `sentencepiece` dependencies.
|
| 33 |
+
|
| 34 |
+
## Run the instructions
|
| 35 |
+
|
| 36 |
+
All instructions in this `README.md` can be run in one go by:
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
bash run.sh
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
Below, we explain these instructions step-by-step.
|
| 43 |
+
|
| 44 |
+
## Install dependencies, download weights and vocabulary.
|
| 45 |
+
|
| 46 |
+
Installation is done in a virtual environment:
|
| 47 |
+
|
| 48 |
+
```
|
| 49 |
+
virtualenv -p python3 .
|
| 50 |
+
source ./bin/activate
|
| 51 |
+
pip install --require-hashes -r requirements.txt
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
Download weights and vocabulary:
|
| 55 |
+
|
| 56 |
+
```
|
| 57 |
+
bash download.sh
|
| 58 |
+
DATA=ag_ckpt_vocab
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
Finally, install `meliad` separately as it is not
|
| 62 |
+
registered with `pip`:
|
| 63 |
+
|
| 64 |
+
```
|
| 65 |
+
MELIAD_PATH=meliad_lib/meliad
|
| 66 |
+
mkdir -p $MELIAD_PATH
|
| 67 |
+
git clone https://github.com/google-research/meliad $MELIAD_PATH
|
| 68 |
+
export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
## Set up common flags
|
| 72 |
+
|
| 73 |
+
Before running the python scripts,
|
| 74 |
+
let us first prepare some commonly used flags.
|
| 75 |
+
The symbolic engine needs definitions and deduction rules to operate.
|
| 76 |
+
These definitions and rules are provided in two text files
|
| 77 |
+
`defs.txt` and `rules.txt`.
|
| 78 |
+
|
| 79 |
+
```shell
|
| 80 |
+
DDAR_ARGS=(
|
| 81 |
+
--defs_file=$(pwd)/defs.txt \
|
| 82 |
+
--rules_file=$(pwd)/rules.txt \
|
| 83 |
+
);
|
| 84 |
+
```
|
| 85 |
+
|
| 86 |
+
Next, we define the flags relevant to the proof search.
|
| 87 |
+
To reproduce the simple examples below,
|
| 88 |
+
we use lightweight values for the proof search parameters:
|
| 89 |
+
|
| 90 |
+
```shell
|
| 91 |
+
BATCH_SIZE=2
|
| 92 |
+
BEAM_SIZE=2
|
| 93 |
+
DEPTH=2
|
| 94 |
+
|
| 95 |
+
SEARCH_ARGS=(
|
| 96 |
+
--beam_size=$BEAM_SIZE
|
| 97 |
+
--search_depth=$DEPTH
|
| 98 |
+
)
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
NOTE: The results in our paper can be obtained by setting
|
| 102 |
+
`BATCH_SIZE=32`, `BEAM_SIZE=512`, `DEPTH=16`
|
| 103 |
+
as described in section Methods.
|
| 104 |
+
To stay under IMO time limits, 4 V100-GPUs and 250 CPU workers
|
| 105 |
+
are needed as shown in Extended Data - Figure 1.
|
| 106 |
+
Note that we also strip away other memory/speed optimizations
|
| 107 |
+
due to internal dependencies and to promote code clarity.
|
| 108 |
+
|
| 109 |
+
Assume the downloaded checkpoint and vocabulary is placed in `DATA`,
|
| 110 |
+
and the installed `meliad` source code is at `MELIAD_PATH`.
|
| 111 |
+
We make use of the `gin` library to manage model configurations,
|
| 112 |
+
following `meliad` conventions. We now define the flags relevant to the
|
| 113 |
+
language model:
|
| 114 |
+
|
| 115 |
+
```shell
|
| 116 |
+
LM_ARGS=(
|
| 117 |
+
--ckpt_path=$DATA \
|
| 118 |
+
--vocab_path=$DATA/geometry.757.model
|
| 119 |
+
--gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \
|
| 120 |
+
--gin_file=base_htrans.gin \
|
| 121 |
+
--gin_file=size/medium_150M.gin \
|
| 122 |
+
--gin_file=options/positions_t5.gin \
|
| 123 |
+
--gin_file=options/lr_cosine_decay.gin \
|
| 124 |
+
--gin_file=options/seq_1024_nocache.gin \
|
| 125 |
+
--gin_file=geometry_150M_generate.gin \
|
| 126 |
+
--gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \
|
| 127 |
+
--gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \
|
| 128 |
+
--gin_param=TransformerTaskConfig.sequence_length=128 \
|
| 129 |
+
--gin_param=Trainer.restore_state_variables=False
|
| 130 |
+
);
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
TIP: Note that you can still run the DDAR solver
|
| 134 |
+
without defining `SEARCH_ARGS` and `LM_ARGS`.
|
| 135 |
+
In such case, simply disable the import of the `lm_inference` module
|
| 136 |
+
inside `alphageometry.py`.
|
| 137 |
+
|
| 138 |
+
## Run DDAR
|
| 139 |
+
|
| 140 |
+
The script loads a problem by reading a list of problems
|
| 141 |
+
from a text file and solves the specific problem in the list according
|
| 142 |
+
to its name. We pass these two pieces of information through the flags
|
| 143 |
+
`--problems_file` and `--problem_name`.
|
| 144 |
+
We use `--mode=ddar` to indicate that we want to use the DDAR solver.
|
| 145 |
+
|
| 146 |
+
Below we showed this solver solving IMO 2000 P1:
|
| 147 |
+
|
| 148 |
+
```shell
|
| 149 |
+
python -m alphageometry \
|
| 150 |
+
--alsologtostderr \
|
| 151 |
+
--problems_file=$(pwd)/imo_ag_30.txt \
|
| 152 |
+
--problem_name=translated_imo_2000_p1 \
|
| 153 |
+
--mode=ddar \
|
| 154 |
+
"${DDAR_ARGS[@]}"
|
| 155 |
+
```
|
| 156 |
+
|
| 157 |
+
Expect the following output
|
| 158 |
+
|
| 159 |
+
```shell
|
| 160 |
+
graph.py:468] translated_imo_2000_p1
|
| 161 |
+
graph.py:469] a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q
|
| 162 |
+
ddar.py:41] Depth 1/1000 time = 1.7772269248962402
|
| 163 |
+
ddar.py:41] Depth 2/1000 time = 5.63526177406311
|
| 164 |
+
ddar.py:41] Depth 3/1000 time = 6.883412837982178
|
| 165 |
+
ddar.py:41] Depth 4/1000 time = 10.275688409805298
|
| 166 |
+
ddar.py:41] Depth 5/1000 time = 12.048273086547852
|
| 167 |
+
alphageometry.py:190]
|
| 168 |
+
==========================
|
| 169 |
+
* From theorem premises:
|
| 170 |
+
A B G1 G2 M N C D E P Q : Points
|
| 171 |
+
AG_1 ⟂ AB [00]
|
| 172 |
+
BA ⟂ G_2B [01]
|
| 173 |
+
G_2M = G_2B [02]
|
| 174 |
+
G_1M = G_1A [03]
|
| 175 |
+
|
| 176 |
+
...
|
| 177 |
+
[log omitted]
|
| 178 |
+
...
|
| 179 |
+
|
| 180 |
+
036. ∠QEB = ∠(QP-EA) [46] & ∠(BE-QP) = ∠AEP [55] ⇒ ∠EQP = ∠QPE [56]
|
| 181 |
+
037. ∠PQE = ∠EPQ [56] ⇒ EP = EQ
|
| 182 |
+
|
| 183 |
+
==========================
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
The output first includes a list of relevant premises that it uses,
|
| 187 |
+
and then proof steps that gradually build up the proof.
|
| 188 |
+
All predicates are numbered to track how they are derived
|
| 189 |
+
from the premises, and to show that the proof is fully justified.
|
| 190 |
+
|
| 191 |
+
TIP: Additionally passing the flag `--out_file=path/to/output/text/file.txt`
|
| 192 |
+
will write the proof to a text file.
|
| 193 |
+
|
| 194 |
+
Running on all problems in `imo_ag_30.txt` will yield solutions to
|
| 195 |
+
14 of them, as reported in Table 1 in our paper.
|
| 196 |
+
|
| 197 |
+
## Run AlphaGeometry:
|
| 198 |
+
|
| 199 |
+
As a simple example, we load `--problem_name=orthocenter`
|
| 200 |
+
from `--problem_file=examples.txt`.
|
| 201 |
+
This time, we pass `--mode=alphageometry` to use the AlphaGeometry solver
|
| 202 |
+
and pass the `SEARCH_ARGS` and `LM_ARGS` flags.
|
| 203 |
+
|
| 204 |
+
```shell
|
| 205 |
+
python -m alphageometry \
|
| 206 |
+
--alsologtostderr \
|
| 207 |
+
--problems_file=$(pwd)/examples.txt \
|
| 208 |
+
--problem_name=orthocenter \
|
| 209 |
+
--mode=alphageometry \
|
| 210 |
+
"${DDAR_ARGS[@]}" \
|
| 211 |
+
"${SEARCH_ARGS[@]}" \
|
| 212 |
+
"${LM_ARGS[@]}"
|
| 213 |
+
```
|
| 214 |
+
|
| 215 |
+
Expect the following output:
|
| 216 |
+
|
| 217 |
+
```shell
|
| 218 |
+
...
|
| 219 |
+
[log omitted]
|
| 220 |
+
...
|
| 221 |
+
training_loop.py:725] Total parameters: 152072288
|
| 222 |
+
training_loop.py:739] Total state size: 0
|
| 223 |
+
training_loop.py:492] Training loop: creating task for mode beam_search
|
| 224 |
+
|
| 225 |
+
graph.py:468] orthocenter
|
| 226 |
+
graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c
|
| 227 |
+
ddar.py:41] Depth 1/1000 time = 0.009987592697143555 branch = 4
|
| 228 |
+
ddar.py:41] Depth 2/1000 time = 0.00672602653503418 branch = 0
|
| 229 |
+
alphageometry.py:221] DD+AR failed to solve the problem.
|
| 230 |
+
alphageometry.py:457] Depth 0. There are 1 nodes to expand:
|
| 231 |
+
alphageometry.py:460] {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00
|
| 232 |
+
alphageometry.py:465] Decoding from {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00
|
| 233 |
+
...
|
| 234 |
+
[log omitted]
|
| 235 |
+
...
|
| 236 |
+
alphageometry.py:470] LM output (score=-1.102287): "e : C a c e 02 C b d e 03 ;"
|
| 237 |
+
alphageometry.py:471] Translation: "e = on_line e a c, on_line e b d"
|
| 238 |
+
|
| 239 |
+
alphageometry.py:480] Solving: "a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c"
|
| 240 |
+
graph.py:468]
|
| 241 |
+
graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c
|
| 242 |
+
ddar.py:41] Depth 1/1000 time = 0.021120786666870117
|
| 243 |
+
ddar.py:41] Depth 2/1000 time = 0.033370018005371094
|
| 244 |
+
ddar.py:41] Depth 3/1000 time = 0.04297471046447754
|
| 245 |
+
alphageometry.py:140]
|
| 246 |
+
==========================
|
| 247 |
+
* From theorem premises:
|
| 248 |
+
A B C D : Points
|
| 249 |
+
BD ⟂ AC [00]
|
| 250 |
+
CD ⟂ AB [01]
|
| 251 |
+
|
| 252 |
+
* Auxiliary Constructions:
|
| 253 |
+
E : Points
|
| 254 |
+
E,B,D are collinear [02]
|
| 255 |
+
E,C,A are collinear [03]
|
| 256 |
+
|
| 257 |
+
* Proof steps:
|
| 258 |
+
001. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEA = ∠CED [04]
|
| 259 |
+
002. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEC = ∠AED [05]
|
| 260 |
+
003. A,E,C are collinear [03] & E,B,D are collinear [02] & AC ⟂ BD [00] ⇒ EC ⟂ EB [06]
|
| 261 |
+
004. EC ⟂ EB [06] & CD ⟂ AB [01] ⇒ ∠(EC-BA) = ∠(EB-CD) [07]
|
| 262 |
+
005. E,C,A are collinear [03] & E,B,D are collinear [02] & ∠(EC-BA) = ∠(EB-CD) [07] ⇒ ∠BAE = ∠CDE [08]
|
| 263 |
+
006. ∠BEA = ∠CED [04] & ∠BAE = ∠CDE [08] (Similar Triangles)⇒ EB:EC = EA:ED [09]
|
| 264 |
+
007. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠BCE = ∠ADE [10]
|
| 265 |
+
008. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠EBC = ∠EAD [11]
|
| 266 |
+
009. ∠BCE = ∠ADE [10] & E,C,A are collinear [03] & E,B,D are collinear [02] & ∠EBC = ∠EAD [11] ⇒ AD ⟂ BC
|
| 267 |
+
==========================
|
| 268 |
+
|
| 269 |
+
alphageometry.py:505] Solved.
|
| 270 |
+
```
|
| 271 |
+
|
| 272 |
+
NOTE: Point `H` is automatically renamed to `D`,
|
| 273 |
+
as the LM is trained on synthetic problems
|
| 274 |
+
where the points are named alphabetically, and so it expects
|
| 275 |
+
the same during test time.
|
| 276 |
+
|
| 277 |
+
NOTE: In this implementation of AlphaGeometry,
|
| 278 |
+
we removed all optimizations that are dependent on
|
| 279 |
+
internal infrastructure, e.g.,
|
| 280 |
+
parallelized model inference on multi GPUs,
|
| 281 |
+
parallelized DDAR on multiple CPUs,
|
| 282 |
+
parallel execution of LM and DDAR,
|
| 283 |
+
shared pool of CPU workers across different problems, etc.
|
| 284 |
+
We also removed some memory/speed optimizations and code
|
| 285 |
+
abstractions in favor of code clarity.
|
| 286 |
+
|
| 287 |
+
As can be seen in the output, initially DDAR failed to solve the problem.
|
| 288 |
+
The LM proposes two auxiliary constructions (because `BATCH_SIZE=2`):
|
| 289 |
+
|
| 290 |
+
* `e = eqdistance e c a b, eqdistance e b a c`, i.e.,
|
| 291 |
+
construct `E` as the intersection of circle (center=C, radius=AB) and
|
| 292 |
+
circle (center=B, radius=AC). This construction has a score of `-1.186`.
|
| 293 |
+
* `e = on_line e a c, on_line e b d`, i.e.,
|
| 294 |
+
`E` is the intersection of `AC` and `BD`.
|
| 295 |
+
This construction has a higher score (`-1.102287`) than the previous.
|
| 296 |
+
|
| 297 |
+
Since the second construction has a higher score, DDAR attempted the second
|
| 298 |
+
construction first and found the solution right away.
|
| 299 |
+
The proof search therefore terminates and there is no second iteration.
|
| 300 |
+
|
| 301 |
+
## Results
|
| 302 |
+
|
| 303 |
+
Before attempting to reproduce the AlphaGeometry numbers in our paper,
|
| 304 |
+
please make sure to pass all tests in the prepared test suite:
|
| 305 |
+
|
| 306 |
+
```
|
| 307 |
+
bash run_tests.sh
|
| 308 |
+
```
|
| 309 |
+
|
| 310 |
+
NOTE: [Issues#14](https://github.com/google-deepmind/alphageometry/issues/14) reports that although the top beam decodes are still the same, the LM is not giving the same score for different users.
|
| 311 |
+
|
| 312 |
+
Then, pass the corresponding values for `--problem_file` (column)
|
| 313 |
+
and `--mode` (row), and
|
| 314 |
+
iterate on all problems to obtain the following results:
|
| 315 |
+
|
| 316 |
+
<center>
|
| 317 |
+
|
| 318 |
+
<b>Number of solved problems:</b>
|
| 319 |
+
|
| 320 |
+
| | `imo_ag_30.txt` | `jgex_ag_231.txt` |
|
| 321 |
+
|----------|------------------|-------------------|
|
| 322 |
+
| `ddar` | 14 | 198 |
|
| 323 |
+
| `alphageometry` | 25 | 228 |
|
| 324 |
+
|
| 325 |
+
</center>
|
| 326 |
+
|
| 327 |
+
## Source code description
|
| 328 |
+
|
| 329 |
+
Files in this repository include python modules/scripts to run the solvers and
|
| 330 |
+
resource files necessary for the script to execute. We listed below
|
| 331 |
+
each of them and their description.
|
| 332 |
+
|
| 333 |
+
| File name | Description |
|
| 334 |
+
|------------------------|------------------------------------------------------------------------------------|
|
| 335 |
+
| `geometry.py` | Implements nodes (Point, Line, Circle, etc) in the proof state graph. |
|
| 336 |
+
| `numericals.py` | Implements the numerical engine in the dynamic geometry environment. |
|
| 337 |
+
| `graph_utils.py` | Implements utilities for the proof state graph. |
|
| 338 |
+
| `graph.py` | Implements the proof state graph. |
|
| 339 |
+
| `problem.py` | Implements the classes that represent the problem premises, conclusion, DAG nodes. |
|
| 340 |
+
| `dd.py` | Implements DD and its traceback. |
|
| 341 |
+
| `ar.py` | Implements AR and its traceback. |
|
| 342 |
+
| `trace_back.py` | Implements the recursive traceback and dependency difference algorithm. |
|
| 343 |
+
| `ddar.py` | Implements the combination DD+AR. |
|
| 344 |
+
| `beam_search.py` | Implements beam decoding of a language model in JAX. |
|
| 345 |
+
| `models.py` | Implements the transformer model. |
|
| 346 |
+
| `transformer_layer.py` | Implements the transformer layer. |
|
| 347 |
+
| `decoder_stack.py` | Implements the transformer decoder stack. |
|
| 348 |
+
| `lm_inference.py` | Implements an interface to a trained LM to perform decoding. |
|
| 349 |
+
| `alphageometry.py` | Main script that loads problems, calls DD+AR or AlphaGeometry solver, and prints solutions. |
|
| 350 |
+
| `pretty.py` | Pretty formating the solutions output by solvers. |
|
| 351 |
+
| `*_test.py` | Tests for the corresponding module. |
|
| 352 |
+
| `download.sh` | Script to download model checkpoints and LM |
|
| 353 |
+
| `run.sh` | Script to execute instructions in README. |
|
| 354 |
+
| `run_tests.sh` | Script to execute the test suite. |
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
Resource files:
|
| 358 |
+
|
| 359 |
+
| Resource file name | Description |
|
| 360 |
+
|------------------------|------------------------------------------------------------------------------------|
|
| 361 |
+
| `defs.txt` | Definitions of different geometric construction actions. |
|
| 362 |
+
| `rules.txt` | Deduction rules for DD. |
|
| 363 |
+
| `geometry_150M_generate.gin`| Gin config of the LM implemented in meliad. |
|
| 364 |
+
| `imo_ag_30.txt` | Problems in IMO-AG-30. |
|
| 365 |
+
| `jgex_ag_231.txt` | Problems in JGEX-AG-231. |
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
## Citing this work
|
| 370 |
+
|
| 371 |
+
```bibtex
|
| 372 |
+
@Article{AlphaGeometryTrinh2024,
|
| 373 |
+
author = {Trinh, Trieu and Wu, Yuhuai and Le, Quoc and He, He and Luong, Thang},
|
| 374 |
+
journal = {Nature},
|
| 375 |
+
title = {Solving Olympiad Geometry without Human Demonstrations},
|
| 376 |
+
year = {2024},
|
| 377 |
+
doi = {10.1038/s41586-023-06747-5}
|
| 378 |
+
}
|
| 379 |
+
```
|
| 380 |
+
|
| 381 |
+
## Acknowledgements
|
| 382 |
+
|
| 383 |
+
This research is a collaboration between the Google Brain team
|
| 384 |
+
(now Google Deepmind) and
|
| 385 |
+
the Computer Science Department of New York University.
|
| 386 |
+
We thank Rif A. Saurous, Denny Zhou, Christian Szegedy, Delesley Hutchins,
|
| 387 |
+
Thomas Kipf, Hieu Pham, Petar Veličković, Debidatta Dwibedi,
|
| 388 |
+
Kyunghyun Cho, Lerrel Pinto, Alfredo Canziani,
|
| 389 |
+
Thomas Wies, He He’s research group,
|
| 390 |
+
Evan Chen (the USA’s IMO team coach),
|
| 391 |
+
Mirek Olsak, Patrik Bak,
|
| 392 |
+
and all three Nature's referees for their help and support.
|
| 393 |
+
|
| 394 |
+
The code of AlphaGeometry communicates with and/or references the following
|
| 395 |
+
separate libraries and packages:
|
| 396 |
+
|
| 397 |
+
* [Abseil](https://github.com/abseil/abseil-py)
|
| 398 |
+
* [JAX](https://github.com/google/jax/)
|
| 399 |
+
* [matplotlib](https://matplotlib.org/)
|
| 400 |
+
* [NumPy](https://numpy.org)
|
| 401 |
+
* [SciPy](https://scipy.org)
|
| 402 |
+
* [TensorFlow](https://github.com/tensorflow/tensorflow)
|
| 403 |
+
* [Meliad](https://github.com/google-research/meliad)
|
| 404 |
+
* [Flax](https://github.com/google/flax)
|
| 405 |
+
* [Gin](https://github.com/google/gin-config)
|
| 406 |
+
* [T5](https://github.com/google-research/text-to-text-transfer-transformer)
|
| 407 |
+
* [SentencePiece](https://github.com/google/sentencepiece)
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
|
| 411 |
+
We thank all their contributors and maintainers!
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
## Disclaimer
|
| 415 |
+
|
| 416 |
+
This is not an officially supported Google product.
|
| 417 |
+
|
| 418 |
+
This research code is provided "as-is" to the broader research community.
|
| 419 |
+
Google does not promise to maintain or otherwise support this code in any way.
|
| 420 |
+
|
| 421 |
+
## Code License
|
| 422 |
+
|
| 423 |
+
Copyright 2023 DeepMind Technologies Limited
|
| 424 |
+
|
| 425 |
+
All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
|
| 426 |
+
you may not use this file except in compliance with the Apache 2.0 license.
|
| 427 |
+
You may obtain a copy of the Apache 2.0 license at:
|
| 428 |
+
https://www.apache.org/licenses/LICENSE-2.0
|
| 429 |
+
|
| 430 |
+
All other materials are licensed under the Creative Commons Attribution 4.0
|
| 431 |
+
International License (CC-BY). You may obtain a copy of the CC-BY license at:
|
| 432 |
+
https://creativecommons.org/licenses/by/4.0/legalcode
|
| 433 |
+
|
| 434 |
+
Unless required by applicable law or agreed to in writing, all software and
|
| 435 |
+
materials distributed here under the Apache 2.0 or CC-BY licenses are
|
| 436 |
+
distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
| 437 |
+
either express or implied. See the licenses for the specific language governing
|
| 438 |
+
permissions and limitations under those licenses.
|
| 439 |
+
|
| 440 |
+
## Model Parameters License
|
| 441 |
+
|
| 442 |
+
The AlphaGeometry checkpoints and vocabulary are made available
|
| 443 |
+
under the terms of the Creative Commons Attribution 4.0
|
| 444 |
+
International (CC BY 4.0) license.
|
| 445 |
+
You can find details at:
|
| 446 |
+
https://creativecommons.org/licenses/by/4.0/legalcode
|
| 447 |
+
|
external/alphageometry/lm_inference_test.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for lm_inference.py."""
|
| 17 |
+
import os
|
| 18 |
+
import unittest
|
| 19 |
+
|
| 20 |
+
from absl import flags
|
| 21 |
+
from absl.testing import absltest
|
| 22 |
+
import lm_inference as lm
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
_DATA_PATH = flags.DEFINE_string('data_path', '', 'path to ckpt and vocab.')
|
| 26 |
+
_MELIAD_PATH = flags.DEFINE_string(
|
| 27 |
+
'meliad_path', '', 'path to meliad repository.'
|
| 28 |
+
) # pylint: disable=line-too-long
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class LmInferenceTest(unittest.TestCase):
|
| 32 |
+
|
| 33 |
+
@classmethod
|
| 34 |
+
def setUpClass(cls):
|
| 35 |
+
super().setUpClass()
|
| 36 |
+
gin_file = [
|
| 37 |
+
'base_htrans.gin',
|
| 38 |
+
'size/medium_150M.gin',
|
| 39 |
+
'options/positions_t5.gin',
|
| 40 |
+
'options/lr_cosine_decay.gin',
|
| 41 |
+
'options/seq_1024_nocache.gin',
|
| 42 |
+
'geometry_150M_generate.gin',
|
| 43 |
+
]
|
| 44 |
+
|
| 45 |
+
gin_param = [
|
| 46 |
+
'DecoderOnlyLanguageModelGenerate.output_token_losses=True',
|
| 47 |
+
'TransformerTaskConfig.batch_size=2',
|
| 48 |
+
'TransformerTaskConfig.sequence_length=128',
|
| 49 |
+
'Trainer.restore_state_variables=False',
|
| 50 |
+
]
|
| 51 |
+
|
| 52 |
+
gin_search_paths = [
|
| 53 |
+
os.path.join(_MELIAD_PATH.value, 'transformer/configs'),
|
| 54 |
+
os.getcwd(),
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
vocab_path = os.path.join(_DATA_PATH.value, 'geometry.757.model')
|
| 58 |
+
|
| 59 |
+
lm.parse_gin_configuration(gin_file, gin_param, gin_paths=gin_search_paths)
|
| 60 |
+
|
| 61 |
+
cls.loaded_lm = lm.LanguageModelInference(
|
| 62 |
+
vocab_path, _DATA_PATH.value, mode='beam_search'
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def test_lm_decode(self):
|
| 66 |
+
outputs = LmInferenceTest.loaded_lm.beam_decode(
|
| 67 |
+
'{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
|
| 68 |
+
' {F1} x00',
|
| 69 |
+
eos_tokens=[';'],
|
| 70 |
+
)
|
| 71 |
+
self.assertEqual(
|
| 72 |
+
outputs['seqs_str'],
|
| 73 |
+
['e : D a b c e 02 D a c b e 03 ;', 'e : C a c e 02 C b d e 03 ;'],
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
def test_lm_score_may_fail_numerically_for_external_meliad(self):
|
| 77 |
+
outputs = LmInferenceTest.loaded_lm.beam_decode(
|
| 78 |
+
'{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
|
| 79 |
+
' {F1} x00',
|
| 80 |
+
eos_tokens=[';'],
|
| 81 |
+
)
|
| 82 |
+
self.assertEqual(
|
| 83 |
+
outputs['scores'],
|
| 84 |
+
[-1.18607294559478759765625, -1.10228693485260009765625],
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
if __name__ == '__main__':
|
| 89 |
+
absltest.main()
|
external/alphageometry/models.py
ADDED
|
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Transformer language model generate mode."""
|
| 17 |
+
|
| 18 |
+
from typing import Any, Tuple
|
| 19 |
+
import beam_search
|
| 20 |
+
import decoder_stack
|
| 21 |
+
import gin
|
| 22 |
+
import jax
|
| 23 |
+
import jax.numpy as jnp
|
| 24 |
+
from transformer import models
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@gin.configurable
|
| 28 |
+
class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel):
|
| 29 |
+
"""Decoder only language modeling in inference mode."""
|
| 30 |
+
|
| 31 |
+
decoder_factory = decoder_stack.DecoderStackGenerate
|
| 32 |
+
|
| 33 |
+
num_heads: int = gin.REQUIRED
|
| 34 |
+
head_size: int = gin.REQUIRED
|
| 35 |
+
|
| 36 |
+
def get_fake_input(self) -> dict[str, Any]:
|
| 37 |
+
fake_input_dict = super().get_fake_input()
|
| 38 |
+
b = self.task_config.batch_size
|
| 39 |
+
n = self.num_heads
|
| 40 |
+
h = self.head_size
|
| 41 |
+
fake_input_dict.update({
|
| 42 |
+
'dstate': tuple(
|
| 43 |
+
[{
|
| 44 |
+
'current_index': jnp.array([0] * b, dtype=jnp.int32),
|
| 45 |
+
'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
|
| 46 |
+
'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
|
| 47 |
+
'recurrent_kvq': None,
|
| 48 |
+
'relative_position_bias': jnp.zeros(
|
| 49 |
+
(b, n, 1, 1024), dtype=jnp.bfloat16
|
| 50 |
+
),
|
| 51 |
+
}]
|
| 52 |
+
* 12
|
| 53 |
+
),
|
| 54 |
+
'eos': jnp.zeros([1024], dtype=jnp.bfloat16),
|
| 55 |
+
'mask': jnp.ones([1024], dtype=jnp.bfloat16),
|
| 56 |
+
'length': 1,
|
| 57 |
+
'temperature': 1.0,
|
| 58 |
+
})
|
| 59 |
+
return fake_input_dict
|
| 60 |
+
|
| 61 |
+
def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]:
|
| 62 |
+
# Make sure this code is not used on untested cases.
|
| 63 |
+
if self.mode not in ['init', 'beam_search']:
|
| 64 |
+
raise ValueError(f'{type(self)} cannot do mode {self.mode}')
|
| 65 |
+
if self.decoder.supports_generate():
|
| 66 |
+
raise ValueError(f'{type(self)}.decoder cannot supports_generate()')
|
| 67 |
+
|
| 68 |
+
self.decoder(
|
| 69 |
+
input_tokens=inputs['targets'][:, 0:1],
|
| 70 |
+
target_tokens=None,
|
| 71 |
+
start_of_sequence=inputs['start_of_sequence'],
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
b = inputs['targets'].shape[0]
|
| 75 |
+
no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_)
|
| 76 |
+
|
| 77 |
+
# This fn is used in both beam_search or topk_sampling.
|
| 78 |
+
def tokens_to_logits_fn(
|
| 79 |
+
input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...]
|
| 80 |
+
) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]:
|
| 81 |
+
(logits, dstate, _) = self.decoder(
|
| 82 |
+
input_tokens=input_token,
|
| 83 |
+
target_tokens=None,
|
| 84 |
+
start_of_sequence=no_start_of_seq,
|
| 85 |
+
decoder_state=dstate,
|
| 86 |
+
)
|
| 87 |
+
return logits[:, -1, :], dstate
|
| 88 |
+
|
| 89 |
+
last_token = jax.lax.dynamic_slice_in_dim(
|
| 90 |
+
inputs['targets'], inputs['length'] - 1, 1, axis=1
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
# last token is used to seed beam_search
|
| 94 |
+
inputs['targets'] = inputs['targets'][:, 0:-1]
|
| 95 |
+
dstate = jax.lax.cond(
|
| 96 |
+
inputs['start_of_sequence'][0],
|
| 97 |
+
lambda: self.generate(inputs)[0],
|
| 98 |
+
lambda: inputs['dstate'],
|
| 99 |
+
)
|
| 100 |
+
|
| 101 |
+
# Then we run beam search, init with last_token & dstate.
|
| 102 |
+
finished_seqs, finished_scores, dstate = beam_search.beam_search_flat(
|
| 103 |
+
last_token,
|
| 104 |
+
dstate,
|
| 105 |
+
tokens_to_logits_fn,
|
| 106 |
+
max_decode_len=512,
|
| 107 |
+
eos=inputs['eos'].reshape((1, 1, -1)),
|
| 108 |
+
mask=inputs['mask'].reshape((1, 1, -1)),
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
return 0.0, {
|
| 112 |
+
'finished_seqs': finished_seqs,
|
| 113 |
+
'finished_scores': finished_scores,
|
| 114 |
+
'dstate': dstate,
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
def generate(
|
| 118 |
+
self, inputs: ...
|
| 119 |
+
) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]:
|
| 120 |
+
"""Generate an output sequence.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
inputs: the same as argument to _call_.
|
| 124 |
+
|
| 125 |
+
Returns:
|
| 126 |
+
An array of generated tokens of shape (batch_size, sequence_length).
|
| 127 |
+
"""
|
| 128 |
+
input_tokens = inputs['targets'] # [b,seq_len]
|
| 129 |
+
start_of_sequence = inputs['start_of_sequence'] # [b]
|
| 130 |
+
target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
|
| 131 |
+
batch_size = target_tokens.shape[0]
|
| 132 |
+
|
| 133 |
+
# Assuming all sequences start at the same time.
|
| 134 |
+
start0 = inputs['start_of_sequence'][0]
|
| 135 |
+
dstate = jax.lax.cond(
|
| 136 |
+
start0,
|
| 137 |
+
lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda
|
| 138 |
+
1024, start_of_sequence
|
| 139 |
+
),
|
| 140 |
+
lambda: inputs['dstate'],
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
first_token = input_tokens[:, 0:1]
|
| 144 |
+
no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
|
| 145 |
+
temperature = 1
|
| 146 |
+
if 'temperature' in inputs:
|
| 147 |
+
temperature = inputs['temperature']
|
| 148 |
+
|
| 149 |
+
num_steps = inputs['length']
|
| 150 |
+
if self.mode == 'beam_search':
|
| 151 |
+
num_steps -= 1
|
| 152 |
+
|
| 153 |
+
def cond_fn(scan_state) -> jnp.bool_:
|
| 154 |
+
_, _, i, _ = scan_state
|
| 155 |
+
return i < num_steps
|
| 156 |
+
|
| 157 |
+
def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]:
|
| 158 |
+
(dstate, input_token, i, _) = scan_state
|
| 159 |
+
|
| 160 |
+
(logits, dstate, _) = self.decoder(
|
| 161 |
+
input_tokens=input_token,
|
| 162 |
+
target_tokens=None,
|
| 163 |
+
start_of_sequence=no_start_of_seq,
|
| 164 |
+
decoder_state=dstate,
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
logits = logits / temperature
|
| 168 |
+
output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1)
|
| 169 |
+
|
| 170 |
+
return (dstate, output_token, i + 1, logits)
|
| 171 |
+
|
| 172 |
+
# Scan over the sequence length.
|
| 173 |
+
dummy_logits = jnp.zeros((batch_size, 1, 1024))
|
| 174 |
+
initial_scan_state = (dstate, first_token, 0, dummy_logits)
|
| 175 |
+
dstate, _, _, logits = jax.lax.while_loop(
|
| 176 |
+
cond_fn, loop_fn, initial_scan_state
|
| 177 |
+
)
|
| 178 |
+
return dstate, logits
|
external/alphageometry/numericals.py
ADDED
|
@@ -0,0 +1,1921 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Numerical representation of geometry."""
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import math
|
| 20 |
+
from typing import Any, Optional, Union
|
| 21 |
+
|
| 22 |
+
import geometry as gm
|
| 23 |
+
import matplotlib
|
| 24 |
+
from matplotlib import pyplot as plt
|
| 25 |
+
import matplotlib.colors as mcolors
|
| 26 |
+
import numpy as np
|
| 27 |
+
from numpy.random import uniform as unif # pylint: disable=g-importing-member
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
matplotlib.use('TkAgg')
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
ATOM = 1e-12
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# Some variables are there for better code reading.
|
| 37 |
+
# pylint: disable=unused-assignment
|
| 38 |
+
# pylint: disable=unused-argument
|
| 39 |
+
# pylint: disable=unused-variable
|
| 40 |
+
|
| 41 |
+
# Naming in geometry is a little different
|
| 42 |
+
# we stick to geometry naming to better read the code.
|
| 43 |
+
# pylint: disable=invalid-name
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class Point:
|
| 47 |
+
"""Numerical point."""
|
| 48 |
+
|
| 49 |
+
def __init__(self, x, y):
|
| 50 |
+
self.x = x
|
| 51 |
+
self.y = y
|
| 52 |
+
|
| 53 |
+
def __lt__(self, other: Point) -> bool:
|
| 54 |
+
return (self.x, self.y) < (other.x, other.y)
|
| 55 |
+
|
| 56 |
+
def __gt__(self, other: Point) -> bool:
|
| 57 |
+
return (self.x, self.y) > (other.x, other.y)
|
| 58 |
+
|
| 59 |
+
def __add__(self, p: Point) -> Point:
|
| 60 |
+
return Point(self.x + p.x, self.y + p.y)
|
| 61 |
+
|
| 62 |
+
def __sub__(self, p: Point) -> Point:
|
| 63 |
+
return Point(self.x - p.x, self.y - p.y)
|
| 64 |
+
|
| 65 |
+
def __mul__(self, f: float) -> Point:
|
| 66 |
+
return Point(self.x * f, self.y * f)
|
| 67 |
+
|
| 68 |
+
def __rmul__(self, f: float) -> Point:
|
| 69 |
+
return self * f
|
| 70 |
+
|
| 71 |
+
def __truediv__(self, f: float) -> Point:
|
| 72 |
+
return Point(self.x / f, self.y / f)
|
| 73 |
+
|
| 74 |
+
def __floordiv__(self, f: float) -> Point:
|
| 75 |
+
div = self / f # true div
|
| 76 |
+
return Point(int(div.x), int(div.y))
|
| 77 |
+
|
| 78 |
+
def __str__(self) -> str:
|
| 79 |
+
return 'P({},{})'.format(self.x, self.y)
|
| 80 |
+
|
| 81 |
+
def close(self, point: Point, tol: float = 1e-12) -> bool:
|
| 82 |
+
return abs(self.x - point.x) < tol and abs(self.y - point.y) < tol
|
| 83 |
+
|
| 84 |
+
def midpoint(self, p: Point) -> Point:
|
| 85 |
+
return Point(0.5 * (self.x + p.x), 0.5 * (self.y + p.y))
|
| 86 |
+
|
| 87 |
+
def distance(self, p: Union[Point, Line, Circle]) -> float:
|
| 88 |
+
if isinstance(p, Line):
|
| 89 |
+
return p.distance(self)
|
| 90 |
+
if isinstance(p, Circle):
|
| 91 |
+
return abs(p.radius - self.distance(p.center))
|
| 92 |
+
dx = self.x - p.x
|
| 93 |
+
dy = self.y - p.y
|
| 94 |
+
return np.sqrt(dx * dx + dy * dy)
|
| 95 |
+
|
| 96 |
+
def distance2(self, p: Point) -> float:
|
| 97 |
+
if isinstance(p, Line):
|
| 98 |
+
return p.distance(self)
|
| 99 |
+
dx = self.x - p.x
|
| 100 |
+
dy = self.y - p.y
|
| 101 |
+
return dx * dx + dy * dy
|
| 102 |
+
|
| 103 |
+
def rotatea(self, ang: float) -> Point:
|
| 104 |
+
sinb, cosb = np.sin(ang), np.cos(ang)
|
| 105 |
+
return self.rotate(sinb, cosb)
|
| 106 |
+
|
| 107 |
+
def rotate(self, sinb: float, cosb: float) -> Point:
|
| 108 |
+
x, y = self.x, self.y
|
| 109 |
+
return Point(x * cosb - y * sinb, x * sinb + y * cosb)
|
| 110 |
+
|
| 111 |
+
def flip(self) -> Point:
|
| 112 |
+
return Point(-self.x, self.y)
|
| 113 |
+
|
| 114 |
+
def perpendicular_line(self, line: Line) -> Line:
|
| 115 |
+
return line.perpendicular_line(self)
|
| 116 |
+
|
| 117 |
+
def foot(self, line: Line) -> Point:
|
| 118 |
+
if isinstance(line, Line):
|
| 119 |
+
l = line.perpendicular_line(self)
|
| 120 |
+
return line_line_intersection(l, line)
|
| 121 |
+
elif isinstance(line, Circle):
|
| 122 |
+
c, r = line.center, line.radius
|
| 123 |
+
return c + (self - c) * r / self.distance(c)
|
| 124 |
+
raise ValueError('Dropping foot to weird type {}'.format(type(line)))
|
| 125 |
+
|
| 126 |
+
def parallel_line(self, line: Line) -> Line:
|
| 127 |
+
return line.parallel_line(self)
|
| 128 |
+
|
| 129 |
+
def norm(self) -> float:
|
| 130 |
+
return np.sqrt(self.x**2 + self.y**2)
|
| 131 |
+
|
| 132 |
+
def cos(self, other: Point) -> float:
|
| 133 |
+
x, y = self.x, self.y
|
| 134 |
+
a, b = other.x, other.y
|
| 135 |
+
return (x * a + y * b) / self.norm() / other.norm()
|
| 136 |
+
|
| 137 |
+
def dot(self, other: Point) -> float:
|
| 138 |
+
return self.x * other.x + self.y * other.y
|
| 139 |
+
|
| 140 |
+
def sign(self, line: Line) -> int:
|
| 141 |
+
return line.sign(self)
|
| 142 |
+
|
| 143 |
+
def is_same(self, other: Point) -> bool:
|
| 144 |
+
return self.distance(other) <= ATOM
|
| 145 |
+
|
| 146 |
+
|
| 147 |
+
class Line:
|
| 148 |
+
"""Numerical line."""
|
| 149 |
+
|
| 150 |
+
def __init__(
|
| 151 |
+
self,
|
| 152 |
+
p1: Point = None,
|
| 153 |
+
p2: Point = None,
|
| 154 |
+
coefficients: tuple[int, int, int] = None,
|
| 155 |
+
):
|
| 156 |
+
if p1 is None and p2 is None and coefficients is None:
|
| 157 |
+
self.coefficients = None, None, None
|
| 158 |
+
return
|
| 159 |
+
|
| 160 |
+
a, b, c = coefficients or (
|
| 161 |
+
p1.y - p2.y,
|
| 162 |
+
p2.x - p1.x,
|
| 163 |
+
p1.x * p2.y - p2.x * p1.y,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Make sure a is always positive (or always negative for that matter)
|
| 167 |
+
# With a == 0, Assuming a = +epsilon > 0
|
| 168 |
+
# Then b such that ax + by = 0 with y>0 should be negative.
|
| 169 |
+
if a < 0.0 or a == 0.0 and b > 0.0:
|
| 170 |
+
a, b, c = -a, -b, -c
|
| 171 |
+
|
| 172 |
+
self.coefficients = a, b, c
|
| 173 |
+
|
| 174 |
+
def parallel_line(self, p: Point) -> Line:
|
| 175 |
+
a, b, _ = self.coefficients
|
| 176 |
+
return Line(coefficients=(a, b, -a * p.x - b * p.y)) # pylint: disable=invalid-unary-operand-type
|
| 177 |
+
|
| 178 |
+
def perpendicular_line(self, p: Point) -> Line:
|
| 179 |
+
a, b, _ = self.coefficients
|
| 180 |
+
return Line(p, p + Point(a, b))
|
| 181 |
+
|
| 182 |
+
def greater_than(self, other: Line) -> bool:
|
| 183 |
+
a, b, _ = self.coefficients
|
| 184 |
+
x, y, _ = other.coefficients
|
| 185 |
+
# b/a > y/x
|
| 186 |
+
return b * x > a * y
|
| 187 |
+
|
| 188 |
+
def __gt__(self, other: Line) -> bool:
|
| 189 |
+
return self.greater_than(other)
|
| 190 |
+
|
| 191 |
+
def __lt__(self, other: Line) -> bool:
|
| 192 |
+
return other.greater_than(self)
|
| 193 |
+
|
| 194 |
+
def same(self, other: Line) -> bool:
|
| 195 |
+
a, b, c = self.coefficients
|
| 196 |
+
x, y, z = other.coefficients
|
| 197 |
+
return close_enough(a * y, b * x) and close_enough(b * z, c * y)
|
| 198 |
+
|
| 199 |
+
def equal(self, other: Line) -> bool:
|
| 200 |
+
a, b, _ = self.coefficients
|
| 201 |
+
x, y, _ = other.coefficients
|
| 202 |
+
# b/a == y/x
|
| 203 |
+
return b * x == a * y
|
| 204 |
+
|
| 205 |
+
def less_than(self, other: Line) -> bool:
|
| 206 |
+
a, b, _ = self.coefficients
|
| 207 |
+
x, y, _ = other.coefficients
|
| 208 |
+
# b/a > y/x
|
| 209 |
+
return b * x < a * y
|
| 210 |
+
|
| 211 |
+
def intersect(self, obj: Union[Line, Circle]) -> tuple[Point, ...]:
|
| 212 |
+
if isinstance(obj, Line):
|
| 213 |
+
return line_line_intersection(self, obj)
|
| 214 |
+
if isinstance(obj, Circle):
|
| 215 |
+
return line_circle_intersection(self, obj)
|
| 216 |
+
|
| 217 |
+
def distance(self, p: Point) -> float:
|
| 218 |
+
a, b, c = self.coefficients
|
| 219 |
+
return abs(self(p.x, p.y)) / math.sqrt(a * a + b * b)
|
| 220 |
+
|
| 221 |
+
def __call__(self, x: Point, y: Point = None) -> float:
|
| 222 |
+
if isinstance(x, Point) and y is None:
|
| 223 |
+
return self(x.x, x.y)
|
| 224 |
+
a, b, c = self.coefficients
|
| 225 |
+
return x * a + y * b + c
|
| 226 |
+
|
| 227 |
+
def is_parallel(self, other: Line) -> bool:
|
| 228 |
+
a, b, _ = self.coefficients
|
| 229 |
+
x, y, _ = other.coefficients
|
| 230 |
+
return abs(a * y - b * x) < ATOM
|
| 231 |
+
|
| 232 |
+
def is_perp(self, other: Line) -> bool:
|
| 233 |
+
a, b, _ = self.coefficients
|
| 234 |
+
x, y, _ = other.coefficients
|
| 235 |
+
return abs(a * x + b * y) < ATOM
|
| 236 |
+
|
| 237 |
+
def cross(self, other: Line) -> float:
|
| 238 |
+
a, b, _ = self.coefficients
|
| 239 |
+
x, y, _ = other.coefficients
|
| 240 |
+
return a * y - b * x
|
| 241 |
+
|
| 242 |
+
def dot(self, other: Line) -> float:
|
| 243 |
+
a, b, _ = self.coefficients
|
| 244 |
+
x, y, _ = other.coefficients
|
| 245 |
+
return a * x + b * y
|
| 246 |
+
|
| 247 |
+
def point_at(self, x: float = None, y: float = None) -> Optional[Point]:
|
| 248 |
+
"""Get a point on line closest to (x, y)."""
|
| 249 |
+
a, b, c = self.coefficients
|
| 250 |
+
# ax + by + c = 0
|
| 251 |
+
if x is None and y is not None:
|
| 252 |
+
if a != 0:
|
| 253 |
+
return Point((-c - b * y) / a, y) # pylint: disable=invalid-unary-operand-type
|
| 254 |
+
else:
|
| 255 |
+
return None
|
| 256 |
+
elif x is not None and y is None:
|
| 257 |
+
if b != 0:
|
| 258 |
+
return Point(x, (-c - a * x) / b) # pylint: disable=invalid-unary-operand-type
|
| 259 |
+
else:
|
| 260 |
+
return None
|
| 261 |
+
elif x is not None and y is not None:
|
| 262 |
+
if a * x + b * y + c == 0.0:
|
| 263 |
+
return Point(x, y)
|
| 264 |
+
return None
|
| 265 |
+
|
| 266 |
+
def diff_side(self, p1: Point, p2: Point) -> Optional[bool]:
|
| 267 |
+
d1 = self(p1.x, p1.y)
|
| 268 |
+
d2 = self(p2.x, p2.y)
|
| 269 |
+
if d1 == 0 or d2 == 0:
|
| 270 |
+
return None
|
| 271 |
+
return d1 * d2 < 0
|
| 272 |
+
|
| 273 |
+
def same_side(self, p1: Point, p2: Point) -> Optional[bool]:
|
| 274 |
+
d1 = self(p1.x, p1.y)
|
| 275 |
+
d2 = self(p2.x, p2.y)
|
| 276 |
+
if d1 == 0 or d2 == 0:
|
| 277 |
+
return None
|
| 278 |
+
return d1 * d2 > 0
|
| 279 |
+
|
| 280 |
+
def sign(self, point: Point) -> int:
|
| 281 |
+
s = self(point.x, point.y)
|
| 282 |
+
if s > 0:
|
| 283 |
+
return 1
|
| 284 |
+
elif s < 0:
|
| 285 |
+
return -1
|
| 286 |
+
return 0
|
| 287 |
+
|
| 288 |
+
def is_same(self, other: Line) -> bool:
|
| 289 |
+
a, b, c = self.coefficients
|
| 290 |
+
x, y, z = other.coefficients
|
| 291 |
+
return abs(a * y - b * x) <= ATOM and abs(b * z - c * y) <= ATOM
|
| 292 |
+
|
| 293 |
+
def sample_within(self, points: list[Point], n: int = 5) -> list[Point]:
|
| 294 |
+
"""Sample a point within the boundary of points."""
|
| 295 |
+
center = sum(points, Point(0.0, 0.0)) * (1.0 / len(points))
|
| 296 |
+
radius = max([p.distance(center) for p in points])
|
| 297 |
+
if close_enough(center.distance(self), radius):
|
| 298 |
+
center = center.foot(self)
|
| 299 |
+
a, b = line_circle_intersection(self, Circle(center.foot(self), radius))
|
| 300 |
+
|
| 301 |
+
result = None
|
| 302 |
+
best = -1.0
|
| 303 |
+
for _ in range(n):
|
| 304 |
+
rand = unif(0.0, 1.0)
|
| 305 |
+
x = a + (b - a) * rand
|
| 306 |
+
mind = min([x.distance(p) for p in points])
|
| 307 |
+
if mind > best:
|
| 308 |
+
best = mind
|
| 309 |
+
result = x
|
| 310 |
+
|
| 311 |
+
return [result]
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class InvalidLineIntersectError(Exception):
|
| 315 |
+
pass
|
| 316 |
+
|
| 317 |
+
|
| 318 |
+
class HalfLine(Line):
|
| 319 |
+
"""Numerical ray."""
|
| 320 |
+
|
| 321 |
+
def __init__(self, tail: Point, head: Point): # pylint: disable=super-init-not-called
|
| 322 |
+
self.line = Line(tail, head)
|
| 323 |
+
self.coefficients = self.line.coefficients
|
| 324 |
+
self.tail = tail
|
| 325 |
+
self.head = head
|
| 326 |
+
|
| 327 |
+
def intersect(self, obj: Union[Line, HalfLine, Circle, HoleCircle]) -> Point:
|
| 328 |
+
if isinstance(obj, (HalfLine, Line)):
|
| 329 |
+
return line_line_intersection(self.line, obj)
|
| 330 |
+
|
| 331 |
+
exclude = [self.tail]
|
| 332 |
+
if isinstance(obj, HoleCircle):
|
| 333 |
+
exclude += [obj.hole]
|
| 334 |
+
|
| 335 |
+
a, b = line_circle_intersection(self.line, obj)
|
| 336 |
+
if any([a.close(x) for x in exclude]):
|
| 337 |
+
return b
|
| 338 |
+
if any([b.close(x) for x in exclude]):
|
| 339 |
+
return a
|
| 340 |
+
|
| 341 |
+
v = self.head - self.tail
|
| 342 |
+
va = a - self.tail
|
| 343 |
+
vb = b - self.tail
|
| 344 |
+
if v.dot(va) > 0:
|
| 345 |
+
return a
|
| 346 |
+
if v.dot(vb) > 0:
|
| 347 |
+
return b
|
| 348 |
+
raise InvalidLineIntersectError()
|
| 349 |
+
|
| 350 |
+
def sample_within(self, points: list[Point], n: int = 5) -> list[Point]:
|
| 351 |
+
center = sum(points, Point(0.0, 0.0)) * (1.0 / len(points))
|
| 352 |
+
radius = max([p.distance(center) for p in points])
|
| 353 |
+
if close_enough(center.distance(self.line), radius):
|
| 354 |
+
center = center.foot(self)
|
| 355 |
+
a, b = line_circle_intersection(self, Circle(center.foot(self), radius))
|
| 356 |
+
|
| 357 |
+
if (a - self.tail).dot(self.head - self.tail) > 0:
|
| 358 |
+
a, b = self.tail, a
|
| 359 |
+
else:
|
| 360 |
+
a, b = self.tail, b # pylint: disable=self-assigning-variable
|
| 361 |
+
|
| 362 |
+
result = None
|
| 363 |
+
best = -1.0
|
| 364 |
+
for _ in range(n):
|
| 365 |
+
x = a + (b - a) * unif(0.0, 1.0)
|
| 366 |
+
mind = min([x.distance(p) for p in points])
|
| 367 |
+
if mind > best:
|
| 368 |
+
best = mind
|
| 369 |
+
result = x
|
| 370 |
+
|
| 371 |
+
return [result]
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def _perpendicular_bisector(p1: Point, p2: Point) -> Line:
|
| 375 |
+
midpoint = (p1 + p2) * 0.5
|
| 376 |
+
return Line(midpoint, midpoint + Point(p2.y - p1.y, p1.x - p2.x))
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def same_sign(
|
| 380 |
+
a: Point, b: Point, c: Point, d: Point, e: Point, f: Point
|
| 381 |
+
) -> bool:
|
| 382 |
+
a, b, c, d, e, f = map(lambda p: p.sym, [a, b, c, d, e, f])
|
| 383 |
+
ab, cb = a - b, c - b
|
| 384 |
+
de, fe = d - e, f - e
|
| 385 |
+
return (ab.x * cb.y - ab.y * cb.x) * (de.x * fe.y - de.y * fe.x) > 0
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class Circle:
|
| 389 |
+
"""Numerical circle."""
|
| 390 |
+
|
| 391 |
+
def __init__(
|
| 392 |
+
self,
|
| 393 |
+
center: Optional[Point] = None,
|
| 394 |
+
radius: Optional[float] = None,
|
| 395 |
+
p1: Optional[Point] = None,
|
| 396 |
+
p2: Optional[Point] = None,
|
| 397 |
+
p3: Optional[Point] = None,
|
| 398 |
+
):
|
| 399 |
+
if not center:
|
| 400 |
+
if not (p1 and p2 and p3):
|
| 401 |
+
self.center = self.radius = self.r2 = None
|
| 402 |
+
return
|
| 403 |
+
# raise ValueError('Circle without center need p1 p2 p3')
|
| 404 |
+
|
| 405 |
+
l12 = _perpendicular_bisector(p1, p2)
|
| 406 |
+
l23 = _perpendicular_bisector(p2, p3)
|
| 407 |
+
center = line_line_intersection(l12, l23)
|
| 408 |
+
|
| 409 |
+
self.center = center
|
| 410 |
+
self.a, self.b = center.x, center.y
|
| 411 |
+
|
| 412 |
+
if not radius:
|
| 413 |
+
if not (p1 or p2 or p3):
|
| 414 |
+
raise ValueError('Circle needs radius or p1 or p2 or p3')
|
| 415 |
+
p = p1 or p2 or p3
|
| 416 |
+
self.r2 = (self.a - p.x) ** 2 + (self.b - p.y) ** 2
|
| 417 |
+
self.radius = math.sqrt(self.r2)
|
| 418 |
+
else:
|
| 419 |
+
self.radius = radius
|
| 420 |
+
self.r2 = radius * radius
|
| 421 |
+
|
| 422 |
+
def intersect(self, obj: Union[Line, Circle]) -> tuple[Point, ...]:
|
| 423 |
+
if isinstance(obj, Line):
|
| 424 |
+
return obj.intersect(self)
|
| 425 |
+
if isinstance(obj, Circle):
|
| 426 |
+
return circle_circle_intersection(self, obj)
|
| 427 |
+
|
| 428 |
+
def sample_within(self, points: list[Point], n: int = 5) -> list[Point]:
|
| 429 |
+
"""Sample a point within the boundary of points."""
|
| 430 |
+
result = None
|
| 431 |
+
best = -1.0
|
| 432 |
+
for _ in range(n):
|
| 433 |
+
ang = unif(0.0, 2.0) * np.pi
|
| 434 |
+
x = self.center + Point(np.cos(ang), np.sin(ang)) * self.radius
|
| 435 |
+
mind = min([x.distance(p) for p in points])
|
| 436 |
+
if mind > best:
|
| 437 |
+
best = mind
|
| 438 |
+
result = x
|
| 439 |
+
|
| 440 |
+
return [result]
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
class HoleCircle(Circle):
|
| 444 |
+
"""Numerical circle with a missing point."""
|
| 445 |
+
|
| 446 |
+
def __init__(self, center: Point, radius: float, hole: Point):
|
| 447 |
+
super().__init__(center, radius)
|
| 448 |
+
self.hole = hole
|
| 449 |
+
|
| 450 |
+
def intersect(self, obj: Union[Line, HalfLine, Circle, HoleCircle]) -> Point:
|
| 451 |
+
if isinstance(obj, Line):
|
| 452 |
+
a, b = line_circle_intersection(obj, self)
|
| 453 |
+
if a.close(self.hole):
|
| 454 |
+
return b
|
| 455 |
+
return a
|
| 456 |
+
if isinstance(obj, HalfLine):
|
| 457 |
+
return obj.intersect(self)
|
| 458 |
+
if isinstance(obj, Circle):
|
| 459 |
+
a, b = circle_circle_intersection(obj, self)
|
| 460 |
+
if a.close(self.hole):
|
| 461 |
+
return b
|
| 462 |
+
return a
|
| 463 |
+
if isinstance(obj, HoleCircle):
|
| 464 |
+
a, b = circle_circle_intersection(obj, self)
|
| 465 |
+
if a.close(self.hole) or a.close(obj.hole):
|
| 466 |
+
return b
|
| 467 |
+
return a
|
| 468 |
+
|
| 469 |
+
|
| 470 |
+
def solve_quad(a: float, b: float, c: float) -> tuple[float, float]:
|
| 471 |
+
"""Solve a x^2 + bx + c = 0."""
|
| 472 |
+
a = 2 * a
|
| 473 |
+
d = b * b - 2 * a * c
|
| 474 |
+
if d < 0:
|
| 475 |
+
return None # the caller should expect this result.
|
| 476 |
+
|
| 477 |
+
y = math.sqrt(d)
|
| 478 |
+
return (-b - y) / a, (-b + y) / a
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def circle_circle_intersection(c1: Circle, c2: Circle) -> tuple[Point, Point]:
|
| 482 |
+
"""Returns a pair of Points as intersections of c1 and c2."""
|
| 483 |
+
# circle 1: (x0, y0), radius r0
|
| 484 |
+
# circle 2: (x1, y1), radius r1
|
| 485 |
+
x0, y0, r0 = c1.a, c1.b, c1.radius
|
| 486 |
+
x1, y1, r1 = c2.a, c2.b, c2.radius
|
| 487 |
+
|
| 488 |
+
d = math.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2)
|
| 489 |
+
if d == 0:
|
| 490 |
+
raise InvalidQuadSolveError()
|
| 491 |
+
|
| 492 |
+
a = (r0**2 - r1**2 + d**2) / (2 * d)
|
| 493 |
+
h = r0**2 - a**2
|
| 494 |
+
if h < 0:
|
| 495 |
+
raise InvalidQuadSolveError()
|
| 496 |
+
h = np.sqrt(h)
|
| 497 |
+
x2 = x0 + a * (x1 - x0) / d
|
| 498 |
+
y2 = y0 + a * (y1 - y0) / d
|
| 499 |
+
x3 = x2 + h * (y1 - y0) / d
|
| 500 |
+
y3 = y2 - h * (x1 - x0) / d
|
| 501 |
+
x4 = x2 - h * (y1 - y0) / d
|
| 502 |
+
y4 = y2 + h * (x1 - x0) / d
|
| 503 |
+
|
| 504 |
+
return Point(x3, y3), Point(x4, y4)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class InvalidQuadSolveError(Exception):
|
| 508 |
+
pass
|
| 509 |
+
|
| 510 |
+
|
| 511 |
+
def line_circle_intersection(line: Line, circle: Circle) -> tuple[Point, Point]:
|
| 512 |
+
"""Returns a pair of points as intersections of line and circle."""
|
| 513 |
+
a, b, c = line.coefficients
|
| 514 |
+
r = float(circle.radius)
|
| 515 |
+
center = circle.center
|
| 516 |
+
p, q = center.x, center.y
|
| 517 |
+
|
| 518 |
+
if b == 0:
|
| 519 |
+
x = -c / a
|
| 520 |
+
x_p = x - p
|
| 521 |
+
x_p2 = x_p * x_p
|
| 522 |
+
y = solve_quad(1, -2 * q, q * q + x_p2 - r * r)
|
| 523 |
+
if y is None:
|
| 524 |
+
raise InvalidQuadSolveError()
|
| 525 |
+
y1, y2 = y
|
| 526 |
+
return (Point(x, y1), Point(x, y2))
|
| 527 |
+
|
| 528 |
+
if a == 0:
|
| 529 |
+
y = -c / b
|
| 530 |
+
y_q = y - q
|
| 531 |
+
y_q2 = y_q * y_q
|
| 532 |
+
x = solve_quad(1, -2 * p, p * p + y_q2 - r * r)
|
| 533 |
+
if x is None:
|
| 534 |
+
raise InvalidQuadSolveError()
|
| 535 |
+
x1, x2 = x
|
| 536 |
+
return (Point(x1, y), Point(x2, y))
|
| 537 |
+
|
| 538 |
+
c_ap = c + a * p
|
| 539 |
+
a2 = a * a
|
| 540 |
+
y = solve_quad(
|
| 541 |
+
a2 + b * b, 2 * (b * c_ap - a2 * q), c_ap * c_ap + a2 * (q * q - r * r)
|
| 542 |
+
)
|
| 543 |
+
if y is None:
|
| 544 |
+
raise InvalidQuadSolveError()
|
| 545 |
+
y1, y2 = y
|
| 546 |
+
|
| 547 |
+
return Point(-(b * y1 + c) / a, y1), Point(-(b * y2 + c) / a, y2)
|
| 548 |
+
|
| 549 |
+
|
| 550 |
+
def _check_between(a: Point, b: Point, c: Point) -> bool:
|
| 551 |
+
"""Whether a is between b & c."""
|
| 552 |
+
return (a - b).dot(c - b) > 0 and (a - c).dot(b - c) > 0
|
| 553 |
+
|
| 554 |
+
|
| 555 |
+
def circle_segment_intersect(
|
| 556 |
+
circle: Circle, p1: Point, p2: Point
|
| 557 |
+
) -> list[Point]:
|
| 558 |
+
l = Line(p1, p2)
|
| 559 |
+
px, py = line_circle_intersection(l, circle)
|
| 560 |
+
|
| 561 |
+
result = []
|
| 562 |
+
if _check_between(px, p1, p2):
|
| 563 |
+
result.append(px)
|
| 564 |
+
if _check_between(py, p1, p2):
|
| 565 |
+
result.append(py)
|
| 566 |
+
return result
|
| 567 |
+
|
| 568 |
+
|
| 569 |
+
def line_segment_intersection(l: Line, A: Point, B: Point) -> Point: # pylint: disable=invalid-name
|
| 570 |
+
a, b, c = l.coefficients
|
| 571 |
+
x1, y1, x2, y2 = A.x, A.y, B.x, B.y
|
| 572 |
+
dx, dy = x2 - x1, y2 - y1
|
| 573 |
+
alpha = (-c - a * x1 - b * y1) / (a * dx + b * dy)
|
| 574 |
+
return Point(x1 + alpha * dx, y1 + alpha * dy)
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
def line_line_intersection(l1: Line, l2: Line) -> Point:
|
| 578 |
+
a1, b1, c1 = l1.coefficients
|
| 579 |
+
a2, b2, c2 = l2.coefficients
|
| 580 |
+
# a1x + b1y + c1 = 0
|
| 581 |
+
# a2x + b2y + c2 = 0
|
| 582 |
+
d = a1 * b2 - a2 * b1
|
| 583 |
+
if d == 0:
|
| 584 |
+
raise InvalidLineIntersectError
|
| 585 |
+
return Point((c2 * b1 - c1 * b2) / d, (c1 * a2 - c2 * a1) / d)
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
def check_too_close(
|
| 589 |
+
newpoints: list[Point], points: list[Point], tol: int = 0.1
|
| 590 |
+
) -> bool:
|
| 591 |
+
if not points:
|
| 592 |
+
return False
|
| 593 |
+
avg = sum(points, Point(0.0, 0.0)) * 1.0 / len(points)
|
| 594 |
+
mindist = min([p.distance(avg) for p in points])
|
| 595 |
+
for p0 in newpoints:
|
| 596 |
+
for p1 in points:
|
| 597 |
+
if p0.distance(p1) < tol * mindist:
|
| 598 |
+
return True
|
| 599 |
+
return False
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
def check_too_far(
|
| 603 |
+
newpoints: list[Point], points: list[Point], tol: int = 4
|
| 604 |
+
) -> bool:
|
| 605 |
+
if len(points) < 2:
|
| 606 |
+
return False
|
| 607 |
+
avg = sum(points, Point(0.0, 0.0)) * 1.0 / len(points)
|
| 608 |
+
maxdist = max([p.distance(avg) for p in points])
|
| 609 |
+
for p in newpoints:
|
| 610 |
+
if p.distance(avg) > maxdist * tol:
|
| 611 |
+
return True
|
| 612 |
+
return False
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def check_aconst(args: list[Point]) -> bool:
|
| 616 |
+
a, b, c, d, num, den = args
|
| 617 |
+
d = d + a - c
|
| 618 |
+
ang = ang_between(a, b, d)
|
| 619 |
+
if ang < 0:
|
| 620 |
+
ang += np.pi
|
| 621 |
+
return close_enough(ang, num * np.pi / den)
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def check(name: str, args: list[Union[gm.Point, Point]]) -> bool:
|
| 625 |
+
"""Numerical check."""
|
| 626 |
+
if name == 'eqangle6':
|
| 627 |
+
name = 'eqangle'
|
| 628 |
+
elif name == 'eqratio6':
|
| 629 |
+
name = 'eqratio'
|
| 630 |
+
elif name in ['simtri2', 'simtri*']:
|
| 631 |
+
name = 'simtri'
|
| 632 |
+
elif name in ['contri2', 'contri*']:
|
| 633 |
+
name = 'contri'
|
| 634 |
+
elif name == 'para':
|
| 635 |
+
name = 'para_or_coll'
|
| 636 |
+
elif name == 'on_line':
|
| 637 |
+
name = 'coll'
|
| 638 |
+
elif name in ['rcompute', 'acompute']:
|
| 639 |
+
return True
|
| 640 |
+
elif name in ['fixl', 'fixc', 'fixb', 'fixt', 'fixp']:
|
| 641 |
+
return True
|
| 642 |
+
|
| 643 |
+
fn_name = 'check_' + name
|
| 644 |
+
if fn_name not in globals():
|
| 645 |
+
return None
|
| 646 |
+
|
| 647 |
+
fun = globals()['check_' + name]
|
| 648 |
+
args = [p.num if isinstance(p, gm.Point) else p for p in args]
|
| 649 |
+
return fun(args)
|
| 650 |
+
|
| 651 |
+
|
| 652 |
+
def check_circle(points: list[Point]) -> bool:
|
| 653 |
+
if len(points) != 4:
|
| 654 |
+
return False
|
| 655 |
+
o, a, b, c = points
|
| 656 |
+
oa, ob, oc = o.distance(a), o.distance(b), o.distance(c)
|
| 657 |
+
return close_enough(oa, ob) and close_enough(ob, oc)
|
| 658 |
+
|
| 659 |
+
|
| 660 |
+
def check_coll(points: list[Point]) -> bool:
|
| 661 |
+
a, b = points[:2]
|
| 662 |
+
l = Line(a, b)
|
| 663 |
+
for p in points[2:]:
|
| 664 |
+
if abs(l(p.x, p.y)) > ATOM:
|
| 665 |
+
return False
|
| 666 |
+
return True
|
| 667 |
+
|
| 668 |
+
|
| 669 |
+
def check_ncoll(points: list[Point]) -> bool:
|
| 670 |
+
return not check_coll(points)
|
| 671 |
+
|
| 672 |
+
|
| 673 |
+
def check_sameside(points: list[Point]) -> bool:
|
| 674 |
+
b, a, c, y, x, z = points
|
| 675 |
+
# whether b is to the same side of a & c as y is to x & z
|
| 676 |
+
ba = b - a
|
| 677 |
+
bc = b - c
|
| 678 |
+
yx = y - x
|
| 679 |
+
yz = y - z
|
| 680 |
+
return ba.dot(bc) * yx.dot(yz) > 0
|
| 681 |
+
|
| 682 |
+
|
| 683 |
+
def check_para_or_coll(points: list[Point]) -> bool:
|
| 684 |
+
return check_para(points) or check_coll(points)
|
| 685 |
+
|
| 686 |
+
|
| 687 |
+
def check_para(points: list[Point]) -> bool:
|
| 688 |
+
a, b, c, d = points
|
| 689 |
+
ab = Line(a, b)
|
| 690 |
+
cd = Line(c, d)
|
| 691 |
+
if ab.same(cd):
|
| 692 |
+
return False
|
| 693 |
+
return ab.is_parallel(cd)
|
| 694 |
+
|
| 695 |
+
|
| 696 |
+
def check_perp(points: list[Point]) -> bool:
|
| 697 |
+
a, b, c, d = points
|
| 698 |
+
ab = Line(a, b)
|
| 699 |
+
cd = Line(c, d)
|
| 700 |
+
return ab.is_perp(cd)
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def check_cyclic(points: list[Point]) -> bool:
|
| 704 |
+
points = list(set(points))
|
| 705 |
+
(a, b, c), *ps = points
|
| 706 |
+
circle = Circle(p1=a, p2=b, p3=c)
|
| 707 |
+
for d in ps:
|
| 708 |
+
if not close_enough(d.distance(circle.center), circle.radius):
|
| 709 |
+
return False
|
| 710 |
+
return True
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def bring_together(
|
| 714 |
+
a: Point, b: Point, c: Point, d: Point
|
| 715 |
+
) -> tuple[Point, Point, Point, Point]:
|
| 716 |
+
ab = Line(a, b)
|
| 717 |
+
cd = Line(c, d)
|
| 718 |
+
x = line_line_intersection(ab, cd)
|
| 719 |
+
unit = Circle(center=x, radius=1.0)
|
| 720 |
+
y, _ = line_circle_intersection(ab, unit)
|
| 721 |
+
z, _ = line_circle_intersection(cd, unit)
|
| 722 |
+
return x, y, x, z
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
def same_clock(
|
| 726 |
+
a: Point, b: Point, c: Point, d: Point, e: Point, f: Point
|
| 727 |
+
) -> bool:
|
| 728 |
+
ba = b - a
|
| 729 |
+
cb = c - b
|
| 730 |
+
ed = e - d
|
| 731 |
+
fe = f - e
|
| 732 |
+
return (ba.x * cb.y - ba.y * cb.x) * (ed.x * fe.y - ed.y * fe.x) > 0
|
| 733 |
+
|
| 734 |
+
|
| 735 |
+
def check_const_angle(points: list[Point]) -> bool:
|
| 736 |
+
"""Check if the angle is equal to the given constant."""
|
| 737 |
+
a, b, c, d, m, n = points
|
| 738 |
+
a, b, c, d = bring_together(a, b, c, d)
|
| 739 |
+
ba = b - a
|
| 740 |
+
dc = d - c
|
| 741 |
+
|
| 742 |
+
a3 = np.arctan2(ba.y, ba.x)
|
| 743 |
+
a4 = np.arctan2(dc.y, dc.x)
|
| 744 |
+
y = a3 - a4
|
| 745 |
+
|
| 746 |
+
return close_enough(m / n % 1, y / np.pi % 1)
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
def check_eqangle(points: list[Point]) -> bool:
|
| 750 |
+
"""Check if 8 points make 2 equal angles."""
|
| 751 |
+
a, b, c, d, e, f, g, h = points
|
| 752 |
+
|
| 753 |
+
ab = Line(a, b)
|
| 754 |
+
cd = Line(c, d)
|
| 755 |
+
ef = Line(e, f)
|
| 756 |
+
gh = Line(g, h)
|
| 757 |
+
|
| 758 |
+
if ab.is_parallel(cd):
|
| 759 |
+
return ef.is_parallel(gh)
|
| 760 |
+
if ef.is_parallel(gh):
|
| 761 |
+
return ab.is_parallel(cd)
|
| 762 |
+
|
| 763 |
+
a, b, c, d = bring_together(a, b, c, d)
|
| 764 |
+
e, f, g, h = bring_together(e, f, g, h)
|
| 765 |
+
|
| 766 |
+
ba = b - a
|
| 767 |
+
dc = d - c
|
| 768 |
+
fe = f - e
|
| 769 |
+
hg = h - g
|
| 770 |
+
|
| 771 |
+
sameclock = (ba.x * dc.y - ba.y * dc.x) * (fe.x * hg.y - fe.y * hg.x) > 0
|
| 772 |
+
if not sameclock:
|
| 773 |
+
ba = ba * -1.0
|
| 774 |
+
|
| 775 |
+
a1 = np.arctan2(fe.y, fe.x)
|
| 776 |
+
a2 = np.arctan2(hg.y, hg.x)
|
| 777 |
+
x = a1 - a2
|
| 778 |
+
|
| 779 |
+
a3 = np.arctan2(ba.y, ba.x)
|
| 780 |
+
a4 = np.arctan2(dc.y, dc.x)
|
| 781 |
+
y = a3 - a4
|
| 782 |
+
|
| 783 |
+
xy = (x - y) % (2 * np.pi)
|
| 784 |
+
return close_enough(xy, 0, tol=1e-11) or close_enough(
|
| 785 |
+
xy, 2 * np.pi, tol=1e-11
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
|
| 789 |
+
def check_eqratio(points: list[Point]) -> bool:
|
| 790 |
+
a, b, c, d, e, f, g, h = points
|
| 791 |
+
ab = a.distance(b)
|
| 792 |
+
cd = c.distance(d)
|
| 793 |
+
ef = e.distance(f)
|
| 794 |
+
gh = g.distance(h)
|
| 795 |
+
return close_enough(ab * gh, cd * ef)
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
def check_cong(points: list[Point]) -> bool:
|
| 799 |
+
a, b, c, d = points
|
| 800 |
+
return close_enough(a.distance(b), c.distance(d))
|
| 801 |
+
|
| 802 |
+
|
| 803 |
+
def check_midp(points: list[Point]) -> bool:
|
| 804 |
+
a, b, c = points
|
| 805 |
+
return check_coll(points) and close_enough(a.distance(b), a.distance(c))
|
| 806 |
+
|
| 807 |
+
|
| 808 |
+
def check_simtri(points: list[Point]) -> bool:
|
| 809 |
+
"""Check if 6 points make a pair of similar triangles."""
|
| 810 |
+
a, b, c, x, y, z = points
|
| 811 |
+
ab = a.distance(b)
|
| 812 |
+
bc = b.distance(c)
|
| 813 |
+
ca = c.distance(a)
|
| 814 |
+
xy = x.distance(y)
|
| 815 |
+
yz = y.distance(z)
|
| 816 |
+
zx = z.distance(x)
|
| 817 |
+
tol = 1e-9
|
| 818 |
+
return close_enough(ab * yz, bc * xy, tol) and close_enough(
|
| 819 |
+
bc * zx, ca * yz, tol
|
| 820 |
+
)
|
| 821 |
+
|
| 822 |
+
|
| 823 |
+
def check_contri(points: list[Point]) -> bool:
|
| 824 |
+
a, b, c, x, y, z = points
|
| 825 |
+
ab = a.distance(b)
|
| 826 |
+
bc = b.distance(c)
|
| 827 |
+
ca = c.distance(a)
|
| 828 |
+
xy = x.distance(y)
|
| 829 |
+
yz = y.distance(z)
|
| 830 |
+
zx = z.distance(x)
|
| 831 |
+
tol = 1e-9
|
| 832 |
+
return (
|
| 833 |
+
close_enough(ab, xy, tol)
|
| 834 |
+
and close_enough(bc, yz, tol)
|
| 835 |
+
and close_enough(ca, zx, tol)
|
| 836 |
+
)
|
| 837 |
+
|
| 838 |
+
|
| 839 |
+
def check_ratio(points: list[Point]) -> bool:
|
| 840 |
+
a, b, c, d, m, n = points
|
| 841 |
+
ab = a.distance(b)
|
| 842 |
+
cd = c.distance(d)
|
| 843 |
+
return close_enough(ab * n, cd * m)
|
| 844 |
+
|
| 845 |
+
|
| 846 |
+
def draw_angle(
|
| 847 |
+
ax: matplotlib.axes.Axes,
|
| 848 |
+
head: Point,
|
| 849 |
+
p1: Point,
|
| 850 |
+
p2: Point,
|
| 851 |
+
color: Any = 'red',
|
| 852 |
+
alpha: float = 0.5,
|
| 853 |
+
frac: float = 1.0,
|
| 854 |
+
) -> None:
|
| 855 |
+
"""Draw an angle on plt ax."""
|
| 856 |
+
d1 = p1 - head
|
| 857 |
+
d2 = p2 - head
|
| 858 |
+
|
| 859 |
+
a1 = np.arctan2(float(d1.y), float(d1.x))
|
| 860 |
+
a2 = np.arctan2(float(d2.y), float(d2.x))
|
| 861 |
+
a1, a2 = a1 * 180 / np.pi, a2 * 180 / np.pi
|
| 862 |
+
a1, a2 = a1 % 360, a2 % 360
|
| 863 |
+
|
| 864 |
+
if a1 > a2:
|
| 865 |
+
a1, a2 = a2, a1
|
| 866 |
+
|
| 867 |
+
if a2 - a1 > 180:
|
| 868 |
+
a1, a2 = a2, a1
|
| 869 |
+
|
| 870 |
+
b1, b2 = a1, a2
|
| 871 |
+
if b1 > b2:
|
| 872 |
+
b2 += 360
|
| 873 |
+
d = b2 - b1
|
| 874 |
+
# if d >= 90:
|
| 875 |
+
# return
|
| 876 |
+
|
| 877 |
+
scale = min(2.0, 90 / d)
|
| 878 |
+
scale = max(scale, 0.4)
|
| 879 |
+
fov = matplotlib.patches.Wedge(
|
| 880 |
+
(float(head.x), float(head.y)),
|
| 881 |
+
unif(0.075, 0.125) * scale * frac,
|
| 882 |
+
a1,
|
| 883 |
+
a2,
|
| 884 |
+
color=color,
|
| 885 |
+
alpha=alpha,
|
| 886 |
+
)
|
| 887 |
+
ax.add_artist(fov)
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def naming_position(
|
| 891 |
+
ax: matplotlib.axes.Axes, p: Point, lines: list[Line], circles: list[Circle]
|
| 892 |
+
) -> tuple[float, float]:
|
| 893 |
+
"""Figure out a good naming position on the drawing."""
|
| 894 |
+
_ = ax
|
| 895 |
+
r = 0.08
|
| 896 |
+
c = Circle(center=p, radius=r)
|
| 897 |
+
avoid = []
|
| 898 |
+
for p1, p2 in lines:
|
| 899 |
+
try:
|
| 900 |
+
avoid.extend(circle_segment_intersect(c, p1, p2))
|
| 901 |
+
except InvalidQuadSolveError:
|
| 902 |
+
continue
|
| 903 |
+
for x in circles:
|
| 904 |
+
try:
|
| 905 |
+
avoid.extend(circle_circle_intersection(c, x))
|
| 906 |
+
except InvalidQuadSolveError:
|
| 907 |
+
continue
|
| 908 |
+
|
| 909 |
+
if not avoid:
|
| 910 |
+
return [p.x + 0.01, p.y + 0.01]
|
| 911 |
+
|
| 912 |
+
angs = sorted([ang_of(p, a) for a in avoid])
|
| 913 |
+
angs += [angs[0] + 2 * np.pi]
|
| 914 |
+
angs = [(angs[i + 1] - a, a) for i, a in enumerate(angs[:-1])]
|
| 915 |
+
|
| 916 |
+
d, a = max(angs)
|
| 917 |
+
ang = a + d / 2
|
| 918 |
+
|
| 919 |
+
name_pos = p + Point(np.cos(ang), np.sin(ang)) * r
|
| 920 |
+
|
| 921 |
+
x, y = (name_pos.x - r / 1.5, name_pos.y - r / 1.5)
|
| 922 |
+
return x, y
|
| 923 |
+
|
| 924 |
+
|
| 925 |
+
def draw_point(
|
| 926 |
+
ax: matplotlib.axes.Axes,
|
| 927 |
+
p: Point,
|
| 928 |
+
name: str,
|
| 929 |
+
lines: list[Line],
|
| 930 |
+
circles: list[Circle],
|
| 931 |
+
color: Any = 'white',
|
| 932 |
+
size: float = 15,
|
| 933 |
+
) -> None:
|
| 934 |
+
"""draw a point."""
|
| 935 |
+
ax.scatter(p.x, p.y, color=color, s=size)
|
| 936 |
+
|
| 937 |
+
if color == 'white':
|
| 938 |
+
color = 'lightgreen'
|
| 939 |
+
else:
|
| 940 |
+
color = 'grey'
|
| 941 |
+
|
| 942 |
+
name = name.upper()
|
| 943 |
+
if len(name) > 1:
|
| 944 |
+
name = name[0] + '_' + name[1:]
|
| 945 |
+
|
| 946 |
+
ax.annotate(
|
| 947 |
+
name, naming_position(ax, p, lines, circles), color=color, fontsize=15
|
| 948 |
+
)
|
| 949 |
+
|
| 950 |
+
|
| 951 |
+
def _draw_line(
|
| 952 |
+
ax: matplotlib.axes.Axes,
|
| 953 |
+
p1: Point,
|
| 954 |
+
p2: Point,
|
| 955 |
+
color: Any = 'white',
|
| 956 |
+
lw: float = 1.2,
|
| 957 |
+
alpha: float = 0.8,
|
| 958 |
+
) -> None:
|
| 959 |
+
"""Draw a line in matplotlib."""
|
| 960 |
+
ls = '-'
|
| 961 |
+
if color == '--':
|
| 962 |
+
color = 'black'
|
| 963 |
+
ls = '--'
|
| 964 |
+
|
| 965 |
+
lx, ly = (p1.x, p2.x), (p1.y, p2.y)
|
| 966 |
+
ax.plot(lx, ly, color=color, lw=lw, alpha=alpha, ls=ls)
|
| 967 |
+
|
| 968 |
+
|
| 969 |
+
def draw_line(
|
| 970 |
+
ax: matplotlib.axes.Axes, line: Line, color: Any = 'white'
|
| 971 |
+
) -> tuple[Point, Point]:
|
| 972 |
+
"""Draw a line."""
|
| 973 |
+
points = line.neighbors(gm.Point)
|
| 974 |
+
if len(points) <= 1:
|
| 975 |
+
return
|
| 976 |
+
|
| 977 |
+
points = [p.num for p in points]
|
| 978 |
+
p1, p2 = points[:2]
|
| 979 |
+
|
| 980 |
+
pmin, pmax = (p1, 0.0), (p2, (p2 - p1).dot(p2 - p1))
|
| 981 |
+
|
| 982 |
+
for p in points[2:]:
|
| 983 |
+
v = (p - p1).dot(p2 - p1)
|
| 984 |
+
if v < pmin[1]:
|
| 985 |
+
pmin = p, v
|
| 986 |
+
if v > pmax[1]:
|
| 987 |
+
pmax = p, v
|
| 988 |
+
|
| 989 |
+
p1, p2 = pmin[0], pmax[0]
|
| 990 |
+
_draw_line(ax, p1, p2, color=color)
|
| 991 |
+
return p1, p2
|
| 992 |
+
|
| 993 |
+
|
| 994 |
+
def _draw_circle(
|
| 995 |
+
ax: matplotlib.axes.Axes, c: Circle, color: Any = 'cyan', lw: float = 1.2
|
| 996 |
+
) -> None:
|
| 997 |
+
ls = '-'
|
| 998 |
+
if color == '--':
|
| 999 |
+
color = 'black'
|
| 1000 |
+
ls = '--'
|
| 1001 |
+
|
| 1002 |
+
ax.add_patch(
|
| 1003 |
+
plt.Circle(
|
| 1004 |
+
(c.center.x, c.center.y),
|
| 1005 |
+
c.radius,
|
| 1006 |
+
color=color,
|
| 1007 |
+
alpha=0.8,
|
| 1008 |
+
fill=False,
|
| 1009 |
+
lw=lw,
|
| 1010 |
+
ls=ls,
|
| 1011 |
+
)
|
| 1012 |
+
)
|
| 1013 |
+
|
| 1014 |
+
|
| 1015 |
+
def draw_circle(
|
| 1016 |
+
ax: matplotlib.axes.Axes, circle: Circle, color: Any = 'cyan'
|
| 1017 |
+
) -> Circle:
|
| 1018 |
+
"""Draw a circle."""
|
| 1019 |
+
if circle.num is not None:
|
| 1020 |
+
circle = circle.num
|
| 1021 |
+
else:
|
| 1022 |
+
points = circle.neighbors(gm.Point)
|
| 1023 |
+
if len(points) <= 2:
|
| 1024 |
+
return
|
| 1025 |
+
points = [p.num for p in points]
|
| 1026 |
+
p1, p2, p3 = points[:3]
|
| 1027 |
+
circle = Circle(p1=p1, p2=p2, p3=p3)
|
| 1028 |
+
|
| 1029 |
+
_draw_circle(ax, circle, color)
|
| 1030 |
+
return circle
|
| 1031 |
+
|
| 1032 |
+
|
| 1033 |
+
def mark_segment(
|
| 1034 |
+
ax: matplotlib.axes.Axes, p1: Point, p2: Point, color: Any, alpha: float
|
| 1035 |
+
) -> None:
|
| 1036 |
+
_ = alpha
|
| 1037 |
+
x, y = (p1.x + p2.x) / 2, (p1.y + p2.y) / 2
|
| 1038 |
+
ax.scatter(x, y, color=color, alpha=1.0, marker='o', s=50)
|
| 1039 |
+
|
| 1040 |
+
|
| 1041 |
+
def highlight_angle(
|
| 1042 |
+
ax: matplotlib.axes.Axes,
|
| 1043 |
+
a: Point,
|
| 1044 |
+
b: Point,
|
| 1045 |
+
c: Point,
|
| 1046 |
+
d: Point,
|
| 1047 |
+
color: Any,
|
| 1048 |
+
alpha: float,
|
| 1049 |
+
) -> None:
|
| 1050 |
+
"""Highlight an angle between ab and cd with (color, alpha)."""
|
| 1051 |
+
try:
|
| 1052 |
+
a, b, c, d = bring_together(a, b, c, d)
|
| 1053 |
+
except: # pylint: disable=bare-except
|
| 1054 |
+
return
|
| 1055 |
+
draw_angle(ax, a, b, d, color=color, alpha=alpha, frac=1.0)
|
| 1056 |
+
|
| 1057 |
+
|
| 1058 |
+
def highlight(
|
| 1059 |
+
ax: matplotlib.axes.Axes,
|
| 1060 |
+
name: str,
|
| 1061 |
+
args: list[gm.Point],
|
| 1062 |
+
lcolor: Any,
|
| 1063 |
+
color1: Any,
|
| 1064 |
+
color2: Any,
|
| 1065 |
+
) -> None:
|
| 1066 |
+
"""Draw highlights."""
|
| 1067 |
+
args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
|
| 1068 |
+
|
| 1069 |
+
if name == 'cyclic':
|
| 1070 |
+
a, b, c, d = args
|
| 1071 |
+
_draw_circle(ax, Circle(p1=a, p2=b, p3=c), color=color1, lw=2.0)
|
| 1072 |
+
if name == 'coll':
|
| 1073 |
+
a, b, c = args
|
| 1074 |
+
a, b = max(a, b, c), min(a, b, c)
|
| 1075 |
+
_draw_line(ax, a, b, color=color1, lw=2.0)
|
| 1076 |
+
if name == 'para':
|
| 1077 |
+
a, b, c, d = args
|
| 1078 |
+
_draw_line(ax, a, b, color=color1, lw=2.0)
|
| 1079 |
+
_draw_line(ax, c, d, color=color2, lw=2.0)
|
| 1080 |
+
if name == 'eqangle':
|
| 1081 |
+
a, b, c, d, e, f, g, h = args
|
| 1082 |
+
|
| 1083 |
+
x = line_line_intersection(Line(a, b), Line(c, d))
|
| 1084 |
+
if b.distance(x) > a.distance(x):
|
| 1085 |
+
a, b = b, a
|
| 1086 |
+
if d.distance(x) > c.distance(x):
|
| 1087 |
+
c, d = d, c
|
| 1088 |
+
a, b, d = x, a, c
|
| 1089 |
+
|
| 1090 |
+
y = line_line_intersection(Line(e, f), Line(g, h))
|
| 1091 |
+
if f.distance(y) > e.distance(y):
|
| 1092 |
+
e, f = f, e
|
| 1093 |
+
if h.distance(y) > g.distance(y):
|
| 1094 |
+
g, h = h, g
|
| 1095 |
+
e, f, h = y, e, g
|
| 1096 |
+
|
| 1097 |
+
_draw_line(ax, a, b, color=lcolor, lw=2.0)
|
| 1098 |
+
_draw_line(ax, a, d, color=lcolor, lw=2.0)
|
| 1099 |
+
_draw_line(ax, e, f, color=lcolor, lw=2.0)
|
| 1100 |
+
_draw_line(ax, e, h, color=lcolor, lw=2.0)
|
| 1101 |
+
if color1 == '--':
|
| 1102 |
+
color1 = 'red'
|
| 1103 |
+
draw_angle(ax, a, b, d, color=color1, alpha=0.5)
|
| 1104 |
+
if color2 == '--':
|
| 1105 |
+
color2 = 'red'
|
| 1106 |
+
draw_angle(ax, e, f, h, color=color2, alpha=0.5)
|
| 1107 |
+
if name == 'perp':
|
| 1108 |
+
a, b, c, d = args
|
| 1109 |
+
_draw_line(ax, a, b, color=color1, lw=2.0)
|
| 1110 |
+
_draw_line(ax, c, d, color=color1, lw=2.0)
|
| 1111 |
+
if name == 'ratio':
|
| 1112 |
+
a, b, c, d, m, n = args
|
| 1113 |
+
_draw_line(ax, a, b, color=color1, lw=2.0)
|
| 1114 |
+
_draw_line(ax, c, d, color=color2, lw=2.0)
|
| 1115 |
+
if name == 'cong':
|
| 1116 |
+
a, b, c, d = args
|
| 1117 |
+
_draw_line(ax, a, b, color=color1, lw=2.0)
|
| 1118 |
+
_draw_line(ax, c, d, color=color2, lw=2.0)
|
| 1119 |
+
if name == 'midp':
|
| 1120 |
+
m, a, b = args
|
| 1121 |
+
_draw_line(ax, a, m, color=color1, lw=2.0, alpha=0.5)
|
| 1122 |
+
_draw_line(ax, b, m, color=color2, lw=2.0, alpha=0.5)
|
| 1123 |
+
if name == 'eqratio':
|
| 1124 |
+
a, b, c, d, m, n, p, q = args
|
| 1125 |
+
_draw_line(ax, a, b, color=color1, lw=2.0, alpha=0.5)
|
| 1126 |
+
_draw_line(ax, c, d, color=color2, lw=2.0, alpha=0.5)
|
| 1127 |
+
_draw_line(ax, m, n, color=color1, lw=2.0, alpha=0.5)
|
| 1128 |
+
_draw_line(ax, p, q, color=color2, lw=2.0, alpha=0.5)
|
| 1129 |
+
|
| 1130 |
+
|
| 1131 |
+
HCOLORS = None
|
| 1132 |
+
|
| 1133 |
+
|
| 1134 |
+
def _draw(
|
| 1135 |
+
ax: matplotlib.axes.Axes,
|
| 1136 |
+
points: list[gm.Point],
|
| 1137 |
+
lines: list[gm.Line],
|
| 1138 |
+
circles: list[gm.Circle],
|
| 1139 |
+
goal: Any,
|
| 1140 |
+
equals: list[tuple[Any, Any]],
|
| 1141 |
+
highlights: list[tuple[str, list[gm.Point]]],
|
| 1142 |
+
):
|
| 1143 |
+
"""Draw everything."""
|
| 1144 |
+
colors = ['red', 'green', 'blue', 'orange', 'magenta', 'purple']
|
| 1145 |
+
pcolor = 'black'
|
| 1146 |
+
lcolor = 'black'
|
| 1147 |
+
ccolor = 'grey'
|
| 1148 |
+
if get_theme() == 'dark':
|
| 1149 |
+
pcolor, lcolor, ccolor = 'white', 'white', 'cyan'
|
| 1150 |
+
elif get_theme() == 'light':
|
| 1151 |
+
pcolor, lcolor, ccolor = 'black', 'black', 'blue'
|
| 1152 |
+
elif get_theme() == 'grey':
|
| 1153 |
+
pcolor, lcolor, ccolor = 'black', 'black', 'grey'
|
| 1154 |
+
colors = ['grey']
|
| 1155 |
+
|
| 1156 |
+
line_boundaries = []
|
| 1157 |
+
for l in lines:
|
| 1158 |
+
p1, p2 = draw_line(ax, l, color=lcolor)
|
| 1159 |
+
line_boundaries.append((p1, p2))
|
| 1160 |
+
circles = [draw_circle(ax, c, color=ccolor) for c in circles]
|
| 1161 |
+
|
| 1162 |
+
for p in points:
|
| 1163 |
+
draw_point(ax, p.num, p.name, line_boundaries, circles, color=pcolor)
|
| 1164 |
+
|
| 1165 |
+
if equals:
|
| 1166 |
+
for i, segs in enumerate(equals['segments']):
|
| 1167 |
+
color = colors[i % len(colors)]
|
| 1168 |
+
for a, b in segs:
|
| 1169 |
+
mark_segment(ax, a, b, color, 0.5)
|
| 1170 |
+
|
| 1171 |
+
for i, angs in enumerate(equals['angles']):
|
| 1172 |
+
color = colors[i % len(colors)]
|
| 1173 |
+
for a, b, c, d in angs:
|
| 1174 |
+
highlight_angle(ax, a, b, c, d, color, 0.5)
|
| 1175 |
+
|
| 1176 |
+
if highlights:
|
| 1177 |
+
global HCOLORS
|
| 1178 |
+
if HCOLORS is None:
|
| 1179 |
+
HCOLORS = [k for k in mcolors.TABLEAU_COLORS.keys() if 'red' not in k]
|
| 1180 |
+
|
| 1181 |
+
for i, (name, args) in enumerate(highlights):
|
| 1182 |
+
color_i = HCOLORS[i % len(HCOLORS)]
|
| 1183 |
+
highlight(ax, name, args, 'black', color_i, color_i)
|
| 1184 |
+
|
| 1185 |
+
if goal:
|
| 1186 |
+
name, args = goal
|
| 1187 |
+
lcolor = color1 = color2 = 'red'
|
| 1188 |
+
highlight(ax, name, args, lcolor, color1, color2)
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
THEME = 'dark'
|
| 1192 |
+
|
| 1193 |
+
|
| 1194 |
+
def set_theme(theme) -> None:
|
| 1195 |
+
global THEME
|
| 1196 |
+
THEME = theme
|
| 1197 |
+
|
| 1198 |
+
|
| 1199 |
+
def get_theme() -> str:
|
| 1200 |
+
return THEME
|
| 1201 |
+
|
| 1202 |
+
|
| 1203 |
+
def draw(
|
| 1204 |
+
points: list[gm.Point],
|
| 1205 |
+
lines: list[gm.Line],
|
| 1206 |
+
circles: list[gm.Circle],
|
| 1207 |
+
segments: list[gm.Segment],
|
| 1208 |
+
goal: Any = None,
|
| 1209 |
+
highlights: list[tuple[str, list[gm.Point]]] = None,
|
| 1210 |
+
equals: list[tuple[Any, Any]] = None,
|
| 1211 |
+
block: bool = True,
|
| 1212 |
+
save_to: str = None,
|
| 1213 |
+
theme: str = 'dark',
|
| 1214 |
+
) -> None:
|
| 1215 |
+
"""Draw everything on the same canvas."""
|
| 1216 |
+
plt.close()
|
| 1217 |
+
imsize = 512 / 100
|
| 1218 |
+
fig, ax = plt.subplots(figsize=(imsize, imsize), dpi=100)
|
| 1219 |
+
|
| 1220 |
+
set_theme(theme)
|
| 1221 |
+
|
| 1222 |
+
if get_theme() == 'dark':
|
| 1223 |
+
ax.set_facecolor((0.0, 0.0, 0.0))
|
| 1224 |
+
else:
|
| 1225 |
+
ax.set_facecolor((1.0, 1.0, 1.0))
|
| 1226 |
+
|
| 1227 |
+
_draw(ax, points, lines, circles, goal, equals, highlights)
|
| 1228 |
+
|
| 1229 |
+
plt.axis('equal')
|
| 1230 |
+
fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
|
| 1231 |
+
if points:
|
| 1232 |
+
xmin = min([p.num.x for p in points])
|
| 1233 |
+
xmax = max([p.num.x for p in points])
|
| 1234 |
+
ymin = min([p.num.y for p in points])
|
| 1235 |
+
ymax = max([p.num.y for p in points])
|
| 1236 |
+
plt.margins((xmax - xmin) * 0.1, (ymax - ymin) * 0.1)
|
| 1237 |
+
|
| 1238 |
+
plt.show(block=block)
|
| 1239 |
+
|
| 1240 |
+
|
| 1241 |
+
def close_enough(a: float, b: float, tol: float = 1e-12) -> bool:
|
| 1242 |
+
return abs(a - b) < tol
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
def assert_close_enough(a: float, b: float, tol: float = 1e-12) -> None:
|
| 1246 |
+
assert close_enough(a, b, tol), f'|{a}-{b}| = {abs(a-b)} >= {tol}'
|
| 1247 |
+
|
| 1248 |
+
|
| 1249 |
+
def ang_of(tail: Point, head: Point) -> float:
|
| 1250 |
+
vector = head - tail
|
| 1251 |
+
arctan = np.arctan2(vector.y, vector.x) % (2 * np.pi)
|
| 1252 |
+
return arctan
|
| 1253 |
+
|
| 1254 |
+
|
| 1255 |
+
def ang_between(tail: Point, head1: Point, head2: Point) -> float:
|
| 1256 |
+
ang1 = ang_of(tail, head1)
|
| 1257 |
+
ang2 = ang_of(tail, head2)
|
| 1258 |
+
diff = ang1 - ang2
|
| 1259 |
+
# return diff % (2*np.pi)
|
| 1260 |
+
if diff > np.pi:
|
| 1261 |
+
return diff - 2 * np.pi
|
| 1262 |
+
if diff < -np.pi:
|
| 1263 |
+
return 2 * np.pi + diff
|
| 1264 |
+
return diff
|
| 1265 |
+
|
| 1266 |
+
|
| 1267 |
+
def head_from(tail: Point, ang: float, length: float = 1) -> Point:
|
| 1268 |
+
vector = Point(np.cos(ang) * length, np.sin(ang) * length)
|
| 1269 |
+
return tail + vector
|
| 1270 |
+
|
| 1271 |
+
|
| 1272 |
+
def random_points(n: int = 3) -> list[Point]:
|
| 1273 |
+
return [Point(unif(-1, 1), unif(-1, 1)) for _ in range(n)]
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
def random_rfss(*points: list[Point]) -> list[Point]:
|
| 1277 |
+
"""Random rotate-flip-scale-shift a point cloud."""
|
| 1278 |
+
# center point cloud.
|
| 1279 |
+
average = sum(points, Point(0.0, 0.0)) * (1.0 / len(points))
|
| 1280 |
+
points = [p - average for p in points]
|
| 1281 |
+
|
| 1282 |
+
# rotate
|
| 1283 |
+
ang = unif(0.0, 2 * np.pi)
|
| 1284 |
+
sin, cos = np.sin(ang), np.cos(ang)
|
| 1285 |
+
# scale and shift
|
| 1286 |
+
scale = unif(0.5, 2.0)
|
| 1287 |
+
shift = Point(unif(-1, 1), unif(-1, 1))
|
| 1288 |
+
points = [p.rotate(sin, cos) * scale + shift for p in points]
|
| 1289 |
+
|
| 1290 |
+
# randomly flip
|
| 1291 |
+
if np.random.rand() < 0.5:
|
| 1292 |
+
points = [p.flip() for p in points]
|
| 1293 |
+
|
| 1294 |
+
return points
|
| 1295 |
+
|
| 1296 |
+
|
| 1297 |
+
def reduce(
|
| 1298 |
+
objs: list[Union[Point, Line, Circle, HalfLine, HoleCircle]],
|
| 1299 |
+
existing_points: list[Point],
|
| 1300 |
+
) -> list[Point]:
|
| 1301 |
+
"""Reduce intersecting objects into one point of intersections."""
|
| 1302 |
+
if all(isinstance(o, Point) for o in objs):
|
| 1303 |
+
return objs
|
| 1304 |
+
|
| 1305 |
+
elif len(objs) == 1:
|
| 1306 |
+
return objs[0].sample_within(existing_points)
|
| 1307 |
+
|
| 1308 |
+
elif len(objs) == 2:
|
| 1309 |
+
a, b = objs
|
| 1310 |
+
result = a.intersect(b)
|
| 1311 |
+
if isinstance(result, Point):
|
| 1312 |
+
return [result]
|
| 1313 |
+
a, b = result
|
| 1314 |
+
a_close = any([a.close(x) for x in existing_points])
|
| 1315 |
+
if a_close:
|
| 1316 |
+
return [b]
|
| 1317 |
+
b_close = any([b.close(x) for x in existing_points])
|
| 1318 |
+
if b_close:
|
| 1319 |
+
return [a]
|
| 1320 |
+
return [np.random.choice([a, b])]
|
| 1321 |
+
|
| 1322 |
+
else:
|
| 1323 |
+
raise ValueError(f'Cannot reduce {objs}')
|
| 1324 |
+
|
| 1325 |
+
|
| 1326 |
+
def sketch(
|
| 1327 |
+
name: str, args: list[Union[Point, gm.Point]]
|
| 1328 |
+
) -> list[Union[Point, Line, Circle, HalfLine, HoleCircle]]:
|
| 1329 |
+
fun = globals()['sketch_' + name]
|
| 1330 |
+
args = [p.num if isinstance(p, gm.Point) else p for p in args]
|
| 1331 |
+
out = fun(args)
|
| 1332 |
+
|
| 1333 |
+
# out can be one or multiple {Point/Line/HalfLine}
|
| 1334 |
+
if isinstance(out, (tuple, list)):
|
| 1335 |
+
return list(out)
|
| 1336 |
+
return [out]
|
| 1337 |
+
|
| 1338 |
+
|
| 1339 |
+
def sketch_on_opline(args: tuple[gm.Point, ...]) -> HalfLine:
|
| 1340 |
+
a, b = args
|
| 1341 |
+
return HalfLine(a, a + a - b)
|
| 1342 |
+
|
| 1343 |
+
|
| 1344 |
+
def sketch_on_hline(args: tuple[gm.Point, ...]) -> HalfLine:
|
| 1345 |
+
a, b = args
|
| 1346 |
+
return HalfLine(a, b)
|
| 1347 |
+
|
| 1348 |
+
|
| 1349 |
+
def sketch_ieq_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1350 |
+
a = Point(0.0, 0.0)
|
| 1351 |
+
b = Point(1.0, 0.0)
|
| 1352 |
+
|
| 1353 |
+
c, _ = Circle(a, p1=b).intersect(Circle(b, p1=a))
|
| 1354 |
+
return a, b, c
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
def sketch_incenter2(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1358 |
+
a, b, c = args
|
| 1359 |
+
l1 = sketch_bisect([b, a, c])
|
| 1360 |
+
l2 = sketch_bisect([a, b, c])
|
| 1361 |
+
i = line_line_intersection(l1, l2)
|
| 1362 |
+
x = i.foot(Line(b, c))
|
| 1363 |
+
y = i.foot(Line(c, a))
|
| 1364 |
+
z = i.foot(Line(a, b))
|
| 1365 |
+
return x, y, z, i
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
def sketch_excenter2(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1369 |
+
a, b, c = args
|
| 1370 |
+
l1 = sketch_bisect([b, a, c])
|
| 1371 |
+
l2 = sketch_exbisect([a, b, c])
|
| 1372 |
+
i = line_line_intersection(l1, l2)
|
| 1373 |
+
x = i.foot(Line(b, c))
|
| 1374 |
+
y = i.foot(Line(c, a))
|
| 1375 |
+
z = i.foot(Line(a, b))
|
| 1376 |
+
return x, y, z, i
|
| 1377 |
+
|
| 1378 |
+
|
| 1379 |
+
def sketch_centroid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1380 |
+
a, b, c = args
|
| 1381 |
+
x = (b + c) * 0.5
|
| 1382 |
+
y = (c + a) * 0.5
|
| 1383 |
+
z = (a + b) * 0.5
|
| 1384 |
+
i = line_line_intersection(Line(a, x), Line(b, y))
|
| 1385 |
+
return x, y, z, i
|
| 1386 |
+
|
| 1387 |
+
|
| 1388 |
+
def sketch_ninepoints(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1389 |
+
a, b, c = args
|
| 1390 |
+
x = (b + c) * 0.5
|
| 1391 |
+
y = (c + a) * 0.5
|
| 1392 |
+
z = (a + b) * 0.5
|
| 1393 |
+
c = Circle(p1=x, p2=y, p3=z)
|
| 1394 |
+
return x, y, z, c.center
|
| 1395 |
+
|
| 1396 |
+
|
| 1397 |
+
def sketch_2l1c(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1398 |
+
"""Sketch a circle touching two lines and another circle."""
|
| 1399 |
+
a, b, c, p = args
|
| 1400 |
+
bc, ac = Line(b, c), Line(a, c)
|
| 1401 |
+
circle = Circle(p, p1=a)
|
| 1402 |
+
|
| 1403 |
+
d, d_ = line_circle_intersection(p.perpendicular_line(bc), circle)
|
| 1404 |
+
if bc.diff_side(d_, a):
|
| 1405 |
+
d = d_
|
| 1406 |
+
|
| 1407 |
+
e, e_ = line_circle_intersection(p.perpendicular_line(ac), circle)
|
| 1408 |
+
if ac.diff_side(e_, b):
|
| 1409 |
+
e = e_
|
| 1410 |
+
|
| 1411 |
+
df = d.perpendicular_line(Line(p, d))
|
| 1412 |
+
ef = e.perpendicular_line(Line(p, e))
|
| 1413 |
+
f = line_line_intersection(df, ef)
|
| 1414 |
+
|
| 1415 |
+
g, g_ = line_circle_intersection(Line(c, f), circle)
|
| 1416 |
+
if bc.same_side(g_, a):
|
| 1417 |
+
g = g_
|
| 1418 |
+
|
| 1419 |
+
b_ = c + (b - c) / b.distance(c)
|
| 1420 |
+
a_ = c + (a - c) / a.distance(c)
|
| 1421 |
+
m = (a_ + b_) * 0.5
|
| 1422 |
+
x = line_line_intersection(Line(c, m), Line(p, g))
|
| 1423 |
+
return x.foot(ac), x.foot(bc), g, x
|
| 1424 |
+
|
| 1425 |
+
|
| 1426 |
+
def sketch_3peq(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1427 |
+
a, b, c = args
|
| 1428 |
+
ab, bc, ca = Line(a, b), Line(b, c), Line(c, a)
|
| 1429 |
+
|
| 1430 |
+
z = b + (c - b) * np.random.uniform(-0.5, 1.5)
|
| 1431 |
+
|
| 1432 |
+
z_ = z * 2 - c
|
| 1433 |
+
l = z_.parallel_line(ca)
|
| 1434 |
+
x = line_line_intersection(l, ab)
|
| 1435 |
+
y = z * 2 - x
|
| 1436 |
+
return x, y, z
|
| 1437 |
+
|
| 1438 |
+
|
| 1439 |
+
def try_to_sketch_intersect(
|
| 1440 |
+
name1: str,
|
| 1441 |
+
args1: list[Union[gm.Point, Point]],
|
| 1442 |
+
name2: str,
|
| 1443 |
+
args2: list[Union[gm.Point, Point]],
|
| 1444 |
+
existing_points: list[Point],
|
| 1445 |
+
) -> Optional[Point]:
|
| 1446 |
+
"""Try to sketch an intersection between two objects."""
|
| 1447 |
+
obj1 = sketch(name1, args1)[0]
|
| 1448 |
+
obj2 = sketch(name2, args2)[0]
|
| 1449 |
+
|
| 1450 |
+
if isinstance(obj1, Line) and isinstance(obj2, Line):
|
| 1451 |
+
fn = line_line_intersection
|
| 1452 |
+
elif isinstance(obj1, Circle) and isinstance(obj2, Circle):
|
| 1453 |
+
fn = circle_circle_intersection
|
| 1454 |
+
else:
|
| 1455 |
+
fn = line_circle_intersection
|
| 1456 |
+
if isinstance(obj2, Line) and isinstance(obj1, Circle):
|
| 1457 |
+
obj1, obj2 = obj2, obj1
|
| 1458 |
+
|
| 1459 |
+
try:
|
| 1460 |
+
x = fn(obj1, obj2)
|
| 1461 |
+
except: # pylint: disable=bare-except
|
| 1462 |
+
return None
|
| 1463 |
+
|
| 1464 |
+
if isinstance(x, Point):
|
| 1465 |
+
return x
|
| 1466 |
+
|
| 1467 |
+
x1, x2 = x
|
| 1468 |
+
|
| 1469 |
+
close1 = check_too_close([x1], existing_points)
|
| 1470 |
+
far1 = check_too_far([x1], existing_points)
|
| 1471 |
+
if not close1 and not far1:
|
| 1472 |
+
return x1
|
| 1473 |
+
close2 = check_too_close([x2], existing_points)
|
| 1474 |
+
far2 = check_too_far([x2], existing_points)
|
| 1475 |
+
if not close2 and not far2:
|
| 1476 |
+
return x2
|
| 1477 |
+
|
| 1478 |
+
return None
|
| 1479 |
+
|
| 1480 |
+
|
| 1481 |
+
def sketch_acircle(args: tuple[gm.Point, ...]) -> Circle:
|
| 1482 |
+
a, b, c, d, f = args
|
| 1483 |
+
de = sketch_aline([c, a, b, f, d])
|
| 1484 |
+
fe = sketch_aline([a, c, b, d, f])
|
| 1485 |
+
e = line_line_intersection(de, fe)
|
| 1486 |
+
return Circle(p1=d, p2=e, p3=f)
|
| 1487 |
+
|
| 1488 |
+
|
| 1489 |
+
def sketch_aline(args: tuple[gm.Point, ...]) -> HalfLine:
|
| 1490 |
+
"""Sketch the construction aline."""
|
| 1491 |
+
A, B, C, D, E = args
|
| 1492 |
+
ab = A - B
|
| 1493 |
+
cb = C - B
|
| 1494 |
+
de = D - E
|
| 1495 |
+
|
| 1496 |
+
dab = A.distance(B)
|
| 1497 |
+
ang_ab = np.arctan2(ab.y / dab, ab.x / dab)
|
| 1498 |
+
|
| 1499 |
+
dcb = C.distance(B)
|
| 1500 |
+
ang_bc = np.arctan2(cb.y / dcb, cb.x / dcb)
|
| 1501 |
+
|
| 1502 |
+
dde = D.distance(E)
|
| 1503 |
+
ang_de = np.arctan2(de.y / dde, de.x / dde)
|
| 1504 |
+
|
| 1505 |
+
ang_ex = ang_de + ang_bc - ang_ab
|
| 1506 |
+
X = E + Point(np.cos(ang_ex), np.sin(ang_ex))
|
| 1507 |
+
return HalfLine(E, X)
|
| 1508 |
+
|
| 1509 |
+
|
| 1510 |
+
def sketch_amirror(args: tuple[gm.Point, ...]) -> HalfLine:
|
| 1511 |
+
"""Sketch the angle mirror."""
|
| 1512 |
+
A, B, C = args # pylint: disable=invalid-name
|
| 1513 |
+
ab = A - B
|
| 1514 |
+
cb = C - B
|
| 1515 |
+
|
| 1516 |
+
dab = A.distance(B)
|
| 1517 |
+
ang_ab = np.arctan2(ab.y / dab, ab.x / dab)
|
| 1518 |
+
dcb = C.distance(B)
|
| 1519 |
+
ang_bc = np.arctan2(cb.y / dcb, cb.x / dcb)
|
| 1520 |
+
|
| 1521 |
+
ang_bx = 2 * ang_bc - ang_ab
|
| 1522 |
+
X = B + Point(np.cos(ang_bx), np.sin(ang_bx)) # pylint: disable=invalid-name
|
| 1523 |
+
return HalfLine(B, X)
|
| 1524 |
+
|
| 1525 |
+
|
| 1526 |
+
def sketch_bisect(args: tuple[gm.Point, ...]) -> Line:
|
| 1527 |
+
a, b, c = args
|
| 1528 |
+
ab = a.distance(b)
|
| 1529 |
+
bc = b.distance(c)
|
| 1530 |
+
x = b + (c - b) * (ab / bc)
|
| 1531 |
+
m = (a + x) * 0.5
|
| 1532 |
+
return Line(b, m)
|
| 1533 |
+
|
| 1534 |
+
|
| 1535 |
+
def sketch_exbisect(args: tuple[gm.Point, ...]) -> Line:
|
| 1536 |
+
a, b, c = args
|
| 1537 |
+
return sketch_bisect(args).perpendicular_line(b)
|
| 1538 |
+
|
| 1539 |
+
|
| 1540 |
+
def sketch_bline(args: tuple[gm.Point, ...]) -> Line:
|
| 1541 |
+
a, b = args
|
| 1542 |
+
m = (a + b) * 0.5
|
| 1543 |
+
return m.perpendicular_line(Line(a, b))
|
| 1544 |
+
|
| 1545 |
+
|
| 1546 |
+
def sketch_dia(args: tuple[gm.Point, ...]) -> Circle:
|
| 1547 |
+
a, b = args
|
| 1548 |
+
return Circle((a + b) * 0.5, p1=a)
|
| 1549 |
+
|
| 1550 |
+
|
| 1551 |
+
def sketch_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
|
| 1552 |
+
a, o, b = args
|
| 1553 |
+
dia = sketch_dia([a, o])
|
| 1554 |
+
return circle_circle_intersection(Circle(o, p1=b), dia)
|
| 1555 |
+
|
| 1556 |
+
|
| 1557 |
+
def sketch_circle(args: tuple[gm.Point, ...]) -> Circle:
|
| 1558 |
+
a, b, c = args
|
| 1559 |
+
return Circle(center=a, radius=b.distance(c))
|
| 1560 |
+
|
| 1561 |
+
|
| 1562 |
+
def sketch_cc_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1563 |
+
"""Sketch tangents to two circles."""
|
| 1564 |
+
o, a, w, b = args
|
| 1565 |
+
ra, rb = o.distance(a), w.distance(b)
|
| 1566 |
+
|
| 1567 |
+
ow = Line(o, w)
|
| 1568 |
+
if close_enough(ra, rb):
|
| 1569 |
+
oo = ow.perpendicular_line(o)
|
| 1570 |
+
oa = Circle(o, ra)
|
| 1571 |
+
x, z = line_circle_intersection(oo, oa)
|
| 1572 |
+
y = x + w - o
|
| 1573 |
+
t = z + w - o
|
| 1574 |
+
return x, y, z, t
|
| 1575 |
+
|
| 1576 |
+
swap = rb > ra
|
| 1577 |
+
if swap:
|
| 1578 |
+
o, a, w, b = w, b, o, a
|
| 1579 |
+
ra, rb = rb, ra
|
| 1580 |
+
|
| 1581 |
+
oa = Circle(o, ra)
|
| 1582 |
+
q = o + (w - o) * ra / (ra - rb)
|
| 1583 |
+
|
| 1584 |
+
x, z = circle_circle_intersection(sketch_dia([o, q]), oa)
|
| 1585 |
+
y = w.foot(Line(x, q))
|
| 1586 |
+
t = w.foot(Line(z, q))
|
| 1587 |
+
|
| 1588 |
+
if swap:
|
| 1589 |
+
x, y, z, t = y, x, t, z
|
| 1590 |
+
|
| 1591 |
+
return x, y, z, t
|
| 1592 |
+
|
| 1593 |
+
|
| 1594 |
+
def sketch_hcircle(args: tuple[gm.Point, ...]) -> HoleCircle:
|
| 1595 |
+
a, b = args
|
| 1596 |
+
return HoleCircle(center=a, radius=a.distance(b), hole=b)
|
| 1597 |
+
|
| 1598 |
+
|
| 1599 |
+
def sketch_e5128(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
|
| 1600 |
+
a, b, c, d = args
|
| 1601 |
+
ad = Line(a, d)
|
| 1602 |
+
|
| 1603 |
+
g = (a + b) * 0.5
|
| 1604 |
+
de = Line(d, g)
|
| 1605 |
+
|
| 1606 |
+
e, f = line_circle_intersection(de, Circle(c, p1=b))
|
| 1607 |
+
|
| 1608 |
+
if e.distance(d) < f.distance(d):
|
| 1609 |
+
e = f
|
| 1610 |
+
return e, g
|
| 1611 |
+
|
| 1612 |
+
|
| 1613 |
+
def sketch_eq_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1614 |
+
"""Sketch quadrangle with two equal opposite sides."""
|
| 1615 |
+
a = Point(0.0, 0.0)
|
| 1616 |
+
b = Point(1.0, 0.0)
|
| 1617 |
+
|
| 1618 |
+
length = np.random.uniform(0.5, 2.0)
|
| 1619 |
+
ang = np.random.uniform(np.pi / 3, np.pi * 2 / 3)
|
| 1620 |
+
d = head_from(a, ang, length)
|
| 1621 |
+
|
| 1622 |
+
ang = ang_of(b, d)
|
| 1623 |
+
ang = np.random.uniform(ang / 10, ang / 9)
|
| 1624 |
+
c = head_from(b, ang, length)
|
| 1625 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1626 |
+
return a, b, c, d
|
| 1627 |
+
|
| 1628 |
+
|
| 1629 |
+
def sketch_eq_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1630 |
+
a = Point(0.0, 0.0)
|
| 1631 |
+
b = Point(1.0, 0.0)
|
| 1632 |
+
l = unif(0.5, 2.0)
|
| 1633 |
+
|
| 1634 |
+
height = unif(0.5, 2.0)
|
| 1635 |
+
c = Point(0.5 + l / 2.0, height)
|
| 1636 |
+
d = Point(0.5 - l / 2.0, height)
|
| 1637 |
+
|
| 1638 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1639 |
+
return a, b, c, d
|
| 1640 |
+
|
| 1641 |
+
|
| 1642 |
+
def sketch_eqangle2(args: tuple[gm.Point, ...]) -> Point:
|
| 1643 |
+
"""Sketch the def eqangle2."""
|
| 1644 |
+
a, b, c = args
|
| 1645 |
+
|
| 1646 |
+
d = c * 2 - b
|
| 1647 |
+
|
| 1648 |
+
ba = b.distance(a)
|
| 1649 |
+
bc = b.distance(c)
|
| 1650 |
+
l = ba * ba / bc
|
| 1651 |
+
|
| 1652 |
+
if unif(0.0, 1.0) < 0.5:
|
| 1653 |
+
be = min(l, bc)
|
| 1654 |
+
be = unif(be * 0.1, be * 0.9)
|
| 1655 |
+
else:
|
| 1656 |
+
be = max(l, bc)
|
| 1657 |
+
be = unif(be * 1.1, be * 1.5)
|
| 1658 |
+
|
| 1659 |
+
e = b + (c - b) * (be / bc)
|
| 1660 |
+
y = b + (a - b) * (be / l)
|
| 1661 |
+
return line_line_intersection(Line(c, y), Line(a, e))
|
| 1662 |
+
|
| 1663 |
+
|
| 1664 |
+
def sketch_eqangle3(args: tuple[gm.Point, ...]) -> Circle:
|
| 1665 |
+
a, b, d, e, f = args
|
| 1666 |
+
de = d.distance(e)
|
| 1667 |
+
ef = e.distance(f)
|
| 1668 |
+
ab = b.distance(a)
|
| 1669 |
+
ang_ax = ang_of(a, b) + ang_between(e, d, f)
|
| 1670 |
+
x = head_from(a, ang_ax, length=de / ef * ab)
|
| 1671 |
+
return Circle(p1=a, p2=b, p3=x)
|
| 1672 |
+
|
| 1673 |
+
|
| 1674 |
+
def sketch_eqdia_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1675 |
+
"""Sketch quadrangle with two equal diagonals."""
|
| 1676 |
+
m = unif(0.3, 0.7)
|
| 1677 |
+
n = unif(0.3, 0.7)
|
| 1678 |
+
a = Point(-m, 0.0)
|
| 1679 |
+
c = Point(1 - m, 0.0)
|
| 1680 |
+
b = Point(0.0, -n)
|
| 1681 |
+
d = Point(0.0, 1 - n)
|
| 1682 |
+
|
| 1683 |
+
ang = unif(-0.25 * np.pi, 0.25 * np.pi)
|
| 1684 |
+
sin, cos = np.sin(ang), np.cos(ang)
|
| 1685 |
+
b = b.rotate(sin, cos)
|
| 1686 |
+
d = d.rotate(sin, cos)
|
| 1687 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1688 |
+
return a, b, c, d
|
| 1689 |
+
|
| 1690 |
+
|
| 1691 |
+
def sketch_free(args: tuple[gm.Point, ...]) -> Point:
|
| 1692 |
+
return random_points(1)[0]
|
| 1693 |
+
|
| 1694 |
+
|
| 1695 |
+
def sketch_isos(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1696 |
+
base = unif(0.5, 1.5)
|
| 1697 |
+
height = unif(0.5, 1.5)
|
| 1698 |
+
|
| 1699 |
+
b = Point(-base / 2, 0.0)
|
| 1700 |
+
c = Point(base / 2, 0.0)
|
| 1701 |
+
a = Point(0.0, height)
|
| 1702 |
+
a, b, c = random_rfss(a, b, c)
|
| 1703 |
+
return a, b, c
|
| 1704 |
+
|
| 1705 |
+
|
| 1706 |
+
def sketch_line(args: tuple[gm.Point, ...]) -> Line:
|
| 1707 |
+
a, b = args
|
| 1708 |
+
return Line(a, b)
|
| 1709 |
+
|
| 1710 |
+
|
| 1711 |
+
def sketch_cyclic(args: tuple[gm.Point, ...]) -> Circle:
|
| 1712 |
+
a, b, c = args
|
| 1713 |
+
return Circle(p1=a, p2=b, p3=c)
|
| 1714 |
+
|
| 1715 |
+
|
| 1716 |
+
def sketch_hline(args: tuple[gm.Point, ...]) -> HalfLine:
|
| 1717 |
+
a, b = args
|
| 1718 |
+
return HalfLine(a, b)
|
| 1719 |
+
|
| 1720 |
+
|
| 1721 |
+
def sketch_midp(args: tuple[gm.Point, ...]) -> Point:
|
| 1722 |
+
a, b = args
|
| 1723 |
+
return (a + b) * 0.5
|
| 1724 |
+
|
| 1725 |
+
|
| 1726 |
+
def sketch_pentagon(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1727 |
+
points = [Point(1.0, 0.0)]
|
| 1728 |
+
ang = 0.0
|
| 1729 |
+
|
| 1730 |
+
for i in range(4):
|
| 1731 |
+
ang += (2 * np.pi - ang) / (5 - i) * unif(0.5, 1.5)
|
| 1732 |
+
point = Point(np.cos(ang), np.sin(ang))
|
| 1733 |
+
points.append(point)
|
| 1734 |
+
|
| 1735 |
+
a, b, c, d, e = points # pylint: disable=unbalanced-tuple-unpacking
|
| 1736 |
+
a, b, c, d, e = random_rfss(a, b, c, d, e)
|
| 1737 |
+
return a, b, c, d, e
|
| 1738 |
+
|
| 1739 |
+
|
| 1740 |
+
def sketch_pline(args: tuple[gm.Point, ...]) -> Line:
|
| 1741 |
+
a, b, c = args
|
| 1742 |
+
return a.parallel_line(Line(b, c))
|
| 1743 |
+
|
| 1744 |
+
|
| 1745 |
+
def sketch_pmirror(args: tuple[gm.Point, ...]) -> Point:
|
| 1746 |
+
a, b = args
|
| 1747 |
+
return b * 2 - a
|
| 1748 |
+
|
| 1749 |
+
|
| 1750 |
+
def sketch_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1751 |
+
"""Sketch a random quadrangle."""
|
| 1752 |
+
m = unif(0.3, 0.7)
|
| 1753 |
+
n = unif(0.3, 0.7)
|
| 1754 |
+
|
| 1755 |
+
a = Point(-m, 0.0)
|
| 1756 |
+
c = Point(1 - m, 0.0)
|
| 1757 |
+
b = Point(0.0, -unif(0.25, 0.75))
|
| 1758 |
+
d = Point(0.0, unif(0.25, 0.75))
|
| 1759 |
+
|
| 1760 |
+
ang = unif(-0.25 * np.pi, 0.25 * np.pi)
|
| 1761 |
+
sin, cos = np.sin(ang), np.cos(ang)
|
| 1762 |
+
b = b.rotate(sin, cos)
|
| 1763 |
+
d = d.rotate(sin, cos)
|
| 1764 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1765 |
+
return a, b, c, d
|
| 1766 |
+
|
| 1767 |
+
|
| 1768 |
+
def sketch_r_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1769 |
+
a = Point(0.0, 1.0)
|
| 1770 |
+
d = Point(0.0, 0.0)
|
| 1771 |
+
b = Point(unif(0.5, 1.5), 1.0)
|
| 1772 |
+
c = Point(unif(0.5, 1.5), 0.0)
|
| 1773 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1774 |
+
return a, b, c, d
|
| 1775 |
+
|
| 1776 |
+
|
| 1777 |
+
def sketch_r_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1778 |
+
a = Point(0.0, 0.0)
|
| 1779 |
+
b = Point(0.0, unif(0.5, 2.0))
|
| 1780 |
+
c = Point(unif(0.5, 2.0), 0.0)
|
| 1781 |
+
a, b, c = random_rfss(a, b, c)
|
| 1782 |
+
return a, b, c
|
| 1783 |
+
|
| 1784 |
+
|
| 1785 |
+
def sketch_rectangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1786 |
+
a = Point(0.0, 0.0)
|
| 1787 |
+
b = Point(0.0, 1.0)
|
| 1788 |
+
l = unif(0.5, 2.0)
|
| 1789 |
+
c = Point(l, 1.0)
|
| 1790 |
+
d = Point(l, 0.0)
|
| 1791 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1792 |
+
return a, b, c, d
|
| 1793 |
+
|
| 1794 |
+
|
| 1795 |
+
def sketch_reflect(args: tuple[gm.Point, ...]) -> Point:
|
| 1796 |
+
a, b, c = args
|
| 1797 |
+
m = a.foot(Line(b, c))
|
| 1798 |
+
return m * 2 - a
|
| 1799 |
+
|
| 1800 |
+
|
| 1801 |
+
def sketch_risos(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1802 |
+
a = Point(0.0, 0.0)
|
| 1803 |
+
b = Point(0.0, 1.0)
|
| 1804 |
+
c = Point(1.0, 0.0)
|
| 1805 |
+
a, b, c = random_rfss(a, b, c)
|
| 1806 |
+
return a, b, c
|
| 1807 |
+
|
| 1808 |
+
|
| 1809 |
+
def sketch_rotaten90(args: tuple[gm.Point, ...]) -> Point:
|
| 1810 |
+
a, b = args
|
| 1811 |
+
ang = -np.pi / 2
|
| 1812 |
+
return a + (b - a).rotate(np.sin(ang), np.cos(ang))
|
| 1813 |
+
|
| 1814 |
+
|
| 1815 |
+
def sketch_rotatep90(args: tuple[gm.Point, ...]) -> Point:
|
| 1816 |
+
a, b = args
|
| 1817 |
+
ang = np.pi / 2
|
| 1818 |
+
return a + (b - a).rotate(np.sin(ang), np.cos(ang))
|
| 1819 |
+
|
| 1820 |
+
|
| 1821 |
+
def sketch_s_angle(args: tuple[gm.Point, ...]) -> HalfLine:
|
| 1822 |
+
a, b, y = args
|
| 1823 |
+
ang = y / 180 * np.pi
|
| 1824 |
+
x = b + (a - b).rotatea(ang)
|
| 1825 |
+
return HalfLine(b, x)
|
| 1826 |
+
|
| 1827 |
+
|
| 1828 |
+
def sketch_segment(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
|
| 1829 |
+
a, b = random_points(2)
|
| 1830 |
+
return a, b
|
| 1831 |
+
|
| 1832 |
+
|
| 1833 |
+
def sketch_shift(args: tuple[gm.Point, ...]) -> Point:
|
| 1834 |
+
a, b, c = args
|
| 1835 |
+
return c + (b - a)
|
| 1836 |
+
|
| 1837 |
+
|
| 1838 |
+
def sketch_square(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
|
| 1839 |
+
a, b = args
|
| 1840 |
+
c = b + (a - b).rotatea(-np.pi / 2)
|
| 1841 |
+
d = a + (b - a).rotatea(np.pi / 2)
|
| 1842 |
+
return c, d
|
| 1843 |
+
|
| 1844 |
+
|
| 1845 |
+
def sketch_isquare(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1846 |
+
a = Point(0.0, 0.0)
|
| 1847 |
+
b = Point(1.0, 0.0)
|
| 1848 |
+
c = Point(1.0, 1.0)
|
| 1849 |
+
d = Point(0.0, 1.0)
|
| 1850 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1851 |
+
return a, b, c, d
|
| 1852 |
+
|
| 1853 |
+
|
| 1854 |
+
def sketch_tline(args: tuple[gm.Point, ...]) -> Line:
|
| 1855 |
+
a, b, c = args
|
| 1856 |
+
return a.perpendicular_line(Line(b, c))
|
| 1857 |
+
|
| 1858 |
+
|
| 1859 |
+
def sketch_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1860 |
+
d = Point(0.0, 0.0)
|
| 1861 |
+
c = Point(1.0, 0.0)
|
| 1862 |
+
|
| 1863 |
+
base = unif(0.5, 2.0)
|
| 1864 |
+
height = unif(0.5, 2.0)
|
| 1865 |
+
a = Point(unif(0.2, 0.5), height)
|
| 1866 |
+
b = Point(a.x + base, height)
|
| 1867 |
+
a, b, c, d = random_rfss(a, b, c, d)
|
| 1868 |
+
return a, b, c, d
|
| 1869 |
+
|
| 1870 |
+
|
| 1871 |
+
def sketch_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1872 |
+
a = Point(0.0, 0.0)
|
| 1873 |
+
b = Point(1.0, 0.0)
|
| 1874 |
+
ac = unif(0.5, 2.0)
|
| 1875 |
+
ang = unif(0.2, 0.8) * np.pi
|
| 1876 |
+
c = head_from(a, ang, ac)
|
| 1877 |
+
return a, b, c
|
| 1878 |
+
|
| 1879 |
+
|
| 1880 |
+
def sketch_triangle12(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
|
| 1881 |
+
b = Point(0.0, 0.0)
|
| 1882 |
+
c = Point(unif(1.5, 2.5), 0.0)
|
| 1883 |
+
a, _ = circle_circle_intersection(Circle(b, 1.0), Circle(c, 2.0))
|
| 1884 |
+
a, b, c = random_rfss(a, b, c)
|
| 1885 |
+
return a, b, c
|
| 1886 |
+
|
| 1887 |
+
|
| 1888 |
+
def sketch_trisect(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
|
| 1889 |
+
"""Sketch two trisectors of an angle."""
|
| 1890 |
+
a, b, c = args
|
| 1891 |
+
ang1 = ang_of(b, a)
|
| 1892 |
+
ang2 = ang_of(b, c)
|
| 1893 |
+
|
| 1894 |
+
swap = 0
|
| 1895 |
+
if ang1 > ang2:
|
| 1896 |
+
ang1, ang2 = ang2, ang1
|
| 1897 |
+
swap += 1
|
| 1898 |
+
|
| 1899 |
+
if ang2 - ang1 > np.pi:
|
| 1900 |
+
ang1, ang2 = ang2, ang1 + 2 * np.pi
|
| 1901 |
+
swap += 1
|
| 1902 |
+
|
| 1903 |
+
angx = ang1 + (ang2 - ang1) / 3
|
| 1904 |
+
angy = ang2 - (ang2 - ang1) / 3
|
| 1905 |
+
|
| 1906 |
+
x = b + Point(np.cos(angx), np.sin(angx))
|
| 1907 |
+
y = b + Point(np.cos(angy), np.sin(angy))
|
| 1908 |
+
|
| 1909 |
+
ac = Line(a, c)
|
| 1910 |
+
x = line_line_intersection(Line(b, x), ac)
|
| 1911 |
+
y = line_line_intersection(Line(b, y), ac)
|
| 1912 |
+
|
| 1913 |
+
if swap == 1:
|
| 1914 |
+
return y, x
|
| 1915 |
+
return x, y
|
| 1916 |
+
|
| 1917 |
+
|
| 1918 |
+
def sketch_trisegment(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
|
| 1919 |
+
a, b = args
|
| 1920 |
+
x, y = a + (b - a) * (1.0 / 3), a + (b - a) * (2.0 / 3)
|
| 1921 |
+
return x, y
|
external/alphageometry/numericals_test.py
ADDED
|
@@ -0,0 +1,313 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit testing for the geometry numericals code."""
|
| 17 |
+
|
| 18 |
+
import unittest
|
| 19 |
+
|
| 20 |
+
from absl.testing import absltest
|
| 21 |
+
import numericals as nm
|
| 22 |
+
|
| 23 |
+
np = nm.np
|
| 24 |
+
|
| 25 |
+
unif = nm.unif
|
| 26 |
+
Point = nm.Point
|
| 27 |
+
Line = nm.Line
|
| 28 |
+
Circle = nm.Circle
|
| 29 |
+
HalfLine = nm.HalfLine
|
| 30 |
+
|
| 31 |
+
line_circle_intersection = nm.line_circle_intersection
|
| 32 |
+
line_line_intersection = nm.line_line_intersection
|
| 33 |
+
|
| 34 |
+
check_coll = nm.check_coll
|
| 35 |
+
check_eqangle = nm.check_eqangle
|
| 36 |
+
|
| 37 |
+
random_points = nm.random_points
|
| 38 |
+
ang_between = nm.ang_between
|
| 39 |
+
head_from = nm.head_from
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class NumericalTest(unittest.TestCase):
|
| 43 |
+
|
| 44 |
+
def test_sketch_ieq_triangle(self):
|
| 45 |
+
a, b, c = nm.sketch_ieq_triangle([])
|
| 46 |
+
self.assertAlmostEqual(a.distance(b), b.distance(c))
|
| 47 |
+
self.assertAlmostEqual(c.distance(a), b.distance(c))
|
| 48 |
+
|
| 49 |
+
def test_sketch_2l1c(self):
|
| 50 |
+
p = nm.Point(0.0, 0.0)
|
| 51 |
+
pi = np.pi
|
| 52 |
+
anga = unif(-0.4 * pi, 0.4 * pi)
|
| 53 |
+
a = Point(np.cos(anga), np.sin(anga))
|
| 54 |
+
angb = unif(0.6 * pi, 1.4 * pi)
|
| 55 |
+
b = Point(np.cos(angb), np.sin(angb))
|
| 56 |
+
|
| 57 |
+
angc = unif(anga + 0.05 * pi, angb - 0.05 * pi)
|
| 58 |
+
c = Point(np.cos(angc), np.sin(angc)) * unif(0.2, 0.8)
|
| 59 |
+
|
| 60 |
+
x, y, z, i = nm.sketch_2l1c([a, b, c, p])
|
| 61 |
+
self.assertTrue(check_coll([x, c, a]))
|
| 62 |
+
self.assertTrue(check_coll([y, c, b]))
|
| 63 |
+
self.assertAlmostEqual(z.distance(p), 1.0)
|
| 64 |
+
self.assertTrue(check_coll([p, i, z]))
|
| 65 |
+
self.assertTrue(Line(i, x).is_perp(Line(c, a)))
|
| 66 |
+
self.assertTrue(Line(i, y).is_perp(Line(c, b)))
|
| 67 |
+
self.assertAlmostEqual(i.distance(x), i.distance(y))
|
| 68 |
+
self.assertAlmostEqual(i.distance(x), i.distance(z))
|
| 69 |
+
|
| 70 |
+
def test_sketch_3peq(self):
|
| 71 |
+
a, b, c = random_points(3)
|
| 72 |
+
x, y, z = nm.sketch_3peq([a, b, c])
|
| 73 |
+
|
| 74 |
+
self.assertTrue(check_coll([a, b, x]))
|
| 75 |
+
self.assertTrue(check_coll([a, c, y]))
|
| 76 |
+
self.assertTrue(check_coll([b, c, z]))
|
| 77 |
+
self.assertTrue(check_coll([x, y, z]))
|
| 78 |
+
self.assertAlmostEqual(z.distance(x), z.distance(y))
|
| 79 |
+
|
| 80 |
+
def test_sketch_aline(self):
|
| 81 |
+
a, b, c, d, e = random_points(5)
|
| 82 |
+
ex = nm.sketch_aline([a, b, c, d, e])
|
| 83 |
+
self.assertIsInstance(ex, HalfLine)
|
| 84 |
+
self.assertEqual(ex.tail, e)
|
| 85 |
+
x = ex.head
|
| 86 |
+
self.assertAlmostEqual(ang_between(b, a, c), ang_between(e, d, x))
|
| 87 |
+
|
| 88 |
+
def test_sketch_amirror(self):
|
| 89 |
+
a, b, c = random_points(3)
|
| 90 |
+
bx = nm.sketch_amirror([a, b, c])
|
| 91 |
+
self.assertIsInstance(bx, HalfLine)
|
| 92 |
+
assert bx.tail == b
|
| 93 |
+
x = bx.head
|
| 94 |
+
|
| 95 |
+
ang1 = ang_between(b, a, c)
|
| 96 |
+
ang2 = ang_between(b, c, x)
|
| 97 |
+
self.assertAlmostEqual(ang1, ang2)
|
| 98 |
+
|
| 99 |
+
def test_sketch_bisect(self):
|
| 100 |
+
a, b, c = random_points(3)
|
| 101 |
+
line = nm.sketch_bisect([a, b, c])
|
| 102 |
+
self.assertAlmostEqual(b.distance(line), 0.0)
|
| 103 |
+
|
| 104 |
+
l = a.perpendicular_line(line)
|
| 105 |
+
x = line_line_intersection(l, Line(b, c))
|
| 106 |
+
self.assertAlmostEqual(a.distance(line), x.distance(line))
|
| 107 |
+
|
| 108 |
+
d, _ = line_circle_intersection(line, Circle(b, radius=1))
|
| 109 |
+
ang1 = ang_between(b, a, d)
|
| 110 |
+
ang2 = ang_between(b, d, c)
|
| 111 |
+
self.assertAlmostEqual(ang1, ang2)
|
| 112 |
+
|
| 113 |
+
def test_sketch_bline(self):
|
| 114 |
+
a, b = random_points(2)
|
| 115 |
+
l = nm.sketch_bline([a, b])
|
| 116 |
+
self.assertTrue(Line(a, b).is_perp(l))
|
| 117 |
+
self.assertAlmostEqual(a.distance(l), b.distance(l))
|
| 118 |
+
|
| 119 |
+
def test_sketch_cc_tangent(self):
|
| 120 |
+
o = Point(0.0, 0.0)
|
| 121 |
+
w = Point(1.0, 0.0)
|
| 122 |
+
|
| 123 |
+
ra = unif(0.0, 0.6)
|
| 124 |
+
rb = unif(0.4, 1.0)
|
| 125 |
+
|
| 126 |
+
a = unif(0.0, np.pi)
|
| 127 |
+
b = unif(0.0, np.pi)
|
| 128 |
+
|
| 129 |
+
a = o + ra * Point(np.cos(a), np.sin(a))
|
| 130 |
+
b = w + rb * Point(np.sin(b), np.cos(b))
|
| 131 |
+
|
| 132 |
+
x, y, z, t = nm.sketch_cc_tangent([o, a, w, b])
|
| 133 |
+
xy = Line(x, y)
|
| 134 |
+
zt = Line(z, t)
|
| 135 |
+
self.assertAlmostEqual(o.distance(xy), o.distance(a))
|
| 136 |
+
self.assertAlmostEqual(o.distance(zt), o.distance(a))
|
| 137 |
+
self.assertAlmostEqual(w.distance(xy), w.distance(b))
|
| 138 |
+
self.assertAlmostEqual(w.distance(zt), w.distance(b))
|
| 139 |
+
|
| 140 |
+
def test_sketch_circle(self):
|
| 141 |
+
a, b, c = random_points(3)
|
| 142 |
+
circle = nm.sketch_circle([a, b, c])
|
| 143 |
+
self.assertAlmostEqual(circle.center.distance(a), 0.0)
|
| 144 |
+
self.assertAlmostEqual(circle.radius, b.distance(c))
|
| 145 |
+
|
| 146 |
+
def test_sketch_e5128(self):
|
| 147 |
+
b = Point(0.0, 0.0)
|
| 148 |
+
c = Point(0.0, 1.0)
|
| 149 |
+
ang = unif(-np.pi / 2, 3 * np.pi / 2)
|
| 150 |
+
d = head_from(c, ang, 1.0)
|
| 151 |
+
a = Point(unif(0.5, 2.0), 0.0)
|
| 152 |
+
|
| 153 |
+
e, g = nm.sketch_e5128([a, b, c, d])
|
| 154 |
+
ang1 = ang_between(a, b, d)
|
| 155 |
+
ang2 = ang_between(e, a, g)
|
| 156 |
+
self.assertAlmostEqual(ang1, ang2)
|
| 157 |
+
|
| 158 |
+
def test_sketch_eq_quadrangle(self):
|
| 159 |
+
a, b, c, d = nm.sketch_eq_quadrangle([])
|
| 160 |
+
self.assertAlmostEqual(a.distance(d), c.distance(b))
|
| 161 |
+
ac = Line(a, c)
|
| 162 |
+
assert ac.diff_side(b, d), (ac(b), ac(d))
|
| 163 |
+
bd = Line(b, d)
|
| 164 |
+
assert bd.diff_side(a, c), (bd(a), bd(c))
|
| 165 |
+
|
| 166 |
+
def test_sketch_eq_trapezoid(self):
|
| 167 |
+
a, b, c, d = nm.sketch_eq_trapezoid([])
|
| 168 |
+
assert Line(a, b).is_parallel(Line(c, d))
|
| 169 |
+
self.assertAlmostEqual(a.distance(d), b.distance(c))
|
| 170 |
+
|
| 171 |
+
def test_sketch_eqangle3(self):
|
| 172 |
+
points = random_points(5)
|
| 173 |
+
x = nm.sketch_eqangle3(points).sample_within(points)[0]
|
| 174 |
+
a, b, d, e, f = points
|
| 175 |
+
self.assertTrue(check_eqangle([x, a, x, b, d, e, d, f]))
|
| 176 |
+
|
| 177 |
+
def test_sketch_eqangle2(self):
|
| 178 |
+
a, b, c = random_points(3)
|
| 179 |
+
x = nm.sketch_eqangle2([a, b, c])
|
| 180 |
+
ang1 = ang_between(a, b, x)
|
| 181 |
+
ang2 = ang_between(c, x, b)
|
| 182 |
+
self.assertAlmostEqual(ang1, ang2)
|
| 183 |
+
|
| 184 |
+
def test_sketch_edia_quadrangle(self):
|
| 185 |
+
a, b, c, d = nm.sketch_eqdia_quadrangle([])
|
| 186 |
+
assert Line(a, c).diff_side(b, d)
|
| 187 |
+
assert Line(b, d).diff_side(a, c)
|
| 188 |
+
self.assertAlmostEqual(a.distance(c), b.distance(d))
|
| 189 |
+
|
| 190 |
+
def test_sketch_isos(self):
|
| 191 |
+
a, b, c = nm.sketch_isos([])
|
| 192 |
+
self.assertAlmostEqual(a.distance(b), a.distance(c))
|
| 193 |
+
self.assertAlmostEqual(ang_between(b, a, c), ang_between(c, b, a))
|
| 194 |
+
|
| 195 |
+
def test_sketch_quadrange(self):
|
| 196 |
+
a, b, c, d = nm.sketch_quadrangle([])
|
| 197 |
+
self.assertTrue(Line(a, c).diff_side(b, d))
|
| 198 |
+
self.assertTrue(Line(b, d).diff_side(a, c))
|
| 199 |
+
|
| 200 |
+
def test_sketch_r_trapezoid(self):
|
| 201 |
+
a, b, c, d = nm.sketch_r_trapezoid([])
|
| 202 |
+
self.assertTrue(Line(a, b).is_perp(Line(a, d)))
|
| 203 |
+
self.assertTrue(Line(a, b).is_parallel(Line(c, d)))
|
| 204 |
+
self.assertTrue(Line(a, c).diff_side(b, d))
|
| 205 |
+
self.assertTrue(Line(b, d).diff_side(a, c))
|
| 206 |
+
|
| 207 |
+
def test_sketch_r_triangle(self):
|
| 208 |
+
a, b, c = nm.sketch_r_triangle([])
|
| 209 |
+
self.assertTrue(Line(a, b).is_perp(Line(a, c)))
|
| 210 |
+
|
| 211 |
+
def test_sketch_rectangle(self):
|
| 212 |
+
a, b, c, d = nm.sketch_rectangle([])
|
| 213 |
+
self.assertTrue(Line(a, b).is_perp(Line(b, c)))
|
| 214 |
+
self.assertTrue(Line(b, c).is_perp(Line(c, d)))
|
| 215 |
+
self.assertTrue(Line(c, d).is_perp(Line(d, a)))
|
| 216 |
+
|
| 217 |
+
def test_sketch_reflect(self):
|
| 218 |
+
a, b, c = random_points(3)
|
| 219 |
+
x = nm.sketch_reflect([a, b, c])
|
| 220 |
+
self.assertTrue(Line(a, x).is_perp(Line(b, c)))
|
| 221 |
+
self.assertAlmostEqual(x.distance(Line(b, c)), a.distance(Line(b, c)))
|
| 222 |
+
|
| 223 |
+
def test_sketch_risos(self):
|
| 224 |
+
a, b, c = nm.sketch_risos([])
|
| 225 |
+
self.assertAlmostEqual(a.distance(b), a.distance(c))
|
| 226 |
+
self.assertTrue(Line(a, b).is_perp(Line(a, c)))
|
| 227 |
+
|
| 228 |
+
def test_sketch_rotaten90(self):
|
| 229 |
+
a, b = random_points(2)
|
| 230 |
+
x = nm.sketch_rotaten90([a, b])
|
| 231 |
+
self.assertAlmostEqual(a.distance(x), a.distance(b))
|
| 232 |
+
self.assertTrue(Line(a, x).is_perp(Line(a, b)))
|
| 233 |
+
d = Point(0.0, 0.0)
|
| 234 |
+
e = Point(0.0, 1.0)
|
| 235 |
+
f = Point(1.0, 0.0)
|
| 236 |
+
self.assertAlmostEqual(ang_between(d, e, f), ang_between(a, b, x))
|
| 237 |
+
|
| 238 |
+
def test_sketch_rotatep90(self):
|
| 239 |
+
a, b = random_points(2)
|
| 240 |
+
x = nm.sketch_rotatep90([a, b])
|
| 241 |
+
self.assertAlmostEqual(a.distance(x), a.distance(b))
|
| 242 |
+
self.assertTrue(Line(a, x).is_perp(Line(a, b)))
|
| 243 |
+
d = Point(0.0, 0.0)
|
| 244 |
+
e = Point(0.0, 1.0)
|
| 245 |
+
f = Point(1.0, 0.0)
|
| 246 |
+
self.assertAlmostEqual(ang_between(d, f, e), ang_between(a, b, x))
|
| 247 |
+
|
| 248 |
+
def test_sketch_s_angle(self):
|
| 249 |
+
a, b = random_points(2)
|
| 250 |
+
y = unif(0.0, np.pi)
|
| 251 |
+
bx = nm.sketch_s_angle([a, b, y / np.pi * 180])
|
| 252 |
+
self.assertIsInstance(bx, HalfLine)
|
| 253 |
+
self.assertEqual(bx.tail, b)
|
| 254 |
+
x = bx.head
|
| 255 |
+
|
| 256 |
+
d = Point(1.0, 0.0)
|
| 257 |
+
e = Point(0.0, 0.0)
|
| 258 |
+
f = Point(np.cos(y), np.sin(y))
|
| 259 |
+
self.assertAlmostEqual(ang_between(e, d, f), ang_between(b, a, x))
|
| 260 |
+
|
| 261 |
+
def test_sketch_shift(self):
|
| 262 |
+
a, b, c = random_points(3)
|
| 263 |
+
x = nm.sketch_shift([a, b, c])
|
| 264 |
+
self.assertTrue((b - a).close(x - c))
|
| 265 |
+
|
| 266 |
+
def test_sketch_square(self):
|
| 267 |
+
a, b = random_points(2)
|
| 268 |
+
c, d = nm.sketch_square([a, b])
|
| 269 |
+
self.assertTrue(Line(a, b).is_perp(Line(b, c)))
|
| 270 |
+
self.assertTrue(Line(b, c).is_perp(Line(c, d)))
|
| 271 |
+
self.assertTrue(Line(c, d).is_perp(Line(d, a)))
|
| 272 |
+
self.assertAlmostEqual(a.distance(b), b.distance(c))
|
| 273 |
+
|
| 274 |
+
def test_sketch_isquare(self):
|
| 275 |
+
a, b, c, d = nm.sketch_isquare([])
|
| 276 |
+
self.assertTrue(Line(a, b).is_perp(Line(b, c)))
|
| 277 |
+
self.assertTrue(Line(b, c).is_perp(Line(c, d)))
|
| 278 |
+
self.assertTrue(Line(c, d).is_perp(Line(d, a)))
|
| 279 |
+
self.assertAlmostEqual(a.distance(b), b.distance(c))
|
| 280 |
+
|
| 281 |
+
def test_sketch_trapezoid(self):
|
| 282 |
+
a, b, c, d = nm.sketch_trapezoid([])
|
| 283 |
+
self.assertTrue(Line(a, b).is_parallel(Line(c, d)))
|
| 284 |
+
self.assertTrue(Line(a, c).diff_side(b, d))
|
| 285 |
+
self.assertTrue(Line(b, d).diff_side(a, c))
|
| 286 |
+
|
| 287 |
+
def test_sketch_triangle(self):
|
| 288 |
+
a, b, c = nm.sketch_triangle([])
|
| 289 |
+
self.assertFalse(check_coll([a, b, c]))
|
| 290 |
+
|
| 291 |
+
def test_sketch_triangle12(self):
|
| 292 |
+
a, b, c = nm.sketch_triangle12([])
|
| 293 |
+
self.assertAlmostEqual(a.distance(b) * 2, a.distance(c))
|
| 294 |
+
|
| 295 |
+
def test_sketch_trisect(self):
|
| 296 |
+
a, b, c = random_points(3)
|
| 297 |
+
x, y = nm.sketch_trisect([a, b, c])
|
| 298 |
+
self.assertAlmostEqual(ang_between(b, a, x), ang_between(b, x, y))
|
| 299 |
+
self.assertAlmostEqual(ang_between(b, x, y), ang_between(b, y, c))
|
| 300 |
+
self.assertAlmostEqual(ang_between(b, a, x) * 3, ang_between(b, a, c))
|
| 301 |
+
|
| 302 |
+
def test_sketch_trisegment(self):
|
| 303 |
+
a, b = random_points(2)
|
| 304 |
+
x, y = nm.sketch_trisegment([a, b])
|
| 305 |
+
self.assertAlmostEqual(
|
| 306 |
+
a.distance(x) + x.distance(y) + y.distance(b), a.distance(b)
|
| 307 |
+
)
|
| 308 |
+
self.assertAlmostEqual(a.distance(x), x.distance(y))
|
| 309 |
+
self.assertAlmostEqual(x.distance(y), y.distance(b))
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
if __name__ == '__main__':
|
| 313 |
+
absltest.main()
|
external/alphageometry/pretty.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Utilities for string manipulation in the DSL."""
|
| 17 |
+
|
| 18 |
+
MAP_SYMBOL = {
|
| 19 |
+
'T': 'perp',
|
| 20 |
+
'P': 'para',
|
| 21 |
+
'D': 'cong',
|
| 22 |
+
'S': 'simtri',
|
| 23 |
+
'I': 'circle',
|
| 24 |
+
'M': 'midp',
|
| 25 |
+
'O': 'cyclic',
|
| 26 |
+
'C': 'coll',
|
| 27 |
+
'^': 'eqangle',
|
| 28 |
+
'/': 'eqratio',
|
| 29 |
+
'%': 'eqratio',
|
| 30 |
+
'=': 'contri',
|
| 31 |
+
'X': 'collx',
|
| 32 |
+
'A': 'acompute',
|
| 33 |
+
'R': 'rcompute',
|
| 34 |
+
'Q': 'fixc',
|
| 35 |
+
'E': 'fixl',
|
| 36 |
+
'V': 'fixb',
|
| 37 |
+
'H': 'fixt',
|
| 38 |
+
'Z': 'fixp',
|
| 39 |
+
'Y': 'ind',
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def map_symbol(c: str) -> str:
|
| 44 |
+
return MAP_SYMBOL[c]
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def map_symbol_inv(c: str) -> str:
|
| 48 |
+
return {v: k for k, v in MAP_SYMBOL.items()}[c]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _gcd(x: int, y: int) -> int:
|
| 52 |
+
while y:
|
| 53 |
+
x, y = y, x % y
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def simplify(n: int, d: int) -> tuple[int, int]:
|
| 58 |
+
g = _gcd(n, d)
|
| 59 |
+
return (n // g, d // g)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def pretty2r(a: str, b: str, c: str, d: str) -> str:
|
| 63 |
+
if b in (c, d):
|
| 64 |
+
a, b = b, a
|
| 65 |
+
|
| 66 |
+
if a == d:
|
| 67 |
+
c, d = d, c
|
| 68 |
+
|
| 69 |
+
return f'{a} {b} {c} {d}'
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def pretty2a(a: str, b: str, c: str, d: str) -> str:
|
| 73 |
+
if b in (c, d):
|
| 74 |
+
a, b = b, a
|
| 75 |
+
|
| 76 |
+
if a == d:
|
| 77 |
+
c, d = d, c
|
| 78 |
+
|
| 79 |
+
return f'{a} {b} {c} {d}'
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def pretty_angle(a: str, b: str, c: str, d: str) -> str:
|
| 83 |
+
if b in (c, d):
|
| 84 |
+
a, b = b, a
|
| 85 |
+
if a == d:
|
| 86 |
+
c, d = d, c
|
| 87 |
+
|
| 88 |
+
if a == c:
|
| 89 |
+
return f'\u2220{b}{a}{d}'
|
| 90 |
+
return f'\u2220({a}{b}-{c}{d})'
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def pretty_nl(name: str, args: list[str]) -> str:
|
| 94 |
+
"""Natural lang formatting a predicate."""
|
| 95 |
+
if name == 'aconst':
|
| 96 |
+
a, b, c, d, y = args
|
| 97 |
+
return f'{pretty_angle(a, b, c, d)} = {y}'
|
| 98 |
+
if name == 'rconst':
|
| 99 |
+
a, b, c, d, y = args
|
| 100 |
+
return f'{a}{b}:{c}{d} = {y}'
|
| 101 |
+
if name == 'acompute':
|
| 102 |
+
a, b, c, d = args
|
| 103 |
+
return f'{pretty_angle(a, b, c, d)}'
|
| 104 |
+
if name in ['coll', 'C']:
|
| 105 |
+
return '' + ','.join(args) + ' are collinear'
|
| 106 |
+
if name == 'collx':
|
| 107 |
+
return '' + ','.join(list(set(args))) + ' are collinear'
|
| 108 |
+
if name in ['cyclic', 'O']:
|
| 109 |
+
return '' + ','.join(args) + ' are concyclic'
|
| 110 |
+
if name in ['midp', 'midpoint', 'M']:
|
| 111 |
+
x, a, b = args
|
| 112 |
+
return f'{x} is midpoint of {a}{b}'
|
| 113 |
+
if name in ['eqangle', 'eqangle6', '^']:
|
| 114 |
+
a, b, c, d, e, f, g, h = args
|
| 115 |
+
return f'{pretty_angle(a, b, c, d)} = {pretty_angle(e, f, g, h)}'
|
| 116 |
+
if name in ['eqratio', 'eqratio6', '/']:
|
| 117 |
+
return '{}{}:{}{} = {}{}:{}{}'.format(*args)
|
| 118 |
+
if name == 'eqratio3':
|
| 119 |
+
a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
|
| 120 |
+
return f'S {o} {a} {b} {o} {c} {d}'
|
| 121 |
+
if name in ['cong', 'D']:
|
| 122 |
+
a, b, c, d = args
|
| 123 |
+
return f'{a}{b} = {c}{d}'
|
| 124 |
+
if name in ['perp', 'T']:
|
| 125 |
+
if len(args) == 2: # this is algebraic derivation.
|
| 126 |
+
ab, cd = args # ab = 'd( ... )'
|
| 127 |
+
return f'{ab} \u27c2 {cd}'
|
| 128 |
+
a, b, c, d = args
|
| 129 |
+
return f'{a}{b} \u27c2 {c}{d}'
|
| 130 |
+
if name in ['para', 'P']:
|
| 131 |
+
if len(args) == 2: # this is algebraic derivation.
|
| 132 |
+
ab, cd = args # ab = 'd( ... )'
|
| 133 |
+
return f'{ab} \u2225 {cd}'
|
| 134 |
+
a, b, c, d = args
|
| 135 |
+
return f'{a}{b} \u2225 {c}{d}'
|
| 136 |
+
if name in ['simtri2', 'simtri', 'simtri*']:
|
| 137 |
+
a, b, c, x, y, z = args
|
| 138 |
+
return f'\u0394{a}{b}{c} is similar to \u0394{x}{y}{z}'
|
| 139 |
+
if name in ['contri2', 'contri', 'contri*']:
|
| 140 |
+
a, b, c, x, y, z = args
|
| 141 |
+
return f'\u0394{a}{b}{c} is congruent to \u0394{x}{y}{z}'
|
| 142 |
+
if name in ['circle', 'I']:
|
| 143 |
+
o, a, b, c = args
|
| 144 |
+
return f'{o} is the circumcenter of \\Delta {a}{b}{c}'
|
| 145 |
+
if name == 'foot':
|
| 146 |
+
a, b, c, d = args
|
| 147 |
+
return f'{a} is the foot of {b} on {c}{d}'
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def pretty(txt: str) -> str:
|
| 151 |
+
"""Pretty formating a predicate string."""
|
| 152 |
+
if isinstance(txt, str):
|
| 153 |
+
txt = txt.split(' ')
|
| 154 |
+
name, *args = txt
|
| 155 |
+
if name == 'ind':
|
| 156 |
+
return 'Y ' + ' '.join(args)
|
| 157 |
+
if name in ['fixc', 'fixl', 'fixb', 'fixt', 'fixp']:
|
| 158 |
+
return map_symbol_inv(name) + ' ' + ' '.join(args)
|
| 159 |
+
if name == 'acompute':
|
| 160 |
+
a, b, c, d = args
|
| 161 |
+
return 'A ' + ' '.join(args)
|
| 162 |
+
if name == 'rcompute':
|
| 163 |
+
a, b, c, d = args
|
| 164 |
+
return 'R ' + ' '.join(args)
|
| 165 |
+
if name == 'aconst':
|
| 166 |
+
a, b, c, d, y = args
|
| 167 |
+
return f'^ {pretty2a(a, b, c, d)} {y}'
|
| 168 |
+
if name == 'rconst':
|
| 169 |
+
a, b, c, d, y = args
|
| 170 |
+
return f'/ {pretty2r(a, b, c, d)} {y}'
|
| 171 |
+
if name == 'coll':
|
| 172 |
+
return 'C ' + ' '.join(args)
|
| 173 |
+
if name == 'collx':
|
| 174 |
+
return 'X ' + ' '.join(args)
|
| 175 |
+
if name == 'cyclic':
|
| 176 |
+
return 'O ' + ' '.join(args)
|
| 177 |
+
if name in ['midp', 'midpoint']:
|
| 178 |
+
x, a, b = args
|
| 179 |
+
return f'M {x} {a} {b}'
|
| 180 |
+
if name == 'eqangle':
|
| 181 |
+
a, b, c, d, e, f, g, h = args
|
| 182 |
+
return f'^ {pretty2a(a, b, c, d)} {pretty2a(e, f, g, h)}'
|
| 183 |
+
if name == 'eqratio':
|
| 184 |
+
a, b, c, d, e, f, g, h = args
|
| 185 |
+
return f'/ {pretty2r(a, b, c, d)} {pretty2r(e, f, g, h)}'
|
| 186 |
+
if name == 'eqratio3':
|
| 187 |
+
a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
|
| 188 |
+
return f'S {o} {a} {b} {o} {c} {d}'
|
| 189 |
+
if name == 'cong':
|
| 190 |
+
a, b, c, d = args
|
| 191 |
+
return f'D {a} {b} {c} {d}'
|
| 192 |
+
if name == 'perp':
|
| 193 |
+
if len(args) == 2: # this is algebraic derivation.
|
| 194 |
+
ab, cd = args # ab = 'd( ... )'
|
| 195 |
+
return f'T {ab} {cd}'
|
| 196 |
+
a, b, c, d = args
|
| 197 |
+
return f'T {a} {b} {c} {d}'
|
| 198 |
+
if name == 'para':
|
| 199 |
+
if len(args) == 2: # this is algebraic derivation.
|
| 200 |
+
ab, cd = args # ab = 'd( ... )'
|
| 201 |
+
return f'P {ab} {cd}'
|
| 202 |
+
a, b, c, d = args
|
| 203 |
+
return f'P {a} {b} {c} {d}'
|
| 204 |
+
if name in ['simtri2', 'simtri', 'simtri*']:
|
| 205 |
+
a, b, c, x, y, z = args
|
| 206 |
+
return f'S {a} {b} {c} {x} {y} {z}'
|
| 207 |
+
if name in ['contri2', 'contri', 'contri*']:
|
| 208 |
+
a, b, c, x, y, z = args
|
| 209 |
+
return f'= {a} {b} {c} {x} {y} {z}'
|
| 210 |
+
if name == 'circle':
|
| 211 |
+
o, a, b, c = args
|
| 212 |
+
return f'I {o} {a} {b} {c}'
|
| 213 |
+
if name == 'foot':
|
| 214 |
+
a, b, c, d = args
|
| 215 |
+
return f'F {a} {b} {c} {d}'
|
| 216 |
+
return ' '.join(txt)
|
external/alphageometry/problem.py
ADDED
|
@@ -0,0 +1,1133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Implements objects to represent problems, theorems, proofs, traceback."""
|
| 17 |
+
|
| 18 |
+
from __future__ import annotations
|
| 19 |
+
|
| 20 |
+
from collections import defaultdict # pylint: disable=g-importing-member
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
import geometry as gm
|
| 24 |
+
import pretty as pt
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# pylint: disable=protected-access
|
| 28 |
+
# pylint: disable=unused-variable
|
| 29 |
+
# pylint: disable=unused-argument
|
| 30 |
+
# pylint: disable=unused-assignment
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def reshape(l: list[Any], n: int = 1) -> list[list[Any]]:
|
| 34 |
+
assert len(l) % n == 0
|
| 35 |
+
columns = [[] for i in range(n)]
|
| 36 |
+
for i, x in enumerate(l):
|
| 37 |
+
columns[i % n].append(x)
|
| 38 |
+
return zip(*columns)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def isint(x: str) -> bool:
|
| 42 |
+
try:
|
| 43 |
+
int(x)
|
| 44 |
+
return True
|
| 45 |
+
except: # pylint: disable=bare-except
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class Construction:
|
| 50 |
+
"""One predicate."""
|
| 51 |
+
|
| 52 |
+
@classmethod
|
| 53 |
+
def from_txt(cls, data: str) -> Construction:
|
| 54 |
+
data = data.split(' ')
|
| 55 |
+
return Construction(data[0], data[1:])
|
| 56 |
+
|
| 57 |
+
def __init__(self, name: str, args: list[str]):
|
| 58 |
+
self.name = name
|
| 59 |
+
self.args = args
|
| 60 |
+
|
| 61 |
+
def translate(self, mapping: dict[str, str]) -> Construction:
|
| 62 |
+
args = [a if isint(a) else mapping[a] for a in self.args]
|
| 63 |
+
return Construction(self.name, args)
|
| 64 |
+
|
| 65 |
+
def txt(self) -> str:
|
| 66 |
+
return ' '.join([self.name] + list(self.args))
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class Clause:
|
| 70 |
+
"""One construction (>= 1 predicate)."""
|
| 71 |
+
|
| 72 |
+
@classmethod
|
| 73 |
+
def from_txt(cls, data: str) -> Clause:
|
| 74 |
+
if data == ' =':
|
| 75 |
+
return Clause([], [])
|
| 76 |
+
points, constructions = data.split(' = ')
|
| 77 |
+
return Clause(
|
| 78 |
+
points.split(' '),
|
| 79 |
+
[Construction.from_txt(c) for c in constructions.split(', ')],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
def __init__(self, points: list[str], constructions: list[Construction]):
|
| 83 |
+
self.points = []
|
| 84 |
+
self.nums = []
|
| 85 |
+
|
| 86 |
+
for p in points:
|
| 87 |
+
num = None
|
| 88 |
+
if isinstance(p, str) and '@' in p:
|
| 89 |
+
p, num = p.split('@')
|
| 90 |
+
x, y = num.split('_')
|
| 91 |
+
num = float(x), float(y)
|
| 92 |
+
self.points.append(p)
|
| 93 |
+
self.nums.append(num)
|
| 94 |
+
|
| 95 |
+
self.constructions = constructions
|
| 96 |
+
|
| 97 |
+
def translate(self, mapping: dict[str, str]) -> Clause:
|
| 98 |
+
points0 = []
|
| 99 |
+
for p in self.points:
|
| 100 |
+
pcount = len(mapping) + 1
|
| 101 |
+
name = chr(96 + pcount)
|
| 102 |
+
if name > 'z': # pcount = 26 -> name = 'z'
|
| 103 |
+
name = chr(97 + (pcount - 1) % 26) + str((pcount - 1) // 26)
|
| 104 |
+
|
| 105 |
+
p0 = mapping.get(p, name)
|
| 106 |
+
mapping[p] = p0
|
| 107 |
+
points0.append(p0)
|
| 108 |
+
return Clause(points0, [c.translate(mapping) for c in self.constructions])
|
| 109 |
+
|
| 110 |
+
def add(self, name: str, args: list[str]) -> None:
|
| 111 |
+
self.constructions.append(Construction(name, args))
|
| 112 |
+
|
| 113 |
+
def txt(self) -> str:
|
| 114 |
+
return (
|
| 115 |
+
' '.join(self.points)
|
| 116 |
+
+ ' = '
|
| 117 |
+
+ ', '.join(c.txt() for c in self.constructions)
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def _gcd(x: int, y: int) -> int:
|
| 122 |
+
while y:
|
| 123 |
+
x, y = y, x % y
|
| 124 |
+
return x
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def simplify(n: int, d: int) -> tuple[int, int]:
|
| 128 |
+
g = _gcd(n, d)
|
| 129 |
+
return (n // g, d // g)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def compare_fn(dep: Dependency) -> tuple[Dependency, str]:
|
| 133 |
+
return (dep, pt.pretty(dep))
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def sort_deps(deps: list[Dependency]) -> list[Dependency]:
|
| 137 |
+
return sorted(deps, key=compare_fn)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class Problem:
|
| 141 |
+
"""Describe one problem to solve."""
|
| 142 |
+
|
| 143 |
+
@classmethod
|
| 144 |
+
def from_txt_file(
|
| 145 |
+
cls, fname: str, to_dict: bool = False, translate: bool = True
|
| 146 |
+
):
|
| 147 |
+
"""Load a problem from a text file."""
|
| 148 |
+
with open(fname, 'r') as f:
|
| 149 |
+
lines = f.read().split('\n')
|
| 150 |
+
|
| 151 |
+
lines = [l for l in lines if l]
|
| 152 |
+
data = [
|
| 153 |
+
cls.from_txt(url + '\n' + problem, translate)
|
| 154 |
+
for (url, problem) in reshape(lines, 2)
|
| 155 |
+
]
|
| 156 |
+
if to_dict:
|
| 157 |
+
return cls.to_dict(data)
|
| 158 |
+
return data
|
| 159 |
+
|
| 160 |
+
@classmethod
|
| 161 |
+
def from_txt(cls, data: str, translate: bool = True) -> Problem:
|
| 162 |
+
"""Load a problem from a str object."""
|
| 163 |
+
url = ''
|
| 164 |
+
if '\n' in data:
|
| 165 |
+
url, data = data.split('\n')
|
| 166 |
+
|
| 167 |
+
if ' ? ' in data:
|
| 168 |
+
clauses, goal = data.split(' ? ')
|
| 169 |
+
goal = Construction.from_txt(goal)
|
| 170 |
+
else:
|
| 171 |
+
clauses, goal = data, None
|
| 172 |
+
|
| 173 |
+
clauses = clauses.split('; ')
|
| 174 |
+
problem = Problem(
|
| 175 |
+
url=url, clauses=[Clause.from_txt(c) for c in clauses], goal=goal
|
| 176 |
+
)
|
| 177 |
+
if translate:
|
| 178 |
+
return problem.translate()
|
| 179 |
+
return problem
|
| 180 |
+
|
| 181 |
+
@classmethod
|
| 182 |
+
def to_dict(cls, data: list[Problem]) -> dict[str, Problem]:
|
| 183 |
+
return {p.url: p for p in data}
|
| 184 |
+
|
| 185 |
+
def __init__(self, url: str, clauses: list[Clause], goal: Construction):
|
| 186 |
+
self.url = url
|
| 187 |
+
self.clauses = clauses
|
| 188 |
+
self.goal = goal
|
| 189 |
+
|
| 190 |
+
def copy(self) -> Problem:
|
| 191 |
+
return Problem(self.url, list(self.clauses), self.goal)
|
| 192 |
+
|
| 193 |
+
def translate(self) -> Problem: # to single-char point names
|
| 194 |
+
"""Translate point names into alphabetical."""
|
| 195 |
+
mapping = {}
|
| 196 |
+
clauses = []
|
| 197 |
+
|
| 198 |
+
for clause in self.clauses:
|
| 199 |
+
clauses.append(clause.translate(mapping))
|
| 200 |
+
|
| 201 |
+
if self.goal:
|
| 202 |
+
goal = self.goal.translate(mapping)
|
| 203 |
+
else:
|
| 204 |
+
goal = self.goal
|
| 205 |
+
|
| 206 |
+
p = Problem(self.url, clauses, goal)
|
| 207 |
+
p.mapping = mapping
|
| 208 |
+
return p
|
| 209 |
+
|
| 210 |
+
def txt(self) -> str:
|
| 211 |
+
return (
|
| 212 |
+
'; '.join([c.txt() for c in self.clauses]) + ' ? ' + self.goal.txt()
|
| 213 |
+
if self.goal
|
| 214 |
+
else ''
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
def setup_str_from_problem(self, definitions: list[Definition]) -> str:
|
| 218 |
+
"""Construct the <theorem_premises> string from Problem object."""
|
| 219 |
+
ref = 0
|
| 220 |
+
|
| 221 |
+
string = []
|
| 222 |
+
for clause in self.clauses:
|
| 223 |
+
group = {}
|
| 224 |
+
p2deps = defaultdict(list)
|
| 225 |
+
for c in clause.constructions:
|
| 226 |
+
cdef = definitions[c.name]
|
| 227 |
+
|
| 228 |
+
if len(c.args) != len(cdef.construction.args):
|
| 229 |
+
assert len(c.args) + len(clause.points) == len(cdef.construction.args)
|
| 230 |
+
c.args = clause.points + c.args
|
| 231 |
+
|
| 232 |
+
mapping = dict(zip(cdef.construction.args, c.args))
|
| 233 |
+
for points, bs in cdef.basics:
|
| 234 |
+
points = tuple([mapping[x] for x in points])
|
| 235 |
+
for p in points:
|
| 236 |
+
group[p] = points
|
| 237 |
+
|
| 238 |
+
for b in bs:
|
| 239 |
+
args = [mapping[a] for a in b.args]
|
| 240 |
+
name = b.name
|
| 241 |
+
if b.name in ['s_angle', 'aconst']:
|
| 242 |
+
x, y, z, v = args
|
| 243 |
+
name = 'aconst'
|
| 244 |
+
v = int(v)
|
| 245 |
+
|
| 246 |
+
if v < 0:
|
| 247 |
+
v = -v
|
| 248 |
+
x, z = z, x
|
| 249 |
+
|
| 250 |
+
m, n = simplify(int(v), 180)
|
| 251 |
+
args = [y, z, y, x, f'{m}pi/{n}']
|
| 252 |
+
|
| 253 |
+
p2deps[points].append(hashed_txt(name, args))
|
| 254 |
+
|
| 255 |
+
for k, v in p2deps.items():
|
| 256 |
+
p2deps[k] = sort_deps(v)
|
| 257 |
+
|
| 258 |
+
points = clause.points
|
| 259 |
+
while points:
|
| 260 |
+
p = points[0]
|
| 261 |
+
gr = group[p]
|
| 262 |
+
points = [x for x in points if x not in gr]
|
| 263 |
+
|
| 264 |
+
deps_str = []
|
| 265 |
+
for dep in p2deps[gr]:
|
| 266 |
+
ref_str = '{:02}'.format(ref)
|
| 267 |
+
dep_str = pt.pretty(dep)
|
| 268 |
+
|
| 269 |
+
if dep[0] == 'aconst':
|
| 270 |
+
m, n = map(int, dep[-1].split('pi/'))
|
| 271 |
+
mn = f'{m}. pi / {n}.'
|
| 272 |
+
dep_str = ' '.join(dep_str.split()[:-1] + [mn])
|
| 273 |
+
|
| 274 |
+
deps_str.append(dep_str + ' ' + ref_str)
|
| 275 |
+
ref += 1
|
| 276 |
+
|
| 277 |
+
string.append(' '.join(gr) + ' : ' + ' '.join(deps_str))
|
| 278 |
+
|
| 279 |
+
string = '{S} ' + ' ; '.join([s.strip() for s in string])
|
| 280 |
+
goal = self.goal
|
| 281 |
+
string += ' ? ' + pt.pretty([goal.name] + goal.args)
|
| 282 |
+
return string
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
def parse_rely(s: str) -> dict[str, str]:
|
| 286 |
+
result = {}
|
| 287 |
+
if not s:
|
| 288 |
+
return result
|
| 289 |
+
s = [x.strip() for x in s.split(',')]
|
| 290 |
+
for x in s:
|
| 291 |
+
a, b = x.split(':')
|
| 292 |
+
a, b = a.strip().split(), b.strip().split()
|
| 293 |
+
result.update({m: b for m in a})
|
| 294 |
+
return result
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
class Definition:
|
| 298 |
+
"""Definitions of construction statements."""
|
| 299 |
+
|
| 300 |
+
@classmethod
|
| 301 |
+
def from_txt_file(cls, fname: str, to_dict: bool = False) -> Definition:
|
| 302 |
+
with open(fname, 'r') as f:
|
| 303 |
+
lines = f.read()
|
| 304 |
+
return cls.from_string(lines, to_dict)
|
| 305 |
+
|
| 306 |
+
@classmethod
|
| 307 |
+
def from_string(cls, string: str, to_dict: bool = False) -> Definition:
|
| 308 |
+
lines = string.split('\n')
|
| 309 |
+
data = [cls.from_txt('\n'.join(group)) for group in reshape(lines, 6)]
|
| 310 |
+
if to_dict:
|
| 311 |
+
return cls.to_dict(data)
|
| 312 |
+
return data
|
| 313 |
+
|
| 314 |
+
@classmethod
|
| 315 |
+
def to_dict(cls, data: list[Definition]) -> dict[str, Definition]:
|
| 316 |
+
return {d.construction.name: d for d in data}
|
| 317 |
+
|
| 318 |
+
@classmethod
|
| 319 |
+
def from_txt(cls, data: str) -> Definition:
|
| 320 |
+
"""Load definitions from a str object."""
|
| 321 |
+
construction, rely, deps, basics, numerics, _ = data.split('\n')
|
| 322 |
+
basics = [] if not basics else [b.strip() for b in basics.split(';')]
|
| 323 |
+
|
| 324 |
+
levels = []
|
| 325 |
+
for bs in basics:
|
| 326 |
+
if ':' in bs:
|
| 327 |
+
points, bs = bs.split(':')
|
| 328 |
+
points = points.strip().split()
|
| 329 |
+
else:
|
| 330 |
+
points = []
|
| 331 |
+
if bs.strip():
|
| 332 |
+
bs = [Construction.from_txt(b.strip()) for b in bs.strip().split(',')]
|
| 333 |
+
else:
|
| 334 |
+
bs = []
|
| 335 |
+
levels.append((points, bs))
|
| 336 |
+
|
| 337 |
+
numerics = [] if not numerics else numerics.split(', ')
|
| 338 |
+
|
| 339 |
+
return Definition(
|
| 340 |
+
construction=Construction.from_txt(construction),
|
| 341 |
+
rely=parse_rely(rely),
|
| 342 |
+
deps=Clause.from_txt(deps),
|
| 343 |
+
basics=levels,
|
| 344 |
+
numerics=[Construction.from_txt(c) for c in numerics],
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
def __init__(
|
| 348 |
+
self,
|
| 349 |
+
construction: Construction,
|
| 350 |
+
rely: dict[str, str],
|
| 351 |
+
deps: Clause,
|
| 352 |
+
basics: list[tuple[list[str], list[Construction]]],
|
| 353 |
+
numerics: list[Construction],
|
| 354 |
+
):
|
| 355 |
+
self.construction = construction
|
| 356 |
+
self.rely = rely
|
| 357 |
+
self.deps = deps
|
| 358 |
+
self.basics = basics
|
| 359 |
+
self.numerics = numerics
|
| 360 |
+
|
| 361 |
+
args = set()
|
| 362 |
+
for num in numerics:
|
| 363 |
+
args.update(num.args)
|
| 364 |
+
|
| 365 |
+
self.points = []
|
| 366 |
+
self.args = []
|
| 367 |
+
for p in self.construction.args:
|
| 368 |
+
if p in args:
|
| 369 |
+
self.args.append(p)
|
| 370 |
+
else:
|
| 371 |
+
self.points.append(p)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
class Theorem:
|
| 375 |
+
"""Deduction rule."""
|
| 376 |
+
|
| 377 |
+
@classmethod
|
| 378 |
+
def from_txt_file(cls, fname: str, to_dict: bool = False) -> Theorem:
|
| 379 |
+
with open(fname, 'r') as f:
|
| 380 |
+
theorems = f.read()
|
| 381 |
+
return cls.from_string(theorems, to_dict)
|
| 382 |
+
|
| 383 |
+
@classmethod
|
| 384 |
+
def from_string(cls, string: str, to_dict: bool = False) -> Theorem:
|
| 385 |
+
"""Load deduction rule from a str object."""
|
| 386 |
+
theorems = string.split('\n')
|
| 387 |
+
theorems = [l for l in theorems if l and not l.startswith('#')]
|
| 388 |
+
theorems = [cls.from_txt(l) for l in theorems]
|
| 389 |
+
|
| 390 |
+
for i, th in enumerate(theorems):
|
| 391 |
+
th.rule_name = 'r{:02}'.format(i)
|
| 392 |
+
|
| 393 |
+
if to_dict:
|
| 394 |
+
result = {}
|
| 395 |
+
for t in theorems:
|
| 396 |
+
if t.name in result:
|
| 397 |
+
t.name += '_'
|
| 398 |
+
result[t.rule_name] = t
|
| 399 |
+
|
| 400 |
+
return result
|
| 401 |
+
|
| 402 |
+
return theorems
|
| 403 |
+
|
| 404 |
+
@classmethod
|
| 405 |
+
def from_txt(cls, data: str) -> Theorem:
|
| 406 |
+
premises, conclusion = data.split(' => ')
|
| 407 |
+
premises = premises.split(', ')
|
| 408 |
+
conclusion = conclusion.split(', ')
|
| 409 |
+
return Theorem(
|
| 410 |
+
premise=[Construction.from_txt(p) for p in premises],
|
| 411 |
+
conclusion=[Construction.from_txt(c) for c in conclusion],
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
def __init__(
|
| 415 |
+
self, premise: list[Construction], conclusion: list[Construction]
|
| 416 |
+
):
|
| 417 |
+
if len(conclusion) != 1:
|
| 418 |
+
raise ValueError('Cannot have more than one conclusion')
|
| 419 |
+
self.name = '_'.join([p.name for p in premise + conclusion])
|
| 420 |
+
self.premise = premise
|
| 421 |
+
self.conclusion = conclusion
|
| 422 |
+
self.is_arg_reduce = False
|
| 423 |
+
|
| 424 |
+
assert len(self.conclusion) == 1
|
| 425 |
+
con = self.conclusion[0]
|
| 426 |
+
|
| 427 |
+
if con.name in [
|
| 428 |
+
'eqratio3',
|
| 429 |
+
'midp',
|
| 430 |
+
'contri',
|
| 431 |
+
'simtri',
|
| 432 |
+
'contri2',
|
| 433 |
+
'simtri2',
|
| 434 |
+
'simtri*',
|
| 435 |
+
'contri*',
|
| 436 |
+
]:
|
| 437 |
+
return
|
| 438 |
+
|
| 439 |
+
prem_args = set(sum([p.args for p in self.premise], []))
|
| 440 |
+
con_args = set(con.args)
|
| 441 |
+
if len(prem_args) <= len(con_args):
|
| 442 |
+
self.is_arg_reduce = True
|
| 443 |
+
|
| 444 |
+
def txt(self) -> str:
|
| 445 |
+
premise_txt = ', '.join([clause.txt() for clause in self.premise])
|
| 446 |
+
conclusion_txt = ', '.join([clause.txt() for clause in self.conclusion])
|
| 447 |
+
return f'{premise_txt} => {conclusion_txt}'
|
| 448 |
+
|
| 449 |
+
def conclusion_name_args(
|
| 450 |
+
self, mapping: dict[str, gm.Point]
|
| 451 |
+
) -> tuple[str, list[gm.Point]]:
|
| 452 |
+
mapping = {arg: p for arg, p in mapping.items() if isinstance(arg, str)}
|
| 453 |
+
c = self.conclusion[0]
|
| 454 |
+
args = [mapping[a] for a in c.args]
|
| 455 |
+
return c.name, args
|
| 456 |
+
|
| 457 |
+
|
| 458 |
+
def why_eqratio(
|
| 459 |
+
d1: gm.Direction,
|
| 460 |
+
d2: gm.Direction,
|
| 461 |
+
d3: gm.Direction,
|
| 462 |
+
d4: gm.Direction,
|
| 463 |
+
level: int,
|
| 464 |
+
) -> list[Dependency]:
|
| 465 |
+
"""Why two ratios are equal, returns a Dependency objects."""
|
| 466 |
+
all12 = list(gm.all_ratios(d1, d2, level))
|
| 467 |
+
all34 = list(gm.all_ratios(d3, d4, level))
|
| 468 |
+
|
| 469 |
+
min_why = None
|
| 470 |
+
for ang12, d1s, d2s in all12:
|
| 471 |
+
for ang34, d3s, d4s in all34:
|
| 472 |
+
why0 = gm.why_equal(ang12, ang34, level)
|
| 473 |
+
if why0 is None:
|
| 474 |
+
continue
|
| 475 |
+
d1_, d2_ = ang12._l
|
| 476 |
+
d3_, d4_ = ang34._l
|
| 477 |
+
why1 = gm.bfs_backtrack(d1, [d1_], d1s)
|
| 478 |
+
why2 = gm.bfs_backtrack(d2, [d2_], d2s)
|
| 479 |
+
why3 = gm.bfs_backtrack(d3, [d3_], d3s)
|
| 480 |
+
why4 = gm.bfs_backtrack(d4, [d4_], d4s)
|
| 481 |
+
why = why0 + why1 + why2 + why3 + why4
|
| 482 |
+
if min_why is None or len(why) < len(min_why[0]):
|
| 483 |
+
min_why = why, ang12, ang34, why0, why1, why2, why3, why4
|
| 484 |
+
|
| 485 |
+
if min_why is None:
|
| 486 |
+
return None
|
| 487 |
+
|
| 488 |
+
_, ang12, ang34, why0, why1, why2, why3, why4 = min_why
|
| 489 |
+
d1_, d2_ = ang12._l
|
| 490 |
+
d3_, d4_ = ang34._l
|
| 491 |
+
|
| 492 |
+
if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
|
| 493 |
+
return why0
|
| 494 |
+
|
| 495 |
+
(a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
|
| 496 |
+
(e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
|
| 497 |
+
deps = []
|
| 498 |
+
if why0:
|
| 499 |
+
dep = Dependency('eqratio', [a_, b_, c_, d_, e_, f_, g_, h_], '', level)
|
| 500 |
+
dep.why = why0
|
| 501 |
+
deps.append(dep)
|
| 502 |
+
|
| 503 |
+
(a, b), (c, d) = d1._obj.points, d2._obj.points
|
| 504 |
+
(e, f), (g, h) = d3._obj.points, d4._obj.points
|
| 505 |
+
for why, (x, y), (x_, y_) in zip(
|
| 506 |
+
[why1, why2, why3, why4],
|
| 507 |
+
[(a, b), (c, d), (e, f), (g, h)],
|
| 508 |
+
[(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
|
| 509 |
+
):
|
| 510 |
+
if why:
|
| 511 |
+
dep = Dependency('cong', [x, y, x_, y_], '', level)
|
| 512 |
+
dep.why = why
|
| 513 |
+
deps.append(dep)
|
| 514 |
+
|
| 515 |
+
return deps
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def why_eqangle(
|
| 519 |
+
d1: gm.Direction,
|
| 520 |
+
d2: gm.Direction,
|
| 521 |
+
d3: gm.Direction,
|
| 522 |
+
d4: gm.Direction,
|
| 523 |
+
level: int,
|
| 524 |
+
verbose: bool = False,
|
| 525 |
+
) -> list[Dependency]:
|
| 526 |
+
"""Why two angles are equal, returns a Dependency objects."""
|
| 527 |
+
all12 = list(gm.all_angles(d1, d2, level))
|
| 528 |
+
all34 = list(gm.all_angles(d3, d4, level))
|
| 529 |
+
|
| 530 |
+
min_why = None
|
| 531 |
+
for ang12, d1s, d2s in all12:
|
| 532 |
+
for ang34, d3s, d4s in all34:
|
| 533 |
+
why0 = gm.why_equal(ang12, ang34, level)
|
| 534 |
+
if why0 is None:
|
| 535 |
+
continue
|
| 536 |
+
d1_, d2_ = ang12._d
|
| 537 |
+
d3_, d4_ = ang34._d
|
| 538 |
+
why1 = gm.bfs_backtrack(d1, [d1_], d1s)
|
| 539 |
+
why2 = gm.bfs_backtrack(d2, [d2_], d2s)
|
| 540 |
+
why3 = gm.bfs_backtrack(d3, [d3_], d3s)
|
| 541 |
+
why4 = gm.bfs_backtrack(d4, [d4_], d4s)
|
| 542 |
+
why = why0 + why1 + why2 + why3 + why4
|
| 543 |
+
if min_why is None or len(why) < len(min_why[0]):
|
| 544 |
+
min_why = why, ang12, ang34, why0, why1, why2, why3, why4
|
| 545 |
+
|
| 546 |
+
if min_why is None:
|
| 547 |
+
return None
|
| 548 |
+
|
| 549 |
+
_, ang12, ang34, why0, why1, why2, why3, why4 = min_why
|
| 550 |
+
why0 = gm.why_equal(ang12, ang34, level)
|
| 551 |
+
d1_, d2_ = ang12._d
|
| 552 |
+
d3_, d4_ = ang34._d
|
| 553 |
+
|
| 554 |
+
if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
|
| 555 |
+
return (d1_, d2_, d3_, d4_), why0
|
| 556 |
+
|
| 557 |
+
(a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
|
| 558 |
+
(e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
|
| 559 |
+
deps = []
|
| 560 |
+
if why0:
|
| 561 |
+
dep = Dependency('eqangle', [a_, b_, c_, d_, e_, f_, g_, h_], '', None)
|
| 562 |
+
dep.why = why0
|
| 563 |
+
deps.append(dep)
|
| 564 |
+
|
| 565 |
+
(a, b), (c, d) = d1._obj.points, d2._obj.points
|
| 566 |
+
(e, f), (g, h) = d3._obj.points, d4._obj.points
|
| 567 |
+
for why, d_xy, (x, y), d_xy_, (x_, y_) in zip(
|
| 568 |
+
[why1, why2, why3, why4],
|
| 569 |
+
[d1, d2, d3, d4],
|
| 570 |
+
[(a, b), (c, d), (e, f), (g, h)],
|
| 571 |
+
[d1_, d2_, d3_, d4_],
|
| 572 |
+
[(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
|
| 573 |
+
):
|
| 574 |
+
xy, xy_ = d_xy._obj, d_xy_._obj
|
| 575 |
+
if why:
|
| 576 |
+
if xy == xy_:
|
| 577 |
+
name = 'collx'
|
| 578 |
+
else:
|
| 579 |
+
name = 'para'
|
| 580 |
+
dep = Dependency(name, [x_, y_, x, y], '', None)
|
| 581 |
+
dep.why = why
|
| 582 |
+
deps.append(dep)
|
| 583 |
+
|
| 584 |
+
return (d1_, d2_, d3_, d4_), deps
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
CONSTRUCTION_RULE = 'c0'
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
class EmptyDependency:
|
| 591 |
+
"""Empty dependency predicate ready to get filled up."""
|
| 592 |
+
|
| 593 |
+
def __init__(self, level: int, rule_name: str):
|
| 594 |
+
self.level = level
|
| 595 |
+
self.rule_name = rule_name or ''
|
| 596 |
+
self.empty = True
|
| 597 |
+
self.why = []
|
| 598 |
+
self.trace = None
|
| 599 |
+
|
| 600 |
+
def populate(self, name: str, args: list[gm.Point]) -> Dependency:
|
| 601 |
+
dep = Dependency(name, args, self.rule_name, self.level)
|
| 602 |
+
dep.trace2 = self.trace
|
| 603 |
+
dep.why = list(self.why)
|
| 604 |
+
return dep
|
| 605 |
+
|
| 606 |
+
def copy(self) -> EmptyDependency:
|
| 607 |
+
other = EmptyDependency(self.level, self.rule_name)
|
| 608 |
+
other.why = list(self.why)
|
| 609 |
+
return other
|
| 610 |
+
|
| 611 |
+
def extend(
|
| 612 |
+
self,
|
| 613 |
+
g: Any,
|
| 614 |
+
name0: str,
|
| 615 |
+
args0: list[gm.Point],
|
| 616 |
+
name: str,
|
| 617 |
+
args: list[gm.Point],
|
| 618 |
+
) -> EmptyDependency:
|
| 619 |
+
"""Extend the dependency list by (name, args)."""
|
| 620 |
+
dep0 = self.populate(name0, args0)
|
| 621 |
+
deps = EmptyDependency(level=self.level, rule_name=None)
|
| 622 |
+
dep = Dependency(name, args, None, deps.level)
|
| 623 |
+
deps.why = [dep0, dep.why_me_or_cache(g, None)]
|
| 624 |
+
return deps
|
| 625 |
+
|
| 626 |
+
def extend_many(
|
| 627 |
+
self,
|
| 628 |
+
g: Any,
|
| 629 |
+
name0: str,
|
| 630 |
+
args0: list[gm.Point],
|
| 631 |
+
name_args: list[tuple[str, list[gm.Point]]],
|
| 632 |
+
) -> EmptyDependency:
|
| 633 |
+
"""Extend the dependency list by many name_args."""
|
| 634 |
+
if not name_args:
|
| 635 |
+
return self
|
| 636 |
+
dep0 = self.populate(name0, args0)
|
| 637 |
+
deps = EmptyDependency(level=self.level, rule_name=None)
|
| 638 |
+
deps.why = [dep0]
|
| 639 |
+
for name, args in name_args:
|
| 640 |
+
dep = Dependency(name, args, None, deps.level)
|
| 641 |
+
deps.why += [dep.why_me_or_cache(g, None)]
|
| 642 |
+
return deps
|
| 643 |
+
|
| 644 |
+
|
| 645 |
+
def maybe_make_equal_pairs(
|
| 646 |
+
a: gm.Point,
|
| 647 |
+
b: gm.Point,
|
| 648 |
+
c: gm.Point,
|
| 649 |
+
d: gm.Point,
|
| 650 |
+
m: gm.Point,
|
| 651 |
+
n: gm.Point,
|
| 652 |
+
p: gm.Point,
|
| 653 |
+
q: gm.Point,
|
| 654 |
+
ab: gm.Line,
|
| 655 |
+
mn: gm.Line,
|
| 656 |
+
g: Any,
|
| 657 |
+
level: int,
|
| 658 |
+
) -> list[Dependency]:
|
| 659 |
+
"""Make a-b:c-d==m-n:p-q in case a-b==m-n or c-d==p-q."""
|
| 660 |
+
if ab != mn:
|
| 661 |
+
return
|
| 662 |
+
why = []
|
| 663 |
+
eqname = 'para' if isinstance(ab, gm.Line) else 'cong'
|
| 664 |
+
colls = [a, b, m, n]
|
| 665 |
+
if len(set(colls)) > 2 and eqname == 'para':
|
| 666 |
+
dep = Dependency('collx', colls, None, level)
|
| 667 |
+
dep.why_me(g, level)
|
| 668 |
+
why += [dep]
|
| 669 |
+
|
| 670 |
+
dep = Dependency(eqname, [c, d, p, q], None, level)
|
| 671 |
+
dep.why_me(g, level)
|
| 672 |
+
why += [dep]
|
| 673 |
+
return why
|
| 674 |
+
|
| 675 |
+
|
| 676 |
+
class Dependency(Construction):
|
| 677 |
+
"""Dependency is a predicate that other predicates depend on."""
|
| 678 |
+
|
| 679 |
+
def __init__(
|
| 680 |
+
self, name: str, args: list[gm.Point], rule_name: str, level: int
|
| 681 |
+
):
|
| 682 |
+
super().__init__(name, args)
|
| 683 |
+
self.rule_name = rule_name or ''
|
| 684 |
+
self.level = level
|
| 685 |
+
self.why = []
|
| 686 |
+
|
| 687 |
+
self._stat = None
|
| 688 |
+
self.trace = None
|
| 689 |
+
|
| 690 |
+
def _find(self, dep_hashed: tuple[str, ...]) -> Dependency:
|
| 691 |
+
for w in self.why:
|
| 692 |
+
f = w._find(dep_hashed)
|
| 693 |
+
if f:
|
| 694 |
+
return f
|
| 695 |
+
if w.hashed() == dep_hashed:
|
| 696 |
+
return w
|
| 697 |
+
|
| 698 |
+
def remove_loop(self) -> Dependency:
|
| 699 |
+
f = self._find(self.hashed())
|
| 700 |
+
if f:
|
| 701 |
+
return f
|
| 702 |
+
return self
|
| 703 |
+
|
| 704 |
+
def copy(self) -> Dependency:
|
| 705 |
+
dep = Dependency(self.name, self.args, self.rule_name, self.level)
|
| 706 |
+
dep.trace = self.trace
|
| 707 |
+
dep.why = list(self.why)
|
| 708 |
+
return dep
|
| 709 |
+
|
| 710 |
+
def why_me_or_cache(self, g: Any, level: int) -> Dependency:
|
| 711 |
+
if self.hashed() in g.cache:
|
| 712 |
+
return g.cache[self.hashed()]
|
| 713 |
+
self.why_me(g, level)
|
| 714 |
+
return self
|
| 715 |
+
|
| 716 |
+
def populate(self, name: str, args: list[gm.Point]) -> Dependency:
|
| 717 |
+
assert self.rule_name == CONSTRUCTION_RULE, self.rule_name
|
| 718 |
+
dep = Dependency(self.name, self.args, self.rule_name, self.level)
|
| 719 |
+
dep.why = list(self.why)
|
| 720 |
+
return dep
|
| 721 |
+
|
| 722 |
+
def why_me(self, g: Any, level: int) -> None:
|
| 723 |
+
"""Figure out the dependencies predicates of self."""
|
| 724 |
+
name, args = self.name, self.args
|
| 725 |
+
|
| 726 |
+
hashed_me = hashed(name, args)
|
| 727 |
+
if hashed_me in g.cache:
|
| 728 |
+
dep = g.cache[hashed_me]
|
| 729 |
+
self.why = dep.why
|
| 730 |
+
self.rule_name = dep.rule_name
|
| 731 |
+
return
|
| 732 |
+
|
| 733 |
+
if self.name == 'para':
|
| 734 |
+
a, b, c, d = self.args
|
| 735 |
+
if {a, b} == {c, d}:
|
| 736 |
+
self.why = []
|
| 737 |
+
return
|
| 738 |
+
|
| 739 |
+
ab = g._get_line(a, b)
|
| 740 |
+
cd = g._get_line(c, d)
|
| 741 |
+
if ab == cd:
|
| 742 |
+
if {a, b} == {c, d}:
|
| 743 |
+
self.why = []
|
| 744 |
+
self.rule_name = ''
|
| 745 |
+
return
|
| 746 |
+
dep = Dependency('coll', list({a, b, c, d}), 't??', None)
|
| 747 |
+
self.why = [dep.why_me_or_cache(g, level)]
|
| 748 |
+
return
|
| 749 |
+
|
| 750 |
+
for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
|
| 751 |
+
x_, y_ = xy.points
|
| 752 |
+
if {x, y} == {x_, y_}:
|
| 753 |
+
continue
|
| 754 |
+
d = Dependency('collx', [x, y, x_, y_], None, level)
|
| 755 |
+
self.why += [d.why_me_or_cache(g, level)]
|
| 756 |
+
|
| 757 |
+
whypara = g.why_equal(ab, cd, None)
|
| 758 |
+
self.why += whypara
|
| 759 |
+
|
| 760 |
+
elif self.name == 'midp':
|
| 761 |
+
m, a, b = self.args
|
| 762 |
+
ma = g._get_segment(m, a)
|
| 763 |
+
mb = g._get_segment(m, b)
|
| 764 |
+
dep = Dependency('coll', [m, a, b], None, None).why_me_or_cache(g, None)
|
| 765 |
+
self.why = [dep] + g.why_equal(ma, mb, level)
|
| 766 |
+
|
| 767 |
+
elif self.name == 'perp':
|
| 768 |
+
a, b, c, d = self.args
|
| 769 |
+
ab = g._get_line(a, b)
|
| 770 |
+
cd = g._get_line(c, d)
|
| 771 |
+
for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
|
| 772 |
+
x_, y_ = xy.points
|
| 773 |
+
if {x, y} == {x_, y_}:
|
| 774 |
+
continue
|
| 775 |
+
d = Dependency('collx', [x, y, x_, y_], None, level)
|
| 776 |
+
self.why += [d.why_me_or_cache(g, level)]
|
| 777 |
+
|
| 778 |
+
_, why = why_eqangle(ab._val, cd._val, cd._val, ab._val, level)
|
| 779 |
+
a, b = ab.points
|
| 780 |
+
c, d = cd.points
|
| 781 |
+
|
| 782 |
+
if hashed(self.name, [a, b, c, d]) != self.hashed():
|
| 783 |
+
d = Dependency(self.name, [a, b, c, d], None, level)
|
| 784 |
+
d.why = why
|
| 785 |
+
why = [d]
|
| 786 |
+
|
| 787 |
+
self.why += why
|
| 788 |
+
|
| 789 |
+
elif self.name == 'cong':
|
| 790 |
+
a, b, c, d = self.args
|
| 791 |
+
ab = g._get_segment(a, b)
|
| 792 |
+
cd = g._get_segment(c, d)
|
| 793 |
+
|
| 794 |
+
self.why = g.why_equal(ab, cd, level)
|
| 795 |
+
|
| 796 |
+
elif self.name == 'coll':
|
| 797 |
+
_, why = gm.line_of_and_why(self.args, level)
|
| 798 |
+
self.why = why
|
| 799 |
+
|
| 800 |
+
elif self.name == 'collx':
|
| 801 |
+
if g.check_coll(self.args):
|
| 802 |
+
args = list(set(self.args))
|
| 803 |
+
hashed_me = hashed('coll', args)
|
| 804 |
+
if hashed_me in g.cache:
|
| 805 |
+
dep = g.cache[hashed_me]
|
| 806 |
+
self.why = [dep]
|
| 807 |
+
self.rule_name = ''
|
| 808 |
+
return
|
| 809 |
+
_, self.why = gm.line_of_and_why(args, level)
|
| 810 |
+
else:
|
| 811 |
+
self.name = 'para'
|
| 812 |
+
self.why_me(g, level)
|
| 813 |
+
|
| 814 |
+
elif self.name == 'cyclic':
|
| 815 |
+
_, why = gm.circle_of_and_why(self.args, level)
|
| 816 |
+
self.why = why
|
| 817 |
+
|
| 818 |
+
elif self.name == 'circle':
|
| 819 |
+
o, a, b, c = self.args
|
| 820 |
+
oa = g._get_segment(o, a)
|
| 821 |
+
ob = g._get_segment(o, b)
|
| 822 |
+
oc = g._get_segment(o, c)
|
| 823 |
+
self.why = g.why_equal(oa, ob, level) + g.why_equal(oa, oc, level)
|
| 824 |
+
|
| 825 |
+
elif self.name in ['eqangle', 'eqangle6']:
|
| 826 |
+
a, b, c, d, m, n, p, q = self.args
|
| 827 |
+
|
| 828 |
+
ab, why1 = g.get_line_thru_pair_why(a, b)
|
| 829 |
+
cd, why2 = g.get_line_thru_pair_why(c, d)
|
| 830 |
+
mn, why3 = g.get_line_thru_pair_why(m, n)
|
| 831 |
+
pq, why4 = g.get_line_thru_pair_why(p, q)
|
| 832 |
+
|
| 833 |
+
if ab is None or cd is None or mn is None or pq is None:
|
| 834 |
+
if {a, b} == {m, n}:
|
| 835 |
+
d = Dependency('para', [c, d, p, q], None, level)
|
| 836 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 837 |
+
if {a, b} == {c, d}:
|
| 838 |
+
d = Dependency('para', [p, q, m, n], None, level)
|
| 839 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 840 |
+
if {c, d} == {p, q}:
|
| 841 |
+
d = Dependency('para', [a, b, m, n], None, level)
|
| 842 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 843 |
+
if {p, q} == {m, n}:
|
| 844 |
+
d = Dependency('para', [a, b, c, d], None, level)
|
| 845 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 846 |
+
return
|
| 847 |
+
|
| 848 |
+
for (x, y), xy, whyxy in zip(
|
| 849 |
+
[(a, b), (c, d), (m, n), (p, q)],
|
| 850 |
+
[ab, cd, mn, pq],
|
| 851 |
+
[why1, why2, why3, why4],
|
| 852 |
+
):
|
| 853 |
+
x_, y_ = xy.points
|
| 854 |
+
if {x, y} == {x_, y_}:
|
| 855 |
+
continue
|
| 856 |
+
d = Dependency('collx', [x, y, x_, y_], None, level)
|
| 857 |
+
d.why = whyxy
|
| 858 |
+
self.why += [d]
|
| 859 |
+
|
| 860 |
+
a, b = ab.points
|
| 861 |
+
c, d = cd.points
|
| 862 |
+
m, n = mn.points
|
| 863 |
+
p, q = pq.points
|
| 864 |
+
diff = hashed(self.name, [a, b, c, d, m, n, p, q]) != self.hashed()
|
| 865 |
+
|
| 866 |
+
whyeqangle = None
|
| 867 |
+
if ab._val and cd._val and mn._val and pq._val:
|
| 868 |
+
whyeqangle = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
|
| 869 |
+
|
| 870 |
+
if whyeqangle:
|
| 871 |
+
(dab, dcd, dmn, dpq), whyeqangle = whyeqangle
|
| 872 |
+
if diff:
|
| 873 |
+
d = Dependency('eqangle', [a, b, c, d, m, n, p, q], None, level)
|
| 874 |
+
d.why = whyeqangle
|
| 875 |
+
whyeqangle = [d]
|
| 876 |
+
self.why += whyeqangle
|
| 877 |
+
|
| 878 |
+
else:
|
| 879 |
+
if (ab == cd and mn == pq) or (ab == mn and cd == pq):
|
| 880 |
+
self.why += []
|
| 881 |
+
elif ab == mn:
|
| 882 |
+
self.why += maybe_make_equal_pairs(
|
| 883 |
+
a, b, c, d, m, n, p, q, ab, mn, g, level
|
| 884 |
+
)
|
| 885 |
+
elif cd == pq:
|
| 886 |
+
self.why += maybe_make_equal_pairs(
|
| 887 |
+
c, d, a, b, p, q, m, n, cd, pq, g, level
|
| 888 |
+
)
|
| 889 |
+
elif ab == cd:
|
| 890 |
+
self.why += maybe_make_equal_pairs(
|
| 891 |
+
a, b, m, n, c, d, p, q, ab, cd, g, level
|
| 892 |
+
)
|
| 893 |
+
elif mn == pq:
|
| 894 |
+
self.why += maybe_make_equal_pairs(
|
| 895 |
+
m, n, a, b, p, q, c, d, mn, pq, g, level
|
| 896 |
+
)
|
| 897 |
+
elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
|
| 898 |
+
dep1 = Dependency('para', [a, b, m, n], None, level)
|
| 899 |
+
dep1.why_me(g, level)
|
| 900 |
+
dep2 = Dependency('para', [c, d, p, q], None, level)
|
| 901 |
+
dep2.why_me(g, level)
|
| 902 |
+
self.why += [dep1, dep2]
|
| 903 |
+
elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
|
| 904 |
+
dep1 = Dependency('para', [a, b, c, d], None, level)
|
| 905 |
+
dep1.why_me(g, level)
|
| 906 |
+
dep2 = Dependency('para', [m, n, p, q], None, level)
|
| 907 |
+
dep2.why_me(g, level)
|
| 908 |
+
self.why += [dep1, dep2]
|
| 909 |
+
elif ab._val and cd._val and mn._val and pq._val:
|
| 910 |
+
self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
|
| 911 |
+
|
| 912 |
+
elif self.name in ['eqratio', 'eqratio6']:
|
| 913 |
+
a, b, c, d, m, n, p, q = self.args
|
| 914 |
+
ab = g._get_segment(a, b)
|
| 915 |
+
cd = g._get_segment(c, d)
|
| 916 |
+
mn = g._get_segment(m, n)
|
| 917 |
+
pq = g._get_segment(p, q)
|
| 918 |
+
|
| 919 |
+
if ab is None or cd is None or mn is None or pq is None:
|
| 920 |
+
if {a, b} == {m, n}:
|
| 921 |
+
d = Dependency('cong', [c, d, p, q], None, level)
|
| 922 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 923 |
+
if {a, b} == {c, d}:
|
| 924 |
+
d = Dependency('cong', [p, q, m, n], None, level)
|
| 925 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 926 |
+
if {c, d} == {p, q}:
|
| 927 |
+
d = Dependency('cong', [a, b, m, n], None, level)
|
| 928 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 929 |
+
if {p, q} == {m, n}:
|
| 930 |
+
d = Dependency('cong', [a, b, c, d], None, level)
|
| 931 |
+
self.why = [d.why_me_or_cache(g, level)]
|
| 932 |
+
return
|
| 933 |
+
|
| 934 |
+
if ab._val and cd._val and mn._val and pq._val:
|
| 935 |
+
self.why = why_eqratio(ab._val, cd._val, mn._val, pq._val, level)
|
| 936 |
+
|
| 937 |
+
if self.why is None:
|
| 938 |
+
self.why = []
|
| 939 |
+
if (ab == cd and mn == pq) or (ab == mn and cd == pq):
|
| 940 |
+
self.why = []
|
| 941 |
+
elif ab == mn:
|
| 942 |
+
self.why += maybe_make_equal_pairs(
|
| 943 |
+
a, b, c, d, m, n, p, q, ab, mn, g, level
|
| 944 |
+
)
|
| 945 |
+
elif cd == pq:
|
| 946 |
+
self.why += maybe_make_equal_pairs(
|
| 947 |
+
c, d, a, b, p, q, m, n, cd, pq, g, level
|
| 948 |
+
)
|
| 949 |
+
elif ab == cd:
|
| 950 |
+
self.why += maybe_make_equal_pairs(
|
| 951 |
+
a, b, m, n, c, d, p, q, ab, cd, g, level
|
| 952 |
+
)
|
| 953 |
+
elif mn == pq:
|
| 954 |
+
self.why += maybe_make_equal_pairs(
|
| 955 |
+
m, n, a, b, p, q, c, d, mn, pq, g, level
|
| 956 |
+
)
|
| 957 |
+
elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
|
| 958 |
+
dep1 = Dependency('cong', [a, b, m, n], None, level)
|
| 959 |
+
dep1.why_me(g, level)
|
| 960 |
+
dep2 = Dependency('cong', [c, d, p, q], None, level)
|
| 961 |
+
dep2.why_me(g, level)
|
| 962 |
+
self.why += [dep1, dep2]
|
| 963 |
+
elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
|
| 964 |
+
dep1 = Dependency('cong', [a, b, c, d], None, level)
|
| 965 |
+
dep1.why_me(g, level)
|
| 966 |
+
dep2 = Dependency('cong', [m, n, p, q], None, level)
|
| 967 |
+
dep2.why_me(g, level)
|
| 968 |
+
self.why += [dep1, dep2]
|
| 969 |
+
elif ab._val and cd._val and mn._val and pq._val:
|
| 970 |
+
self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
|
| 971 |
+
|
| 972 |
+
elif self.name in ['diff', 'npara', 'nperp', 'ncoll', 'sameside']:
|
| 973 |
+
self.why = []
|
| 974 |
+
|
| 975 |
+
elif self.name == 'simtri':
|
| 976 |
+
a, b, c, x, y, z = self.args
|
| 977 |
+
dep1 = Dependency('eqangle', [a, b, a, c, x, y, x, z], '', level)
|
| 978 |
+
dep1.why_me(g, level)
|
| 979 |
+
dep2 = Dependency('eqangle', [b, a, b, c, y, x, y, z], '', level)
|
| 980 |
+
dep2.why_me(g, level)
|
| 981 |
+
self.rule_name = 'r34'
|
| 982 |
+
self.why = [dep1, dep2]
|
| 983 |
+
|
| 984 |
+
elif self.name == 'contri':
|
| 985 |
+
a, b, c, x, y, z = self.args
|
| 986 |
+
dep1 = Dependency('cong', [a, b, x, y], '', level)
|
| 987 |
+
dep1.why_me(g, level)
|
| 988 |
+
dep2 = Dependency('cong', [b, c, y, z], '', level)
|
| 989 |
+
dep2.why_me(g, level)
|
| 990 |
+
dep3 = Dependency('cong', [c, a, z, x], '', level)
|
| 991 |
+
dep3.why_me(g, level)
|
| 992 |
+
self.rule_name = 'r32'
|
| 993 |
+
self.why = [dep1, dep2, dep3]
|
| 994 |
+
|
| 995 |
+
elif self.name == 'ind':
|
| 996 |
+
pass
|
| 997 |
+
|
| 998 |
+
elif self.name == 'aconst':
|
| 999 |
+
a, b, c, d, ang0 = self.args
|
| 1000 |
+
|
| 1001 |
+
measure = ang0._val
|
| 1002 |
+
|
| 1003 |
+
for ang in measure.neighbors(gm.Angle):
|
| 1004 |
+
if ang == ang0:
|
| 1005 |
+
continue
|
| 1006 |
+
d1, d2 = ang._d
|
| 1007 |
+
l1, l2 = d1._obj, d2._obj
|
| 1008 |
+
(a1, b1), (c1, d1) = l1.points, l2.points
|
| 1009 |
+
|
| 1010 |
+
if not g.check_para_or_coll([a, b, a1, b1]) or not g.check_para_or_coll(
|
| 1011 |
+
[c, d, c1, d1]
|
| 1012 |
+
):
|
| 1013 |
+
continue
|
| 1014 |
+
|
| 1015 |
+
self.why = []
|
| 1016 |
+
for args in [(a, b, a1, b1), (c, d, c1, d1)]:
|
| 1017 |
+
if g.check_coll(args):
|
| 1018 |
+
if len(set(args)) > 2:
|
| 1019 |
+
dep = Dependency('coll', args, None, None)
|
| 1020 |
+
self.why.append(dep.why_me_or_cache(g, level))
|
| 1021 |
+
else:
|
| 1022 |
+
dep = Dependency('para', args, None, None)
|
| 1023 |
+
self.why.append(dep.why_me_or_cache(g, level))
|
| 1024 |
+
|
| 1025 |
+
self.why += gm.why_equal(ang, ang0)
|
| 1026 |
+
break
|
| 1027 |
+
|
| 1028 |
+
elif self.name == 'rconst':
|
| 1029 |
+
a, b, c, d, rat0 = self.args
|
| 1030 |
+
|
| 1031 |
+
val = rat0._val
|
| 1032 |
+
|
| 1033 |
+
for rat in val.neighbors(gm.Ratio):
|
| 1034 |
+
if rat == rat0:
|
| 1035 |
+
continue
|
| 1036 |
+
l1, l2 = rat._l
|
| 1037 |
+
s1, s2 = l1._obj, l2._obj
|
| 1038 |
+
(a1, b1), (c1, d1) = list(s1.points), list(s2.points)
|
| 1039 |
+
|
| 1040 |
+
if not g.check_cong([a, b, a1, b1]) or not g.check_cong([c, d, c1, d1]):
|
| 1041 |
+
continue
|
| 1042 |
+
|
| 1043 |
+
self.why = []
|
| 1044 |
+
for args in [(a, b, a1, b1), (c, d, c1, d1)]:
|
| 1045 |
+
if len(set(args)) > 2:
|
| 1046 |
+
dep = Dependency('cong', args, None, None)
|
| 1047 |
+
self.why.append(dep.why_me_or_cache(g, level))
|
| 1048 |
+
|
| 1049 |
+
self.why += gm.why_equal(rat, rat0)
|
| 1050 |
+
break
|
| 1051 |
+
|
| 1052 |
+
else:
|
| 1053 |
+
raise ValueError('Not recognize', self.name)
|
| 1054 |
+
|
| 1055 |
+
def hashed(self, rename: bool = False) -> tuple[str, ...]:
|
| 1056 |
+
return hashed(self.name, self.args, rename=rename)
|
| 1057 |
+
|
| 1058 |
+
|
| 1059 |
+
def hashed(
|
| 1060 |
+
name: str, args: list[gm.Point], rename: bool = False
|
| 1061 |
+
) -> tuple[str, ...]:
|
| 1062 |
+
if name == 's_angle':
|
| 1063 |
+
args = [p.name if not rename else p.new_name for p in args[:-1]] + [
|
| 1064 |
+
str(args[-1])
|
| 1065 |
+
]
|
| 1066 |
+
else:
|
| 1067 |
+
args = [p.name if not rename else p.new_name for p in args]
|
| 1068 |
+
return hashed_txt(name, args)
|
| 1069 |
+
|
| 1070 |
+
|
| 1071 |
+
def hashed_txt(name: str, args: list[str]) -> tuple[str, ...]:
|
| 1072 |
+
"""Return a tuple unique to name and args upto arg permutation equivariant."""
|
| 1073 |
+
|
| 1074 |
+
if name in ['const', 'aconst', 'rconst']:
|
| 1075 |
+
a, b, c, d, y = args
|
| 1076 |
+
a, b = sorted([a, b])
|
| 1077 |
+
c, d = sorted([c, d])
|
| 1078 |
+
return name, a, b, c, d, y
|
| 1079 |
+
|
| 1080 |
+
if name in ['npara', 'nperp', 'para', 'cong', 'perp', 'collx']:
|
| 1081 |
+
a, b, c, d = args
|
| 1082 |
+
|
| 1083 |
+
a, b = sorted([a, b])
|
| 1084 |
+
c, d = sorted([c, d])
|
| 1085 |
+
(a, b), (c, d) = sorted([(a, b), (c, d)])
|
| 1086 |
+
|
| 1087 |
+
return (name, a, b, c, d)
|
| 1088 |
+
|
| 1089 |
+
if name in ['midp', 'midpoint']:
|
| 1090 |
+
a, b, c = args
|
| 1091 |
+
b, c = sorted([b, c])
|
| 1092 |
+
return (name, a, b, c)
|
| 1093 |
+
|
| 1094 |
+
if name in ['coll', 'cyclic', 'ncoll', 'diff', 'triangle']:
|
| 1095 |
+
return (name,) + tuple(sorted(list(set(args))))
|
| 1096 |
+
|
| 1097 |
+
if name == 'circle':
|
| 1098 |
+
x, a, b, c = args
|
| 1099 |
+
return (name, x) + tuple(sorted([a, b, c]))
|
| 1100 |
+
|
| 1101 |
+
if name in ['eqangle', 'eqratio', 'eqangle6', 'eqratio6']:
|
| 1102 |
+
a, b, c, d, e, f, g, h = args
|
| 1103 |
+
a, b = sorted([a, b])
|
| 1104 |
+
c, d = sorted([c, d])
|
| 1105 |
+
e, f = sorted([e, f])
|
| 1106 |
+
g, h = sorted([g, h])
|
| 1107 |
+
if tuple(sorted([a, b, e, f])) > tuple(sorted([c, d, g, h])):
|
| 1108 |
+
a, b, e, f, c, d, g, h = c, d, g, h, a, b, e, f
|
| 1109 |
+
if (a, b, c, d) > (e, f, g, h):
|
| 1110 |
+
a, b, c, d, e, f, g, h = e, f, g, h, a, b, c, d
|
| 1111 |
+
|
| 1112 |
+
if name == 'eqangle6':
|
| 1113 |
+
name = 'eqangle'
|
| 1114 |
+
if name == 'eqratio6':
|
| 1115 |
+
name = 'eqratio'
|
| 1116 |
+
return (name,) + (a, b, c, d, e, f, g, h)
|
| 1117 |
+
|
| 1118 |
+
if name in ['contri', 'simtri', 'simtri2', 'contri2', 'contri*', 'simtri*']:
|
| 1119 |
+
a, b, c, x, y, z = args
|
| 1120 |
+
(a, x), (b, y), (c, z) = sorted([(a, x), (b, y), (c, z)], key=sorted)
|
| 1121 |
+
(a, b, c), (x, y, z) = sorted([(a, b, c), (x, y, z)], key=sorted)
|
| 1122 |
+
return (name, a, b, c, x, y, z)
|
| 1123 |
+
|
| 1124 |
+
if name in ['eqratio3']:
|
| 1125 |
+
a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
|
| 1126 |
+
(a, c), (b, d) = sorted([(a, c), (b, d)], key=sorted)
|
| 1127 |
+
(a, b), (c, d) = sorted([(a, b), (c, d)], key=sorted)
|
| 1128 |
+
return (name, a, b, c, d, o, o)
|
| 1129 |
+
|
| 1130 |
+
if name in ['sameside', 's_angle']:
|
| 1131 |
+
return (name,) + tuple(args)
|
| 1132 |
+
|
| 1133 |
+
raise ValueError(f'Not recognize {name} to hash.')
|
external/alphageometry/problem_test.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2023 DeepMind Technologies Limited
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# ==============================================================================
|
| 15 |
+
|
| 16 |
+
"""Unit tests for problem.py."""
|
| 17 |
+
import unittest
|
| 18 |
+
|
| 19 |
+
from absl.testing import absltest
|
| 20 |
+
import problem as pr
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ProblemTest(unittest.TestCase):
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def setUpClass(cls):
|
| 27 |
+
super().setUpClass()
|
| 28 |
+
cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
|
| 29 |
+
|
| 30 |
+
def test_orthocenter_no_translate(self):
|
| 31 |
+
txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long
|
| 32 |
+
|
| 33 |
+
# read the txt into pr.Problem object, do not change the name of points:
|
| 34 |
+
p = pr.Problem.from_txt(txt, translate=False)
|
| 35 |
+
|
| 36 |
+
# This is fed into the LM, translating from constructive to constrained:
|
| 37 |
+
setup_str = p.setup_str_from_problem(ProblemTest.defs)
|
| 38 |
+
|
| 39 |
+
self.assertEqual(
|
| 40 |
+
setup_str,
|
| 41 |
+
'{S} a : ; b : ; c : ; h : T a b c h 00 T a c b h 01 ? T a h b c',
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
def test_orthocenter_translate(self):
|
| 45 |
+
txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long
|
| 46 |
+
|
| 47 |
+
# Read the txt into pr.Problem object, change h -> d to match
|
| 48 |
+
# training data distribution.
|
| 49 |
+
p = pr.Problem.from_txt(txt, translate=True)
|
| 50 |
+
|
| 51 |
+
# This is fed into the LM, translating from constructive to constrained:
|
| 52 |
+
setup_str = p.setup_str_from_problem(ProblemTest.defs)
|
| 53 |
+
|
| 54 |
+
self.assertEqual(
|
| 55 |
+
setup_str,
|
| 56 |
+
'{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c',
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == '__main__':
|
| 61 |
+
absltest.main()
|
external/alphageometry/requirements.in
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow==2.13.0
|
| 2 |
+
numpy==1.23.5
|
| 3 |
+
scipy==1.10.0
|
| 4 |
+
matplotlib==3.7.0
|
| 5 |
+
gdown==4.7.1
|
| 6 |
+
jax==0.4.6
|
| 7 |
+
jaxlib==0.4.6
|
| 8 |
+
flax==0.5.3
|
| 9 |
+
gin-config==0.5.0
|
| 10 |
+
gin==0.1.6
|
| 11 |
+
t5==0.9.4
|
| 12 |
+
sentencepiece==0.1.99
|
| 13 |
+
absl-py==1.4.0
|
| 14 |
+
clu==0.0.7
|
| 15 |
+
optax==0.1.7
|
| 16 |
+
seqio==0.0.18
|
| 17 |
+
tensorflow-datasets==4.9.3
|
external/alphageometry/requirements.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|