ZAIDX11 commited on
Commit
6d92a68
·
verified ·
1 Parent(s): 85a5bf3

Add files using upload-large-folder tool

Browse files
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/__init__.py ADDED
@@ -0,0 +1,687 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections.abc import Callable
2
+ # Copyright 2024 The JAX Authors. All Rights Reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ # ==============================================================================
16
+
17
+ import contextlib
18
+ import ctypes
19
+ import dataclasses
20
+ import functools
21
+ import itertools
22
+ import os
23
+ import pathlib
24
+ import subprocess
25
+ import tempfile
26
+ import time
27
+ from typing import Any, Generic, Sequence, TypeVar
28
+
29
+ import jax
30
+ from jax._src import config
31
+ from jax._src import core as jax_core
32
+ from jax._src.interpreters import mlir
33
+ from jax._src.lib import xla_client
34
+ from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
35
+ from jaxlib.mlir import ir
36
+ from jaxlib.mlir.dialects import arith
37
+ from jaxlib.mlir.dialects import builtin
38
+ from jaxlib.mlir.dialects import func
39
+ from jaxlib.mlir.dialects import gpu
40
+ from jaxlib.mlir.dialects import llvm
41
+ from jaxlib.mlir.dialects import memref
42
+ from jaxlib.mlir.dialects import nvgpu
43
+ from jaxlib.mlir.dialects import nvvm
44
+ from jaxlib.mlir.passmanager import PassManager
45
+ import numpy as np
46
+
47
+ from . import dsl as mgpu
48
+ from . import profiler
49
+ from . import utils
50
+
51
+ # mypy: ignore-errors
52
+
53
+ # MLIR can't find libdevice unless we point it to the CUDA path
54
+ # TODO(apaszke): Unify with jax._src.lib.cuda_path
55
+ CUDA_ROOT = "/usr/local/cuda"
56
+ if os.environ.get("CUDA_ROOT") is None:
57
+ os.environ["CUDA_ROOT"] = CUDA_ROOT
58
+ else:
59
+ CUDA_ROOT = os.environ["CUDA_ROOT"]
60
+
61
+ PTXAS_PATH = os.path.join(CUDA_ROOT, "bin/ptxas")
62
+ NVDISASM_PATH = os.path.join(CUDA_ROOT, "bin/nvdisasm")
63
+
64
+ TMA_DESCRIPTOR_BYTES = 128
65
+ TMA_DESCRIPTOR_ALIGNMENT = 64
66
+
67
+
68
+ c = mgpu.c # This is too common to fully qualify.
69
+
70
+
71
+ RUNTIME_PATH = pathlib.Path(mosaic_gpu_lib._mosaic_gpu_ext.__file__).parent / "libmosaic_gpu_runtime.so"
72
+ if RUNTIME_PATH.exists():
73
+ # Set this so that the custom call can find it
74
+ os.environ["MOSAIC_GPU_RUNTIME_LIB_PATH"] = str(RUNTIME_PATH)
75
+
76
+
77
+ mosaic_gpu_p = jax.core.Primitive("mosaic_gpu_p")
78
+ mosaic_gpu_p.multiple_results = True
79
+
80
+
81
+ @mosaic_gpu_p.def_abstract_eval
82
+ def _mosaic_gpu_abstract_eval(*_, module, out_types, gmem_scratch_bytes):
83
+ del module, gmem_scratch_bytes # Unused.
84
+ return [jax._src.core.ShapedArray(t.shape, t.dtype) for t in out_types]
85
+
86
+ # TODO(apaszke): Implement a proper system for managing kernel lifetimes
87
+ kernel_idx = itertools.count()
88
+
89
+ def _mosaic_gpu_lowering_rule(ctx, *args, module, out_types, gmem_scratch_bytes):
90
+ del out_types # Unused.
91
+ idx_bytes = next(kernel_idx).to_bytes(8, byteorder="little")
92
+ op = mlir.custom_call(
93
+ "mosaic_gpu",
94
+ result_types=[
95
+ *(mlir.aval_to_ir_type(aval) for aval in ctx.avals_out),
96
+ mlir.aval_to_ir_type(
97
+ jax_core.ShapedArray((gmem_scratch_bytes,), np.uint8)
98
+ ),
99
+ ],
100
+ operands=args,
101
+ backend_config=idx_bytes
102
+ + module.operation.get_asm(binary=True, enable_debug_info=True),
103
+ )
104
+ return op.results[:-1] # Skip the scratch space.
105
+
106
+ mlir.register_lowering(mosaic_gpu_p, _mosaic_gpu_lowering_rule, "cuda")
107
+
108
+
109
+ @dataclasses.dataclass(frozen=True)
110
+ class MemRefTransform:
111
+ def apply(self, ref: ir.Value) -> ir.Value:
112
+ raise NotImplementedError("Subclasses should override this method")
113
+
114
+ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
115
+ raise NotImplementedError("Subclasses should override this method")
116
+
117
+ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
118
+ raise NotImplementedError("Subclasses should override this method")
119
+
120
+
121
+ @dataclasses.dataclass(frozen=True)
122
+ class TileTransform(MemRefTransform):
123
+ """Tiles a suffix of memref dimensions.
124
+
125
+ For example, given a memref of shape (5, 128, 128) and a tiling of (64, 32),
126
+ the shape of the result will be (5, 2, 4, 64, 32). The shape always ends with
127
+ the tile shape, and the size of tiled dimensions is divided by the tile size.
128
+ This is especially useful for swizzled WGMMA, which expect tiled layouts in
129
+ shared memory.
130
+ """
131
+ tiling: tuple[int, ...]
132
+
133
+ def apply(self, ref: ir.Value) -> ir.Value:
134
+ untiled_rank = ir.MemRefType(ref.type).rank
135
+ tiling_rank = len(self.tiling)
136
+ tiled_rank = untiled_rank + tiling_rank
137
+ for t, d in zip(self.tiling[::-1], range(untiled_rank)[::-1]):
138
+ ref = mgpu.memref_unfold(ref, d, (None, t))
139
+ permutation = (
140
+ *range(untiled_rank - tiling_rank),
141
+ *range(untiled_rank - tiling_rank, tiled_rank, 2),
142
+ *range(untiled_rank - tiling_rank + 1, tiled_rank, 2),
143
+ )
144
+ return mgpu.memref_transpose(ref, permutation)
145
+
146
+ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
147
+ index = ir.IndexType.get()
148
+ tiling_rank = len(self.tiling)
149
+ return (
150
+ *idx[:-tiling_rank],
151
+ *(
152
+ arith.divui(i, c(t, index))
153
+ for i, t in zip(idx[-tiling_rank:], self.tiling)
154
+ ),
155
+ *(
156
+ arith.remui(i, c(t, index))
157
+ for i, t in zip(idx[-tiling_rank:], self.tiling)
158
+ ),
159
+ )
160
+
161
+ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
162
+ # Note that this also checks that tiled dims are not squeezed. Their slice
163
+ # size would be 1 if so.
164
+ tiling_rank = len(self.tiling)
165
+ for size, tile_size in zip(shape[-tiling_rank:], self.tiling):
166
+ if size % tile_size:
167
+ raise ValueError(
168
+ f"Expected GMEM slice shape {shape} suffix to be a multiple"
169
+ f" of tiling {self.tiling}"
170
+ )
171
+ return (
172
+ *shape[:-tiling_rank],
173
+ *(s // t for s, t in zip(shape[-tiling_rank:], self.tiling)),
174
+ *self.tiling,
175
+ )
176
+
177
+
178
+ @dataclasses.dataclass(frozen=True)
179
+ class TransposeTransform(MemRefTransform):
180
+ """Transposes memref dimensions."""
181
+ permutation: tuple[int, ...]
182
+
183
+ def __post_init__(self):
184
+ if len(self.permutation) != len(set(self.permutation)):
185
+ raise ValueError("Permutation must be a permutation")
186
+
187
+ def apply(self, ref: ir.Value) -> ir.Value:
188
+ return mgpu.memref_transpose(ref, self.permutation)
189
+
190
+ def transform_index(self, idx: Sequence[ir.Value]) -> tuple[ir.Value, ...]:
191
+ return tuple(idx[p] for p in self.permutation)
192
+
193
+ def transform_shape(self, shape: Sequence[int]) -> tuple[int, ...]:
194
+ return tuple(shape[p] for p in self.permutation)
195
+
196
+
197
+ OnDeviceProfiler = profiler.OnDeviceProfiler
198
+
199
+
200
+ @dataclasses.dataclass()
201
+ class LaunchContext:
202
+ launch_op: gpu.LaunchOp
203
+ gmem_scratch_ptr: ir.Value
204
+ profiler: OnDeviceProfiler | None = None
205
+ next_scratch_offset: int = 0
206
+ host_scratch_init: list[Callable[[ir.Value], None]] = dataclasses.field(
207
+ default_factory=list, init=False
208
+ )
209
+ tma_descriptors: dict[
210
+ tuple[ir.Value, tuple[int, ...], int | None, tuple[MemRefTransform, ...]],
211
+ ir.Value,
212
+ ] = dataclasses.field(default_factory=dict, init=False)
213
+
214
+ @contextlib.contextmanager
215
+ def named_region(self, *args, **kwargs):
216
+ if self.profiler is not None:
217
+ with self.profiler.record(*args, **kwargs):
218
+ yield
219
+ else:
220
+ yield
221
+
222
+ def _alloc_scratch(
223
+ self,
224
+ size: int,
225
+ alignment: int | None = None,
226
+ host_init: Callable[[ir.Value], None] = lambda _: None,
227
+ device_init: Callable[[ir.Value], Any] = lambda x: x,
228
+ ) -> ir.Value:
229
+ """Allocates a GMEM scratch buffer.
230
+
231
+ The buffer is initialized on the host and then copied to GMEM before the
232
+ kernel launch.
233
+ """
234
+ i8 = ir.IntegerType.get_signless(8)
235
+ ptr_ty = ir.Type.parse("!llvm.ptr")
236
+ if alignment is None:
237
+ alignment = size
238
+ if self.next_scratch_offset % alignment:
239
+ raise NotImplementedError # TODO(apaszke): Pad to match alignment
240
+ alloc_base = self.next_scratch_offset
241
+ self.next_scratch_offset += size
242
+ def host_init_wrapped(host_ptr):
243
+ with ir.InsertionPoint(self.launch_op):
244
+ host_init(
245
+ llvm.getelementptr(ptr_ty, host_ptr, [], [alloc_base], i8)
246
+ )
247
+ self.host_scratch_init.append(host_init_wrapped)
248
+ with ir.InsertionPoint.at_block_begin(self.launch_op.body.blocks[0]):
249
+ return device_init(llvm.getelementptr(
250
+ ptr_ty, self.gmem_scratch_ptr, [], [alloc_base], i8
251
+ ))
252
+
253
+ def _get_tma_desc(
254
+ self,
255
+ ref,
256
+ gmem_transform: tuple[MemRefTransform, ...],
257
+ transformed_slice_shape: tuple[int, ...],
258
+ swizzle: int | None,
259
+ ):
260
+ index = ir.IndexType.get()
261
+ ref_ty = ir.MemRefType(ref.type)
262
+ tma_desc_key = (ref, transformed_slice_shape, swizzle, gmem_transform)
263
+ if (tma_desc := self.tma_descriptors.get(tma_desc_key, None)) is None:
264
+ swizzle_str = f"swizzle_{swizzle}b" if swizzle is not None else "none"
265
+ default_tensor_map_attrs = dict(
266
+ swizzle=swizzle_str, l2promo="none", oob="zero", interleave="none"
267
+ )
268
+ tensor_map_ty = utils.get_tensormap_descriptor(
269
+ tensor=(
270
+ f"memref<{'x'.join(map(str, transformed_slice_shape))}x{ref_ty.element_type}, 3>"
271
+ ),
272
+ **default_tensor_map_attrs,
273
+ )
274
+ with ir.InsertionPoint(self.launch_op):
275
+ for t in gmem_transform:
276
+ ref = t.apply(ref)
277
+ ref_ty = ir.MemRefType(ref.type)
278
+
279
+ i64 = ir.IntegerType.get_signless(64)
280
+ ptr_ty = ir.Type.parse("!llvm.ptr")
281
+ def init_tma_desc(host_ptr):
282
+ _, offset, *sizes_and_strides = memref.extract_strided_metadata(ref)
283
+ aligned_ptr_idx = memref.extract_aligned_pointer_as_index(ref)
284
+ as_i64 = lambda i: arith.index_cast(i64, i)
285
+ alloc_ptr = llvm.inttoptr(ptr_ty, as_i64(aligned_ptr_idx))
286
+ llvm_dyn = -2147483648 # TODO(apaszke): Improve the MLIR bindings...
287
+ base_ptr = llvm.getelementptr(
288
+ ptr_ty, alloc_ptr, [as_i64(offset)], [llvm_dyn], ref_ty.element_type,
289
+ )
290
+ rank = ref_ty.rank
291
+ assert rank * 2 == len(sizes_and_strides)
292
+ args = [
293
+ host_ptr,
294
+ base_ptr,
295
+ c(utils.bytewidth(ref_ty.element_type), i64),
296
+ c(rank, i64),
297
+ utils.pack_array([as_i64(i) for i in sizes_and_strides[:rank]]),
298
+ utils.pack_array([as_i64(i) for i in sizes_and_strides[rank:]]),
299
+ c(0 if swizzle is None else swizzle, i64),
300
+ utils.pack_array([c(v, i64) for v in transformed_slice_shape]),
301
+ ]
302
+ func.call([], "mosaic_gpu_init_tma_desc", args)
303
+ def cast_tma_desc(device_ptr):
304
+ # TODO(apaszke): Investigate why prefetching can cause launch failures
305
+ # nvvm.prefetch_tensormap(device_ptr)
306
+ return builtin.unrealized_conversion_cast(
307
+ [tensor_map_ty], [device_ptr]
308
+ )
309
+ tma_desc = self._alloc_scratch(
310
+ TMA_DESCRIPTOR_BYTES,
311
+ alignment=TMA_DESCRIPTOR_ALIGNMENT,
312
+ host_init=init_tma_desc,
313
+ device_init=cast_tma_desc,
314
+ )
315
+ self.tma_descriptors[tma_desc_key] = tma_desc
316
+ return tma_desc
317
+
318
+ def async_copy(
319
+ self,
320
+ *,
321
+ src_ref,
322
+ dst_ref,
323
+ gmem_slice: Any = (),
324
+ gmem_transform: MemRefTransform | tuple[MemRefTransform, ...] = (),
325
+ barrier: mgpu.Barrier | None = None,
326
+ swizzle: int | None = None,
327
+ arrive: bool | None = None,
328
+ uniform: bool = True,
329
+ ):
330
+ index = ir.IndexType.get()
331
+ smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
332
+ src_ref_ty = ir.MemRefType(src_ref.type)
333
+ dst_ref_ty = ir.MemRefType(dst_ref.type)
334
+ element_type = src_ref_ty.element_type
335
+ if element_type != dst_ref_ty.element_type:
336
+ raise ValueError(
337
+ f"Expected same element type, got {element_type} and"
338
+ f" {dst_ref_ty.element_type}"
339
+ )
340
+ if not isinstance(gmem_transform, tuple):
341
+ gmem_transform = (gmem_transform,)
342
+
343
+ if src_ref_ty.memory_space is None and dst_ref_ty.memory_space == smem:
344
+ gmem_ref, smem_ref = src_ref, dst_ref
345
+ if barrier is None:
346
+ raise ValueError("Barriers are required for GMEM -> SMEM copies")
347
+ if arrive is None:
348
+ arrive = True # Arrive by default
349
+ elif src_ref_ty.memory_space == smem and dst_ref_ty.memory_space is None:
350
+ gmem_ref, smem_ref = dst_ref, src_ref
351
+ if barrier is not None:
352
+ raise ValueError("Barriers are unsupported for SMEM -> GMEM copies")
353
+ if arrive is not None:
354
+ raise ValueError("arrive is unsupported for SMEM -> GMEM copies")
355
+ else:
356
+ raise ValueError("Only SMEM <-> GMEM copies supported")
357
+ # TODO(apaszke): This is a very approximate check. Improve it!
358
+ expected_name = "builtin.unrealized_conversion_cast"
359
+ if (
360
+ gmem_ref.owner is None
361
+ or gmem_ref.owner.opview.OPERATION_NAME != expected_name
362
+ ):
363
+ raise ValueError("GMEM reference in async_copy must be a kernel argument")
364
+
365
+ base_indices, slice_shape, is_squeezed = utils.parse_indices(
366
+ gmem_slice, ir.MemRefType(gmem_ref.type).shape
367
+ )
368
+ dyn_base_indices = tuple(
369
+ c(i, index) if not isinstance(i, ir.Value) else i for i in base_indices
370
+ )
371
+ slice_shape = tuple(slice_shape)
372
+ for t in gmem_transform:
373
+ dyn_base_indices = t.transform_index(dyn_base_indices)
374
+ slice_shape = t.transform_shape(slice_shape)
375
+ for dim, squeezed in enumerate(is_squeezed):
376
+ if squeezed:
377
+ smem_ref = mgpu.memref_unsqueeze(smem_ref, dim)
378
+ smem_ref_ty = ir.MemRefType(smem_ref.type)
379
+
380
+ if slice_shape != tuple(smem_ref_ty.shape):
381
+ raise ValueError(
382
+ "Expected the SMEM reference to have the same shape as the tiled"
383
+ f" slice: {tuple(smem_ref_ty.shape)} != {slice_shape}"
384
+ )
385
+ tma_desc = self._get_tma_desc(
386
+ gmem_ref, gmem_transform, slice_shape, swizzle,
387
+ )
388
+
389
+ # nvgpu TMA instructions expect reversed indices...
390
+ rev_dyn_based_indices = reversed(dyn_base_indices)
391
+
392
+ uniform_ctx = mgpu.single_thread if uniform else contextlib.nullcontext
393
+
394
+ if gmem_ref is src_ref:
395
+ with uniform_ctx():
396
+ assert barrier is not None # for pytype
397
+ barrier_group = barrier.barrier_array.value
398
+ barrier_idx = barrier.offset
399
+ if arrive:
400
+ slice_bytes = c(
401
+ np.prod(slice_shape) * mgpu.bytewidth(element_type), index
402
+ )
403
+ nvgpu.mbarrier_arrive_expect_tx(
404
+ barrier_group, slice_bytes, barrier_idx
405
+ )
406
+ nvgpu.tma_async_load(
407
+ smem_ref, barrier_group, tma_desc, rev_dyn_based_indices, barrier_idx
408
+ )
409
+ else:
410
+ with uniform_ctx():
411
+ nvgpu.tma_async_store(smem_ref, tma_desc, rev_dyn_based_indices)
412
+ nvvm.cp_async_bulk_commit_group()
413
+
414
+ def await_async_copy(
415
+ self, allow_groups: int, await_read_only: bool = False
416
+ ):
417
+ nvvm.cp_async_bulk_wait_group(allow_groups, read=await_read_only)
418
+ gpu.barrier() # Groups are supposedly tracked per-thread
419
+
420
+
421
+ # ShapeTrees currently can not contain unions.
422
+ ShapeTree = Any
423
+ RefTree = Any
424
+ T = TypeVar('T')
425
+
426
+
427
+ @dataclasses.dataclass(frozen=True)
428
+ class Union(Generic[T]):
429
+ members: Sequence[T]
430
+
431
+
432
+ def _count_buffer_bytes(shape_dtype: jax.ShapeDtypeStruct) -> int:
433
+ return np.prod(shape_dtype.shape) * np.dtype(shape_dtype.dtype).itemsize
434
+
435
+
436
+ def _construct_smem_reftree(
437
+ dynamic_smem: ir.Value, smem_buffers: ShapeTree) -> RefTree:
438
+ index = ir.IndexType.get()
439
+ smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
440
+ flat_ref_tys, smem_buffer_tree = jax.tree.flatten(smem_buffers)
441
+ smem_refs = []
442
+ dynamic_smem_offset = 0
443
+ for ref_ty in flat_ref_tys:
444
+ mlir_dtype = mlir.dtype_to_ir_type(ref_ty.dtype)
445
+ tile_smem = memref.view(
446
+ ir.MemRefType.get(ref_ty.shape, mlir_dtype, memory_space=smem),
447
+ dynamic_smem, c(dynamic_smem_offset, index), [],
448
+ )
449
+ dynamic_smem_offset += _count_buffer_bytes(ref_ty)
450
+ smem_refs.append(tile_smem)
451
+ return jax.tree.unflatten(smem_buffer_tree, smem_refs)
452
+
453
+
454
+ # TODO(apaszke): Inline this
455
+ @contextlib.contextmanager
456
+ def _launch(
457
+ token,
458
+ grid,
459
+ block,
460
+ gmem_scratch_ptr,
461
+ smem_buffers: ShapeTree | Union[ShapeTree],
462
+ profiler_spec: profiler.ProfilerSpec | None = None,
463
+ maybe_prof_buffer: ir.Value | None = None,
464
+ ):
465
+ if (profiler_spec is None) != (maybe_prof_buffer is None):
466
+ raise ValueError
467
+ index = ir.IndexType.get()
468
+ i32 = ir.IntegerType.get_signless(32)
469
+ i8 = ir.IntegerType.get_signless(8)
470
+ grid_vals = [c(i, index) for i in grid]
471
+ block_vals = [c(i, index) for i in block]
472
+
473
+ if isinstance(smem_buffers, Union):
474
+ smem_disjoint_live_buffers_collections = smem_buffers.members
475
+ compute_smem_bytes = max(
476
+ sum(_count_buffer_bytes(l) for l in jax.tree.leaves(s))
477
+ for s in smem_buffers.members)
478
+ else:
479
+ smem_disjoint_live_buffers_collections = [smem_buffers]
480
+ compute_smem_bytes = sum(
481
+ _count_buffer_bytes(l) for l in jax.tree.leaves(smem_buffers))
482
+
483
+ smem_bytes = compute_smem_bytes
484
+ if profiler_spec is not None:
485
+ smem_bytes += profiler_spec.smem_bytes(block=block)
486
+
487
+ # TODO(cperivol): Query the shared memory size programmatically.
488
+ if smem_bytes > 228 * 1024:
489
+ raise ValueError(f"Mosaic GPU kernel exceeds available shared memory {smem_bytes=} > 228000")
490
+ launch_op = gpu.LaunchOp(
491
+ token.type, [token], *grid_vals, *block_vals,
492
+ dynamicSharedMemorySize=c(smem_bytes, i32))
493
+ launch_op.body.blocks.append(*([index] * 12)) # Append an empty block
494
+ smem = ir.Attribute.parse("#gpu.address_space<workgroup>")
495
+ with ir.InsertionPoint(launch_op.body.blocks[0]):
496
+ dynamic_smem = gpu.dynamic_shared_memory(
497
+ ir.MemRefType.get(
498
+ (ir.ShapedType.get_dynamic_size(),), i8, memory_space=smem
499
+ )
500
+ )
501
+ smem_ref_trees = []
502
+
503
+ for smem_live_buffers_collection in smem_disjoint_live_buffers_collections:
504
+ smem_ref_tree = _construct_smem_reftree(
505
+ dynamic_smem, smem_live_buffers_collection)
506
+ smem_ref_trees.append(smem_ref_tree)
507
+
508
+ if profiler_spec:
509
+ prof_smem = memref.view(
510
+ ir.MemRefType.get(
511
+ (profiler_spec.smem_i32_elements(block=block),),
512
+ i32, memory_space=smem,
513
+ ),
514
+ dynamic_smem, c(compute_smem_bytes, index), [],
515
+ )
516
+ prof = profiler.OnDeviceProfiler(
517
+ profiler_spec, prof_smem, maybe_prof_buffer
518
+ )
519
+ else:
520
+ prof = None
521
+
522
+ if isinstance(smem_buffers, Union):
523
+ smem_ref_tree: Union[RefTree] = Union(smem_ref_trees)
524
+ else:
525
+ smem_ref_tree: RefTree = smem_ref_trees[0] if smem_ref_trees else []
526
+
527
+ yield LaunchContext(launch_op, gmem_scratch_ptr, prof), smem_ref_tree
528
+ if prof is not None:
529
+ prof.finalize(grid=grid, block=block)
530
+ gpu.terminator()
531
+
532
+
533
+ def _lower_as_gpu_kernel(
534
+ body,
535
+ grid: tuple[int, ...],
536
+ block: tuple[int, ...],
537
+ in_shapes: tuple[Any, ...],
538
+ out_shape,
539
+ smem_scratch_shape: ShapeTree | Union[ShapeTree],
540
+ prof_spec: profiler.ProfilerSpec | None = None,
541
+ ):
542
+ ptr_ty = ir.Type.parse("!llvm.ptr")
543
+ token_ty = ir.Type.parse("!gpu.async.token")
544
+ i8 = ir.IntegerType.get_signless(8)
545
+ i64 = ir.IntegerType.get_signless(64)
546
+
547
+ def _shape_to_ref_ty(shape: jax.ShapeDtypeStruct) -> ir.MemRefType:
548
+ return ir.MemRefType.get(shape.shape, mlir.dtype_to_ir_type(shape.dtype))
549
+
550
+ in_ref_tys = [_shape_to_ref_ty(t) for t in in_shapes]
551
+
552
+ unwrap_output_tuple = False
553
+ if isinstance(out_shape, list):
554
+ out_shape = tuple(out_shape)
555
+ elif not isinstance(out_shape, tuple):
556
+ out_shape = (out_shape,)
557
+ unwrap_output_tuple = True
558
+ out_ref_tys = [_shape_to_ref_ty(t) for t in out_shape]
559
+ if prof_spec is not None:
560
+ out_shape = (*out_shape, prof_spec.jax_buffer_type(grid, block))
561
+ out_ref_tys.append(prof_spec.mlir_buffer_type(grid, block))
562
+
563
+ module = ir.Module.create()
564
+ with ir.InsertionPoint(module.body):
565
+ _declare_runtime_functions()
566
+ gmem_scratch_bytes = 0
567
+ @func.FuncOp.from_py_func(ptr_ty, ptr_ty)
568
+ def main(token_ptr, buffers):
569
+ nonlocal gmem_scratch_bytes
570
+ token = builtin.unrealized_conversion_cast([token_ty], [token_ptr])
571
+ arg_refs = []
572
+ i = -1
573
+ for i, ref_ty in enumerate([*in_ref_tys, *out_ref_tys]):
574
+ ptr = llvm.LoadOp(ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i], ptr_ty))
575
+ arg_refs.append(utils.ptr_as_memref(ptr, ir.MemRefType(ref_ty)))
576
+ gmem_scratch_ptr = llvm.LoadOp(
577
+ ptr_ty, llvm.GEPOp(ptr_ty, buffers, [], [i + 1], ptr_ty)
578
+ )
579
+ in_refs = arg_refs[:len(in_ref_tys)]
580
+ out_refs = arg_refs[len(in_ref_tys):]
581
+ prof_buffer = out_refs.pop() if prof_spec is not None else None
582
+ with _launch(
583
+ token, grid, block, gmem_scratch_ptr, smem_scratch_shape,
584
+ prof_spec, prof_buffer
585
+ ) as (launch_ctx, smem_refs):
586
+ body(launch_ctx, *in_refs, *out_refs, smem_refs)
587
+ gmem_scratch_bytes = launch_ctx.next_scratch_offset
588
+ # Allocate and initialize the host buffer right before the launch.
589
+ # Note that we couldn't do that before, because we had to run the body
590
+ # to learn what the scratch contains.
591
+ with ir.InsertionPoint(launch_ctx.launch_op):
592
+ host_scratch_ptr = llvm.alloca(ptr_ty, c(gmem_scratch_bytes, i64), i8)
593
+ for init_callback in launch_ctx.host_scratch_init:
594
+ init_callback(host_scratch_ptr)
595
+ func.call(
596
+ [],
597
+ "mosaic_gpu_memcpy_async_h2d",
598
+ [
599
+ gmem_scratch_ptr,
600
+ host_scratch_ptr,
601
+ c(gmem_scratch_bytes, i64),
602
+ token_ptr,
603
+ ],
604
+ )
605
+ main.func_op.attributes["llvm.emit_c_interface"] = ir.UnitAttr.get()
606
+ module.operation.verify()
607
+
608
+ return module, out_shape, gmem_scratch_bytes, unwrap_output_tuple
609
+
610
+
611
+ def as_gpu_kernel(
612
+ body,
613
+ grid: tuple[int, ...],
614
+ block: tuple[int, ...],
615
+ in_shape,
616
+ out_shape,
617
+ smem_scratch_shape: ShapeTree | Union[ShapeTree],
618
+ prof_spec: profiler.ProfilerSpec | None = None,
619
+ ):
620
+ if isinstance(in_shape, list):
621
+ in_shape = tuple(in_shape)
622
+ elif not isinstance(in_shape, tuple):
623
+ in_shape = (in_shape,)
624
+
625
+ module, out_shape, gmem_scratch_bytes, unwrap_output_tuple = (
626
+ _lower_as_gpu_kernel(
627
+ body, grid, block, in_shape, out_shape, smem_scratch_shape, prof_spec
628
+ )
629
+ )
630
+
631
+ expected_arg_treedef = jax.tree.structure(in_shape)
632
+ def _check_args(*args):
633
+ arg_treedef = jax.tree.structure(args)
634
+ if arg_treedef != expected_arg_treedef:
635
+ raise ValueError(
636
+ f"Invalid argument structure: expected {expected_arg_treedef}, got"
637
+ f" {arg_treedef}, ({args=})"
638
+ )
639
+
640
+ def bind(*args):
641
+ return mosaic_gpu_p.bind(
642
+ *args,
643
+ out_types=out_shape,
644
+ module=module,
645
+ gmem_scratch_bytes=gmem_scratch_bytes,
646
+ )
647
+
648
+ if prof_spec is not None:
649
+ @jax.jit
650
+ def prof_kernel(*args):
651
+ _check_args(*args)
652
+ *results, prof_buffer = bind(*args)
653
+ def dump_profile(prof_buffer):
654
+ out_file = os.path.join(
655
+ os.getenv("TEST_UNDECLARED_OUTPUTS_DIR"),
656
+ f"{time.time_ns()}-trace.json",
657
+ )
658
+ try:
659
+ with open(out_file, "x") as f:
660
+ prof_spec.dump(prof_buffer, f, grid=grid, block=block)
661
+ except FileExistsError:
662
+ pass # TODO: Retry
663
+ jax.debug.callback(dump_profile, prof_buffer)
664
+ return results[0] if unwrap_output_tuple else results
665
+ return prof_kernel
666
+ else:
667
+ @jax.jit
668
+ def kernel(*args):
669
+ _check_args(*args)
670
+ results = bind(*args)
671
+ return results[0] if unwrap_output_tuple else results
672
+ return kernel
673
+
674
+
675
+ def _declare_runtime_functions():
676
+ """Declares the runtime functions that can be used by the generated code."""
677
+ ptr_ty = ir.Type.parse("!llvm.ptr")
678
+ i64 = ir.IntegerType.get_signless(64)
679
+ arg_tys = [ptr_ty, ptr_ty, i64, i64, ptr_ty, ptr_ty, i64, ptr_ty]
680
+ init_tma_desc_type = ir.FunctionType.get(arg_tys, [])
681
+ func.FuncOp(
682
+ "mosaic_gpu_init_tma_desc", init_tma_desc_type, visibility="private"
683
+ )
684
+ memcpy_async_type = ir.FunctionType.get([ptr_ty, ptr_ty, i64, ptr_ty], [])
685
+ func.FuncOp(
686
+ "mosaic_gpu_memcpy_async_h2d", memcpy_async_type, visibility="private"
687
+ )
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/fragmented_array.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The JAX Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for code generator."""
16
+
17
+ import dataclasses
18
+
19
+ import jax
20
+ from jaxlib.mlir import ir
21
+ from jaxlib.mlir.dialects import arith
22
+ from jaxlib.mlir.dialects import gpu
23
+ from jaxlib.mlir.dialects import llvm
24
+ from jaxlib.mlir.dialects import math as mlir_math
25
+ from jaxlib.mlir.dialects import memref
26
+ from jaxlib.mlir.dialects import nvvm
27
+ from jaxlib.mlir.dialects import vector
28
+ import numpy as np
29
+
30
+ from . import dsl as mgpu
31
+ from . import utils
32
+
33
+ # mypy: ignore-errors
34
+
35
+ WARPGROUP_SIZE = utils.WARPGROUP_SIZE
36
+ c = utils.c
37
+
38
+
39
+ @dataclasses.dataclass(frozen=True)
40
+ class WGSplatFragLayout:
41
+ """A fragmented array where all the values are equal represented as a register per thread.
42
+
43
+ FragmentedArrays in this layout can be are always the result of a
44
+ splat, each thread in the warpgroup has a single copy of the value,
45
+ while the FragmentedArray pretends it has whatever shape the user
46
+ wants. This means we can trivially broadcast, reshape and do
47
+ elementwise operations with all other layouts.
48
+
49
+ Example:
50
+
51
+ To load a value in
52
+ ```
53
+ FragmentedArray.splat(memref.load(ref_1d, [1]), (10,20,2))
54
+ ```
55
+
56
+ A shape is always provided for sanity check reasons.
57
+
58
+ """
59
+
60
+ shape: tuple[int, ...] = ()
61
+
62
+ def can_broadcast_to(self, shape) -> bool:
63
+ """Check that the shape can be broadcast.
64
+
65
+ Only dimensions of size 1 can be broadcast. All other dimensions
66
+ must be the same as the argument shape.
67
+ """
68
+ return all(dim1 == dim2 or dim1 == 1 for dim1, dim2 in zip(self.shape[::-1], shape[::-1]))
69
+
70
+
71
+ @dataclasses.dataclass(frozen=True)
72
+ class WGMMAFragLayout:
73
+ """[m, n] matrix, where m % 64 == 0 == n % 8."""
74
+
75
+
76
+ @dataclasses.dataclass(frozen=True)
77
+ class WGMMARowFragLayout:
78
+ """[m] matrix, where m % 64 == 0."""
79
+
80
+
81
+ @dataclasses.dataclass(frozen=True)
82
+ class WGStridedFragLayout:
83
+ """Convert the array to 1D and then shard across threads."""
84
+
85
+ shape: tuple[int, ...]
86
+ vec_size: int
87
+
88
+ def __post_init__(self):
89
+ if np.prod(self.shape) % (self.vec_size * WARPGROUP_SIZE) != 0:
90
+ raise ValueError((self, WARPGROUP_SIZE))
91
+
92
+ @classmethod
93
+ def from_memref_type(cls, memref_ty: ir.Type):
94
+ if not ir.MemRefType.isinstance(memref_ty):
95
+ raise TypeError(memref_ty)
96
+
97
+ memref_type = ir.MemRefType(memref_ty)
98
+ bw = mgpu.bytewidth(memref_type.element_type)
99
+ assert 8 % bw == 0 and 8 // bw != 0, bw
100
+ if np.prod(memref_type.shape) % WARPGROUP_SIZE != 0:
101
+ raise ValueError(
102
+ "Ref must have a number of elements that is a multiple of"
103
+ f" {WARPGROUP_SIZE}"
104
+ )
105
+ max_vec_size = np.prod(memref_type.shape) // WARPGROUP_SIZE
106
+ return cls(
107
+ shape=tuple(memref_type.shape), vec_size=min(8 // bw, max_vec_size)
108
+ )
109
+
110
+ def thread_vec_idxs(self):
111
+ """The indexes to be used for vector load/store WGStridedFragLayout.
112
+
113
+ Yields:
114
+ The indices of the vector that correspond to the current thread.
115
+ """
116
+ index = ir.IndexType.get()
117
+ cardinality = np.prod(self.shape)
118
+ assert cardinality % (WARPGROUP_SIZE * self.vec_size) == 0
119
+ reg_num = cardinality // (WARPGROUP_SIZE * self.vec_size)
120
+ tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE, index))
121
+ off = arith.muli(tidx, c(self.vec_size, tidx.type))
122
+ for i in range(reg_num):
123
+ yield [arith.addi(off, c(i * WARPGROUP_SIZE * self.vec_size, tidx.type))]
124
+
125
+
126
+ FragmentedLayout = WGSplatFragLayout | WGStridedFragLayout | WGMMAFragLayout | WGMMARowFragLayout
127
+
128
+
129
+ WGMMA_LAYOUT = WGMMAFragLayout()
130
+ WGMMA_ROW_LAYOUT = WGMMARowFragLayout()
131
+
132
+
133
+ @jax.tree_util.register_pytree_node_class
134
+ class FragmentedArray:
135
+ registers: np.ndarray # of ir.Value, see checks in init for shapes.
136
+ layout: FragmentedLayout
137
+
138
+ def __init__(self, *, _registers: np.ndarray, _layout: FragmentedLayout):
139
+ self.registers = _registers
140
+ self.layout = _layout
141
+
142
+ match self.layout:
143
+ # Registers are [m_tiles, n_tiles, 2 rows, 1 cols] in WGMMA layout
144
+ # Each element is a vector<2xdtype>
145
+ case WGMMAFragLayout():
146
+ if self.registers.ndim != 4 or self.registers.shape[2:] != (2, 1):
147
+ raise ValueError("Invalid register array shape")
148
+
149
+ # Registers are [m_tiles, 2 rows] in WGMMA_ROW layout
150
+ # Each element is a dtype scalar
151
+ case WGMMARowFragLayout():
152
+ if self.registers.ndim != 2 or self.registers.shape[-1] != 2:
153
+ raise ValueError("Invalid register array shape")
154
+
155
+ # Registers are flat
156
+ case WGStridedFragLayout(shape):
157
+ (reg_size,) = ir.VectorType(_registers.flat[0].type).shape
158
+ if np.prod(shape) != np.prod(_registers.shape) * WARPGROUP_SIZE * reg_size:
159
+ raise ValueError((reg_size, shape, _registers.shape, WARPGROUP_SIZE), _registers.flat[0].type)
160
+
161
+ # Just a single register
162
+ case WGSplatFragLayout():
163
+ if _registers.size != 1:
164
+ raise ValueError(f"WGStridedFragLayout requires a single value {_registers.shape} ({_registers.size})")
165
+
166
+ case _:
167
+ raise NotImplementedError
168
+
169
+ @classmethod
170
+ def load_strided(cls, ref: ir.Value):
171
+ if not ir.MemRefType.isinstance(ref.type):
172
+ raise TypeError(ref.type)
173
+
174
+ ref_ty = ir.MemRefType(ref.type)
175
+ ref_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
176
+ layout = WGStridedFragLayout.from_memref_type(ref_ty)
177
+ vec_ty = ir.VectorType.get((layout.vec_size,), ref_ty.element_type)
178
+ vecs = [vector.load(vec_ty, ref_1d, vec_idx) for vec_idx in layout.thread_vec_idxs()]
179
+ return cls(_registers=np.array(vecs), _layout=layout)
180
+
181
+ @classmethod
182
+ def splat(cls, value, shape, layout=None):
183
+ layout = layout or WGSplatFragLayout(shape)
184
+ match layout:
185
+ case WGMMARowFragLayout():
186
+ if len(shape) != 1:
187
+ raise ValueError
188
+ if shape[0] % 64:
189
+ raise ValueError
190
+ reg_shape = (shape[0] // 64, 2)
191
+ case WGMMAFragLayout():
192
+ if len(shape) != 2:
193
+ raise ValueError
194
+ if shape[0] % 64 or shape[1] % 8:
195
+ raise ValueError
196
+ reg_shape = (shape[0] // 64, shape[1] // 8, 2, 1)
197
+ value = vector.splat(ir.VectorType.get((2,), value.type), value)
198
+ case WGStridedFragLayout(vec_size=vec_size):
199
+ assert shape == layout.shape
200
+ elems = np.prod(shape)
201
+ reg_shape = (elems // (WARPGROUP_SIZE * vec_size),)
202
+ value = vector.splat(ir.VectorType.get((vec_size,), value.type), value)
203
+ case WGSplatFragLayout():
204
+ assert shape == layout.shape
205
+ reg_shape = ()
206
+ case _:
207
+ raise NotImplementedError(layout)
208
+
209
+ return cls(
210
+ _registers=np.full(reg_shape, value, dtype=object),
211
+ _layout=layout,
212
+ )
213
+
214
+ @property
215
+ def shape(self):
216
+ match self.layout:
217
+ case WGMMAFragLayout():
218
+ row_tiles, col_tiles = self.registers.shape[:2]
219
+ return (row_tiles * 64, col_tiles * 8)
220
+ case WGMMARowFragLayout():
221
+ row_tiles = self.registers.shape[0]
222
+ return (row_tiles * 64,)
223
+ case WGStridedFragLayout(shape):
224
+ return shape
225
+ case WGSplatFragLayout(shape=shape):
226
+ return shape
227
+
228
+ @property
229
+ def mlir_dtype(self):
230
+ reg_ty = self.registers.flat[0].type
231
+ match self.layout:
232
+ case WGMMAFragLayout() | WGStridedFragLayout():
233
+ return ir.VectorType(reg_ty).element_type
234
+ case WGMMARowFragLayout() | WGSplatFragLayout():
235
+ return reg_ty
236
+
237
+ def _pointwise(self, op, *other):
238
+ other_arrs = []
239
+ for o in other:
240
+ if not isinstance(o, FragmentedArray):
241
+ if not isinstance(o, ir.Value):
242
+ raise NotImplementedError(o)
243
+
244
+ o = FragmentedArray.splat(o, shape=self.shape, layout=self.layout)
245
+
246
+ if isinstance(o.layout, WGSplatFragLayout):
247
+ if not o.layout.can_broadcast_to(self.shape):
248
+ raise ValueError("Can't broadcast shape.")
249
+ o = FragmentedArray.splat(o.registers.flat[0], shape=self.shape, layout=self.layout)
250
+ else:
251
+ if self.layout != o.layout:
252
+ raise ValueError("Incompatible FragmentedArray layouts")
253
+ if self.registers.shape != o.registers.shape:
254
+ raise ValueError("Incompatible FragmentedArray shapes")
255
+
256
+ other_arrs.append(o)
257
+ new_regs = np.empty_like(self.registers)
258
+
259
+ for idx, reg in np.ndenumerate(self.registers):
260
+ new_regs[idx] = op(reg, *(o.registers[idx] for o in other_arrs))
261
+ return FragmentedArray(_registers=new_regs, _layout=self.layout)
262
+
263
+ def __add__(self, other):
264
+ if ir.FloatType.isinstance(self.mlir_dtype):
265
+ return self._pointwise(arith.addf, other)
266
+ elif ir.IntegerType.isinstance(self.mlir_dtype):
267
+ return self._pointwise(arith.addi, other)
268
+ else:
269
+ raise NotImplementedError(self.mlir_dtype)
270
+
271
+ def __mul__(self, other):
272
+ if ir.FloatType.isinstance(self.mlir_dtype):
273
+ return self._pointwise(arith.mulf, other)
274
+ elif ir.IntegerType.isinstance(self.mlir_dtype):
275
+ return self._pointwise(arith.muli, other)
276
+ else:
277
+ raise NotImplementedError(self.mlir_dtype)
278
+
279
+ def __sub__(self, other):
280
+ if not ir.FloatType.isinstance(self.mlir_dtype):
281
+ raise NotImplementedError
282
+ return self._pointwise(arith.subf, other)
283
+
284
+ def __truediv__(self, other):
285
+ if not ir.FloatType.isinstance(self.mlir_dtype):
286
+ raise NotImplementedError
287
+ return self._pointwise(arith.divf, other)
288
+
289
+ def max(self, other):
290
+ if not ir.FloatType.isinstance(self.mlir_dtype):
291
+ raise NotImplementedError
292
+ return self._pointwise(arith.maximumf, other)
293
+
294
+ def exp(self, approx: bool = False):
295
+ if not ir.FloatType.isinstance(self.mlir_dtype):
296
+ raise NotImplementedError
297
+ def fast_exp(x):
298
+ f32 = ir.F32Type.get()
299
+ if self.mlir_dtype != f32:
300
+ raise NotImplementedError
301
+ log2e = arith.constant(f32, ir.FloatAttr.get(f32, 1.4426950408889634))
302
+ if x.type == f32:
303
+ scaled = arith.mulf(x, log2e)
304
+ return llvm.inline_asm(
305
+ f32, [scaled], "ex2.approx.f32 $0,$1;", "=f,f", asm_dialect=0
306
+ )
307
+ elif ir.VectorType.isinstance(x.type):
308
+ index = ir.IndexType.get()
309
+ result = llvm.mlir_undef(x.type)
310
+ for i in range(2):
311
+ v = vector.extractelement(x, position=c(i, index))
312
+ vr = fast_exp(v)
313
+ result = vector.insertelement(vr, result, position=c(i, index))
314
+ return result
315
+ else:
316
+ raise NotImplementedError(x.type)
317
+ return self._pointwise(fast_exp if approx else mlir_math.exp)
318
+
319
+ def rsqrt(self):
320
+ return self._pointwise(mlir_math.rsqrt)
321
+
322
+ def __and__(self, other):
323
+ if not ir.IntegerType.isinstance(self.mlir_dtype):
324
+ raise ValueError(
325
+ "Bitwise operations only defined for integer types, not"
326
+ f" {self.mlir_dtype}"
327
+ )
328
+
329
+ return self._pointwise(arith.andi, other)
330
+
331
+ def bitcast(self, elt: ir.Type):
332
+ reg_type = self.registers.flat[0].type
333
+ if ir.VectorType.isinstance(reg_type):
334
+ reg_shape = ir.VectorType(reg_type).shape
335
+ ty = ir.VectorType.get(reg_shape, elt)
336
+ else:
337
+ ty = elt
338
+
339
+ return self._pointwise(lambda x: arith.bitcast(ty, x))
340
+
341
+ def __getitem__(self, idx):
342
+ if self.layout != WGMMA_LAYOUT:
343
+ raise NotImplementedError("Only WGMMA layouts support slicing")
344
+ base_idx, slice_shape, is_squeezed = utils.parse_indices(idx, self.shape)
345
+ if any(is_squeezed):
346
+ raise NotImplementedError("Only slicing implemented")
347
+ if (
348
+ base_idx[0] % 64
349
+ or slice_shape[0] % 64
350
+ or base_idx[1] % 8
351
+ or slice_shape[1] % 8
352
+ ):
353
+ raise NotImplementedError("Only tile aligned slicing supported")
354
+ base_idx[0] //= 64
355
+ slice_shape[0] //= 64
356
+ base_idx[1] //= 8
357
+ slice_shape[1] //= 8
358
+ new_regs = self.registers[
359
+ base_idx[0] : base_idx[0] + slice_shape[0],
360
+ base_idx[1] : base_idx[1] + slice_shape[1],
361
+ ]
362
+ return FragmentedArray(_registers=new_regs, _layout=self.layout)
363
+
364
+ # TODO(apaszke): Support JAX dtypes here as well?
365
+ def astype(self, new_dtype: ir.Type):
366
+ cur_dtype = self.mlir_dtype
367
+ if cur_dtype == new_dtype:
368
+ return self
369
+ from_float = ir.FloatType.isinstance(cur_dtype)
370
+ to_float = ir.FloatType.isinstance(new_dtype)
371
+ from_integer = ir.IntegerType.isinstance(cur_dtype)
372
+ to_integer = ir.IntegerType.isinstance(new_dtype)
373
+ if from_float and to_float:
374
+ if ir.FloatType(cur_dtype).width > ir.FloatType(new_dtype).width:
375
+ convert = arith.truncf
376
+ else:
377
+ convert = arith.extf
378
+ elif from_integer and to_integer:
379
+ if ir.IntegerType(cur_dtype).width > ir.IntegerType(new_dtype).width:
380
+ convert = arith.trunci
381
+ else:
382
+ convert = arith.extsi
383
+ elif from_integer and to_float:
384
+ convert = arith.sitofp
385
+ elif from_float and to_integer:
386
+ convert = arith.fptosi
387
+ new_registers = np.empty_like(self.registers)
388
+ match self.layout:
389
+ case WGMMAFragLayout():
390
+ new_reg_ty = ir.VectorType.get((2,), new_dtype)
391
+ case WGStridedFragLayout(vec_size=vec_size):
392
+ new_reg_ty = ir.VectorType.get((vec_size,), new_dtype)
393
+ case WGMMARowFragLayout() | WGSplatFragLayout():
394
+ new_reg_ty = new_dtype
395
+ case _:
396
+ raise NotImplementedError(f"Unsupported layout {self.layout}")
397
+ for idx, reg in np.ndenumerate(self.registers):
398
+ new_registers[idx] = convert(new_reg_ty, reg)
399
+ return FragmentedArray(_registers=new_registers, _layout=self.layout)
400
+
401
+ def reduce_sum(self, scratch) -> ir.Value:
402
+ index = ir.IndexType.get()
403
+ if not isinstance(self.layout, WGStridedFragLayout):
404
+ raise NotImplementedError(f"Unsupported layout {self.layout}")
405
+ result = c(0, self.mlir_dtype)
406
+ for reg in self.registers:
407
+ result = arith.addf(
408
+ result,
409
+ vector.reduction(self.mlir_dtype, vector.CombiningKind.ADD, reg),
410
+ )
411
+ scratch_ty = ir.MemRefType(scratch.type)
412
+ if scratch_ty.element_type != self.mlir_dtype or scratch_ty.shape != [4]:
413
+ raise ValueError(f"Expected shape={(4,)}, {self.mlir_dtype} (got {scratch_ty})")
414
+
415
+ if ir.FloatType.isinstance(self.mlir_dtype):
416
+ op = arith.addf
417
+ elif ir.IntegerType.isinstance(self.mlir_dtype):
418
+ op = arith.addi
419
+ else:
420
+ raise NotImplementedError(self.mlir_dtype)
421
+
422
+ warp_result = utils.warp_tree_reduce(result, op, 32)
423
+ warp_id = arith.divui(gpu.thread_id(gpu.Dimension.x), c(32, index))
424
+ memref.store(warp_result, scratch, [warp_id])
425
+ utils.commit_shared()
426
+ zero_index = c(0, index)
427
+ with mgpu.single_thread():
428
+ scratch_vec = vector.load(
429
+ ir.VectorType.get((4,), self.mlir_dtype),
430
+ scratch,
431
+ [zero_index],
432
+ )
433
+ scratch_sum = vector.reduction(
434
+ self.mlir_dtype, vector.CombiningKind.ADD, scratch_vec
435
+ )
436
+ memref.store(scratch_sum, scratch, [zero_index])
437
+ utils.commit_shared()
438
+ return memref.load(scratch, [zero_index])
439
+
440
+ def reduce(self, op, axis):
441
+ if self.layout != WGMMA_LAYOUT:
442
+ raise NotImplementedError(self.layout)
443
+ if axis != 1:
444
+ raise NotImplementedError
445
+ index = ir.IndexType.get()
446
+ i32 = ir.IntegerType.get_signless(32)
447
+ new_regs = np.empty(self.registers.shape[::2], dtype=object)
448
+ assert self.registers.shape[-1] == 1
449
+ for row_tile, row_subtile in np.ndindex(new_regs.shape):
450
+ # Reduce the registers owned by the current thread over n tiles
451
+ thread_result_vec = self.registers[row_tile, 0, row_subtile, 0]
452
+ for n_tile in range(1, self.registers.shape[1]):
453
+ thread_result_vec = op(
454
+ thread_result_vec, self.registers[row_tile, n_tile, row_subtile, 0]
455
+ )
456
+ thread_result = op(
457
+ vector.extractelement(thread_result_vec, position=c(0, index)),
458
+ vector.extractelement(thread_result_vec, position=c(1, index)),
459
+ )
460
+ # Do a shuffle to reduce in groups of 4 consecutive threads.
461
+ result = thread_result
462
+ for i in (1, 2):
463
+ other_result = nvvm.shfl_sync(
464
+ result.type,
465
+ c(0xFFFFFFFF, i32),
466
+ result,
467
+ c(i, i32),
468
+ c(0x1F, i32),
469
+ nvvm.ShflKind.bfly,
470
+ )
471
+ result = op(result, other_result)
472
+ new_regs[row_tile, row_subtile] = result
473
+ return FragmentedArray(_registers=new_regs, _layout=WGMMA_ROW_LAYOUT)
474
+
475
+ def broadcast(self, shape):
476
+ if not isinstance(self.layout, WGSplatFragLayout):
477
+ raise NotImplementedError(self.layout)
478
+
479
+ if self.shape == shape:
480
+ return self
481
+
482
+ if not self.layout.can_broadcast_to(shape):
483
+ raise ValueError(f"Can't broadcast {self.shape} to {shape}")
484
+
485
+ return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape))
486
+
487
+ def reshape(self, shape):
488
+ if self.shape == shape:
489
+ return self
490
+
491
+ if not isinstance(self.layout, WGSplatFragLayout):
492
+ raise NotImplementedError(self.layout)
493
+
494
+ if np.prod(shape) != np.prod(self.shape):
495
+ raise ValueError(f"Can't reshape {self.shape} to {shape}")
496
+
497
+ return FragmentedArray(_registers=self.registers, _layout=WGSplatFragLayout(shape))
498
+
499
+ def broadcast_minor(self, n):
500
+ if self.layout != WGMMA_ROW_LAYOUT:
501
+ raise NotImplementedError
502
+ num_row_tiles = self.registers.shape[0]
503
+ num_col_tiles, rem = divmod(n, 8)
504
+ if rem:
505
+ raise ValueError("Number of columns must be divisible by 8")
506
+ new_regs = np.empty((num_row_tiles, num_col_tiles, 2, 1), dtype=object)
507
+ dtype = self.mlir_dtype
508
+ for (row_tile, row_subtile), reg in np.ndenumerate(self.registers):
509
+ new_regs[row_tile, :, row_subtile, :] = vector.splat(
510
+ ir.VectorType.get((2,), dtype), reg
511
+ )
512
+ return FragmentedArray(_registers=new_regs, _layout=WGMMA_LAYOUT)
513
+
514
+ def store_untiled(self, ref: ir.Value):
515
+ if not ir.MemRefType.isinstance(ref.type):
516
+ raise ValueError(ref)
517
+
518
+ match self.layout:
519
+ case WGMMAFragLayout():
520
+ self._store_untiled_wgmma(ref)
521
+ case WGStridedFragLayout():
522
+ self._store_untiled_wg_strided(ref)
523
+ case _:
524
+ raise NotImplementedError(self.layout)
525
+
526
+ def _store_untiled_wg_strided(self, ref: ir.Value):
527
+ ref_ty = ir.MemRefType(ref.type)
528
+ ref_shape = tuple(ref_ty.shape)
529
+ if ref_shape != self.shape:
530
+ raise ValueError((ref_shape, self.shape))
531
+ smem_1d = mgpu.memref_fold(ref, 0, len(ref_ty.shape))
532
+ for idx, reg in zip(self.layout.thread_vec_idxs(), self.registers.flat):
533
+ vector.store(reg, smem_1d, idx)
534
+
535
+ def _store_untiled_wgmma(self, ref: ir.Value):
536
+ """Stores accumulator to a 2D memref. Not optimized at the moment."""
537
+ assert self.layout == WGMMA_LAYOUT
538
+ index = ir.IndexType.get()
539
+ m, n = self.shape
540
+ ref_ty = ir.MemRefType(ref.type)
541
+ if ref_ty.shape != [m, n]:
542
+ raise ValueError(ref.type, (m, n))
543
+
544
+ def c(x):
545
+ return arith.ConstantOp(index, ir.IntegerAttr.get(index, x))
546
+
547
+ tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
548
+ lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
549
+ warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
550
+ row_base = arith.addi(
551
+ arith.divui(lane_id, c(4)), arith.muli(warp_id, c(16))
552
+ )
553
+ col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
554
+ it = np.ndenumerate(self.registers)
555
+ for (row_tile, col_tile, row_idx, col_zero), elem in it:
556
+ del col_zero
557
+ row = arith.addi(row_base, c(row_tile * 64 + row_idx * 8))
558
+ for col_idx in range(2):
559
+ value = vector.extractelement(elem, position=c(col_idx))
560
+ col = arith.addi(col_base, c(col_tile * 8 + col_idx))
561
+ memref.store(value, ref, [row, col])
562
+
563
+ def store_tiled(self, ref, swizzle: int | None):
564
+ if self.layout != WGMMA_LAYOUT:
565
+ raise NotImplementedError
566
+ dtype = self.mlir_dtype
567
+ bw = mgpu.bytewidth(dtype)
568
+ m, n = self.shape
569
+ assert m % 64 == 0 # This is implied by the layout.
570
+ cols_per_tile = 128 // bw
571
+ expected_shape = [m // 64, n // cols_per_tile, 64, cols_per_tile]
572
+ if ir.MemRefType(ref.type).shape != expected_shape:
573
+ raise ValueError(ref.type, (m, n))
574
+ for get, _, idxs in self.transfer_tiled(self.shape, dtype, swizzle):
575
+ vector.store(get(self.registers), ref, idxs)
576
+
577
+ @classmethod
578
+ def load_tiled(cls, ref, swizzle: int | None):
579
+ ref_ty = ir.MemRefType(ref.type)
580
+ dtype = ref_ty.element_type
581
+ bw = mgpu.bytewidth(dtype)
582
+ m_tiles, n_tiles, m_tile_size, n_tile_size = ref_ty.shape
583
+ if m_tile_size != 64 or n_tile_size != (128 // bw):
584
+ raise ValueError
585
+ m, n = m_tiles * m_tile_size, n_tiles * n_tile_size
586
+ assert m % 64 == 0 # This is implied by the layout.
587
+ registers = np.full(
588
+ (m_tiles, n // 8, 2, 1),
589
+ vector.splat(ir.VectorType.get((2,), dtype), c(0, dtype)),
590
+ dtype=object,
591
+ )
592
+ for _, update, idxs in cls.transfer_tiled((m, n), dtype, swizzle):
593
+ update(registers, vector.load(ir.VectorType.get((2,), dtype), ref, idxs))
594
+ return cls(_registers=registers, _layout=WGMMA_LAYOUT)
595
+
596
+ @staticmethod
597
+ def transfer_tiled(shape, dtype, swizzle: int | None):
598
+ bw = mgpu.bytewidth(dtype)
599
+ m, n = shape
600
+ if n % 32 != 0:
601
+ raise NotImplementedError
602
+ cols_per_tile = 128 // bw
603
+ if swizzle != 128:
604
+ raise NotImplementedError("Only 128B swizzle supported")
605
+
606
+ c = arith.ConstantOp.create_index
607
+ tidx = arith.remui(gpu.thread_id(gpu.Dimension.x), c(WARPGROUP_SIZE))
608
+ lane_id = arith.remui(tidx, c(32)) # {0, 1, ..., 31}
609
+ warp_id = arith.divui(tidx, c(32)) # {0, 1, 2, 3}
610
+ sub_row_base = arith.divui(lane_id, c(4)) # {0, 1, ..., 7}
611
+ if bw > 2: # Stagger is only necessary for values larger than 16bit.
612
+ is_even_row = arith.cmpi(
613
+ arith.CmpIPredicate.eq, arith.remui(sub_row_base, c(2)), c(0)
614
+ )
615
+ else:
616
+ # We rely on canonicalization to clean up the selects.
617
+ i1 = ir.IntegerType.get_signless(1)
618
+ is_even_row = arith.constant(i1, ir.BoolAttr.get(True))
619
+ row_base = arith.addi(sub_row_base, arith.muli(warp_id, c(16)))
620
+ col_base = arith.muli(arith.remui(lane_id, c(4)), c(2)) # {0, 2, 4, 6}
621
+ # The swizzle pattern is constant for a given thread.
622
+ col_swizzle_bits = arith.muli(sub_row_base, c(16 // bw))
623
+ for row_group in range(m // 64):
624
+ for col_group in range(n // cols_per_tile):
625
+ for row_subidx in range(2):
626
+ row = arith.addi(row_base, c(row_subidx * 8))
627
+ for col_subidx in range(cols_per_tile // 8):
628
+ # We stagger the even and odd rows a little to avoid bank conflicts.
629
+ # It seems that the STS.64 is 2x faster (and the hardware reports no
630
+ # conflicts) when the conflicts are split between half-warps, as
631
+ # opposed to having them within the half-warp. This requires a
632
+ # little more work for the selects, but is ultimately worth it.
633
+ col_subidx_even = col_subidx
634
+ col_subidx_odd = col_subidx ^ 2
635
+ col_off = arith.select(
636
+ is_even_row, c(col_subidx_even * 8), c(col_subidx_odd * 8)
637
+ )
638
+ col = arith.addi(col_base, col_off)
639
+ col = arith.xori(col, col_swizzle_bits)
640
+ reg_idx_even = col_subidx_even + col_group * (cols_per_tile // 8)
641
+ reg_idx_odd = col_subidx_odd + col_group * (cols_per_tile // 8)
642
+ even_idx = row_group, reg_idx_even, row_subidx, 0
643
+ odd_idx = row_group, reg_idx_odd, row_subidx, 0
644
+ idx = c(row_group), c(col_group), row, col
645
+ def get_register(regs, even_idx=even_idx, odd_idx=odd_idx):
646
+ value_even = regs[even_idx]
647
+ value_odd = regs[odd_idx]
648
+ return arith.select(is_even_row, value_even, value_odd)
649
+ def update_registers(regs, new, even_idx=even_idx, odd_idx=odd_idx):
650
+ regs[even_idx] = arith.select(is_even_row, new, regs[even_idx])
651
+ regs[odd_idx] = arith.select(is_even_row, regs[odd_idx], new)
652
+ yield get_register, update_registers, idx
653
+
654
+ def tree_flatten(self):
655
+ return list(self.registers.flat), (self.layout, self.registers.shape)
656
+
657
+ @classmethod
658
+ def tree_unflatten(cls, aux, flat_registers):
659
+ layout, reg_shape = aux
660
+ registers = np.asarray(flat_registers, dtype=object).reshape(reg_shape)
661
+ return cls(_registers=registers, _layout=layout)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/profiler.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The JAX Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import contextlib
17
+ import ctypes
18
+ import functools
19
+ import json
20
+ import math
21
+
22
+ import jax
23
+ from jax._src.interpreters import mlir
24
+ from jax._src.lib import mosaic_gpu as mosaic_gpu_lib
25
+ from jax._src.lib import xla_client
26
+ import jax.numpy as jnp
27
+ from jaxlib.mlir import ir
28
+ from jaxlib.mlir.dialects import arith
29
+ from jaxlib.mlir.dialects import gpu
30
+ from jaxlib.mlir.dialects import memref
31
+ from jaxlib.mlir.dialects import scf
32
+ import numpy as np
33
+
34
+ from .utils import * # noqa: F403
35
+
36
+ # ruff: noqa: F405
37
+ # mypy: ignore-errors
38
+
39
+ xla_client.register_custom_call_target(
40
+ "mosaic_gpu_record_event",
41
+ mosaic_gpu_lib._mosaic_gpu_ext._record_event_capsule(),
42
+ platform="CUDA",
43
+ )
44
+
45
+ record_event_p = jax.core.Primitive("record_event")
46
+ record_event_p.multiple_results = True
47
+
48
+ @record_event_p.def_abstract_eval
49
+ def _record_event_abstract_eval(*args, event):
50
+ del event # Unused.
51
+ return args
52
+
53
+ @functools.partial(mlir.register_lowering, record_event_p, platform="cuda")
54
+ def _record_event_lowering_rule(ctx, *args, event):
55
+ ptr_bytes = ctypes.cast(event, ctypes.c_void_p).value.to_bytes(
56
+ 8, byteorder="little"
57
+ ) # pytype: disable=attribute-error
58
+ op = mlir.custom_call(
59
+ "mosaic_gpu_record_event",
60
+ result_types=[mlir.aval_to_ir_type(aval) for aval in ctx.avals_out],
61
+ operands=args,
62
+ backend_config=ptr_bytes,
63
+ operand_output_aliases={i: i for i in range(len(args))},
64
+ )
65
+ return op.results
66
+
67
+ def _record_event(args, event):
68
+ flat_args, treedef = jax.tree.flatten(args)
69
+ return jax.tree.unflatten(
70
+ treedef, record_event_p.bind(*flat_args, event=event)
71
+ )
72
+
73
+ def measure(f, *args):
74
+ # TODO(apaszke): Raise if this is called under jit.
75
+ start_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
76
+ end_event = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_create()
77
+ try:
78
+ @jax.jit
79
+ def run(*args):
80
+ return _record_event(f(*_record_event(args, start_event)), end_event)
81
+ jax.block_until_ready(run(*args)) # Warmup.
82
+ results = jax.block_until_ready(run(*args))
83
+ elapsed = mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_elapsed(
84
+ start_event, end_event
85
+ )
86
+ finally:
87
+ mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(start_event)
88
+ mosaic_gpu_lib._mosaic_gpu_ext._gpu_event_destroy(end_event)
89
+ return results, elapsed
90
+
91
+
92
+ class ProfilerSpec:
93
+ ENTER = 0
94
+ EXIT = 1 << 31
95
+
96
+ def __init__(self, entries_per_warpgroup: int):
97
+ self.entries_per_warpgroup = entries_per_warpgroup
98
+ self.interned_names = {}
99
+
100
+ def _num_warpgroups(
101
+ self, grid: tuple[int, ...], block: tuple[int, ...]
102
+ ) -> int:
103
+ if math.prod(block) % WARPGROUP_SIZE:
104
+ raise ValueError("Block size is not a multiple of warpgroup size")
105
+ return math.prod(grid) * math.prod(block) // WARPGROUP_SIZE
106
+
107
+ def mlir_buffer_type(
108
+ self, grid: tuple[int, ...], block: tuple[int, ...]
109
+ ) -> ir.Type:
110
+ return ir.MemRefType.get(
111
+ (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
112
+ ir.IntegerType.get_signless(32),
113
+ )
114
+
115
+ def jax_buffer_type(
116
+ self, grid: tuple[int, ...], block: tuple[int, ...]
117
+ ) -> ir.Type:
118
+ return jax.ShapeDtypeStruct(
119
+ (self._num_warpgroups(grid, block) * self.entries_per_warpgroup,),
120
+ jnp.uint32,
121
+ )
122
+
123
+ def smem_i32_elements(self, block: tuple[int, ...]):
124
+ num_warpgroups = self._num_warpgroups((), block)
125
+ return int(num_warpgroups * self.entries_per_warpgroup)
126
+
127
+ def smem_bytes(self, block: tuple[int, ...]):
128
+ bytes_per_entry = 4
129
+ return self.smem_i32_elements(block) * bytes_per_entry
130
+
131
+ def intern_name(self, name: str) -> int:
132
+ if name_id := self.interned_names.get(name, None):
133
+ return name_id
134
+ name_id = self.interned_names[name] = len(self.interned_names)
135
+ if name_id & self.EXIT:
136
+ raise RuntimeError("Allocated too many names")
137
+ return name_id
138
+
139
+ def dump(self, buffer, f, grid: tuple[int, ...], block: tuple[int, ...]):
140
+ buffer = np.asarray(buffer)
141
+ num_blocks = math.prod(grid)
142
+ warpgroups_per_block = self._num_warpgroups((), block)
143
+ entries = buffer.reshape(
144
+ num_blocks, warpgroups_per_block, self.entries_per_warpgroup
145
+ )
146
+ start_times = entries[..., :2].astype(np.int64)
147
+ start_times = (start_times[..., 0] << 32) + start_times[..., 1]
148
+ start_times -= start_times.min() # Normalize
149
+ entries_used = entries[..., 2]
150
+ if np.any(entries_used > self.entries_per_warpgroup - 2):
151
+ raise RuntimeError("Insufficient space to capture a full trace")
152
+ traces = entries[..., 3:]
153
+ unintern = {v: k for k, v in self.interned_names.items()}
154
+ events = []
155
+ for block_idx, wg_idx in np.ndindex(num_blocks, warpgroups_per_block):
156
+ valid_entries = entries_used[block_idx, wg_idx] - 3
157
+ local_clock_offset = None
158
+ assert valid_entries % 2 == 0, valid_entries
159
+ start_time = start_times[block_idx, wg_idx]
160
+ block_events = []
161
+ for i in range(0, valid_entries, 2):
162
+ tag = traces[block_idx, wg_idx, i]
163
+ time = traces[block_idx, wg_idx, i + 1]
164
+ if local_clock_offset is None:
165
+ local_clock_offset = time
166
+ time -= local_clock_offset
167
+ time -= i * 6 # Account for the overhead of profiling.
168
+ if time < 0:
169
+ break # Detect a timer wraparound
170
+ name_id = tag
171
+ begin = True
172
+ if name_id & ProfilerSpec.EXIT:
173
+ name_id = name_id ^ ProfilerSpec.EXIT
174
+ begin = False
175
+ name = unintern[name_id]
176
+ block_events.append({
177
+ "name": name,
178
+ "ph": "B" if begin else "E",
179
+ "ts": float(start_time + time) / 1e3,
180
+ "pid": 1 + block_idx,
181
+ "tid": 1 + wg_idx,
182
+ })
183
+ else: # If we didn't break
184
+ events.extend(block_events)
185
+ return json.dump({"displayTimeUnit": "ns", "traceEvents": events}, f)
186
+
187
+
188
+ class OnDeviceProfiler:
189
+
190
+ def __init__(self, spec: ProfilerSpec, smem_buffer: ir.Value, gmem_buffer: ir.Value):
191
+ self.spec = spec
192
+ # self.should_store = gpu.thread_id(gpu.Dimension.x)
193
+ i32 = ir.IntegerType.get_signless(32)
194
+ index = ir.IndexType.get()
195
+ self.entries_per_wg = spec.entries_per_warpgroup
196
+ wg_idx = warpgroup_idx(sync=False)
197
+ self.smem_buffer = memref_slice(
198
+ smem_buffer,
199
+ ds(
200
+ arith.index_cast(
201
+ index, arith.muli(wg_idx, c(self.entries_per_wg, i32))
202
+ ),
203
+ self.entries_per_wg,
204
+ ),
205
+ )
206
+ self.gmem_buffer = gmem_buffer
207
+ # Hopefully mem2reg will remove the allocation.
208
+ self.offset = memref.alloca(ir.MemRefType.get((), i32), [], [])
209
+ memref.store(c(0, i32), self.offset, [])
210
+
211
+ @contextlib.contextmanager
212
+ def record(self, name: str):
213
+ i32 = ir.IntegerType.get_signless(32)
214
+ index = ir.IndexType.get()
215
+ name_id = self.spec.intern_name(name)
216
+ def store(modifier):
217
+ cur = arith.index_cast(index, memref.load(self.offset, []))
218
+ # TODO(apaszke): Clamp indices
219
+ # bound = arith.subi(self.entries_per_block, c(2, index))
220
+ # cur = arith.select(
221
+ # arith.cmpi(arith.CmpIPredicate.ult, cur, bound), cur, bound
222
+ # )
223
+ memref.store(c(modifier | name_id, i32), self.smem_buffer, [cur])
224
+ memref.store(
225
+ clock(), self.smem_buffer, [arith.addi(cur, c(1, cur.type))]
226
+ )
227
+ memref.store(
228
+ arith.index_cast(i32, arith.addi(cur, c(2, cur.type))),
229
+ self.offset,
230
+ [],
231
+ )
232
+ store(ProfilerSpec.ENTER)
233
+ yield
234
+ store(ProfilerSpec.EXIT)
235
+
236
+ def finalize(self, grid: tuple[int, ...], block: tuple[int, ...]):
237
+ index = ir.IndexType.get()
238
+ i32 = ir.IntegerType.get_signless(32)
239
+
240
+ gpu.barrier() # Make sure all warpgroups are done.
241
+
242
+ block_idx = c(0, index)
243
+ for dim in gpu.Dimension: # pytype: disable=wrong-arg-types
244
+ block_idx = arith.addi(
245
+ arith.muli(block_idx, gpu.grid_dim(dim)), gpu.block_id(dim)
246
+ )
247
+ wg_idx = warpgroup_idx(sync=False)
248
+ wg_per_block = math.prod(block) // WARPGROUP_SIZE
249
+ global_wg_idx = arith.addi(
250
+ arith.muli(block_idx, c(wg_per_block, index)),
251
+ arith.index_cast(index, wg_idx),
252
+ )
253
+ start_offset = arith.muli(global_wg_idx, c(self.entries_per_wg, index))
254
+ wg_gmem_buffer = memref.subview(
255
+ self.gmem_buffer, [start_offset], [self.entries_per_wg], [1],
256
+ result_type=ir.Type.parse(
257
+ f"memref<{self.entries_per_wg}xi32, strided<[1], offset: ?>>"
258
+ ),
259
+ )
260
+ thread_in_wg = arith.remui(thread_idx(), c(128, i32))
261
+ if_first = scf.IfOp(
262
+ arith.cmpi(arith.CmpIPredicate.eq, thread_in_wg, c(0, i32))
263
+ )
264
+ with ir.InsertionPoint(if_first.then_block):
265
+ # TODO(apaszke): Either use globaltimer or delete
266
+ # memref.store(globaltimer("high"), block_gmem_buffer, [c(0, index)])
267
+ # memref.store(globaltimer("low"), block_gmem_buffer, [c(1, index)])
268
+ memref.store(c(0, i32), wg_gmem_buffer, [c(0, index)])
269
+ memref.store(c(0, i32), wg_gmem_buffer, [c(1, index)])
270
+ memref.store(
271
+ arith.addi(memref.load(self.offset, []), c(3, i32)),
272
+ wg_gmem_buffer,
273
+ [c(2, index)],
274
+ )
275
+
276
+ for_op = scf.ForOp(
277
+ c(0, index),
278
+ c(self.entries_per_wg - 3, index),
279
+ c(1, index),
280
+ )
281
+ with ir.InsertionPoint(for_op.body):
282
+ x = memref.load(self.smem_buffer, [for_op.induction_variable])
283
+ memref.store(
284
+ x,
285
+ wg_gmem_buffer,
286
+ [arith.addi(for_op.induction_variable, c(3, index))],
287
+ )
288
+ scf.yield_([])
289
+ scf.yield_([])
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/utils.py ADDED
@@ -0,0 +1,699 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The JAX Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+ """Utilities for code generator."""
16
+
17
+ from collections.abc import Iterator
18
+ import contextlib
19
+ import dataclasses
20
+ from typing import Any, Literal, Sequence
21
+
22
+ import jax
23
+ from jaxlib.mlir import ir
24
+ from jaxlib.mlir.dialects import arith
25
+ from jaxlib.mlir.dialects import builtin
26
+ from jaxlib.mlir.dialects import gpu
27
+ from jaxlib.mlir.dialects import llvm
28
+ from jaxlib.mlir.dialects import memref
29
+ from jaxlib.mlir.dialects import nvgpu
30
+ from jaxlib.mlir.dialects import nvvm
31
+ from jaxlib.mlir.dialects import scf
32
+ from jaxlib.mlir.dialects import vector
33
+ import numpy as np
34
+
35
+ # mypy: ignore-errors
36
+
37
+ WARPGROUP_SIZE: int = 128
38
+ DYNAMIC = -9223372036854775808
39
+
40
+ # pylint: disable=line-too-long, wildcard-import, missing-function-docstring, bad-continuation, g-bad-todo, protected-access, g-explicit-length-test, missing-class-docstring, g-doc-return-or-yield, g-inconsistent-quotes
41
+
42
+
43
+ def ptr_as_memref(ptr, memref_ty: ir.MemRefType):
44
+ if len(memref_ty.shape) == 0:
45
+ raise NotImplementedError
46
+ i64 = ir.IntegerType.get_signless(64)
47
+ rank = len(memref_ty.shape)
48
+ desc_ty = ir.Type.parse(
49
+ f"!llvm.struct<(ptr, ptr, i64, array<{rank} x i64>, array<{rank} x i64>)>"
50
+ )
51
+ desc = llvm.UndefOp(desc_ty)
52
+ desc = llvm.InsertValueOp(desc, ptr, [0]) # Allocation
53
+ desc = llvm.InsertValueOp(desc, ptr, [1]) # Aligned Base
54
+ desc = llvm.InsertValueOp(
55
+ desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, 0)), [2]
56
+ )
57
+ for i, s in enumerate(memref_ty.shape):
58
+ desc = llvm.InsertValueOp(
59
+ desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [3, i]
60
+ )
61
+ for i, s in enumerate(get_contiguous_strides(memref_ty.shape)):
62
+ desc = llvm.InsertValueOp(
63
+ desc, llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, s)), [4, i]
64
+ )
65
+ return builtin.unrealized_conversion_cast([memref_ty], [desc])
66
+
67
+
68
+ def pack_array(values):
69
+ if not values:
70
+ raise ValueError("Empty array")
71
+ elem_ty = values[0].type
72
+ i64 = ir.IntegerType.get_signless(64)
73
+ ptr_ty = ir.Type.parse("!llvm.ptr")
74
+ arr_ptr = llvm.alloca(ptr_ty, c(len(values), i64), elem_ty)
75
+ for i, v in enumerate(values):
76
+ elem_ptr = llvm.getelementptr(ptr_ty, arr_ptr, [], [i], elem_ty)
77
+ llvm.store(v, elem_ptr)
78
+ return arr_ptr
79
+
80
+
81
+ def get_contiguous_strides(xs):
82
+ strides_ret = []
83
+ stride = 1
84
+ for x in xs[::-1]:
85
+ strides_ret.append(stride)
86
+ stride *= x
87
+ return strides_ret[::-1]
88
+
89
+
90
+ def c(val: int | float, ty):
91
+ if ir.IntegerType.isinstance(ty) or ir.IndexType.isinstance(ty):
92
+ if not isinstance(val, (int, np.integer)):
93
+ raise TypeError(type(val))
94
+ attr = ir.IntegerAttr.get(ty, val)
95
+ elif ir.FloatType.isinstance(ty):
96
+ attr = ir.FloatAttr.get(ty, val)
97
+ elif ir.VectorType.isinstance(ty):
98
+ return vector.splat(ty, c(val, ir.VectorType(ty).element_type))
99
+ else:
100
+ raise NotImplementedError(ty)
101
+ return arith.constant(ty, attr)
102
+
103
+
104
+ def get_tensormap_descriptor(**attrs):
105
+ return ir.Type.parse(
106
+ f"!nvgpu.tensormap.descriptor<{', '.join(k + '=' + v for k, v in attrs.items())}>"
107
+ )
108
+
109
+
110
+ def debug_print(fmt, *args, uniform=True):
111
+ type_formats = []
112
+ new_args = []
113
+ for arg in args:
114
+ ty_format = None
115
+ if ir.IndexType.isinstance(arg.type):
116
+ ty_format = "%llu"
117
+ if ir.IntegerType.isinstance(arg.type):
118
+ width = ir.IntegerType(arg.type).width
119
+ if width == 64:
120
+ ty_format = "%llu"
121
+ elif width == 1:
122
+ ty_format = "%llu"
123
+ arg = arith.extui(ir.IntegerType.get_signless(64), arg)
124
+ if ir.F32Type.isinstance(arg.type):
125
+ ty_format = "%f"
126
+ if ir.F16Type.isinstance(arg.type):
127
+ ty_format = "%f"
128
+ arg = arith.extf(ir.F32Type.get(), arg)
129
+ if ty_format is None:
130
+ raise NotImplementedError(arg.type)
131
+ type_formats.append(ty_format)
132
+ new_args.append(arg)
133
+ ctx = single_thread if uniform else contextlib.nullcontext
134
+ with ctx():
135
+ gpu.printf(fmt.format(*type_formats) + "\n", new_args)
136
+
137
+
138
+ @dataclasses.dataclass(frozen=True)
139
+ class ForResult:
140
+ op: scf.ForOp
141
+ results: tuple[Any, ...]
142
+
143
+ @property
144
+ def result(self):
145
+ if len(self.results) != 1:
146
+ raise ValueError
147
+ return self.results[0]
148
+
149
+
150
+ def fori(bound, carrys):
151
+ unwrap = False
152
+ if not isinstance(carrys, (list, tuple)):
153
+ carrys = [carrys]
154
+ unwrap = True
155
+ flat_carrys, carry_treedef = jax.tree.flatten(carrys)
156
+
157
+ def wrapper(f):
158
+ index = ir.IndexType.get()
159
+ c0 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 0))
160
+ c1 = arith.ConstantOp(index, ir.IntegerAttr.get(index, 1))
161
+ for_op = scf.ForOp(c0, bound, c1, flat_carrys)
162
+ with ir.InsertionPoint(for_op.body):
163
+ i = for_op.induction_variable
164
+ inner_carrys = jax.tree.unflatten(carry_treedef, for_op.inner_iter_args)
165
+ if unwrap:
166
+ [inner_carrys] = inner_carrys
167
+ new_carrys = f(i, inner_carrys)
168
+ if unwrap:
169
+ new_carrys = [new_carrys]
170
+ new_flat_carrys, new_carry_treedef = jax.tree.flatten(new_carrys)
171
+ if new_carry_treedef != carry_treedef:
172
+ raise ValueError(new_carry_treedef, carry_treedef)
173
+ scf.YieldOp(new_flat_carrys)
174
+ final_flat_carrys = for_op.results
175
+ return ForResult(
176
+ for_op, jax.tree.unflatten(carry_treedef, final_flat_carrys)
177
+ )
178
+
179
+ return wrapper
180
+
181
+
182
+ def thread_idx():
183
+ i32 = ir.IntegerType.get_signless(32)
184
+ as_i32 = lambda x: arith.index_cast(i32, x)
185
+ tidx = as_i32(gpu.thread_id(gpu.Dimension.x))
186
+ stride = as_i32(gpu.block_dim(gpu.Dimension.x))
187
+ for dim in (gpu.Dimension.y, gpu.Dimension.z):
188
+ tidx = arith.addi(tidx, arith.muli(as_i32(gpu.thread_id(dim)), stride))
189
+ stride = arith.muli(stride, as_i32(gpu.block_dim(dim)))
190
+ return tidx
191
+
192
+
193
+ def _warp_bcast(val, lane_idx=0):
194
+ i32 = ir.IntegerType.get_signless(32)
195
+ mask = c(0xFFFFFFFF, i32)
196
+ return nvvm.shfl_sync(
197
+ val.type, mask, val, c(lane_idx, i32), c(0x1F, i32), nvvm.ShflKind.idx
198
+ )
199
+
200
+
201
+ def warp_idx(sync=True):
202
+ i32 = ir.IntegerType.get_signless(32)
203
+ warp_idx = arith.shrui(thread_idx(), c(5, i32))
204
+ # Performing a warp broadcast improves performance as compiler understands
205
+ # that the value is uniform across the warp.
206
+ return _warp_bcast(warp_idx) if sync else warp_idx
207
+
208
+
209
+ def warpgroup_idx(sync=True):
210
+ i32 = ir.IntegerType.get_signless(32)
211
+ wg_idx = arith.shrui(thread_idx(), c(7, i32))
212
+ # Performing a warp broadcast improves performance as compiler understands
213
+ # that the value is uniform across the warp.
214
+ return _warp_bcast(wg_idx) if sync else wg_idx
215
+
216
+
217
+ # True withon `once()` contexts.
218
+ _ONCE_REGION_ACTIVE = False
219
+
220
+
221
+ @contextlib.contextmanager
222
+ def single_thread():
223
+ """Runs the context only from a single thread."""
224
+ global _ONCE_REGION_ACTIVE
225
+
226
+ if _ONCE_REGION_ACTIVE:
227
+ yield
228
+ return
229
+
230
+ warp = warp_idx()
231
+ first_warp = arith.cmpi(arith.CmpIPredicate.eq, warp, c(0, warp.type))
232
+ elected = nvvm.elect_sync(ir.IntegerType.get_signless(1))
233
+ should_run = arith.andi(first_warp, elected)
234
+ if_op = scf.IfOp(should_run)
235
+ _ONCE_REGION_ACTIVE = True
236
+ try:
237
+ with ir.InsertionPoint(if_op.then_block):
238
+ yield
239
+ scf.YieldOp([])
240
+ finally:
241
+ _ONCE_REGION_ACTIVE = False
242
+
243
+
244
+ def clock():
245
+ i32 = ir.IntegerType.get_signless(32)
246
+ return llvm.inline_asm(
247
+ i32, [], "mov.u32 $0,%clock;", "=r", asm_dialect=0, has_side_effects=True
248
+ )
249
+
250
+
251
+ def globaltimer(kind: Literal["low", "high"] | None = None):
252
+ if kind is None:
253
+ i64 = ir.IntegerType.get_signless(64)
254
+ return llvm.inline_asm(
255
+ i64, [], "mov.u32 $0,%globaltimer;",
256
+ "=l", asm_dialect=0, has_side_effects=True,
257
+ )
258
+ i32 = ir.IntegerType.get_signless(32)
259
+ return llvm.inline_asm(
260
+ i32, [], f"mov.u32 $0,%globaltimer_{kind[:2]};",
261
+ "=r", asm_dialect=0, has_side_effects=True,
262
+ )
263
+
264
+
265
+ def bytewidth(ty: ir.Type):
266
+ if ir.IntegerType.isinstance(ty):
267
+ return ir.IntegerType(ty).width // 8
268
+ if ir.FloatType.isinstance(ty):
269
+ return ir.FloatType(ty).width // 8
270
+ raise NotImplementedError(ty)
271
+
272
+
273
+ @dataclasses.dataclass(frozen=True)
274
+ class DynamicSlice:
275
+ base: ir.Value | int
276
+ length: int
277
+
278
+
279
+ ds = DynamicSlice
280
+
281
+
282
+ def memref_slice(ref: ir.Value, index) -> ir.Value:
283
+ ref_ty = ir.MemRefType(ref.type)
284
+ base_indices, slice_shape, is_squeezed = parse_indices(index, ref_ty.shape)
285
+
286
+ memref_strides, offset = ref_ty.get_strides_and_offset()
287
+ new_offset = offset
288
+ for idx, stride in zip(base_indices, memref_strides):
289
+ if isinstance(idx, int):
290
+ new_offset += idx * stride
291
+ else:
292
+ new_offset = ir.ShapedType.get_dynamic_stride_or_offset()
293
+ break
294
+ new_strides = [
295
+ s for s, squeeze in zip(memref_strides, is_squeezed) if not squeeze
296
+ ]
297
+ new_shape = [s for s, squeeze in zip(slice_shape, is_squeezed) if not squeeze]
298
+ new_layout = ir.StridedLayoutAttr.get(new_offset, new_strides)
299
+
300
+ ref_slice = memref.subview(
301
+ ref, base_indices, slice_shape, [1] * len(ref_ty.shape),
302
+ result_type=ir.MemRefType.get(
303
+ new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
304
+ ),
305
+ )
306
+ return ref_slice
307
+
308
+
309
+ def _is_contiguous_shape_slice(
310
+ ref_ty: ir.MemRefType, dim_slice: slice | None = slice(None)
311
+ ):
312
+ # If it's not a strided layout then we are definitely contiguous.
313
+ if not ir.StridedLayoutAttr.isinstance(ref_ty.layout):
314
+ return True
315
+
316
+ strides = ir.StridedLayoutAttr(ref_ty.layout).strides[dim_slice]
317
+ shape = ref_ty.shape[dim_slice]
318
+
319
+ # Check that each dimension fits exactly it the immediately larger stride.
320
+ ss = sorted(zip(strides, shape), key=lambda x: x[0], reverse=True)
321
+ for (prev_stride, _), (stride, shape) in zip(ss, ss[1:]):
322
+ if stride * shape != prev_stride:
323
+ return False
324
+
325
+ return True
326
+
327
+
328
+ def memref_fold(ref: ir.Value, dim, fold_rank) -> ir.Value:
329
+ ref_ty = ir.MemRefType(ref.type)
330
+ new_shape = list(ref_ty.shape)
331
+ new_shape[dim : dim + fold_rank] = [np.prod(new_shape[dim : dim + fold_rank])]
332
+ identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
333
+ contig_strided_1d = ir.Attribute.parse("strided<[1]>")
334
+ # Not sure why but MLIR expects the strided 1D layout to disappear in this op.
335
+ if ref_ty.layout == identity or ref_ty.layout == contig_strided_1d:
336
+ new_layout = ir.AffineMapAttr.get(
337
+ ir.AffineMap.get_identity(ref_ty.rank - fold_rank + 1)
338
+ )
339
+ elif _is_contiguous_shape_slice(ref_ty, slice(dim, dim + fold_rank)):
340
+ new_strides, offset = ref_ty.get_strides_and_offset()
341
+ new_strides[dim : dim + fold_rank] = [new_strides[dim + fold_rank - 1]]
342
+ new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
343
+ else:
344
+ raise NotImplementedError(
345
+ f"strides={ref_ty.get_strides_and_offset()[0]}, {ref_ty.shape=},"
346
+ f" {dim=}, {fold_rank=}"
347
+ )
348
+
349
+ new_ty = ir.MemRefType.get(
350
+ new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
351
+ )
352
+ assoc = [[d] for d in range(dim)]
353
+ assoc.append([dim + i for i in range(fold_rank)])
354
+ assoc.extend([d] for d in range(dim + fold_rank, ref_ty.rank))
355
+ assert len(assoc) == new_ty.rank
356
+ return memref.collapse_shape(new_ty, ref, assoc)
357
+
358
+
359
+ def memref_unfold(ref: ir.Value, dim, factors) -> ir.Value:
360
+ """Unfolds dim into two dimensions, the size of leading one given be major_factor."""
361
+ ref_ty = ir.MemRefType(ref.type)
362
+ new_shape = list(ref_ty.shape)
363
+ if sum(f is None for f in factors) > 1:
364
+ raise ValueError("Can only infer one dimension")
365
+ known_factor_prod = np.prod([f for f in factors if f is not None])
366
+ if new_shape[dim] % known_factor_prod:
367
+ raise ValueError("Non-divisible unfold:", new_shape[dim], factors)
368
+ factors = tuple(
369
+ new_shape[dim] // known_factor_prod if f is None else f for f in factors
370
+ )
371
+ new_shape[dim : dim + 1] = factors
372
+ identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
373
+ if ref_ty.layout == identity:
374
+ new_layout = ir.AffineMapAttr.get(
375
+ ir.AffineMap.get_identity(ref_ty.rank + len(factors) - 1)
376
+ )
377
+ else:
378
+ new_strides, offset = ref_ty.get_strides_and_offset()
379
+ prev_stride = new_strides[dim]
380
+ inserted_strides = []
381
+ for f in reversed(factors):
382
+ inserted_strides.append(prev_stride)
383
+ prev_stride *= f
384
+ new_strides[dim : dim + 1] = reversed(inserted_strides)
385
+ new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
386
+ new_ty = ir.MemRefType.get(
387
+ new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
388
+ )
389
+ if dim == ref_ty.rank:
390
+ assoc = [[d] for d in range(ref_ty.rank)]
391
+ assoc[-1].extend(range(ref_ty.rank, ref_ty.rank + len(factors) - 1))
392
+ else:
393
+ assoc = [[d] for d in range(dim)]
394
+ assoc.append(list(range(dim, dim + len(factors))))
395
+ assoc.extend([d + len(factors) - 1] for d in range(dim + 1, ref_ty.rank))
396
+ assert len(assoc) == ref_ty.rank
397
+ return memref.expand_shape(new_ty, ref, assoc, [], new_ty.shape)
398
+
399
+
400
+ def memref_unsqueeze(ref: ir.Value, dim) -> ir.Value:
401
+ """Inserts a singleton dimension."""
402
+ ref_ty = ir.MemRefType(ref.type)
403
+ if dim == ref_ty.rank:
404
+ new_shape = list(ref_ty.shape)
405
+ new_shape.append(1)
406
+ identity = ir.AffineMapAttr.get(ir.AffineMap.get_identity(ref_ty.rank))
407
+ if ref_ty.layout == identity:
408
+ new_layout = ir.AffineMapAttr.get(
409
+ ir.AffineMap.get_identity(ref_ty.rank + 1)
410
+ )
411
+ else:
412
+ new_strides, offset = ref_ty.get_strides_and_offset()
413
+ new_strides.append(1)
414
+ new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
415
+ new_ty = ir.MemRefType.get(
416
+ new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
417
+ )
418
+ assoc = [[d] for d in range(ref_ty.rank)]
419
+ assoc[-1].append(ref_ty.rank)
420
+ return memref.expand_shape(new_ty, ref, assoc, [], new_ty.shape)
421
+ else:
422
+ return memref_unfold(ref, dim, (1, None))
423
+
424
+
425
+ def memref_transpose(ref: ir.Value, permutation: Sequence[int]) -> ir.Value:
426
+ ref_ty = ir.MemRefType(ref.type)
427
+ strides, offset = ref_ty.get_strides_and_offset()
428
+ new_strides = [strides[p] for p in permutation]
429
+ new_shape = [ref_ty.shape[p] for p in permutation]
430
+ new_layout = ir.StridedLayoutAttr.get(offset, new_strides)
431
+ new_ty = ir.MemRefType.get(
432
+ new_shape, ref_ty.element_type, new_layout, ref_ty.memory_space
433
+ )
434
+ return memref.transpose(
435
+ new_ty, ref, ir.AffineMap.get_permutation(permutation)
436
+ )
437
+
438
+
439
+ def parse_indices(
440
+ index, shape: tuple[int, ...]
441
+ ) -> tuple[list[ir.Value | int], list[int], list[bool]]:
442
+ if not isinstance(index, tuple):
443
+ index = (index,)
444
+ if trailing_dims := len(shape) - len(index):
445
+ index += (slice(None),) * trailing_dims
446
+ base_indices = []
447
+ slice_shape = []
448
+ is_squeezed = []
449
+ for idx, bound in zip(index, shape):
450
+ if isinstance(idx, (ir.Operation, ir.OpView)):
451
+ idx = idx.result
452
+ if isinstance(idx, int):
453
+ base_indices.append(idx)
454
+ slice_shape.append(1)
455
+ is_squeezed.append(True)
456
+ elif isinstance(idx, slice):
457
+ if idx.step is not None:
458
+ raise NotImplementedError("Strided slices not implemented")
459
+ base_indices.append(idx.start or 0)
460
+ slice_shape.append((idx.stop or bound) - (idx.start or 0))
461
+ is_squeezed.append(False)
462
+ elif isinstance(idx, DynamicSlice):
463
+ base_indices.append(idx.base)
464
+ slice_shape.append(idx.length)
465
+ is_squeezed.append(False)
466
+ elif isinstance(idx, ir.Value):
467
+ if not ir.IndexType.isinstance(idx.type):
468
+ raise ValueError("Expected an index-typed index")
469
+ base_indices.append(idx)
470
+ slice_shape.append(1)
471
+ is_squeezed.append(True)
472
+ else:
473
+ raise NotImplementedError(type(idx))
474
+ assert len(base_indices) == len(slice_shape) == len(is_squeezed) == len(shape)
475
+ return base_indices, slice_shape, is_squeezed
476
+
477
+
478
+ def commit_shared():
479
+ gpu.barrier()
480
+ nvvm.fence_proxy(
481
+ nvvm.ProxyKind.async_shared, space=nvvm.SharedSpace.shared_cta
482
+ )
483
+
484
+
485
+ class BarrierArray:
486
+
487
+ def __init__(self, num_barriers: int, arrival_count: int = 1):
488
+ barrier_group_ty = ir.Type.parse(
489
+ "!nvgpu.mbarrier.group<memorySpace=#gpu.address_space<workgroup>,"
490
+ f" num_barriers={num_barriers}>"
491
+ )
492
+
493
+ self.num_barriers = num_barriers
494
+ self.value = nvgpu.mbarrier_create(barrier_group_ty)
495
+ self.num_barriers = num_barriers
496
+ index = ir.IndexType.get()
497
+ if num_barriers > 32:
498
+ raise NotImplementedError("Only up to 32 barriers per group supported")
499
+ i32 = ir.IntegerType.get_signless(32)
500
+ self.phases = memref.alloca(ir.MemRefType.get((), i32), [], [])
501
+ memref.store(c(0, i32), self.phases, [])
502
+ with single_thread():
503
+ for i in range(num_barriers):
504
+ nvgpu.mbarrier_init(self.value, c(arrival_count, index), c(i, index))
505
+ gpu.barrier()
506
+
507
+ def __iter__(self) -> Iterator["Barrier"]:
508
+ for offset in range(self.num_barriers):
509
+ yield self[offset]
510
+
511
+ def __getitem__(self, offset: ir.Value | int):
512
+ if isinstance(offset, int):
513
+ offset = c(offset, ir.IndexType.get())
514
+ return Barrier(self, offset)
515
+
516
+
517
+ @dataclasses.dataclass(frozen=True)
518
+ class Barrier:
519
+ barrier_array: BarrierArray
520
+ offset: ir.Value
521
+
522
+ def wait_parity(self, parity):
523
+ index = ir.IndexType.get()
524
+ nvgpu.mbarrier_try_wait_parity(
525
+ self.barrier_array.value, parity, c(10000000, index), self.offset,
526
+ )
527
+
528
+ def wait(self):
529
+ i32 = ir.IntegerType.get_signless(32)
530
+ parities = memref.load(self.barrier_array.phases, [])
531
+ offset_i32 = arith.index_castui(i32, self.offset)
532
+ bitmask = arith.shli(c(1, i32), offset_i32)
533
+ parity = arith.cmpi(
534
+ arith.CmpIPredicate.ne, arith.andi(parities, bitmask), c(0, i32)
535
+ )
536
+ new_parities = arith.xori(parities, bitmask)
537
+ memref.store(new_parities, self.barrier_array.phases, [])
538
+ self.wait_parity(parity)
539
+
540
+ def arrive(self):
541
+ token_ty = ir.Type.parse("!nvgpu.mbarrier.token")
542
+ nvgpu.mbarrier_arrive(token_ty, self.barrier_array.value, self.offset)
543
+
544
+
545
+ class Partition:
546
+ source_bounds: tuple[int, ...]
547
+ target_bounds: tuple[int, ...]
548
+ partition: tuple[int | None, ...]
549
+ base_offset: tuple[ir.Value, ...] | None
550
+
551
+ def __init__(
552
+ self,
553
+ elements: tuple[int, ...],
554
+ *,
555
+ partition: tuple[int | None, ...],
556
+ base_offset: tuple[ir.Value, ...] | None = None,
557
+ num_chunks: tuple[int, ...] | None = None,
558
+ chunk_size: tuple[int, ...] | None = None,
559
+ ):
560
+ self.target_bounds = elements
561
+ self.partition = partition
562
+ self.base_offset = base_offset
563
+ if len(self.target_bounds) != len(self.partition):
564
+ raise ValueError
565
+ if num_chunks is None == chunk_size is None:
566
+ raise ValueError(
567
+ "Exactly one of num_chunks and chunk_size must be specified"
568
+ )
569
+ if num_chunks is not None:
570
+ self.source_bounds = num_chunks
571
+ else:
572
+ if len(chunk_size) != len(self.target_bounds):
573
+ raise ValueError
574
+ source_bounds = []
575
+ for els, chunk in zip(elements, chunk_size):
576
+ if els % chunk:
577
+ raise ValueError("Non-divisible partition", elements, chunk_size)
578
+ source_bounds.append(els // chunk)
579
+ self.source_bounds = tuple(source_bounds)
580
+
581
+ seen_dims = set()
582
+ for p in self.partition:
583
+ if p is None:
584
+ continue
585
+ if not (0 <= p < len(self.source_bounds)):
586
+ raise ValueError
587
+ if p in seen_dims:
588
+ raise ValueError
589
+ seen_dims.add(p)
590
+ for tb, p in zip(self.target_bounds, self.partition):
591
+ if p is not None and tb % self.source_bounds[p]:
592
+ raise ValueError("Non-divisible partitioning")
593
+
594
+ @property
595
+ def num_chunks(self) -> tuple[int, ...]:
596
+ return self.source_bounds
597
+
598
+ @property
599
+ def target_block_shape(self):
600
+ return tuple(tb if p is None else tb // self.source_bounds[p]
601
+ for tb, p in zip(self.target_bounds, self.partition))
602
+
603
+ def get_base(self, *source_coords: ir.Value | int) -> list[ir.Value]:
604
+ coords = []
605
+ index = ir.IndexType.get()
606
+ for i, (tbs, p) in enumerate(zip(self.target_block_shape, self.partition)):
607
+ if p is None:
608
+ dim_base = c(0, index)
609
+ else:
610
+ dim_base = arith.muli(c(tbs, index), source_coords[p])
611
+ if self.base_offset is not None:
612
+ dim_base = arith.addi(self.base_offset[i], dim_base)
613
+ coords.append(dim_base)
614
+ return coords
615
+
616
+
617
+ class Partition1D:
618
+ partition: Partition
619
+
620
+ def __init__(
621
+ self,
622
+ elements: int,
623
+ *,
624
+ base_offset: ir.Value | None = None,
625
+ num_chunks: int | None = None,
626
+ chunk_size: int | None = None,
627
+ ):
628
+ self.base_offset = base_offset
629
+ if num_chunks is None == chunk_size is None:
630
+ raise ValueError(
631
+ "Exactly one of num_chunks and chunk_size must be specified"
632
+ )
633
+ common_kwargs = dict(elements=(elements,), partition=(0,))
634
+ if base_offset is not None:
635
+ common_kwargs["base_offset"] = (base_offset,)
636
+ if num_chunks is not None:
637
+ self.partition = Partition(num_chunks=(num_chunks,), **common_kwargs)
638
+ else:
639
+ self.partition = Partition(chunk_size=(chunk_size,), **common_kwargs)
640
+
641
+ @property
642
+ def num_chunks(self) -> int:
643
+ return self.partition.source_bounds[0]
644
+
645
+ def get_base(self, source_coords: ir.Value) -> ir.Value:
646
+ return self.partition.get_base(source_coords)[0]
647
+
648
+ def refine(
649
+ self,
650
+ *,
651
+ chunk: ir.Value | None = None,
652
+ num_chunks: int | None = None,
653
+ chunk_size: int | None = None,
654
+ ):
655
+ return Partition1D(
656
+ self.partition.target_block_shape[0],
657
+ num_chunks=num_chunks,
658
+ chunk_size=chunk_size,
659
+ base_offset=self.get_base(chunk) if chunk is not None else None,
660
+ )
661
+
662
+
663
+ def tile_shape(shape, tiling):
664
+ if len(tiling) > len(shape):
665
+ raise ValueError
666
+ if not tiling:
667
+ return shape
668
+ tiling_rank = len(tiling)
669
+ for s, t in zip(shape[-tiling_rank:], tiling):
670
+ if s % t:
671
+ raise ValueError("Non-divisible tiling:", shape, tiling)
672
+ return (
673
+ *shape[:-tiling_rank],
674
+ *(s // t for s, t in zip(shape[-tiling_rank:], tiling)),
675
+ *tiling,
676
+ )
677
+
678
+
679
+ def warp_tree_reduce(value, op, group_size):
680
+ """Reduce a value across the warpgroup."""
681
+ assert 32 % group_size == 0 and group_size <= 32
682
+ i32 = ir.IntegerType.get_signless(32)
683
+ result = value
684
+ iters = np.log2(group_size)
685
+ if not iters.is_integer():
686
+ raise ValueError(f"Warp reduction group size should be a power of 2 (got {group_size})")
687
+ iters = int(iters)
688
+ for i in range(iters):
689
+ other_result = nvvm.shfl_sync(
690
+ result.type,
691
+ c(0xFFFFFFFF, i32),
692
+ result,
693
+ c(1 << i, i32),
694
+ c(0x1F, i32),
695
+ nvvm.ShflKind.bfly,
696
+ )
697
+ result = op(result, other_result)
698
+
699
+ return result
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/mosaic/gpu/wgmma.py ADDED
@@ -0,0 +1,518 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 The JAX Authors. All Rights Reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ import dataclasses
17
+ import enum
18
+ import functools
19
+ import itertools
20
+
21
+ import jax
22
+ from jaxlib.mlir import ir
23
+ from jaxlib.mlir.dialects import arith
24
+ from jaxlib.mlir.dialects import builtin
25
+ from jaxlib.mlir.dialects import llvm
26
+ from jaxlib.mlir.dialects import nvvm
27
+ from jaxlib.mlir.dialects import vector
28
+ import numpy as np
29
+
30
+ from . import dsl as mgpu
31
+
32
+ # mypy: ignore-errors
33
+
34
+ c = mgpu.c
35
+ bytewidth = mgpu.bytewidth
36
+
37
+
38
+ @jax.tree_util.register_pytree_node_class
39
+ @dataclasses.dataclass
40
+ class WGMMAAccumulator:
41
+ """A FragmentedArray that has is synchronized with the async proxy.
42
+
43
+ This implies that it requires no additional synchronization when passed in
44
+ as a WGMMA accumulator. In particular, when created from a
45
+ FragmentedArray, the necessary synchronization is inserted at construction.
46
+ """
47
+ value: mgpu.FragmentedArray
48
+
49
+ def __init__(self, *, _value: mgpu.FragmentedArray, _sync: bool = True):
50
+ if _value.layout != mgpu.WGMMA_LAYOUT:
51
+ raise ValueError("Only WGMMA layouts supported in WGMMAAccumulator")
52
+ self.value = _value
53
+ if _sync:
54
+ self._value = wgmma_fence(_value)
55
+
56
+ @classmethod
57
+ def zero(cls, m, n, dtype=None):
58
+ if m % 64 or n % 8:
59
+ raise ValueError
60
+ f32 = ir.F32Type.get()
61
+ if dtype is None:
62
+ dtype = f32
63
+ zero = arith.constant(dtype, ir.FloatAttr.get(dtype, 0.0))
64
+ return cls(
65
+ _value=mgpu.FragmentedArray.splat(zero, (m, n), mgpu.WGMMA_LAYOUT)
66
+ )
67
+
68
+ @classmethod
69
+ def from_registers(cls, registers):
70
+ return cls(_value=registers)
71
+
72
+ def tree_flatten(self):
73
+ return (self.value,), ()
74
+
75
+ @classmethod
76
+ def tree_unflatten(cls, aux, value):
77
+ del aux
78
+ return cls(_value=value[0], _sync=False)
79
+
80
+
81
+ def wgmma_encode(x: int):
82
+ result = (x & 0x3FFFF) >> 4
83
+ if result << 4 != x:
84
+ raise ValueError("Cannot encode value in a WGMMA descriptor")
85
+ return result
86
+
87
+
88
+ def llvm_mul(x, y):
89
+ return llvm.mul(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
90
+
91
+
92
+ def llvm_add(x, y):
93
+ return llvm.add(x, y, overflow_flags=llvm.IntegerOverflowFlags.none)
94
+
95
+
96
+ def get_memref_base(memref_arg, memory_space=None):
97
+ i64 = ir.IntegerType.get_signless(64)
98
+ memref_ty = ir.MemRefType(memref_arg.type)
99
+ if len(memref_ty.shape) == 0:
100
+ raise NotImplementedError
101
+ elem_bytewidth = bytewidth(memref_ty.element_type)
102
+ rank = len(memref_ty.shape)
103
+ # TODO: Read out memory space from memref
104
+ space = "" if memory_space is None else "<" + str(memory_space) + ">"
105
+ ptr_ty = ir.Type.parse("!llvm.ptr" + space)
106
+ desc_ty = ir.Type.parse(
107
+ f"!llvm.struct<({ptr_ty}, {ptr_ty}, i64, array<{rank} x i64>,"
108
+ f" array<{rank} x i64>)>"
109
+ )
110
+ desc = builtin.UnrealizedConversionCastOp([desc_ty], [memref_arg])
111
+ aligned_ptr = llvm.extractvalue(ptr_ty, desc, [1])
112
+ offset_elems = llvm.extractvalue(i64, desc, [2])
113
+ offset_bytes = llvm_mul(offset_elems, c(elem_bytewidth, i64))
114
+ return llvm.inttoptr(
115
+ ptr_ty, llvm_add(llvm.ptrtoint(i64, aligned_ptr), offset_bytes)
116
+ )
117
+
118
+
119
+ def create_descriptor(
120
+ memref_arg,
121
+ leading_byte_offset: int,
122
+ stride_byte_offset: int,
123
+ swizzle: int | None,
124
+ memory_space: int | None = None,
125
+ nvgpu_type=None,
126
+ ):
127
+ i64 = ir.IntegerType.get_signless(64)
128
+ ptr_val = llvm.ptrtoint(i64, get_memref_base(memref_arg, memory_space))
129
+ if swizzle is None:
130
+ swizzle_encoding = 0
131
+ elif swizzle == 128:
132
+ swizzle_encoding = 1
133
+ else:
134
+ raise NotImplementedError(swizzle)
135
+ encoded_base_addr = llvm.LShrOp(
136
+ llvm.AndOp(ptr_val, c(0x3FFFF, i64)), c(4, i64)
137
+ )
138
+ desc_const = (
139
+ (wgmma_encode(leading_byte_offset) << 16)
140
+ | (wgmma_encode(stride_byte_offset) << 32)
141
+ |
142
+ # We ignore the offset
143
+ (swizzle_encoding << 62)
144
+ )
145
+ desc = llvm.OrOp(encoded_base_addr, c(desc_const, i64))
146
+ if nvgpu_type is not None:
147
+ desc = builtin.UnrealizedConversionCastOp([nvgpu_type], [desc])
148
+ return desc.result
149
+
150
+
151
+ def _unpack_i32(vec_ty, r):
152
+ i32 = ir.IntegerType.get_signless(32)
153
+ return vector.bitcast(
154
+ vec_ty, vector.splat(ir.VectorType.get((1,), i32), r)
155
+ )
156
+
157
+
158
+ def _supported_wgmma_types(dtype, abtype) -> bool:
159
+ input_types_are = lambda ty: ty.isinstance(abtype)
160
+ if ir.F32Type.isinstance(dtype):
161
+ return any(input_types_are(ty) for ty in (ir.FloatTF32Type, ir.BF16Type, ir.F16Type))
162
+ elif ir.F16Type.isinstance(dtype):
163
+ return input_types_are(ir.F16Type)
164
+ else:
165
+ return False
166
+
167
+
168
+ def wgmma_m64k128B(
169
+ acc: np.ndarray, # of register Values
170
+ a,
171
+ b_descriptor: ir.Value,
172
+ a_transpose: bool | None,
173
+ b_transpose: bool,
174
+ a_k_stride: int | None,
175
+ b_k_stride: int,
176
+ n: int,
177
+ element_type: ir.Type,
178
+ ):
179
+ out_ty = ir.VectorType(acc.flat[0].type).element_type
180
+ if not _supported_wgmma_types(out_ty, element_type):
181
+ raise ValueError(f"Usupported wgmma types {(out_ty, element_type)=}")
182
+
183
+ f16 = ir.F16Type.get()
184
+ i32 = ir.IntegerType.get_signless(32)
185
+ i64 = ir.IntegerType.get_signless(64)
186
+ index = ir.IndexType.get()
187
+ if b_k_stride % 16:
188
+ raise ValueError
189
+ if n % (128 // bytewidth(element_type)):
190
+ raise ValueError
191
+ # Only 16-bit types support transposes
192
+ supports_transpose = bytewidth(element_type) == 2
193
+ if not supports_transpose and (a_transpose or b_transpose):
194
+ raise ValueError("Only f16 WGMMA supports transposes")
195
+ if a_in_regs := isinstance(a, mgpu.FragmentedArray):
196
+ if a.mlir_dtype != ir.F16Type.get() and a.mlir_dtype != ir.BF16Type.get():
197
+ raise ValueError(f"Unsupported A register array dtype: {a.mlir_dtype}")
198
+ if a.layout != mgpu.WGMMA_LAYOUT or a.shape != (64, 64):
199
+ raise ValueError("Unsupported A register array layout")
200
+ if a_k_stride is not None or a_transpose is not None:
201
+ raise ValueError("Unsupported WGMMA features with A in registers")
202
+ else:
203
+ if a_k_stride is None or a_k_stride % 16:
204
+ raise ValueError
205
+ if a_transpose is None:
206
+ raise ValueError
207
+
208
+ if ir.F32Type.isinstance(out_ty):
209
+ num_acc_regs = n // 2
210
+ out_ty_field = out_ty
211
+ acc_regs = [ # pylint: disable=g-complex-comprehension
212
+ vector.extractelement(reg, position=c(pos, index))
213
+ for reg in acc.flat
214
+ for pos in range(2)
215
+ ]
216
+ to_acc_vec_regs = functools.partial(_as_fragmented_reg_ndarray, dtype=out_ty, shape=acc.shape)
217
+ acc_constraint = "f"
218
+ elif ir.F16Type.isinstance(out_ty):
219
+ num_acc_regs = n // 4
220
+ out_ty_field = i32
221
+ acc_regs = [_as_i32_reg(reg) for reg in acc.flat]
222
+ vec_ty = ir.VectorType(acc.flat[0].type)
223
+ to_acc_vec_regs = lambda regs : np.array([_unpack_i32(vec_ty, reg) for reg in regs]).reshape(acc.shape)
224
+ acc_constraint = "r"
225
+ else:
226
+ raise ValueError(f"WGMMA instruciton only supports f32 and f16 out (got {out_ty})")
227
+
228
+ num_imm_regs = 4 if supports_transpose else 2
229
+
230
+ if a_in_regs:
231
+ a_reg_constraints = ["r"] * 4 # 4x f16x2 registers
232
+ num_imm_regs -= 1 # transpose not supported for a in registers
233
+ else:
234
+ a_reg_constraints = ["l"] # descriptor
235
+ # Reference for i/o aliasing: https://gcc.gnu.org/onlinedocs/gcc/Extended-Asm.html
236
+ # Seems like it's not actually documented in LLVM IR docs.
237
+ reg_constraints_list = (
238
+ [f"={acc_constraint}"] * num_acc_regs # accumulator registers
239
+ + [str(i) for i in range(num_acc_regs)] # we alias outputs as inputs, too.
240
+ + a_reg_constraints # a descriptor / registers
241
+ + ["l"] * 1 # b descriptor
242
+ + ["n"] * (1 + num_imm_regs) # literal constants
243
+ )
244
+ reg_constraints = ",".join(reg_constraints_list)
245
+
246
+ reg_count = itertools.count()
247
+
248
+ def take_regs(n):
249
+ return (f"${i}" for i in itertools.islice(reg_count, n))
250
+
251
+ acc_reg_vector = "{" + ",".join(take_regs(num_acc_regs)) + "}"
252
+ for _ in take_regs(num_acc_regs): # Ignore next entries: aliasing.
253
+ pass
254
+ if a_in_regs:
255
+ a_regs = "{" + ",".join(take_regs(len(a_reg_constraints))) + "}"
256
+ else:
257
+ a_regs, = take_regs(1)
258
+ b_desc_reg, use_out_reg = take_regs(2)
259
+ imm_regs = ", ".join(take_regs(num_imm_regs)) # Immediate regs (scale, ...).
260
+ assert next(reg_count) == len(reg_constraints_list)
261
+ el_ty = element_type
262
+ k_instr = 32 // bytewidth(element_type)
263
+ wgmma_instr = (
264
+ f"wgmma.mma_async.sync.aligned.m64n{n}k{k_instr}.{out_ty}.{el_ty}.{el_ty} "
265
+ f"{acc_reg_vector}, {a_regs}, {b_desc_reg}, p, {imm_regs};"
266
+ )
267
+ ptx = f"{{ .reg .pred p; setp.ne.b32 p, {use_out_reg}, 0; {wgmma_instr} }}\n"
268
+
269
+ def lc(x):
270
+ return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
271
+
272
+ use_out = scale_a = scale_b = lc(1)
273
+ imms = [use_out, scale_a, scale_b]
274
+ if supports_transpose and a_transpose is not None:
275
+ imms += [lc(int(a_transpose)), lc(int(b_transpose))]
276
+ elif supports_transpose:
277
+ imms += [lc(int(b_transpose))]
278
+ if acc.ndim != 4 or acc.shape[0] != 1 or acc.shape[2:] != (2, 1):
279
+ raise ValueError(acc.shape)
280
+ acc_struct_type = ir.Type.parse(
281
+ f"!llvm.struct<({','.join(str(out_ty_field) for _ in acc_regs)})>"
282
+ )
283
+ for i in range(4):
284
+ # Slice out the relevant part of A or advance the A descriptor.
285
+ if a_in_regs:
286
+ a_slice = a[:, (i * 16) : ((i + 1) * 16)]
287
+ a_args = [_as_i32_reg(v) for v in a_slice.registers.flat]
288
+ else:
289
+ if i > 0:
290
+ a = llvm_add(
291
+ a,
292
+ llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, a_k_stride >> 4)),
293
+ )
294
+ a_args = [a]
295
+ # Advance the B descriptor.
296
+ if i > 0:
297
+ b_descriptor = llvm_add(
298
+ b_descriptor,
299
+ llvm.ConstantOp(i64, ir.IntegerAttr.get(i64, b_k_stride >> 4)),
300
+ )
301
+ assert len(a_args) == len(a_reg_constraints)
302
+ acc_struct = llvm.inline_asm(
303
+ acc_struct_type,
304
+ [*acc_regs, *a_args, b_descriptor, *imms],
305
+ ptx,
306
+ reg_constraints,
307
+ asm_dialect=0,
308
+ has_side_effects=True,
309
+ )
310
+ acc_regs = [
311
+ llvm.extractvalue(out_ty_field, acc_struct, [i]) for i in range(len(acc_regs))
312
+ ]
313
+ return to_acc_vec_regs(acc_regs)
314
+
315
+
316
+ class WGMMALayout(enum.Enum):
317
+ ROW_MAJOR = enum.auto()
318
+ COL_MAJOR = enum.auto()
319
+
320
+
321
+ # TODO(apaszke): Remove WGMMALayout. Make input shapes logical and infer
322
+ # transpositions from memref strides.
323
+ def wgmma(
324
+ acc: WGMMAAccumulator,
325
+ a,
326
+ b,
327
+ *,
328
+ # Order only applies within each tile!
329
+ a_order: WGMMALayout | None = None,
330
+ b_order: WGMMALayout = WGMMALayout.ROW_MAJOR,
331
+ ):
332
+ if a_in_regs := isinstance(a, mgpu.FragmentedArray):
333
+ a_element_type = a.mlir_dtype
334
+ a_shape = a.shape
335
+ else:
336
+ a_ty = ir.MemRefType(a.type)
337
+ a_element_type = a_ty.element_type
338
+ a_shape = a_ty.shape
339
+ b_ty = ir.MemRefType(b.type)
340
+ supported_types = {ir.F16Type.get(), ir.BF16Type.get(), ir.F32Type.get()}
341
+ if a_element_type not in supported_types:
342
+ raise ValueError(a_element_type)
343
+ if b_ty.element_type not in supported_types:
344
+ raise ValueError(b_ty.element_type)
345
+ if (element_type := a_element_type) != b_ty.element_type:
346
+ raise ValueError
347
+ element_bytewidth = bytewidth(element_type)
348
+ kn_tile = 128 // element_bytewidth
349
+
350
+ groups_k, groups_n = b_ty.shape[:2]
351
+ if b_ty.shape[2:] != [kn_tile, kn_tile]:
352
+ raise ValueError(b_ty.shape)
353
+
354
+ if a_in_regs:
355
+ if a_element_type != ir.F16Type.get() and a_element_type != ir.BF16Type.get():
356
+ raise ValueError(a_element_type)
357
+ if a_shape[0] % 64 or a_shape[1] % kn_tile:
358
+ raise ValueError(a_shape)
359
+ if a_shape[1] // kn_tile != groups_k:
360
+ raise ValueError(a_shape[1] // kn_tile, groups_k)
361
+ groups_m = a_shape[0] // 64
362
+ if a_order is not None:
363
+ raise ValueError(
364
+ "a_order can only be specified when A is in shared memory"
365
+ )
366
+ else:
367
+ groups_m = a_shape[0]
368
+ if a_shape[1] != groups_k:
369
+ raise ValueError(a_shape[1], groups_k)
370
+ if a_shape[2:] != [64, kn_tile]:
371
+ raise ValueError(a_shape)
372
+ if a_order is None:
373
+ a_order = WGMMALayout.ROW_MAJOR
374
+
375
+ row_major = WGMMALayout.ROW_MAJOR
376
+ col_major = WGMMALayout.COL_MAJOR
377
+ a_desc_fields = dict(
378
+ leading_byte_offset=((1 if a_order == row_major else 512) << 4),
379
+ stride_byte_offset=(64 << 4),
380
+ swizzle=128,
381
+ memory_space=3,
382
+ )
383
+ b_desc_fields = dict(
384
+ leading_byte_offset=((512 if b_order == row_major else 1) << 4),
385
+ stride_byte_offset=(64 << 4),
386
+ swizzle=128,
387
+ memory_space=3,
388
+ )
389
+ wgmma_params = dict(
390
+ a_transpose=a_order == col_major,
391
+ b_transpose=b_order == row_major,
392
+ a_k_stride=(2 if a_order == row_major else 128) * 16,
393
+ b_k_stride=(128 if b_order == row_major else 2) * 16,
394
+ n=(groups_n * kn_tile),
395
+ element_type=ir.FloatTF32Type.get()
396
+ if ir.F32Type.isinstance(element_type)
397
+ else element_type,
398
+ )
399
+ if a_in_regs:
400
+ wgmma_params["a_k_stride"] = wgmma_params["a_transpose"] = None
401
+
402
+ if a_in_regs:
403
+ a = wgmma_fence(a) # Make sure the registers are ready.
404
+ a_m_byte_stride = a_k_byte_stride = a_desc_base = None # Silence pytype.
405
+ else:
406
+ a_desc_base = create_descriptor(a, **a_desc_fields)
407
+ a_strides, _ = ir.MemRefType(a.type).get_strides_and_offset()
408
+ a_byte_strides = [s * element_bytewidth for s in a_strides]
409
+ a_m_byte_stride, a_k_byte_stride = a_byte_strides[:2]
410
+ if a_byte_strides[2:] != [128, element_bytewidth]:
411
+ raise ValueError(a_byte_strides)
412
+ b_desc_base = create_descriptor(b, **b_desc_fields)
413
+ b_strides, _ = b_ty.get_strides_and_offset()
414
+ b_byte_strides = [s * element_bytewidth for s in b_strides]
415
+ b_k_byte_stride = b_byte_strides[0]
416
+ if b_byte_strides[1:] != [128 * kn_tile, 128, element_bytewidth]:
417
+ raise ValueError(b_byte_strides)
418
+
419
+ i64 = ir.IntegerType.get_signless(64)
420
+ new_acc_regs = acc.value.registers.copy()
421
+ for mi in range(groups_m):
422
+ for ki in range(groups_k):
423
+ if a_in_regs:
424
+ a_mk = a[mi * 64 : (mi + 1) * 64, ki * kn_tile : (ki + 1) * kn_tile]
425
+ else:
426
+ a_mk = llvm_add(
427
+ a_desc_base,
428
+ c(wgmma_encode(mi * a_m_byte_stride + ki * a_k_byte_stride), i64),
429
+ )
430
+ b_k = llvm_add(b_desc_base, c(wgmma_encode(ki * b_k_byte_stride), i64))
431
+ new_acc_regs[mi : mi + 1] = wgmma_m64k128B(
432
+ new_acc_regs[mi : mi + 1], a_mk, b_k, **wgmma_params
433
+ )
434
+ return WGMMAAccumulator(
435
+ _value=mgpu.FragmentedArray(
436
+ _registers=new_acc_regs, _layout=mgpu.WGMMA_LAYOUT
437
+ ),
438
+ _sync=False,
439
+ )
440
+
441
+
442
+ def wgmma_fence(array: mgpu.FragmentedArray):
443
+ """Fences the array construction from WGMMA instructions.
444
+
445
+ This is a little workaround to force LLVM to initialize the PTX registers
446
+ before the wgmma.fence.sync.aligned instruction. Otherwise, LLVM treats
447
+ in-register computation as pure and can move it after the fence, which is
448
+ explicitly disallowed by the PTX programming model.
449
+ """
450
+ i32 = ir.IntegerType.get_signless(32)
451
+ index = ir.IndexType.get()
452
+ dtype = array.mlir_dtype
453
+ src_vec_ty = ir.VectorType(array.registers.flat[0].type)
454
+ assert src_vec_ty.shape == [2]
455
+
456
+ if dtype == ir.F32Type.get():
457
+ regs = [ # pylint: disable=g-complex-comprehension
458
+ vector.extractelement(reg, position=c(pos, index))
459
+ for reg in array.registers.flat
460
+ for pos in range(2)
461
+ ]
462
+ reg_dtype = dtype
463
+ reg_constraints_list = ["=f"] * len(regs) + ["f"] * len(regs)
464
+ ptx_lines = [f"mov.f32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
465
+ elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
466
+ regs = [_as_i32_reg(reg) for reg in array.registers.flat]
467
+ reg_dtype = i32
468
+ reg_constraints_list = ["=r"] * len(regs) + ["r"] * len(regs)
469
+ ptx_lines = [f"mov.b32 ${i}, ${len(regs)+i}" for i in range(len(regs))]
470
+ else:
471
+ raise NotImplementedError(dtype)
472
+ reg_constraints = ",".join(reg_constraints_list)
473
+ # Copy over the registers. ptxas should be able to remove the moves.
474
+ ptx_lines.append("wgmma.fence.sync.aligned")
475
+ ptx = ";\n".join(ptx_lines) + ";\n"
476
+ dtype_str = str(reg_dtype)
477
+ struct_ty = ir.Type.parse(
478
+ f"!llvm.struct<({','.join(dtype_str for _ in regs)})>"
479
+ )
480
+ acc_struct = llvm.inline_asm(
481
+ struct_ty, regs, ptx, reg_constraints,
482
+ asm_dialect=0, has_side_effects=True,
483
+ )
484
+ regs = [
485
+ llvm.extractvalue(reg_dtype, acc_struct, [i]) for i in range(len(regs))
486
+ ]
487
+ if dtype == ir.F32Type.get():
488
+ registers = _as_fragmented_reg_ndarray(
489
+ regs, array.mlir_dtype, array.registers.shape
490
+ )
491
+ elif dtype == ir.F16Type.get() or dtype == ir.BF16Type.get():
492
+ regs = [_unpack_i32(src_vec_ty, r) for r in regs]
493
+ registers = np.asarray(regs, dtype=object).reshape(array.registers.shape)
494
+ else:
495
+ raise NotImplementedError(dtype)
496
+ return mgpu.FragmentedArray(_registers=registers, _layout=array.layout)
497
+
498
+
499
+ def _as_fragmented_reg_ndarray(flat_regs, dtype: ir.Type, shape: tuple[int, ...]):
500
+ vec_regs = []
501
+ for first, second in zip(flat_regs[::2], flat_regs[1::2]):
502
+ vec = llvm.mlir_undef(ir.VectorType.get((2,), dtype))
503
+ vec = llvm.insertelement(vec, first, position=_lc(0))
504
+ vec = llvm.insertelement(vec, second, position=_lc(1))
505
+ vec_regs.append(vec)
506
+ return np.asarray(vec_regs, dtype=object).reshape(shape)
507
+
508
+
509
+ def _as_i32_reg(v):
510
+ i32 = ir.IntegerType.get_signless(32)
511
+ return llvm.extractelement(
512
+ vector.bitcast(ir.VectorType.get((1,), i32), v), _lc(0)
513
+ )
514
+
515
+
516
+ def _lc(x):
517
+ i32 = ir.IntegerType.get_signless(32)
518
+ return llvm.ConstantOp(i32, ir.IntegerAttr.get(i32, x)).result
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/__init__.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module for pallas, a JAX extension for custom kernels."""
16
+
17
+ from jax._src import pallas
18
+ from jax._src.pallas.core import BlockSpec
19
+ from jax._src.pallas.core import no_block_spec
20
+ from jax._src.pallas.core import Unblocked
21
+ from jax._src.pallas.core import unblocked
22
+ from jax._src.pallas.pallas_call import pallas_call
23
+ from jax._src.pallas.pallas_call import pallas_call_p
24
+ from jax._src.pallas.primitives import atomic_add
25
+ from jax._src.pallas.primitives import atomic_and
26
+ from jax._src.pallas.primitives import atomic_cas
27
+ from jax._src.pallas.primitives import atomic_max
28
+ from jax._src.pallas.primitives import atomic_min
29
+ from jax._src.pallas.primitives import atomic_or
30
+ from jax._src.pallas.primitives import atomic_xchg
31
+ from jax._src.pallas.primitives import atomic_xor
32
+ from jax._src.pallas.primitives import debug_print
33
+ from jax._src.pallas.primitives import dot
34
+ from jax._src.pallas.primitives import load
35
+ from jax._src.pallas.primitives import max_contiguous
36
+ from jax._src.pallas.primitives import multiple_of
37
+ from jax._src.pallas.primitives import num_programs
38
+ from jax._src.pallas.primitives import program_id
39
+ from jax._src.pallas.primitives import store
40
+ from jax._src.pallas.primitives import swap
41
+ from jax._src.pallas.utils import cdiv
42
+ from jax._src.pallas.utils import next_power_of_2
43
+ from jax._src.pallas.utils import strides_from_shape
44
+ from jax._src.pallas.utils import when
45
+ from jax._src.state.indexing import ds
46
+ from jax._src.state.indexing import dslice
47
+ from jax._src.state.indexing import Slice
48
+ from jax._src.state.primitives import broadcast_to
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/gpu.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Triton-specific Pallas APIs."""
16
+
17
+ from jax._src.pallas.triton.primitives import approx_tanh
18
+ from jax._src.pallas.triton.primitives import elementwise_inline_asm
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/ops/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # All files within ops should be treated as user code.
16
+ import os
17
+ import jax._src.source_info_util
18
+ jax._src.source_info_util.register_inclusion(os.path.dirname(__file__))
19
+ del os, jax
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/ops/gpu/attention.py ADDED
@@ -0,0 +1,573 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Module containing fused attention forward and backward pass."""
16
+ from __future__ import annotations
17
+
18
+ import functools
19
+ from typing import Any, Optional
20
+
21
+ import jax
22
+ from jax import lax
23
+ from jax.experimental import pallas as pl
24
+ import jax.numpy as jnp
25
+ import numpy as np
26
+
27
+ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max)
28
+
29
+
30
+ def mha_forward_kernel(
31
+ q_ref,
32
+ k_ref,
33
+ v_ref, # Input arrays
34
+ segment_ids_ref: jax.Array | None, # segment_id arrays
35
+ o_ref: Any, # Output
36
+ *residual_refs: Any, # Residual outputs
37
+ num_heads: int,
38
+ sm_scale: float,
39
+ causal: bool,
40
+ block_q: int,
41
+ block_d: int,
42
+ block_k: int,
43
+ ):
44
+ seq_len = q_ref.shape[0]
45
+ start_q = pl.program_id(0)
46
+
47
+ # o is the buffer where we accumulate the output on sram.
48
+ # m_i and l_i (see FlashAttention paper) are updated during the k,v loop.
49
+ m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf')
50
+ l_i = jnp.zeros(block_q, dtype=jnp.float32)
51
+ # acc is the buffer where we accumulate the output on sram.
52
+ o = jnp.zeros((block_q, block_d), dtype=jnp.float32)
53
+
54
+ # Load q: it will stay in L1 throughout. Indices form a matrix because we
55
+ # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index.
56
+ # q tile has shape [block_q, block_d], block_d == head_dim.
57
+ curr_q_slice = pl.dslice(start_q * block_q, block_q)
58
+ q = pl.load(q_ref, (curr_q_slice, pl.dslice(None)))
59
+ q_segment_ids = (
60
+ None
61
+ if segment_ids_ref is None
62
+ else pl.load(segment_ids_ref, (curr_q_slice,))
63
+ )
64
+ # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size
65
+ # (Bc == block_k here), and fast over blocks of q (size Br == block_q here).
66
+ # Here we only loop over blocks of kv to process entire seq_len, the loop over
67
+ # blocks of q is carried out by the grid.
68
+ def body(start_k, carry):
69
+ o_prev, m_prev, l_prev = carry
70
+ curr_k_slice = pl.dslice(start_k * block_k, block_k)
71
+
72
+ k = pl.load(k_ref, (curr_k_slice, slice(None)))
73
+ kv_segment_ids = (
74
+ None
75
+ if segment_ids_ref is None
76
+ else pl.load(segment_ids_ref, (curr_k_slice,))
77
+ )
78
+ qk = pl.dot(q, k.T) # [block_q, block_k]
79
+ if sm_scale != 1.:
80
+ qk *= sm_scale # [block_q, block_k]
81
+
82
+ # Avoids Triton crash.
83
+ # if num_heads > 2:
84
+ # qk = qk.astype(q_ref.dtype)
85
+ # qk = qk.astype(jnp.float32)
86
+
87
+ if causal or segment_ids_ref is not None:
88
+ mask = None
89
+ if segment_ids_ref is not None:
90
+ mask = segment_mask(q_segment_ids, kv_segment_ids)
91
+ if causal:
92
+ span_q = start_q * block_q + jnp.arange(block_q)
93
+ span_k = start_k * block_k + jnp.arange(block_k)
94
+ causal_mask = span_q[:, None] >= span_k[None, :]
95
+ mask = (
96
+ causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
97
+ )
98
+ # Apply mask to qk.
99
+ qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
100
+
101
+ m_curr = qk.max(axis=-1)
102
+ m_next = jnp.maximum(m_prev, m_curr)
103
+ correction = jnp.exp(m_prev - m_next)
104
+ l_prev_corr = correction * l_prev
105
+ s_curr = jnp.exp(
106
+ qk - m_next[:, None]
107
+ ) # Use m_next instead of m_curr to avoid a correction on l_curr
108
+ l_curr = s_curr.sum(axis=-1)
109
+ l_next = l_prev_corr + l_curr
110
+ l_next_rcp = 1. / l_next
111
+ s_curr = s_curr * l_next_rcp[:, None]
112
+ o_prev_corr = (l_prev_corr * l_next_rcp)[:, None] * o_prev
113
+ v = pl.load(v_ref, (curr_k_slice, pl.dslice(block_d)))
114
+ o_curr = pl.dot(s_curr.astype(v.dtype), v)
115
+
116
+ o_next = o_prev_corr + o_curr
117
+ return o_next, m_next, l_next
118
+ if causal:
119
+ # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q)
120
+ upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k)
121
+ else:
122
+ upper_bound = pl.cdiv(seq_len, block_k) # type: ignore
123
+ o, m_i, l_i = lax.fori_loop(0, upper_bound, body, (o, m_i, l_i))
124
+
125
+ if residual_refs:
126
+ l_ref, m_ref = residual_refs
127
+ pl.store(l_ref, (curr_q_slice,), l_i)
128
+ pl.store(m_ref, (curr_q_slice,), m_i)
129
+ # Write output to dram.
130
+ o = o.astype(o_ref.dtype)
131
+ pl.store(o_ref, (curr_q_slice, pl.dslice(None)), o)
132
+
133
+
134
+ def segment_mask(
135
+ q_segment_ids: jax.Array,
136
+ kv_segment_ids: jax.Array,
137
+ ):
138
+ # [B, T, 1] or [T, 1]
139
+ q_segment_ids = jnp.expand_dims(q_segment_ids, axis=-1)
140
+ # [B, 1, S] or [1, S]
141
+ if kv_segment_ids.ndim == 1:
142
+ kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=0)
143
+ else:
144
+ kv_segment_ids = jnp.expand_dims(kv_segment_ids, axis=1)
145
+ return jnp.equal(q_segment_ids, kv_segment_ids).astype(jnp.bool_)
146
+
147
+
148
+ @functools.partial(
149
+ jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13]
150
+ )
151
+ @functools.partial(
152
+ jax.jit,
153
+ static_argnames=[
154
+ "sm_scale",
155
+ "causal",
156
+ "block_q",
157
+ "block_k",
158
+ "backward_pass_impl",
159
+ "num_warps",
160
+ "num_stages",
161
+ "grid",
162
+ "interpret",
163
+ "debug",
164
+ ],
165
+ )
166
+ def mha(
167
+ q,
168
+ k,
169
+ v,
170
+ segment_ids: jnp.ndarray | None,
171
+ sm_scale: float = 1.0,
172
+ causal: bool = False,
173
+ block_q: int = 128,
174
+ block_k: int = 128,
175
+ backward_pass_impl: str = "triton",
176
+ num_warps: int | None = None,
177
+ num_stages: int = 2,
178
+ grid: tuple[int, ...] | None = None,
179
+ interpret: bool = False,
180
+ debug: bool = False,
181
+ ):
182
+ del backward_pass_impl
183
+ batch_size, seq_len, num_heads, head_dim = q.shape
184
+ block_q = min(block_q, seq_len)
185
+ block_k = min(block_k, seq_len)
186
+ # Heuristics.
187
+ grid_ = grid
188
+ if grid_ is None:
189
+ grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
190
+
191
+ num_warps_ = num_warps
192
+ if num_warps_ is None:
193
+ num_warps_ = 4 if head_dim <= 64 else 8
194
+ kernel = functools.partial(mha_forward_kernel, num_heads=num_heads,
195
+ sm_scale=sm_scale, block_q=block_q,
196
+ block_k=block_k, block_d=head_dim,
197
+ causal=causal)
198
+
199
+ in_specs = [
200
+ pl.BlockSpec(
201
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
202
+ ),
203
+ pl.BlockSpec(
204
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
205
+ ),
206
+ pl.BlockSpec(
207
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
208
+ ),
209
+ ]
210
+ in_specs.append(
211
+ None # type: ignore[arg-type]
212
+ if segment_ids is None
213
+ else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len))
214
+ )
215
+ out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
216
+ return pl.pallas_call(
217
+ kernel,
218
+ grid=grid_,
219
+ in_specs=in_specs,
220
+ out_specs=pl.BlockSpec(
221
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
222
+ ),
223
+ compiler_params=dict(
224
+ triton=dict(num_warps=num_warps_, num_stages=num_stages)
225
+ ),
226
+ out_shape=out_shape,
227
+ debug=debug,
228
+ interpret=interpret,
229
+ name="mha_forward",
230
+ )(q, k, v, segment_ids)
231
+
232
+
233
+ def _mha_forward(
234
+ q,
235
+ k,
236
+ v,
237
+ segment_ids: jax.Array | None,
238
+ sm_scale: float,
239
+ causal: bool,
240
+ block_q: int,
241
+ block_k: int,
242
+ backward_pass_impl: str,
243
+ num_warps: int | None,
244
+ num_stages: int,
245
+ grid: Any,
246
+ interpret: bool,
247
+ debug: bool,
248
+ ):
249
+ del backward_pass_impl
250
+ batch_size, seq_len, num_heads, head_dim = q.shape
251
+ block_q = min(block_q, seq_len)
252
+ block_k = min(block_k, seq_len)
253
+ # Heuristics.
254
+ grid_ = grid
255
+ if grid_ is None:
256
+ grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
257
+
258
+ num_warps_ = num_warps
259
+ if num_warps_ is None:
260
+ num_warps_ = 4 if head_dim <= 64 else 8
261
+ kernel = functools.partial(mha_forward_kernel, num_heads=num_heads,
262
+ sm_scale=sm_scale, causal=causal, block_q=block_q,
263
+ block_k=block_k, block_d=head_dim)
264
+ out_shape = [
265
+ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
266
+ jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # l
267
+ dtype=jnp.float32),
268
+ jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m
269
+ dtype=jnp.float32)
270
+ ]
271
+ in_specs = [
272
+ pl.BlockSpec(
273
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
274
+ ),
275
+ pl.BlockSpec(
276
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
277
+ ),
278
+ pl.BlockSpec(
279
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
280
+ ),
281
+ ]
282
+ in_specs.append(
283
+ None # type: ignore[arg-type]
284
+ if segment_ids is None
285
+ else pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len))
286
+ )
287
+ out, l, m = pl.pallas_call(
288
+ kernel,
289
+ grid=grid_,
290
+ in_specs=in_specs,
291
+ out_specs=[
292
+ pl.BlockSpec(
293
+ lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
294
+ ),
295
+ pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
296
+ pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
297
+ ],
298
+ compiler_params=dict(
299
+ triton=dict(num_warps=num_warps_, num_stages=num_stages)
300
+ ),
301
+ out_shape=out_shape,
302
+ debug=debug,
303
+ interpret=interpret,
304
+ name="mha_forward",
305
+ )(q, k, v, segment_ids)
306
+ return out, (q, k, v, segment_ids, out, l, m)
307
+
308
+
309
+ def _preprocess_backward_kernel(out_ref, dout_ref, l_ref,
310
+ new_dout_ref, delta_ref, *,
311
+ block_q: int):
312
+ pid_m = pl.program_id(0)
313
+
314
+ off_m = pl.ds(pid_m * block_q, block_q)
315
+ # load
316
+ o = pl.load(out_ref, (off_m, slice(None))).astype(jnp.float32)
317
+ do = pl.load(dout_ref, (off_m, slice(None))).astype(jnp.float32)
318
+ denom = pl.load(l_ref, (off_m,)).astype(jnp.float32)
319
+ # compute
320
+ do = do / denom[:, None]
321
+ delta = jnp.sum(o * do, axis=1)
322
+ # write-back
323
+ pl.store(new_dout_ref, (off_m, slice(None)),
324
+ do.astype(new_dout_ref.dtype))
325
+ pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype))
326
+
327
+ @jax.named_scope("preprocess_backward")
328
+ def _preprocess_backward(out, do, l, block_q: int,
329
+ debug: bool, interpret: bool):
330
+ batch_size, seq_len, num_heads, head_dim = out.shape
331
+ out_shape = [
332
+ jax.ShapeDtypeStruct(do.shape, do.dtype),
333
+ jax.ShapeDtypeStruct(l.shape, l.dtype),
334
+ ]
335
+ do_scaled, delta = pl.pallas_call(
336
+ functools.partial(_preprocess_backward_kernel, block_q=block_q),
337
+ grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads),
338
+ in_specs=[
339
+ pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
340
+ pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
341
+ pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
342
+ ],
343
+ out_specs=[
344
+ pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)),
345
+ pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)),
346
+ ],
347
+ compiler_params=dict(
348
+ triton=dict(num_warps=4, num_stages=3)
349
+ ),
350
+ out_shape=out_shape,
351
+ debug=debug,
352
+ interpret=interpret,
353
+ name="mha_preprocess_backward")(out, do, l)
354
+ return do_scaled, delta
355
+
356
+
357
+ def mha_backward_kernel(
358
+ # Inputs
359
+ q_ref,
360
+ k_ref,
361
+ v_ref,
362
+ segment_ids_ref: jax.Array | None,
363
+ out_ref,
364
+ do_scaled_ref,
365
+ l_ref,
366
+ m_ref,
367
+ delta_ref,
368
+ _,
369
+ # Outputs
370
+ dq_ref,
371
+ dk_ref,
372
+ dv_ref,
373
+ *,
374
+ sm_scale: float,
375
+ causal: bool,
376
+ block_q: int,
377
+ block_d: int,
378
+ block_k: int,
379
+ ):
380
+ del out_ref, l_ref # Not needed
381
+ seq_len = q_ref.shape[0]
382
+
383
+ def outer_loop(start_k, _):
384
+
385
+ dv = jnp.zeros([block_k, block_d], dtype=jnp.float32)
386
+ dk = jnp.zeros([block_k, block_d], dtype=jnp.float32)
387
+ k = pl.load(k_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
388
+ v = pl.load(v_ref, (pl.ds(start_k * block_k, block_k), slice(None)))
389
+ span_k = start_k * block_k + jnp.arange(block_k)
390
+ kv_segment_ids = (
391
+ None
392
+ if segment_ids_ref is None
393
+ else pl.load(segment_ids_ref, (pl.ds(start_k * block_k, block_k),))
394
+ )
395
+
396
+ def inner_loop(start_q, carry):
397
+ dv, dk = carry
398
+ q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
399
+ qk = pl.dot(q, k.T)
400
+ qk = qk.astype(q_ref.dtype)
401
+ qk = qk.astype(jnp.float32)
402
+ if sm_scale != 1.0:
403
+ qk *= sm_scale
404
+
405
+ q_segment_ids = (
406
+ None
407
+ if segment_ids_ref is None
408
+ else pl.load(segment_ids_ref, (pl.ds(start_q * block_q, block_q),))
409
+ )
410
+
411
+ if causal or segment_ids_ref is not None:
412
+ mask = None
413
+ if segment_ids_ref is not None:
414
+ mask = segment_mask(q_segment_ids, kv_segment_ids)
415
+
416
+ if causal:
417
+ span_q = start_q * block_q + jnp.arange(block_q)
418
+ causal_mask = span_q[:, None] >= span_k[None, :]
419
+ mask = (
420
+ causal_mask
421
+ if mask is None
422
+ else jnp.logical_and(mask, causal_mask)
423
+ )
424
+ qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE)
425
+
426
+ m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),))
427
+ p = jnp.exp(qk - m[:, None])
428
+ do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None)))
429
+ dv = dv + pl.dot(p.astype(do.dtype).T, do)
430
+ di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),))
431
+ dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None]
432
+ dp = dp + pl.dot(do, v.T)
433
+ ds = p * dp
434
+ if sm_scale != 1.0:
435
+ ds = ds * sm_scale
436
+ dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q)
437
+ dq = pl.load(dq_ref, (pl.ds(start_q * block_q, block_q),
438
+ slice(None)), eviction_policy="evict_last")
439
+ dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype)
440
+ pl.store(dq_ref, (pl.ds(start_q * block_q, block_q),
441
+ slice(None)), dq, eviction_policy="evict_last")
442
+ return dv, dk
443
+ if causal:
444
+ lower_bound = lax.div(start_k * block_k, block_q)
445
+ else:
446
+ lower_bound = 0
447
+ dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop,
448
+ (dv, dk))
449
+ pl.store(dv_ref, (pl.ds(start_k * block_k, block_k),
450
+ slice(None)), dv.astype(dv_ref.dtype))
451
+ pl.store(dk_ref, (pl.ds(start_k * block_k, block_k),
452
+ slice(None)), dk.astype(dk_ref.dtype))
453
+ lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None)
454
+
455
+
456
+ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
457
+ backward_pass_impl: str, num_warps: int | None,
458
+ num_stages: int, grid: Any, interpret: bool,
459
+ debug: bool, res, do):
460
+ del num_warps, num_stages, grid
461
+ q, k, v, segment_ids, out, l, m = res
462
+
463
+ if backward_pass_impl == "xla":
464
+ return jax.vjp(
465
+ functools.partial(mha_reference, sm_scale=sm_scale, causal=causal),
466
+ q,
467
+ k,
468
+ v,
469
+ segment_ids,
470
+ )[1](do)
471
+ elif backward_pass_impl == "triton":
472
+ batch_size, seq_len, num_heads, head_dim = q.shape
473
+ block_q = min(block_q, seq_len)
474
+ block_k = min(block_k, seq_len)
475
+ do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret)
476
+ # We accumulate into dq so we need to initialize it to zeros.
477
+ dq = jnp.zeros(q.shape, jnp.float32)
478
+ out_shapes = [
479
+ jax.ShapeDtypeStruct(dq.shape, dq.dtype),
480
+ jax.ShapeDtypeStruct(k.shape, k.dtype),
481
+ jax.ShapeDtypeStruct(v.shape, v.dtype),
482
+ ]
483
+
484
+ in_specs = [
485
+ pl.BlockSpec(
486
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
487
+ ),
488
+ pl.BlockSpec(
489
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
490
+ ),
491
+ pl.BlockSpec(
492
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
493
+ ),
494
+ pl.BlockSpec(
495
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
496
+ ),
497
+ pl.BlockSpec(
498
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
499
+ ),
500
+ pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
501
+ pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
502
+ pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)),
503
+ pl.BlockSpec(
504
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
505
+ ),
506
+ ]
507
+ if segment_ids is None:
508
+ in_specs.insert(3, None) # type: ignore[arg-type]
509
+ input_output_aliases = {8: 0}
510
+ else:
511
+ in_specs.insert(3, pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len)))
512
+ input_output_aliases = {9: 0}
513
+ grid = (batch_size, num_heads)
514
+ # TODO(sharadmv): figure out why num_warps=8 doesn't work!
515
+ num_warps = 8
516
+ dq, dk, dv = pl.pallas_call(
517
+ functools.partial(
518
+ mha_backward_kernel,
519
+ block_q=block_q,
520
+ block_d=head_dim,
521
+ block_k=block_k,
522
+ sm_scale=sm_scale,
523
+ causal=causal,
524
+ ),
525
+ grid=grid,
526
+ out_shape=out_shapes,
527
+ in_specs=in_specs,
528
+ out_specs=[
529
+ pl.BlockSpec(
530
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
531
+ ),
532
+ pl.BlockSpec(
533
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
534
+ ),
535
+ pl.BlockSpec(
536
+ lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)
537
+ ),
538
+ ],
539
+ name="mha_backward",
540
+ debug=debug,
541
+ interpret=interpret,
542
+ compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)),
543
+ input_output_aliases=input_output_aliases,
544
+ )(q, k, v, segment_ids, out, do_scaled, l, m, delta, dq)
545
+ else:
546
+ raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}")
547
+ return dq.astype(q.dtype), dk, dv, None
548
+ mha.defvjp(_mha_forward, _mha_backward)
549
+
550
+
551
+ @functools.partial(jax.jit, static_argnames=['sm_scale', 'causal'])
552
+ def mha_reference(
553
+ q,
554
+ k,
555
+ v,
556
+ segment_ids: jnp.ndarray | None,
557
+ sm_scale=1.0,
558
+ causal: bool = False,
559
+ ):
560
+ q_seq_len = q.shape[1]
561
+ kv_seq_len = k.shape[1]
562
+ logits = jnp.einsum('bqhc,bkhc->bhqk', q, k).astype(jnp.float32)
563
+ mask = None
564
+ if segment_ids is not None:
565
+ mask = jnp.expand_dims(segment_mask(segment_ids, segment_ids), 1)
566
+ mask = jnp.broadcast_to(mask, logits.shape)
567
+ if causal:
568
+ causal_mask = jnp.tril(jnp.ones((1, 1, q_seq_len, kv_seq_len), dtype=bool))
569
+ causal_mask = jnp.broadcast_to(causal_mask, logits.shape)
570
+ mask = causal_mask if mask is None else jnp.logical_and(mask, causal_mask)
571
+ logits = logits if mask is None else jnp.where(mask, logits, float("-inf"))
572
+ weights = jax.nn.softmax(logits * sm_scale).astype(q.dtype)
573
+ return jnp.einsum('bhqk,bkhc->bqhc', weights, v)
external/alphageometry/.venv-ag/Lib/site-packages/jax/experimental/pallas/tpu.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The JAX Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """Mosaic-specific Pallas APIs."""
16
+
17
+ from jax._src.pallas.mosaic import core
18
+ from jax._src.pallas.mosaic.core import dma_semaphore
19
+ from jax._src.pallas.mosaic.core import PrefetchScalarGridSpec
20
+ from jax._src.pallas.mosaic.core import semaphore
21
+ from jax._src.pallas.mosaic.core import SemaphoreType
22
+ from jax._src.pallas.mosaic.core import TPUMemorySpace
23
+ from jax._src.pallas.mosaic.kernel_regeneration_util import encode_kernel_regeneration_metadata
24
+ from jax._src.pallas.mosaic.kernel_regeneration_util import extract_kernel_regeneration_metadata
25
+ from jax._src.pallas.mosaic.lowering import LoweringException
26
+ from jax._src.pallas.mosaic.pipeline import BufferedRef
27
+ from jax._src.pallas.mosaic.pipeline import emit_pipeline
28
+ from jax._src.pallas.mosaic.pipeline import emit_pipeline_with_allocations
29
+ from jax._src.pallas.mosaic.pipeline import get_pipeline_schedule
30
+ from jax._src.pallas.mosaic.pipeline import make_pipeline_allocations
31
+ from jax._src.pallas.mosaic.primitives import async_copy
32
+ from jax._src.pallas.mosaic.primitives import async_remote_copy
33
+ from jax._src.pallas.mosaic.primitives import bitcast
34
+ from jax._src.pallas.mosaic.primitives import delay
35
+ from jax._src.pallas.mosaic.primitives import device_id
36
+ from jax._src.pallas.mosaic.primitives import DeviceIdType
37
+ from jax._src.pallas.mosaic.primitives import get_barrier_semaphore
38
+ from jax._src.pallas.mosaic.primitives import make_async_copy
39
+ from jax._src.pallas.mosaic.primitives import make_async_remote_copy
40
+ from jax._src.pallas.mosaic.primitives import repeat
41
+ from jax._src.pallas.mosaic.primitives import roll
42
+ from jax._src.pallas.mosaic.primitives import run_scoped
43
+ from jax._src.pallas.mosaic.primitives import semaphore_read
44
+ from jax._src.pallas.mosaic.primitives import semaphore_signal
45
+ from jax._src.pallas.mosaic.primitives import semaphore_wait
46
+ from jax._src.pallas.mosaic.primitives import prng_seed
47
+ from jax._src.pallas.mosaic.primitives import prng_random_bits
48
+ from jax._src.tpu_custom_call import CostEstimate
49
+
50
+ ANY = TPUMemorySpace.ANY
51
+ CMEM = TPUMemorySpace.CMEM
52
+ SMEM = TPUMemorySpace.SMEM
53
+ VMEM = TPUMemorySpace.VMEM
external/alphageometry/README.md ADDED
@@ -0,0 +1,447 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # Solving Olympiad Geometry without Human Demonstrations
3
+
4
+
5
+ This repository contains the code necessary to
6
+ reproduce DDAR and AlphaGeometry,
7
+ the two geometry theorem provers
8
+ introduced in the [Nature 2024](https://www.nature.com/articles/s41586-023-06747-5) paper:
9
+
10
+ *<center>"Solving Olympiad Geometry without Human Demonstrations".</center>*
11
+
12
+
13
+ </br>
14
+
15
+
16
+ <center>
17
+ <img alt="fig1" width="800px" src="fig1.svg">
18
+ </center>
19
+
20
+
21
+ ## Dependencies
22
+
23
+ For the instructions presented below,
24
+ we use Python 3.10.9, and dependencies with their exact
25
+ version numbers listed in `requirements.txt`.
26
+
27
+ Our code depends on `meliad`, which is
28
+ not a registered package with `pip`. See instructions below
29
+ for how to manually install `meliad`.
30
+
31
+ Note that one can still run the DDAR solver
32
+ without the `meliad` and `sentencepiece` dependencies.
33
+
34
+ ## Run the instructions
35
+
36
+ All instructions in this `README.md` can be run in one go by:
37
+
38
+ ```
39
+ bash run.sh
40
+ ```
41
+
42
+ Below, we explain these instructions step-by-step.
43
+
44
+ ## Install dependencies, download weights and vocabulary.
45
+
46
+ Installation is done in a virtual environment:
47
+
48
+ ```
49
+ virtualenv -p python3 .
50
+ source ./bin/activate
51
+ pip install --require-hashes -r requirements.txt
52
+ ```
53
+
54
+ Download weights and vocabulary:
55
+
56
+ ```
57
+ bash download.sh
58
+ DATA=ag_ckpt_vocab
59
+ ```
60
+
61
+ Finally, install `meliad` separately as it is not
62
+ registered with `pip`:
63
+
64
+ ```
65
+ MELIAD_PATH=meliad_lib/meliad
66
+ mkdir -p $MELIAD_PATH
67
+ git clone https://github.com/google-research/meliad $MELIAD_PATH
68
+ export PYTHONPATH=$PYTHONPATH:$MELIAD_PATH
69
+ ```
70
+
71
+ ## Set up common flags
72
+
73
+ Before running the python scripts,
74
+ let us first prepare some commonly used flags.
75
+ The symbolic engine needs definitions and deduction rules to operate.
76
+ These definitions and rules are provided in two text files
77
+ `defs.txt` and `rules.txt`.
78
+
79
+ ```shell
80
+ DDAR_ARGS=(
81
+ --defs_file=$(pwd)/defs.txt \
82
+ --rules_file=$(pwd)/rules.txt \
83
+ );
84
+ ```
85
+
86
+ Next, we define the flags relevant to the proof search.
87
+ To reproduce the simple examples below,
88
+ we use lightweight values for the proof search parameters:
89
+
90
+ ```shell
91
+ BATCH_SIZE=2
92
+ BEAM_SIZE=2
93
+ DEPTH=2
94
+
95
+ SEARCH_ARGS=(
96
+ --beam_size=$BEAM_SIZE
97
+ --search_depth=$DEPTH
98
+ )
99
+ ```
100
+
101
+ NOTE: The results in our paper can be obtained by setting
102
+ `BATCH_SIZE=32`, `BEAM_SIZE=512`, `DEPTH=16`
103
+ as described in section Methods.
104
+ To stay under IMO time limits, 4 V100-GPUs and 250 CPU workers
105
+ are needed as shown in Extended Data - Figure 1.
106
+ Note that we also strip away other memory/speed optimizations
107
+ due to internal dependencies and to promote code clarity.
108
+
109
+ Assume the downloaded checkpoint and vocabulary is placed in `DATA`,
110
+ and the installed `meliad` source code is at `MELIAD_PATH`.
111
+ We make use of the `gin` library to manage model configurations,
112
+ following `meliad` conventions. We now define the flags relevant to the
113
+ language model:
114
+
115
+ ```shell
116
+ LM_ARGS=(
117
+ --ckpt_path=$DATA \
118
+ --vocab_path=$DATA/geometry.757.model
119
+ --gin_search_paths=$MELIAD_PATH/transformer/configs,$(pwd) \
120
+ --gin_file=base_htrans.gin \
121
+ --gin_file=size/medium_150M.gin \
122
+ --gin_file=options/positions_t5.gin \
123
+ --gin_file=options/lr_cosine_decay.gin \
124
+ --gin_file=options/seq_1024_nocache.gin \
125
+ --gin_file=geometry_150M_generate.gin \
126
+ --gin_param=DecoderOnlyLanguageModelGenerate.output_token_losses=True \
127
+ --gin_param=TransformerTaskConfig.batch_size=$BATCH_SIZE \
128
+ --gin_param=TransformerTaskConfig.sequence_length=128 \
129
+ --gin_param=Trainer.restore_state_variables=False
130
+ );
131
+ ```
132
+
133
+ TIP: Note that you can still run the DDAR solver
134
+ without defining `SEARCH_ARGS` and `LM_ARGS`.
135
+ In such case, simply disable the import of the `lm_inference` module
136
+ inside `alphageometry.py`.
137
+
138
+ ## Run DDAR
139
+
140
+ The script loads a problem by reading a list of problems
141
+ from a text file and solves the specific problem in the list according
142
+ to its name. We pass these two pieces of information through the flags
143
+ `--problems_file` and `--problem_name`.
144
+ We use `--mode=ddar` to indicate that we want to use the DDAR solver.
145
+
146
+ Below we showed this solver solving IMO 2000 P1:
147
+
148
+ ```shell
149
+ python -m alphageometry \
150
+ --alsologtostderr \
151
+ --problems_file=$(pwd)/imo_ag_30.txt \
152
+ --problem_name=translated_imo_2000_p1 \
153
+ --mode=ddar \
154
+ "${DDAR_ARGS[@]}"
155
+ ```
156
+
157
+ Expect the following output
158
+
159
+ ```shell
160
+ graph.py:468] translated_imo_2000_p1
161
+ graph.py:469] a b = segment a b; g1 = on_tline g1 a a b; g2 = on_tline g2 b b a; m = on_circle m g1 a, on_circle m g2 b; n = on_circle n g1 a, on_circle n g2 b; c = on_pline c m a b, on_circle c g1 a; d = on_pline d m a b, on_circle d g2 b; e = on_line e a c, on_line e b d; p = on_line p a n, on_line p c d; q = on_line q b n, on_line q c d ? cong e p e q
162
+ ddar.py:41] Depth 1/1000 time = 1.7772269248962402
163
+ ddar.py:41] Depth 2/1000 time = 5.63526177406311
164
+ ddar.py:41] Depth 3/1000 time = 6.883412837982178
165
+ ddar.py:41] Depth 4/1000 time = 10.275688409805298
166
+ ddar.py:41] Depth 5/1000 time = 12.048273086547852
167
+ alphageometry.py:190]
168
+ ==========================
169
+ * From theorem premises:
170
+ A B G1 G2 M N C D E P Q : Points
171
+ AG_1 ⟂ AB [00]
172
+ BA ⟂ G_2B [01]
173
+ G_2M = G_2B [02]
174
+ G_1M = G_1A [03]
175
+
176
+ ...
177
+ [log omitted]
178
+ ...
179
+
180
+ 036. ∠QEB = ∠(QP-EA) [46] & ∠(BE-QP) = ∠AEP [55] ⇒ ∠EQP = ∠QPE [56]
181
+ 037. ∠PQE = ∠EPQ [56] ⇒ EP = EQ
182
+
183
+ ==========================
184
+ ```
185
+
186
+ The output first includes a list of relevant premises that it uses,
187
+ and then proof steps that gradually build up the proof.
188
+ All predicates are numbered to track how they are derived
189
+ from the premises, and to show that the proof is fully justified.
190
+
191
+ TIP: Additionally passing the flag `--out_file=path/to/output/text/file.txt`
192
+ will write the proof to a text file.
193
+
194
+ Running on all problems in `imo_ag_30.txt` will yield solutions to
195
+ 14 of them, as reported in Table 1 in our paper.
196
+
197
+ ## Run AlphaGeometry:
198
+
199
+ As a simple example, we load `--problem_name=orthocenter`
200
+ from `--problem_file=examples.txt`.
201
+ This time, we pass `--mode=alphageometry` to use the AlphaGeometry solver
202
+ and pass the `SEARCH_ARGS` and `LM_ARGS` flags.
203
+
204
+ ```shell
205
+ python -m alphageometry \
206
+ --alsologtostderr \
207
+ --problems_file=$(pwd)/examples.txt \
208
+ --problem_name=orthocenter \
209
+ --mode=alphageometry \
210
+ "${DDAR_ARGS[@]}" \
211
+ "${SEARCH_ARGS[@]}" \
212
+ "${LM_ARGS[@]}"
213
+ ```
214
+
215
+ Expect the following output:
216
+
217
+ ```shell
218
+ ...
219
+ [log omitted]
220
+ ...
221
+ training_loop.py:725] Total parameters: 152072288
222
+ training_loop.py:739] Total state size: 0
223
+ training_loop.py:492] Training loop: creating task for mode beam_search
224
+
225
+ graph.py:468] orthocenter
226
+ graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b ? perp a d b c
227
+ ddar.py:41] Depth 1/1000 time = 0.009987592697143555 branch = 4
228
+ ddar.py:41] Depth 2/1000 time = 0.00672602653503418 branch = 0
229
+ alphageometry.py:221] DD+AR failed to solve the problem.
230
+ alphageometry.py:457] Depth 0. There are 1 nodes to expand:
231
+ alphageometry.py:460] {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00
232
+ alphageometry.py:465] Decoding from {S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c {F1} x00
233
+ ...
234
+ [log omitted]
235
+ ...
236
+ alphageometry.py:470] LM output (score=-1.102287): "e : C a c e 02 C b d e 03 ;"
237
+ alphageometry.py:471] Translation: "e = on_line e a c, on_line e b d"
238
+
239
+ alphageometry.py:480] Solving: "a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c"
240
+ graph.py:468]
241
+ graph.py:469] a b c = triangle a b c; d = on_tline d b a c, on_tline d c a b; e = on_line e a c, on_line e b d ? perp a d b c
242
+ ddar.py:41] Depth 1/1000 time = 0.021120786666870117
243
+ ddar.py:41] Depth 2/1000 time = 0.033370018005371094
244
+ ddar.py:41] Depth 3/1000 time = 0.04297471046447754
245
+ alphageometry.py:140]
246
+ ==========================
247
+ * From theorem premises:
248
+ A B C D : Points
249
+ BD ⟂ AC [00]
250
+ CD ⟂ AB [01]
251
+
252
+ * Auxiliary Constructions:
253
+ E : Points
254
+ E,B,D are collinear [02]
255
+ E,C,A are collinear [03]
256
+
257
+ * Proof steps:
258
+ 001. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEA = ∠CED [04]
259
+ 002. E,B,D are collinear [02] & E,C,A are collinear [03] & BD ⟂ AC [00] ⇒ ∠BEC = ∠AED [05]
260
+ 003. A,E,C are collinear [03] & E,B,D are collinear [02] & AC ⟂ BD [00] ⇒ EC ⟂ EB [06]
261
+ 004. EC ⟂ EB [06] & CD ⟂ AB [01] ⇒ ∠(EC-BA) = ∠(EB-CD) [07]
262
+ 005. E,C,A are collinear [03] & E,B,D are collinear [02] & ∠(EC-BA) = ∠(EB-CD) [07] ⇒ ∠BAE = ∠CDE [08]
263
+ 006. ∠BEA = ∠CED [04] & ∠BAE = ∠CDE [08] (Similar Triangles)⇒ EB:EC = EA:ED [09]
264
+ 007. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠BCE = ∠ADE [10]
265
+ 008. EB:EC = EA:ED [09] & ∠BEC = ∠AED [05] (Similar Triangles)⇒ ∠EBC = ∠EAD [11]
266
+ 009. ∠BCE = ∠ADE [10] & E,C,A are collinear [03] & E,B,D are collinear [02] & ∠EBC = ∠EAD [11] ⇒ AD ⟂ BC
267
+ ==========================
268
+
269
+ alphageometry.py:505] Solved.
270
+ ```
271
+
272
+ NOTE: Point `H` is automatically renamed to `D`,
273
+ as the LM is trained on synthetic problems
274
+ where the points are named alphabetically, and so it expects
275
+ the same during test time.
276
+
277
+ NOTE: In this implementation of AlphaGeometry,
278
+ we removed all optimizations that are dependent on
279
+ internal infrastructure, e.g.,
280
+ parallelized model inference on multi GPUs,
281
+ parallelized DDAR on multiple CPUs,
282
+ parallel execution of LM and DDAR,
283
+ shared pool of CPU workers across different problems, etc.
284
+ We also removed some memory/speed optimizations and code
285
+ abstractions in favor of code clarity.
286
+
287
+ As can be seen in the output, initially DDAR failed to solve the problem.
288
+ The LM proposes two auxiliary constructions (because `BATCH_SIZE=2`):
289
+
290
+ * `e = eqdistance e c a b, eqdistance e b a c`, i.e.,
291
+ construct `E` as the intersection of circle (center=C, radius=AB) and
292
+ circle (center=B, radius=AC). This construction has a score of `-1.186`.
293
+ * `e = on_line e a c, on_line e b d`, i.e.,
294
+ `E` is the intersection of `AC` and `BD`.
295
+ This construction has a higher score (`-1.102287`) than the previous.
296
+
297
+ Since the second construction has a higher score, DDAR attempted the second
298
+ construction first and found the solution right away.
299
+ The proof search therefore terminates and there is no second iteration.
300
+
301
+ ## Results
302
+
303
+ Before attempting to reproduce the AlphaGeometry numbers in our paper,
304
+ please make sure to pass all tests in the prepared test suite:
305
+
306
+ ```
307
+ bash run_tests.sh
308
+ ```
309
+
310
+ NOTE: [Issues#14](https://github.com/google-deepmind/alphageometry/issues/14) reports that although the top beam decodes are still the same, the LM is not giving the same score for different users.
311
+
312
+ Then, pass the corresponding values for `--problem_file` (column)
313
+ and `--mode` (row), and
314
+ iterate on all problems to obtain the following results:
315
+
316
+ <center>
317
+
318
+ <b>Number of solved problems:</b>
319
+
320
+ | | `imo_ag_30.txt` | `jgex_ag_231.txt` |
321
+ |----------|------------------|-------------------|
322
+ | `ddar` | 14 | 198 |
323
+ | `alphageometry` | 25 | 228 |
324
+
325
+ </center>
326
+
327
+ ## Source code description
328
+
329
+ Files in this repository include python modules/scripts to run the solvers and
330
+ resource files necessary for the script to execute. We listed below
331
+ each of them and their description.
332
+
333
+ | File name | Description |
334
+ |------------------------|------------------------------------------------------------------------------------|
335
+ | `geometry.py` | Implements nodes (Point, Line, Circle, etc) in the proof state graph. |
336
+ | `numericals.py` | Implements the numerical engine in the dynamic geometry environment. |
337
+ | `graph_utils.py` | Implements utilities for the proof state graph. |
338
+ | `graph.py` | Implements the proof state graph. |
339
+ | `problem.py` | Implements the classes that represent the problem premises, conclusion, DAG nodes. |
340
+ | `dd.py` | Implements DD and its traceback. |
341
+ | `ar.py` | Implements AR and its traceback. |
342
+ | `trace_back.py` | Implements the recursive traceback and dependency difference algorithm. |
343
+ | `ddar.py` | Implements the combination DD+AR. |
344
+ | `beam_search.py` | Implements beam decoding of a language model in JAX. |
345
+ | `models.py` | Implements the transformer model. |
346
+ | `transformer_layer.py` | Implements the transformer layer. |
347
+ | `decoder_stack.py` | Implements the transformer decoder stack. |
348
+ | `lm_inference.py` | Implements an interface to a trained LM to perform decoding. |
349
+ | `alphageometry.py` | Main script that loads problems, calls DD+AR or AlphaGeometry solver, and prints solutions. |
350
+ | `pretty.py` | Pretty formating the solutions output by solvers. |
351
+ | `*_test.py` | Tests for the corresponding module. |
352
+ | `download.sh` | Script to download model checkpoints and LM |
353
+ | `run.sh` | Script to execute instructions in README. |
354
+ | `run_tests.sh` | Script to execute the test suite. |
355
+
356
+
357
+ Resource files:
358
+
359
+ | Resource file name | Description |
360
+ |------------------------|------------------------------------------------------------------------------------|
361
+ | `defs.txt` | Definitions of different geometric construction actions. |
362
+ | `rules.txt` | Deduction rules for DD. |
363
+ | `geometry_150M_generate.gin`| Gin config of the LM implemented in meliad. |
364
+ | `imo_ag_30.txt` | Problems in IMO-AG-30. |
365
+ | `jgex_ag_231.txt` | Problems in JGEX-AG-231. |
366
+
367
+
368
+
369
+ ## Citing this work
370
+
371
+ ```bibtex
372
+ @Article{AlphaGeometryTrinh2024,
373
+ author = {Trinh, Trieu and Wu, Yuhuai and Le, Quoc and He, He and Luong, Thang},
374
+ journal = {Nature},
375
+ title = {Solving Olympiad Geometry without Human Demonstrations},
376
+ year = {2024},
377
+ doi = {10.1038/s41586-023-06747-5}
378
+ }
379
+ ```
380
+
381
+ ## Acknowledgements
382
+
383
+ This research is a collaboration between the Google Brain team
384
+ (now Google Deepmind) and
385
+ the Computer Science Department of New York University.
386
+ We thank Rif A. Saurous, Denny Zhou, Christian Szegedy, Delesley Hutchins,
387
+ Thomas Kipf, Hieu Pham, Petar Veličković, Debidatta Dwibedi,
388
+ Kyunghyun Cho, Lerrel Pinto, Alfredo Canziani,
389
+ Thomas Wies, He He’s research group,
390
+ Evan Chen (the USA’s IMO team coach),
391
+ Mirek Olsak, Patrik Bak,
392
+ and all three Nature's referees for their help and support.
393
+
394
+ The code of AlphaGeometry communicates with and/or references the following
395
+ separate libraries and packages:
396
+
397
+ * [Abseil](https://github.com/abseil/abseil-py)
398
+ * [JAX](https://github.com/google/jax/)
399
+ * [matplotlib](https://matplotlib.org/)
400
+ * [NumPy](https://numpy.org)
401
+ * [SciPy](https://scipy.org)
402
+ * [TensorFlow](https://github.com/tensorflow/tensorflow)
403
+ * [Meliad](https://github.com/google-research/meliad)
404
+ * [Flax](https://github.com/google/flax)
405
+ * [Gin](https://github.com/google/gin-config)
406
+ * [T5](https://github.com/google-research/text-to-text-transfer-transformer)
407
+ * [SentencePiece](https://github.com/google/sentencepiece)
408
+
409
+
410
+
411
+ We thank all their contributors and maintainers!
412
+
413
+
414
+ ## Disclaimer
415
+
416
+ This is not an officially supported Google product.
417
+
418
+ This research code is provided "as-is" to the broader research community.
419
+ Google does not promise to maintain or otherwise support this code in any way.
420
+
421
+ ## Code License
422
+
423
+ Copyright 2023 DeepMind Technologies Limited
424
+
425
+ All software is licensed under the Apache License, Version 2.0 (Apache 2.0);
426
+ you may not use this file except in compliance with the Apache 2.0 license.
427
+ You may obtain a copy of the Apache 2.0 license at:
428
+ https://www.apache.org/licenses/LICENSE-2.0
429
+
430
+ All other materials are licensed under the Creative Commons Attribution 4.0
431
+ International License (CC-BY). You may obtain a copy of the CC-BY license at:
432
+ https://creativecommons.org/licenses/by/4.0/legalcode
433
+
434
+ Unless required by applicable law or agreed to in writing, all software and
435
+ materials distributed here under the Apache 2.0 or CC-BY licenses are
436
+ distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
437
+ either express or implied. See the licenses for the specific language governing
438
+ permissions and limitations under those licenses.
439
+
440
+ ## Model Parameters License
441
+
442
+ The AlphaGeometry checkpoints and vocabulary are made available
443
+ under the terms of the Creative Commons Attribution 4.0
444
+ International (CC BY 4.0) license.
445
+ You can find details at:
446
+ https://creativecommons.org/licenses/by/4.0/legalcode
447
+
external/alphageometry/lm_inference_test.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for lm_inference.py."""
17
+ import os
18
+ import unittest
19
+
20
+ from absl import flags
21
+ from absl.testing import absltest
22
+ import lm_inference as lm
23
+
24
+
25
+ _DATA_PATH = flags.DEFINE_string('data_path', '', 'path to ckpt and vocab.')
26
+ _MELIAD_PATH = flags.DEFINE_string(
27
+ 'meliad_path', '', 'path to meliad repository.'
28
+ ) # pylint: disable=line-too-long
29
+
30
+
31
+ class LmInferenceTest(unittest.TestCase):
32
+
33
+ @classmethod
34
+ def setUpClass(cls):
35
+ super().setUpClass()
36
+ gin_file = [
37
+ 'base_htrans.gin',
38
+ 'size/medium_150M.gin',
39
+ 'options/positions_t5.gin',
40
+ 'options/lr_cosine_decay.gin',
41
+ 'options/seq_1024_nocache.gin',
42
+ 'geometry_150M_generate.gin',
43
+ ]
44
+
45
+ gin_param = [
46
+ 'DecoderOnlyLanguageModelGenerate.output_token_losses=True',
47
+ 'TransformerTaskConfig.batch_size=2',
48
+ 'TransformerTaskConfig.sequence_length=128',
49
+ 'Trainer.restore_state_variables=False',
50
+ ]
51
+
52
+ gin_search_paths = [
53
+ os.path.join(_MELIAD_PATH.value, 'transformer/configs'),
54
+ os.getcwd(),
55
+ ]
56
+
57
+ vocab_path = os.path.join(_DATA_PATH.value, 'geometry.757.model')
58
+
59
+ lm.parse_gin_configuration(gin_file, gin_param, gin_paths=gin_search_paths)
60
+
61
+ cls.loaded_lm = lm.LanguageModelInference(
62
+ vocab_path, _DATA_PATH.value, mode='beam_search'
63
+ )
64
+
65
+ def test_lm_decode(self):
66
+ outputs = LmInferenceTest.loaded_lm.beam_decode(
67
+ '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
68
+ ' {F1} x00',
69
+ eos_tokens=[';'],
70
+ )
71
+ self.assertEqual(
72
+ outputs['seqs_str'],
73
+ ['e : D a b c e 02 D a c b e 03 ;', 'e : C a c e 02 C b d e 03 ;'],
74
+ )
75
+
76
+ def test_lm_score_may_fail_numerically_for_external_meliad(self):
77
+ outputs = LmInferenceTest.loaded_lm.beam_decode(
78
+ '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c'
79
+ ' {F1} x00',
80
+ eos_tokens=[';'],
81
+ )
82
+ self.assertEqual(
83
+ outputs['scores'],
84
+ [-1.18607294559478759765625, -1.10228693485260009765625],
85
+ )
86
+
87
+
88
+ if __name__ == '__main__':
89
+ absltest.main()
external/alphageometry/models.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Transformer language model generate mode."""
17
+
18
+ from typing import Any, Tuple
19
+ import beam_search
20
+ import decoder_stack
21
+ import gin
22
+ import jax
23
+ import jax.numpy as jnp
24
+ from transformer import models
25
+
26
+
27
+ @gin.configurable
28
+ class DecoderOnlyLanguageModelGenerate(models.DecoderOnlyLanguageModel):
29
+ """Decoder only language modeling in inference mode."""
30
+
31
+ decoder_factory = decoder_stack.DecoderStackGenerate
32
+
33
+ num_heads: int = gin.REQUIRED
34
+ head_size: int = gin.REQUIRED
35
+
36
+ def get_fake_input(self) -> dict[str, Any]:
37
+ fake_input_dict = super().get_fake_input()
38
+ b = self.task_config.batch_size
39
+ n = self.num_heads
40
+ h = self.head_size
41
+ fake_input_dict.update({
42
+ 'dstate': tuple(
43
+ [{
44
+ 'current_index': jnp.array([0] * b, dtype=jnp.int32),
45
+ 'keys': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
46
+ 'values': jnp.zeros((b, 2048, n, h), dtype=jnp.bfloat16),
47
+ 'recurrent_kvq': None,
48
+ 'relative_position_bias': jnp.zeros(
49
+ (b, n, 1, 1024), dtype=jnp.bfloat16
50
+ ),
51
+ }]
52
+ * 12
53
+ ),
54
+ 'eos': jnp.zeros([1024], dtype=jnp.bfloat16),
55
+ 'mask': jnp.ones([1024], dtype=jnp.bfloat16),
56
+ 'length': 1,
57
+ 'temperature': 1.0,
58
+ })
59
+ return fake_input_dict
60
+
61
+ def __call__(self, inputs: ...) -> tuple[Any, dict[str, Any]]:
62
+ # Make sure this code is not used on untested cases.
63
+ if self.mode not in ['init', 'beam_search']:
64
+ raise ValueError(f'{type(self)} cannot do mode {self.mode}')
65
+ if self.decoder.supports_generate():
66
+ raise ValueError(f'{type(self)}.decoder cannot supports_generate()')
67
+
68
+ self.decoder(
69
+ input_tokens=inputs['targets'][:, 0:1],
70
+ target_tokens=None,
71
+ start_of_sequence=inputs['start_of_sequence'],
72
+ )
73
+
74
+ b = inputs['targets'].shape[0]
75
+ no_start_of_seq = jnp.array([False] * b, dtype=jnp.bool_)
76
+
77
+ # This fn is used in both beam_search or topk_sampling.
78
+ def tokens_to_logits_fn(
79
+ input_token: jnp.ndarray, dstate: tuple[dict[str, jnp.ndarray], ...]
80
+ ) -> tuple[jnp.ndarray, tuple[dict[str, jnp.ndarray], ...]]:
81
+ (logits, dstate, _) = self.decoder(
82
+ input_tokens=input_token,
83
+ target_tokens=None,
84
+ start_of_sequence=no_start_of_seq,
85
+ decoder_state=dstate,
86
+ )
87
+ return logits[:, -1, :], dstate
88
+
89
+ last_token = jax.lax.dynamic_slice_in_dim(
90
+ inputs['targets'], inputs['length'] - 1, 1, axis=1
91
+ )
92
+
93
+ # last token is used to seed beam_search
94
+ inputs['targets'] = inputs['targets'][:, 0:-1]
95
+ dstate = jax.lax.cond(
96
+ inputs['start_of_sequence'][0],
97
+ lambda: self.generate(inputs)[0],
98
+ lambda: inputs['dstate'],
99
+ )
100
+
101
+ # Then we run beam search, init with last_token & dstate.
102
+ finished_seqs, finished_scores, dstate = beam_search.beam_search_flat(
103
+ last_token,
104
+ dstate,
105
+ tokens_to_logits_fn,
106
+ max_decode_len=512,
107
+ eos=inputs['eos'].reshape((1, 1, -1)),
108
+ mask=inputs['mask'].reshape((1, 1, -1)),
109
+ )
110
+
111
+ return 0.0, {
112
+ 'finished_seqs': finished_seqs,
113
+ 'finished_scores': finished_scores,
114
+ 'dstate': dstate,
115
+ }
116
+
117
+ def generate(
118
+ self, inputs: ...
119
+ ) -> tuple[tuple[dict[str, jnp.ndarray, ...], ...], jnp.ndarray]:
120
+ """Generate an output sequence.
121
+
122
+ Args:
123
+ inputs: the same as argument to _call_.
124
+
125
+ Returns:
126
+ An array of generated tokens of shape (batch_size, sequence_length).
127
+ """
128
+ input_tokens = inputs['targets'] # [b,seq_len]
129
+ start_of_sequence = inputs['start_of_sequence'] # [b]
130
+ target_tokens = jnp.pad(input_tokens[:, 1:], [(0, 0), (0, 1)])
131
+ batch_size = target_tokens.shape[0]
132
+
133
+ # Assuming all sequences start at the same time.
134
+ start0 = inputs['start_of_sequence'][0]
135
+ dstate = jax.lax.cond(
136
+ start0,
137
+ lambda: self.decoder.init_decoder_state_vanilla( # pylint: disable=g-long-lambda
138
+ 1024, start_of_sequence
139
+ ),
140
+ lambda: inputs['dstate'],
141
+ )
142
+
143
+ first_token = input_tokens[:, 0:1]
144
+ no_start_of_seq = jnp.array([False] * batch_size, dtype=jnp.bool_)
145
+ temperature = 1
146
+ if 'temperature' in inputs:
147
+ temperature = inputs['temperature']
148
+
149
+ num_steps = inputs['length']
150
+ if self.mode == 'beam_search':
151
+ num_steps -= 1
152
+
153
+ def cond_fn(scan_state) -> jnp.bool_:
154
+ _, _, i, _ = scan_state
155
+ return i < num_steps
156
+
157
+ def loop_fn(scan_state: Any) -> Tuple[Any, Any, Any, Any]:
158
+ (dstate, input_token, i, _) = scan_state
159
+
160
+ (logits, dstate, _) = self.decoder(
161
+ input_tokens=input_token,
162
+ target_tokens=None,
163
+ start_of_sequence=no_start_of_seq,
164
+ decoder_state=dstate,
165
+ )
166
+
167
+ logits = logits / temperature
168
+ output_token = jax.lax.dynamic_slice_in_dim(target_tokens, i, 1, axis=1)
169
+
170
+ return (dstate, output_token, i + 1, logits)
171
+
172
+ # Scan over the sequence length.
173
+ dummy_logits = jnp.zeros((batch_size, 1, 1024))
174
+ initial_scan_state = (dstate, first_token, 0, dummy_logits)
175
+ dstate, _, _, logits = jax.lax.while_loop(
176
+ cond_fn, loop_fn, initial_scan_state
177
+ )
178
+ return dstate, logits
external/alphageometry/numericals.py ADDED
@@ -0,0 +1,1921 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Numerical representation of geometry."""
17
+ from __future__ import annotations
18
+
19
+ import math
20
+ from typing import Any, Optional, Union
21
+
22
+ import geometry as gm
23
+ import matplotlib
24
+ from matplotlib import pyplot as plt
25
+ import matplotlib.colors as mcolors
26
+ import numpy as np
27
+ from numpy.random import uniform as unif # pylint: disable=g-importing-member
28
+
29
+
30
+ matplotlib.use('TkAgg')
31
+
32
+
33
+ ATOM = 1e-12
34
+
35
+
36
+ # Some variables are there for better code reading.
37
+ # pylint: disable=unused-assignment
38
+ # pylint: disable=unused-argument
39
+ # pylint: disable=unused-variable
40
+
41
+ # Naming in geometry is a little different
42
+ # we stick to geometry naming to better read the code.
43
+ # pylint: disable=invalid-name
44
+
45
+
46
+ class Point:
47
+ """Numerical point."""
48
+
49
+ def __init__(self, x, y):
50
+ self.x = x
51
+ self.y = y
52
+
53
+ def __lt__(self, other: Point) -> bool:
54
+ return (self.x, self.y) < (other.x, other.y)
55
+
56
+ def __gt__(self, other: Point) -> bool:
57
+ return (self.x, self.y) > (other.x, other.y)
58
+
59
+ def __add__(self, p: Point) -> Point:
60
+ return Point(self.x + p.x, self.y + p.y)
61
+
62
+ def __sub__(self, p: Point) -> Point:
63
+ return Point(self.x - p.x, self.y - p.y)
64
+
65
+ def __mul__(self, f: float) -> Point:
66
+ return Point(self.x * f, self.y * f)
67
+
68
+ def __rmul__(self, f: float) -> Point:
69
+ return self * f
70
+
71
+ def __truediv__(self, f: float) -> Point:
72
+ return Point(self.x / f, self.y / f)
73
+
74
+ def __floordiv__(self, f: float) -> Point:
75
+ div = self / f # true div
76
+ return Point(int(div.x), int(div.y))
77
+
78
+ def __str__(self) -> str:
79
+ return 'P({},{})'.format(self.x, self.y)
80
+
81
+ def close(self, point: Point, tol: float = 1e-12) -> bool:
82
+ return abs(self.x - point.x) < tol and abs(self.y - point.y) < tol
83
+
84
+ def midpoint(self, p: Point) -> Point:
85
+ return Point(0.5 * (self.x + p.x), 0.5 * (self.y + p.y))
86
+
87
+ def distance(self, p: Union[Point, Line, Circle]) -> float:
88
+ if isinstance(p, Line):
89
+ return p.distance(self)
90
+ if isinstance(p, Circle):
91
+ return abs(p.radius - self.distance(p.center))
92
+ dx = self.x - p.x
93
+ dy = self.y - p.y
94
+ return np.sqrt(dx * dx + dy * dy)
95
+
96
+ def distance2(self, p: Point) -> float:
97
+ if isinstance(p, Line):
98
+ return p.distance(self)
99
+ dx = self.x - p.x
100
+ dy = self.y - p.y
101
+ return dx * dx + dy * dy
102
+
103
+ def rotatea(self, ang: float) -> Point:
104
+ sinb, cosb = np.sin(ang), np.cos(ang)
105
+ return self.rotate(sinb, cosb)
106
+
107
+ def rotate(self, sinb: float, cosb: float) -> Point:
108
+ x, y = self.x, self.y
109
+ return Point(x * cosb - y * sinb, x * sinb + y * cosb)
110
+
111
+ def flip(self) -> Point:
112
+ return Point(-self.x, self.y)
113
+
114
+ def perpendicular_line(self, line: Line) -> Line:
115
+ return line.perpendicular_line(self)
116
+
117
+ def foot(self, line: Line) -> Point:
118
+ if isinstance(line, Line):
119
+ l = line.perpendicular_line(self)
120
+ return line_line_intersection(l, line)
121
+ elif isinstance(line, Circle):
122
+ c, r = line.center, line.radius
123
+ return c + (self - c) * r / self.distance(c)
124
+ raise ValueError('Dropping foot to weird type {}'.format(type(line)))
125
+
126
+ def parallel_line(self, line: Line) -> Line:
127
+ return line.parallel_line(self)
128
+
129
+ def norm(self) -> float:
130
+ return np.sqrt(self.x**2 + self.y**2)
131
+
132
+ def cos(self, other: Point) -> float:
133
+ x, y = self.x, self.y
134
+ a, b = other.x, other.y
135
+ return (x * a + y * b) / self.norm() / other.norm()
136
+
137
+ def dot(self, other: Point) -> float:
138
+ return self.x * other.x + self.y * other.y
139
+
140
+ def sign(self, line: Line) -> int:
141
+ return line.sign(self)
142
+
143
+ def is_same(self, other: Point) -> bool:
144
+ return self.distance(other) <= ATOM
145
+
146
+
147
+ class Line:
148
+ """Numerical line."""
149
+
150
+ def __init__(
151
+ self,
152
+ p1: Point = None,
153
+ p2: Point = None,
154
+ coefficients: tuple[int, int, int] = None,
155
+ ):
156
+ if p1 is None and p2 is None and coefficients is None:
157
+ self.coefficients = None, None, None
158
+ return
159
+
160
+ a, b, c = coefficients or (
161
+ p1.y - p2.y,
162
+ p2.x - p1.x,
163
+ p1.x * p2.y - p2.x * p1.y,
164
+ )
165
+
166
+ # Make sure a is always positive (or always negative for that matter)
167
+ # With a == 0, Assuming a = +epsilon > 0
168
+ # Then b such that ax + by = 0 with y>0 should be negative.
169
+ if a < 0.0 or a == 0.0 and b > 0.0:
170
+ a, b, c = -a, -b, -c
171
+
172
+ self.coefficients = a, b, c
173
+
174
+ def parallel_line(self, p: Point) -> Line:
175
+ a, b, _ = self.coefficients
176
+ return Line(coefficients=(a, b, -a * p.x - b * p.y)) # pylint: disable=invalid-unary-operand-type
177
+
178
+ def perpendicular_line(self, p: Point) -> Line:
179
+ a, b, _ = self.coefficients
180
+ return Line(p, p + Point(a, b))
181
+
182
+ def greater_than(self, other: Line) -> bool:
183
+ a, b, _ = self.coefficients
184
+ x, y, _ = other.coefficients
185
+ # b/a > y/x
186
+ return b * x > a * y
187
+
188
+ def __gt__(self, other: Line) -> bool:
189
+ return self.greater_than(other)
190
+
191
+ def __lt__(self, other: Line) -> bool:
192
+ return other.greater_than(self)
193
+
194
+ def same(self, other: Line) -> bool:
195
+ a, b, c = self.coefficients
196
+ x, y, z = other.coefficients
197
+ return close_enough(a * y, b * x) and close_enough(b * z, c * y)
198
+
199
+ def equal(self, other: Line) -> bool:
200
+ a, b, _ = self.coefficients
201
+ x, y, _ = other.coefficients
202
+ # b/a == y/x
203
+ return b * x == a * y
204
+
205
+ def less_than(self, other: Line) -> bool:
206
+ a, b, _ = self.coefficients
207
+ x, y, _ = other.coefficients
208
+ # b/a > y/x
209
+ return b * x < a * y
210
+
211
+ def intersect(self, obj: Union[Line, Circle]) -> tuple[Point, ...]:
212
+ if isinstance(obj, Line):
213
+ return line_line_intersection(self, obj)
214
+ if isinstance(obj, Circle):
215
+ return line_circle_intersection(self, obj)
216
+
217
+ def distance(self, p: Point) -> float:
218
+ a, b, c = self.coefficients
219
+ return abs(self(p.x, p.y)) / math.sqrt(a * a + b * b)
220
+
221
+ def __call__(self, x: Point, y: Point = None) -> float:
222
+ if isinstance(x, Point) and y is None:
223
+ return self(x.x, x.y)
224
+ a, b, c = self.coefficients
225
+ return x * a + y * b + c
226
+
227
+ def is_parallel(self, other: Line) -> bool:
228
+ a, b, _ = self.coefficients
229
+ x, y, _ = other.coefficients
230
+ return abs(a * y - b * x) < ATOM
231
+
232
+ def is_perp(self, other: Line) -> bool:
233
+ a, b, _ = self.coefficients
234
+ x, y, _ = other.coefficients
235
+ return abs(a * x + b * y) < ATOM
236
+
237
+ def cross(self, other: Line) -> float:
238
+ a, b, _ = self.coefficients
239
+ x, y, _ = other.coefficients
240
+ return a * y - b * x
241
+
242
+ def dot(self, other: Line) -> float:
243
+ a, b, _ = self.coefficients
244
+ x, y, _ = other.coefficients
245
+ return a * x + b * y
246
+
247
+ def point_at(self, x: float = None, y: float = None) -> Optional[Point]:
248
+ """Get a point on line closest to (x, y)."""
249
+ a, b, c = self.coefficients
250
+ # ax + by + c = 0
251
+ if x is None and y is not None:
252
+ if a != 0:
253
+ return Point((-c - b * y) / a, y) # pylint: disable=invalid-unary-operand-type
254
+ else:
255
+ return None
256
+ elif x is not None and y is None:
257
+ if b != 0:
258
+ return Point(x, (-c - a * x) / b) # pylint: disable=invalid-unary-operand-type
259
+ else:
260
+ return None
261
+ elif x is not None and y is not None:
262
+ if a * x + b * y + c == 0.0:
263
+ return Point(x, y)
264
+ return None
265
+
266
+ def diff_side(self, p1: Point, p2: Point) -> Optional[bool]:
267
+ d1 = self(p1.x, p1.y)
268
+ d2 = self(p2.x, p2.y)
269
+ if d1 == 0 or d2 == 0:
270
+ return None
271
+ return d1 * d2 < 0
272
+
273
+ def same_side(self, p1: Point, p2: Point) -> Optional[bool]:
274
+ d1 = self(p1.x, p1.y)
275
+ d2 = self(p2.x, p2.y)
276
+ if d1 == 0 or d2 == 0:
277
+ return None
278
+ return d1 * d2 > 0
279
+
280
+ def sign(self, point: Point) -> int:
281
+ s = self(point.x, point.y)
282
+ if s > 0:
283
+ return 1
284
+ elif s < 0:
285
+ return -1
286
+ return 0
287
+
288
+ def is_same(self, other: Line) -> bool:
289
+ a, b, c = self.coefficients
290
+ x, y, z = other.coefficients
291
+ return abs(a * y - b * x) <= ATOM and abs(b * z - c * y) <= ATOM
292
+
293
+ def sample_within(self, points: list[Point], n: int = 5) -> list[Point]:
294
+ """Sample a point within the boundary of points."""
295
+ center = sum(points, Point(0.0, 0.0)) * (1.0 / len(points))
296
+ radius = max([p.distance(center) for p in points])
297
+ if close_enough(center.distance(self), radius):
298
+ center = center.foot(self)
299
+ a, b = line_circle_intersection(self, Circle(center.foot(self), radius))
300
+
301
+ result = None
302
+ best = -1.0
303
+ for _ in range(n):
304
+ rand = unif(0.0, 1.0)
305
+ x = a + (b - a) * rand
306
+ mind = min([x.distance(p) for p in points])
307
+ if mind > best:
308
+ best = mind
309
+ result = x
310
+
311
+ return [result]
312
+
313
+
314
+ class InvalidLineIntersectError(Exception):
315
+ pass
316
+
317
+
318
+ class HalfLine(Line):
319
+ """Numerical ray."""
320
+
321
+ def __init__(self, tail: Point, head: Point): # pylint: disable=super-init-not-called
322
+ self.line = Line(tail, head)
323
+ self.coefficients = self.line.coefficients
324
+ self.tail = tail
325
+ self.head = head
326
+
327
+ def intersect(self, obj: Union[Line, HalfLine, Circle, HoleCircle]) -> Point:
328
+ if isinstance(obj, (HalfLine, Line)):
329
+ return line_line_intersection(self.line, obj)
330
+
331
+ exclude = [self.tail]
332
+ if isinstance(obj, HoleCircle):
333
+ exclude += [obj.hole]
334
+
335
+ a, b = line_circle_intersection(self.line, obj)
336
+ if any([a.close(x) for x in exclude]):
337
+ return b
338
+ if any([b.close(x) for x in exclude]):
339
+ return a
340
+
341
+ v = self.head - self.tail
342
+ va = a - self.tail
343
+ vb = b - self.tail
344
+ if v.dot(va) > 0:
345
+ return a
346
+ if v.dot(vb) > 0:
347
+ return b
348
+ raise InvalidLineIntersectError()
349
+
350
+ def sample_within(self, points: list[Point], n: int = 5) -> list[Point]:
351
+ center = sum(points, Point(0.0, 0.0)) * (1.0 / len(points))
352
+ radius = max([p.distance(center) for p in points])
353
+ if close_enough(center.distance(self.line), radius):
354
+ center = center.foot(self)
355
+ a, b = line_circle_intersection(self, Circle(center.foot(self), radius))
356
+
357
+ if (a - self.tail).dot(self.head - self.tail) > 0:
358
+ a, b = self.tail, a
359
+ else:
360
+ a, b = self.tail, b # pylint: disable=self-assigning-variable
361
+
362
+ result = None
363
+ best = -1.0
364
+ for _ in range(n):
365
+ x = a + (b - a) * unif(0.0, 1.0)
366
+ mind = min([x.distance(p) for p in points])
367
+ if mind > best:
368
+ best = mind
369
+ result = x
370
+
371
+ return [result]
372
+
373
+
374
+ def _perpendicular_bisector(p1: Point, p2: Point) -> Line:
375
+ midpoint = (p1 + p2) * 0.5
376
+ return Line(midpoint, midpoint + Point(p2.y - p1.y, p1.x - p2.x))
377
+
378
+
379
+ def same_sign(
380
+ a: Point, b: Point, c: Point, d: Point, e: Point, f: Point
381
+ ) -> bool:
382
+ a, b, c, d, e, f = map(lambda p: p.sym, [a, b, c, d, e, f])
383
+ ab, cb = a - b, c - b
384
+ de, fe = d - e, f - e
385
+ return (ab.x * cb.y - ab.y * cb.x) * (de.x * fe.y - de.y * fe.x) > 0
386
+
387
+
388
+ class Circle:
389
+ """Numerical circle."""
390
+
391
+ def __init__(
392
+ self,
393
+ center: Optional[Point] = None,
394
+ radius: Optional[float] = None,
395
+ p1: Optional[Point] = None,
396
+ p2: Optional[Point] = None,
397
+ p3: Optional[Point] = None,
398
+ ):
399
+ if not center:
400
+ if not (p1 and p2 and p3):
401
+ self.center = self.radius = self.r2 = None
402
+ return
403
+ # raise ValueError('Circle without center need p1 p2 p3')
404
+
405
+ l12 = _perpendicular_bisector(p1, p2)
406
+ l23 = _perpendicular_bisector(p2, p3)
407
+ center = line_line_intersection(l12, l23)
408
+
409
+ self.center = center
410
+ self.a, self.b = center.x, center.y
411
+
412
+ if not radius:
413
+ if not (p1 or p2 or p3):
414
+ raise ValueError('Circle needs radius or p1 or p2 or p3')
415
+ p = p1 or p2 or p3
416
+ self.r2 = (self.a - p.x) ** 2 + (self.b - p.y) ** 2
417
+ self.radius = math.sqrt(self.r2)
418
+ else:
419
+ self.radius = radius
420
+ self.r2 = radius * radius
421
+
422
+ def intersect(self, obj: Union[Line, Circle]) -> tuple[Point, ...]:
423
+ if isinstance(obj, Line):
424
+ return obj.intersect(self)
425
+ if isinstance(obj, Circle):
426
+ return circle_circle_intersection(self, obj)
427
+
428
+ def sample_within(self, points: list[Point], n: int = 5) -> list[Point]:
429
+ """Sample a point within the boundary of points."""
430
+ result = None
431
+ best = -1.0
432
+ for _ in range(n):
433
+ ang = unif(0.0, 2.0) * np.pi
434
+ x = self.center + Point(np.cos(ang), np.sin(ang)) * self.radius
435
+ mind = min([x.distance(p) for p in points])
436
+ if mind > best:
437
+ best = mind
438
+ result = x
439
+
440
+ return [result]
441
+
442
+
443
+ class HoleCircle(Circle):
444
+ """Numerical circle with a missing point."""
445
+
446
+ def __init__(self, center: Point, radius: float, hole: Point):
447
+ super().__init__(center, radius)
448
+ self.hole = hole
449
+
450
+ def intersect(self, obj: Union[Line, HalfLine, Circle, HoleCircle]) -> Point:
451
+ if isinstance(obj, Line):
452
+ a, b = line_circle_intersection(obj, self)
453
+ if a.close(self.hole):
454
+ return b
455
+ return a
456
+ if isinstance(obj, HalfLine):
457
+ return obj.intersect(self)
458
+ if isinstance(obj, Circle):
459
+ a, b = circle_circle_intersection(obj, self)
460
+ if a.close(self.hole):
461
+ return b
462
+ return a
463
+ if isinstance(obj, HoleCircle):
464
+ a, b = circle_circle_intersection(obj, self)
465
+ if a.close(self.hole) or a.close(obj.hole):
466
+ return b
467
+ return a
468
+
469
+
470
+ def solve_quad(a: float, b: float, c: float) -> tuple[float, float]:
471
+ """Solve a x^2 + bx + c = 0."""
472
+ a = 2 * a
473
+ d = b * b - 2 * a * c
474
+ if d < 0:
475
+ return None # the caller should expect this result.
476
+
477
+ y = math.sqrt(d)
478
+ return (-b - y) / a, (-b + y) / a
479
+
480
+
481
+ def circle_circle_intersection(c1: Circle, c2: Circle) -> tuple[Point, Point]:
482
+ """Returns a pair of Points as intersections of c1 and c2."""
483
+ # circle 1: (x0, y0), radius r0
484
+ # circle 2: (x1, y1), radius r1
485
+ x0, y0, r0 = c1.a, c1.b, c1.radius
486
+ x1, y1, r1 = c2.a, c2.b, c2.radius
487
+
488
+ d = math.sqrt((x1 - x0) ** 2 + (y1 - y0) ** 2)
489
+ if d == 0:
490
+ raise InvalidQuadSolveError()
491
+
492
+ a = (r0**2 - r1**2 + d**2) / (2 * d)
493
+ h = r0**2 - a**2
494
+ if h < 0:
495
+ raise InvalidQuadSolveError()
496
+ h = np.sqrt(h)
497
+ x2 = x0 + a * (x1 - x0) / d
498
+ y2 = y0 + a * (y1 - y0) / d
499
+ x3 = x2 + h * (y1 - y0) / d
500
+ y3 = y2 - h * (x1 - x0) / d
501
+ x4 = x2 - h * (y1 - y0) / d
502
+ y4 = y2 + h * (x1 - x0) / d
503
+
504
+ return Point(x3, y3), Point(x4, y4)
505
+
506
+
507
+ class InvalidQuadSolveError(Exception):
508
+ pass
509
+
510
+
511
+ def line_circle_intersection(line: Line, circle: Circle) -> tuple[Point, Point]:
512
+ """Returns a pair of points as intersections of line and circle."""
513
+ a, b, c = line.coefficients
514
+ r = float(circle.radius)
515
+ center = circle.center
516
+ p, q = center.x, center.y
517
+
518
+ if b == 0:
519
+ x = -c / a
520
+ x_p = x - p
521
+ x_p2 = x_p * x_p
522
+ y = solve_quad(1, -2 * q, q * q + x_p2 - r * r)
523
+ if y is None:
524
+ raise InvalidQuadSolveError()
525
+ y1, y2 = y
526
+ return (Point(x, y1), Point(x, y2))
527
+
528
+ if a == 0:
529
+ y = -c / b
530
+ y_q = y - q
531
+ y_q2 = y_q * y_q
532
+ x = solve_quad(1, -2 * p, p * p + y_q2 - r * r)
533
+ if x is None:
534
+ raise InvalidQuadSolveError()
535
+ x1, x2 = x
536
+ return (Point(x1, y), Point(x2, y))
537
+
538
+ c_ap = c + a * p
539
+ a2 = a * a
540
+ y = solve_quad(
541
+ a2 + b * b, 2 * (b * c_ap - a2 * q), c_ap * c_ap + a2 * (q * q - r * r)
542
+ )
543
+ if y is None:
544
+ raise InvalidQuadSolveError()
545
+ y1, y2 = y
546
+
547
+ return Point(-(b * y1 + c) / a, y1), Point(-(b * y2 + c) / a, y2)
548
+
549
+
550
+ def _check_between(a: Point, b: Point, c: Point) -> bool:
551
+ """Whether a is between b & c."""
552
+ return (a - b).dot(c - b) > 0 and (a - c).dot(b - c) > 0
553
+
554
+
555
+ def circle_segment_intersect(
556
+ circle: Circle, p1: Point, p2: Point
557
+ ) -> list[Point]:
558
+ l = Line(p1, p2)
559
+ px, py = line_circle_intersection(l, circle)
560
+
561
+ result = []
562
+ if _check_between(px, p1, p2):
563
+ result.append(px)
564
+ if _check_between(py, p1, p2):
565
+ result.append(py)
566
+ return result
567
+
568
+
569
+ def line_segment_intersection(l: Line, A: Point, B: Point) -> Point: # pylint: disable=invalid-name
570
+ a, b, c = l.coefficients
571
+ x1, y1, x2, y2 = A.x, A.y, B.x, B.y
572
+ dx, dy = x2 - x1, y2 - y1
573
+ alpha = (-c - a * x1 - b * y1) / (a * dx + b * dy)
574
+ return Point(x1 + alpha * dx, y1 + alpha * dy)
575
+
576
+
577
+ def line_line_intersection(l1: Line, l2: Line) -> Point:
578
+ a1, b1, c1 = l1.coefficients
579
+ a2, b2, c2 = l2.coefficients
580
+ # a1x + b1y + c1 = 0
581
+ # a2x + b2y + c2 = 0
582
+ d = a1 * b2 - a2 * b1
583
+ if d == 0:
584
+ raise InvalidLineIntersectError
585
+ return Point((c2 * b1 - c1 * b2) / d, (c1 * a2 - c2 * a1) / d)
586
+
587
+
588
+ def check_too_close(
589
+ newpoints: list[Point], points: list[Point], tol: int = 0.1
590
+ ) -> bool:
591
+ if not points:
592
+ return False
593
+ avg = sum(points, Point(0.0, 0.0)) * 1.0 / len(points)
594
+ mindist = min([p.distance(avg) for p in points])
595
+ for p0 in newpoints:
596
+ for p1 in points:
597
+ if p0.distance(p1) < tol * mindist:
598
+ return True
599
+ return False
600
+
601
+
602
+ def check_too_far(
603
+ newpoints: list[Point], points: list[Point], tol: int = 4
604
+ ) -> bool:
605
+ if len(points) < 2:
606
+ return False
607
+ avg = sum(points, Point(0.0, 0.0)) * 1.0 / len(points)
608
+ maxdist = max([p.distance(avg) for p in points])
609
+ for p in newpoints:
610
+ if p.distance(avg) > maxdist * tol:
611
+ return True
612
+ return False
613
+
614
+
615
+ def check_aconst(args: list[Point]) -> bool:
616
+ a, b, c, d, num, den = args
617
+ d = d + a - c
618
+ ang = ang_between(a, b, d)
619
+ if ang < 0:
620
+ ang += np.pi
621
+ return close_enough(ang, num * np.pi / den)
622
+
623
+
624
+ def check(name: str, args: list[Union[gm.Point, Point]]) -> bool:
625
+ """Numerical check."""
626
+ if name == 'eqangle6':
627
+ name = 'eqangle'
628
+ elif name == 'eqratio6':
629
+ name = 'eqratio'
630
+ elif name in ['simtri2', 'simtri*']:
631
+ name = 'simtri'
632
+ elif name in ['contri2', 'contri*']:
633
+ name = 'contri'
634
+ elif name == 'para':
635
+ name = 'para_or_coll'
636
+ elif name == 'on_line':
637
+ name = 'coll'
638
+ elif name in ['rcompute', 'acompute']:
639
+ return True
640
+ elif name in ['fixl', 'fixc', 'fixb', 'fixt', 'fixp']:
641
+ return True
642
+
643
+ fn_name = 'check_' + name
644
+ if fn_name not in globals():
645
+ return None
646
+
647
+ fun = globals()['check_' + name]
648
+ args = [p.num if isinstance(p, gm.Point) else p for p in args]
649
+ return fun(args)
650
+
651
+
652
+ def check_circle(points: list[Point]) -> bool:
653
+ if len(points) != 4:
654
+ return False
655
+ o, a, b, c = points
656
+ oa, ob, oc = o.distance(a), o.distance(b), o.distance(c)
657
+ return close_enough(oa, ob) and close_enough(ob, oc)
658
+
659
+
660
+ def check_coll(points: list[Point]) -> bool:
661
+ a, b = points[:2]
662
+ l = Line(a, b)
663
+ for p in points[2:]:
664
+ if abs(l(p.x, p.y)) > ATOM:
665
+ return False
666
+ return True
667
+
668
+
669
+ def check_ncoll(points: list[Point]) -> bool:
670
+ return not check_coll(points)
671
+
672
+
673
+ def check_sameside(points: list[Point]) -> bool:
674
+ b, a, c, y, x, z = points
675
+ # whether b is to the same side of a & c as y is to x & z
676
+ ba = b - a
677
+ bc = b - c
678
+ yx = y - x
679
+ yz = y - z
680
+ return ba.dot(bc) * yx.dot(yz) > 0
681
+
682
+
683
+ def check_para_or_coll(points: list[Point]) -> bool:
684
+ return check_para(points) or check_coll(points)
685
+
686
+
687
+ def check_para(points: list[Point]) -> bool:
688
+ a, b, c, d = points
689
+ ab = Line(a, b)
690
+ cd = Line(c, d)
691
+ if ab.same(cd):
692
+ return False
693
+ return ab.is_parallel(cd)
694
+
695
+
696
+ def check_perp(points: list[Point]) -> bool:
697
+ a, b, c, d = points
698
+ ab = Line(a, b)
699
+ cd = Line(c, d)
700
+ return ab.is_perp(cd)
701
+
702
+
703
+ def check_cyclic(points: list[Point]) -> bool:
704
+ points = list(set(points))
705
+ (a, b, c), *ps = points
706
+ circle = Circle(p1=a, p2=b, p3=c)
707
+ for d in ps:
708
+ if not close_enough(d.distance(circle.center), circle.radius):
709
+ return False
710
+ return True
711
+
712
+
713
+ def bring_together(
714
+ a: Point, b: Point, c: Point, d: Point
715
+ ) -> tuple[Point, Point, Point, Point]:
716
+ ab = Line(a, b)
717
+ cd = Line(c, d)
718
+ x = line_line_intersection(ab, cd)
719
+ unit = Circle(center=x, radius=1.0)
720
+ y, _ = line_circle_intersection(ab, unit)
721
+ z, _ = line_circle_intersection(cd, unit)
722
+ return x, y, x, z
723
+
724
+
725
+ def same_clock(
726
+ a: Point, b: Point, c: Point, d: Point, e: Point, f: Point
727
+ ) -> bool:
728
+ ba = b - a
729
+ cb = c - b
730
+ ed = e - d
731
+ fe = f - e
732
+ return (ba.x * cb.y - ba.y * cb.x) * (ed.x * fe.y - ed.y * fe.x) > 0
733
+
734
+
735
+ def check_const_angle(points: list[Point]) -> bool:
736
+ """Check if the angle is equal to the given constant."""
737
+ a, b, c, d, m, n = points
738
+ a, b, c, d = bring_together(a, b, c, d)
739
+ ba = b - a
740
+ dc = d - c
741
+
742
+ a3 = np.arctan2(ba.y, ba.x)
743
+ a4 = np.arctan2(dc.y, dc.x)
744
+ y = a3 - a4
745
+
746
+ return close_enough(m / n % 1, y / np.pi % 1)
747
+
748
+
749
+ def check_eqangle(points: list[Point]) -> bool:
750
+ """Check if 8 points make 2 equal angles."""
751
+ a, b, c, d, e, f, g, h = points
752
+
753
+ ab = Line(a, b)
754
+ cd = Line(c, d)
755
+ ef = Line(e, f)
756
+ gh = Line(g, h)
757
+
758
+ if ab.is_parallel(cd):
759
+ return ef.is_parallel(gh)
760
+ if ef.is_parallel(gh):
761
+ return ab.is_parallel(cd)
762
+
763
+ a, b, c, d = bring_together(a, b, c, d)
764
+ e, f, g, h = bring_together(e, f, g, h)
765
+
766
+ ba = b - a
767
+ dc = d - c
768
+ fe = f - e
769
+ hg = h - g
770
+
771
+ sameclock = (ba.x * dc.y - ba.y * dc.x) * (fe.x * hg.y - fe.y * hg.x) > 0
772
+ if not sameclock:
773
+ ba = ba * -1.0
774
+
775
+ a1 = np.arctan2(fe.y, fe.x)
776
+ a2 = np.arctan2(hg.y, hg.x)
777
+ x = a1 - a2
778
+
779
+ a3 = np.arctan2(ba.y, ba.x)
780
+ a4 = np.arctan2(dc.y, dc.x)
781
+ y = a3 - a4
782
+
783
+ xy = (x - y) % (2 * np.pi)
784
+ return close_enough(xy, 0, tol=1e-11) or close_enough(
785
+ xy, 2 * np.pi, tol=1e-11
786
+ )
787
+
788
+
789
+ def check_eqratio(points: list[Point]) -> bool:
790
+ a, b, c, d, e, f, g, h = points
791
+ ab = a.distance(b)
792
+ cd = c.distance(d)
793
+ ef = e.distance(f)
794
+ gh = g.distance(h)
795
+ return close_enough(ab * gh, cd * ef)
796
+
797
+
798
+ def check_cong(points: list[Point]) -> bool:
799
+ a, b, c, d = points
800
+ return close_enough(a.distance(b), c.distance(d))
801
+
802
+
803
+ def check_midp(points: list[Point]) -> bool:
804
+ a, b, c = points
805
+ return check_coll(points) and close_enough(a.distance(b), a.distance(c))
806
+
807
+
808
+ def check_simtri(points: list[Point]) -> bool:
809
+ """Check if 6 points make a pair of similar triangles."""
810
+ a, b, c, x, y, z = points
811
+ ab = a.distance(b)
812
+ bc = b.distance(c)
813
+ ca = c.distance(a)
814
+ xy = x.distance(y)
815
+ yz = y.distance(z)
816
+ zx = z.distance(x)
817
+ tol = 1e-9
818
+ return close_enough(ab * yz, bc * xy, tol) and close_enough(
819
+ bc * zx, ca * yz, tol
820
+ )
821
+
822
+
823
+ def check_contri(points: list[Point]) -> bool:
824
+ a, b, c, x, y, z = points
825
+ ab = a.distance(b)
826
+ bc = b.distance(c)
827
+ ca = c.distance(a)
828
+ xy = x.distance(y)
829
+ yz = y.distance(z)
830
+ zx = z.distance(x)
831
+ tol = 1e-9
832
+ return (
833
+ close_enough(ab, xy, tol)
834
+ and close_enough(bc, yz, tol)
835
+ and close_enough(ca, zx, tol)
836
+ )
837
+
838
+
839
+ def check_ratio(points: list[Point]) -> bool:
840
+ a, b, c, d, m, n = points
841
+ ab = a.distance(b)
842
+ cd = c.distance(d)
843
+ return close_enough(ab * n, cd * m)
844
+
845
+
846
+ def draw_angle(
847
+ ax: matplotlib.axes.Axes,
848
+ head: Point,
849
+ p1: Point,
850
+ p2: Point,
851
+ color: Any = 'red',
852
+ alpha: float = 0.5,
853
+ frac: float = 1.0,
854
+ ) -> None:
855
+ """Draw an angle on plt ax."""
856
+ d1 = p1 - head
857
+ d2 = p2 - head
858
+
859
+ a1 = np.arctan2(float(d1.y), float(d1.x))
860
+ a2 = np.arctan2(float(d2.y), float(d2.x))
861
+ a1, a2 = a1 * 180 / np.pi, a2 * 180 / np.pi
862
+ a1, a2 = a1 % 360, a2 % 360
863
+
864
+ if a1 > a2:
865
+ a1, a2 = a2, a1
866
+
867
+ if a2 - a1 > 180:
868
+ a1, a2 = a2, a1
869
+
870
+ b1, b2 = a1, a2
871
+ if b1 > b2:
872
+ b2 += 360
873
+ d = b2 - b1
874
+ # if d >= 90:
875
+ # return
876
+
877
+ scale = min(2.0, 90 / d)
878
+ scale = max(scale, 0.4)
879
+ fov = matplotlib.patches.Wedge(
880
+ (float(head.x), float(head.y)),
881
+ unif(0.075, 0.125) * scale * frac,
882
+ a1,
883
+ a2,
884
+ color=color,
885
+ alpha=alpha,
886
+ )
887
+ ax.add_artist(fov)
888
+
889
+
890
+ def naming_position(
891
+ ax: matplotlib.axes.Axes, p: Point, lines: list[Line], circles: list[Circle]
892
+ ) -> tuple[float, float]:
893
+ """Figure out a good naming position on the drawing."""
894
+ _ = ax
895
+ r = 0.08
896
+ c = Circle(center=p, radius=r)
897
+ avoid = []
898
+ for p1, p2 in lines:
899
+ try:
900
+ avoid.extend(circle_segment_intersect(c, p1, p2))
901
+ except InvalidQuadSolveError:
902
+ continue
903
+ for x in circles:
904
+ try:
905
+ avoid.extend(circle_circle_intersection(c, x))
906
+ except InvalidQuadSolveError:
907
+ continue
908
+
909
+ if not avoid:
910
+ return [p.x + 0.01, p.y + 0.01]
911
+
912
+ angs = sorted([ang_of(p, a) for a in avoid])
913
+ angs += [angs[0] + 2 * np.pi]
914
+ angs = [(angs[i + 1] - a, a) for i, a in enumerate(angs[:-1])]
915
+
916
+ d, a = max(angs)
917
+ ang = a + d / 2
918
+
919
+ name_pos = p + Point(np.cos(ang), np.sin(ang)) * r
920
+
921
+ x, y = (name_pos.x - r / 1.5, name_pos.y - r / 1.5)
922
+ return x, y
923
+
924
+
925
+ def draw_point(
926
+ ax: matplotlib.axes.Axes,
927
+ p: Point,
928
+ name: str,
929
+ lines: list[Line],
930
+ circles: list[Circle],
931
+ color: Any = 'white',
932
+ size: float = 15,
933
+ ) -> None:
934
+ """draw a point."""
935
+ ax.scatter(p.x, p.y, color=color, s=size)
936
+
937
+ if color == 'white':
938
+ color = 'lightgreen'
939
+ else:
940
+ color = 'grey'
941
+
942
+ name = name.upper()
943
+ if len(name) > 1:
944
+ name = name[0] + '_' + name[1:]
945
+
946
+ ax.annotate(
947
+ name, naming_position(ax, p, lines, circles), color=color, fontsize=15
948
+ )
949
+
950
+
951
+ def _draw_line(
952
+ ax: matplotlib.axes.Axes,
953
+ p1: Point,
954
+ p2: Point,
955
+ color: Any = 'white',
956
+ lw: float = 1.2,
957
+ alpha: float = 0.8,
958
+ ) -> None:
959
+ """Draw a line in matplotlib."""
960
+ ls = '-'
961
+ if color == '--':
962
+ color = 'black'
963
+ ls = '--'
964
+
965
+ lx, ly = (p1.x, p2.x), (p1.y, p2.y)
966
+ ax.plot(lx, ly, color=color, lw=lw, alpha=alpha, ls=ls)
967
+
968
+
969
+ def draw_line(
970
+ ax: matplotlib.axes.Axes, line: Line, color: Any = 'white'
971
+ ) -> tuple[Point, Point]:
972
+ """Draw a line."""
973
+ points = line.neighbors(gm.Point)
974
+ if len(points) <= 1:
975
+ return
976
+
977
+ points = [p.num for p in points]
978
+ p1, p2 = points[:2]
979
+
980
+ pmin, pmax = (p1, 0.0), (p2, (p2 - p1).dot(p2 - p1))
981
+
982
+ for p in points[2:]:
983
+ v = (p - p1).dot(p2 - p1)
984
+ if v < pmin[1]:
985
+ pmin = p, v
986
+ if v > pmax[1]:
987
+ pmax = p, v
988
+
989
+ p1, p2 = pmin[0], pmax[0]
990
+ _draw_line(ax, p1, p2, color=color)
991
+ return p1, p2
992
+
993
+
994
+ def _draw_circle(
995
+ ax: matplotlib.axes.Axes, c: Circle, color: Any = 'cyan', lw: float = 1.2
996
+ ) -> None:
997
+ ls = '-'
998
+ if color == '--':
999
+ color = 'black'
1000
+ ls = '--'
1001
+
1002
+ ax.add_patch(
1003
+ plt.Circle(
1004
+ (c.center.x, c.center.y),
1005
+ c.radius,
1006
+ color=color,
1007
+ alpha=0.8,
1008
+ fill=False,
1009
+ lw=lw,
1010
+ ls=ls,
1011
+ )
1012
+ )
1013
+
1014
+
1015
+ def draw_circle(
1016
+ ax: matplotlib.axes.Axes, circle: Circle, color: Any = 'cyan'
1017
+ ) -> Circle:
1018
+ """Draw a circle."""
1019
+ if circle.num is not None:
1020
+ circle = circle.num
1021
+ else:
1022
+ points = circle.neighbors(gm.Point)
1023
+ if len(points) <= 2:
1024
+ return
1025
+ points = [p.num for p in points]
1026
+ p1, p2, p3 = points[:3]
1027
+ circle = Circle(p1=p1, p2=p2, p3=p3)
1028
+
1029
+ _draw_circle(ax, circle, color)
1030
+ return circle
1031
+
1032
+
1033
+ def mark_segment(
1034
+ ax: matplotlib.axes.Axes, p1: Point, p2: Point, color: Any, alpha: float
1035
+ ) -> None:
1036
+ _ = alpha
1037
+ x, y = (p1.x + p2.x) / 2, (p1.y + p2.y) / 2
1038
+ ax.scatter(x, y, color=color, alpha=1.0, marker='o', s=50)
1039
+
1040
+
1041
+ def highlight_angle(
1042
+ ax: matplotlib.axes.Axes,
1043
+ a: Point,
1044
+ b: Point,
1045
+ c: Point,
1046
+ d: Point,
1047
+ color: Any,
1048
+ alpha: float,
1049
+ ) -> None:
1050
+ """Highlight an angle between ab and cd with (color, alpha)."""
1051
+ try:
1052
+ a, b, c, d = bring_together(a, b, c, d)
1053
+ except: # pylint: disable=bare-except
1054
+ return
1055
+ draw_angle(ax, a, b, d, color=color, alpha=alpha, frac=1.0)
1056
+
1057
+
1058
+ def highlight(
1059
+ ax: matplotlib.axes.Axes,
1060
+ name: str,
1061
+ args: list[gm.Point],
1062
+ lcolor: Any,
1063
+ color1: Any,
1064
+ color2: Any,
1065
+ ) -> None:
1066
+ """Draw highlights."""
1067
+ args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
1068
+
1069
+ if name == 'cyclic':
1070
+ a, b, c, d = args
1071
+ _draw_circle(ax, Circle(p1=a, p2=b, p3=c), color=color1, lw=2.0)
1072
+ if name == 'coll':
1073
+ a, b, c = args
1074
+ a, b = max(a, b, c), min(a, b, c)
1075
+ _draw_line(ax, a, b, color=color1, lw=2.0)
1076
+ if name == 'para':
1077
+ a, b, c, d = args
1078
+ _draw_line(ax, a, b, color=color1, lw=2.0)
1079
+ _draw_line(ax, c, d, color=color2, lw=2.0)
1080
+ if name == 'eqangle':
1081
+ a, b, c, d, e, f, g, h = args
1082
+
1083
+ x = line_line_intersection(Line(a, b), Line(c, d))
1084
+ if b.distance(x) > a.distance(x):
1085
+ a, b = b, a
1086
+ if d.distance(x) > c.distance(x):
1087
+ c, d = d, c
1088
+ a, b, d = x, a, c
1089
+
1090
+ y = line_line_intersection(Line(e, f), Line(g, h))
1091
+ if f.distance(y) > e.distance(y):
1092
+ e, f = f, e
1093
+ if h.distance(y) > g.distance(y):
1094
+ g, h = h, g
1095
+ e, f, h = y, e, g
1096
+
1097
+ _draw_line(ax, a, b, color=lcolor, lw=2.0)
1098
+ _draw_line(ax, a, d, color=lcolor, lw=2.0)
1099
+ _draw_line(ax, e, f, color=lcolor, lw=2.0)
1100
+ _draw_line(ax, e, h, color=lcolor, lw=2.0)
1101
+ if color1 == '--':
1102
+ color1 = 'red'
1103
+ draw_angle(ax, a, b, d, color=color1, alpha=0.5)
1104
+ if color2 == '--':
1105
+ color2 = 'red'
1106
+ draw_angle(ax, e, f, h, color=color2, alpha=0.5)
1107
+ if name == 'perp':
1108
+ a, b, c, d = args
1109
+ _draw_line(ax, a, b, color=color1, lw=2.0)
1110
+ _draw_line(ax, c, d, color=color1, lw=2.0)
1111
+ if name == 'ratio':
1112
+ a, b, c, d, m, n = args
1113
+ _draw_line(ax, a, b, color=color1, lw=2.0)
1114
+ _draw_line(ax, c, d, color=color2, lw=2.0)
1115
+ if name == 'cong':
1116
+ a, b, c, d = args
1117
+ _draw_line(ax, a, b, color=color1, lw=2.0)
1118
+ _draw_line(ax, c, d, color=color2, lw=2.0)
1119
+ if name == 'midp':
1120
+ m, a, b = args
1121
+ _draw_line(ax, a, m, color=color1, lw=2.0, alpha=0.5)
1122
+ _draw_line(ax, b, m, color=color2, lw=2.0, alpha=0.5)
1123
+ if name == 'eqratio':
1124
+ a, b, c, d, m, n, p, q = args
1125
+ _draw_line(ax, a, b, color=color1, lw=2.0, alpha=0.5)
1126
+ _draw_line(ax, c, d, color=color2, lw=2.0, alpha=0.5)
1127
+ _draw_line(ax, m, n, color=color1, lw=2.0, alpha=0.5)
1128
+ _draw_line(ax, p, q, color=color2, lw=2.0, alpha=0.5)
1129
+
1130
+
1131
+ HCOLORS = None
1132
+
1133
+
1134
+ def _draw(
1135
+ ax: matplotlib.axes.Axes,
1136
+ points: list[gm.Point],
1137
+ lines: list[gm.Line],
1138
+ circles: list[gm.Circle],
1139
+ goal: Any,
1140
+ equals: list[tuple[Any, Any]],
1141
+ highlights: list[tuple[str, list[gm.Point]]],
1142
+ ):
1143
+ """Draw everything."""
1144
+ colors = ['red', 'green', 'blue', 'orange', 'magenta', 'purple']
1145
+ pcolor = 'black'
1146
+ lcolor = 'black'
1147
+ ccolor = 'grey'
1148
+ if get_theme() == 'dark':
1149
+ pcolor, lcolor, ccolor = 'white', 'white', 'cyan'
1150
+ elif get_theme() == 'light':
1151
+ pcolor, lcolor, ccolor = 'black', 'black', 'blue'
1152
+ elif get_theme() == 'grey':
1153
+ pcolor, lcolor, ccolor = 'black', 'black', 'grey'
1154
+ colors = ['grey']
1155
+
1156
+ line_boundaries = []
1157
+ for l in lines:
1158
+ p1, p2 = draw_line(ax, l, color=lcolor)
1159
+ line_boundaries.append((p1, p2))
1160
+ circles = [draw_circle(ax, c, color=ccolor) for c in circles]
1161
+
1162
+ for p in points:
1163
+ draw_point(ax, p.num, p.name, line_boundaries, circles, color=pcolor)
1164
+
1165
+ if equals:
1166
+ for i, segs in enumerate(equals['segments']):
1167
+ color = colors[i % len(colors)]
1168
+ for a, b in segs:
1169
+ mark_segment(ax, a, b, color, 0.5)
1170
+
1171
+ for i, angs in enumerate(equals['angles']):
1172
+ color = colors[i % len(colors)]
1173
+ for a, b, c, d in angs:
1174
+ highlight_angle(ax, a, b, c, d, color, 0.5)
1175
+
1176
+ if highlights:
1177
+ global HCOLORS
1178
+ if HCOLORS is None:
1179
+ HCOLORS = [k for k in mcolors.TABLEAU_COLORS.keys() if 'red' not in k]
1180
+
1181
+ for i, (name, args) in enumerate(highlights):
1182
+ color_i = HCOLORS[i % len(HCOLORS)]
1183
+ highlight(ax, name, args, 'black', color_i, color_i)
1184
+
1185
+ if goal:
1186
+ name, args = goal
1187
+ lcolor = color1 = color2 = 'red'
1188
+ highlight(ax, name, args, lcolor, color1, color2)
1189
+
1190
+
1191
+ THEME = 'dark'
1192
+
1193
+
1194
+ def set_theme(theme) -> None:
1195
+ global THEME
1196
+ THEME = theme
1197
+
1198
+
1199
+ def get_theme() -> str:
1200
+ return THEME
1201
+
1202
+
1203
+ def draw(
1204
+ points: list[gm.Point],
1205
+ lines: list[gm.Line],
1206
+ circles: list[gm.Circle],
1207
+ segments: list[gm.Segment],
1208
+ goal: Any = None,
1209
+ highlights: list[tuple[str, list[gm.Point]]] = None,
1210
+ equals: list[tuple[Any, Any]] = None,
1211
+ block: bool = True,
1212
+ save_to: str = None,
1213
+ theme: str = 'dark',
1214
+ ) -> None:
1215
+ """Draw everything on the same canvas."""
1216
+ plt.close()
1217
+ imsize = 512 / 100
1218
+ fig, ax = plt.subplots(figsize=(imsize, imsize), dpi=100)
1219
+
1220
+ set_theme(theme)
1221
+
1222
+ if get_theme() == 'dark':
1223
+ ax.set_facecolor((0.0, 0.0, 0.0))
1224
+ else:
1225
+ ax.set_facecolor((1.0, 1.0, 1.0))
1226
+
1227
+ _draw(ax, points, lines, circles, goal, equals, highlights)
1228
+
1229
+ plt.axis('equal')
1230
+ fig.subplots_adjust(left=0, right=1, top=1, bottom=0, wspace=0, hspace=0)
1231
+ if points:
1232
+ xmin = min([p.num.x for p in points])
1233
+ xmax = max([p.num.x for p in points])
1234
+ ymin = min([p.num.y for p in points])
1235
+ ymax = max([p.num.y for p in points])
1236
+ plt.margins((xmax - xmin) * 0.1, (ymax - ymin) * 0.1)
1237
+
1238
+ plt.show(block=block)
1239
+
1240
+
1241
+ def close_enough(a: float, b: float, tol: float = 1e-12) -> bool:
1242
+ return abs(a - b) < tol
1243
+
1244
+
1245
+ def assert_close_enough(a: float, b: float, tol: float = 1e-12) -> None:
1246
+ assert close_enough(a, b, tol), f'|{a}-{b}| = {abs(a-b)} >= {tol}'
1247
+
1248
+
1249
+ def ang_of(tail: Point, head: Point) -> float:
1250
+ vector = head - tail
1251
+ arctan = np.arctan2(vector.y, vector.x) % (2 * np.pi)
1252
+ return arctan
1253
+
1254
+
1255
+ def ang_between(tail: Point, head1: Point, head2: Point) -> float:
1256
+ ang1 = ang_of(tail, head1)
1257
+ ang2 = ang_of(tail, head2)
1258
+ diff = ang1 - ang2
1259
+ # return diff % (2*np.pi)
1260
+ if diff > np.pi:
1261
+ return diff - 2 * np.pi
1262
+ if diff < -np.pi:
1263
+ return 2 * np.pi + diff
1264
+ return diff
1265
+
1266
+
1267
+ def head_from(tail: Point, ang: float, length: float = 1) -> Point:
1268
+ vector = Point(np.cos(ang) * length, np.sin(ang) * length)
1269
+ return tail + vector
1270
+
1271
+
1272
+ def random_points(n: int = 3) -> list[Point]:
1273
+ return [Point(unif(-1, 1), unif(-1, 1)) for _ in range(n)]
1274
+
1275
+
1276
+ def random_rfss(*points: list[Point]) -> list[Point]:
1277
+ """Random rotate-flip-scale-shift a point cloud."""
1278
+ # center point cloud.
1279
+ average = sum(points, Point(0.0, 0.0)) * (1.0 / len(points))
1280
+ points = [p - average for p in points]
1281
+
1282
+ # rotate
1283
+ ang = unif(0.0, 2 * np.pi)
1284
+ sin, cos = np.sin(ang), np.cos(ang)
1285
+ # scale and shift
1286
+ scale = unif(0.5, 2.0)
1287
+ shift = Point(unif(-1, 1), unif(-1, 1))
1288
+ points = [p.rotate(sin, cos) * scale + shift for p in points]
1289
+
1290
+ # randomly flip
1291
+ if np.random.rand() < 0.5:
1292
+ points = [p.flip() for p in points]
1293
+
1294
+ return points
1295
+
1296
+
1297
+ def reduce(
1298
+ objs: list[Union[Point, Line, Circle, HalfLine, HoleCircle]],
1299
+ existing_points: list[Point],
1300
+ ) -> list[Point]:
1301
+ """Reduce intersecting objects into one point of intersections."""
1302
+ if all(isinstance(o, Point) for o in objs):
1303
+ return objs
1304
+
1305
+ elif len(objs) == 1:
1306
+ return objs[0].sample_within(existing_points)
1307
+
1308
+ elif len(objs) == 2:
1309
+ a, b = objs
1310
+ result = a.intersect(b)
1311
+ if isinstance(result, Point):
1312
+ return [result]
1313
+ a, b = result
1314
+ a_close = any([a.close(x) for x in existing_points])
1315
+ if a_close:
1316
+ return [b]
1317
+ b_close = any([b.close(x) for x in existing_points])
1318
+ if b_close:
1319
+ return [a]
1320
+ return [np.random.choice([a, b])]
1321
+
1322
+ else:
1323
+ raise ValueError(f'Cannot reduce {objs}')
1324
+
1325
+
1326
+ def sketch(
1327
+ name: str, args: list[Union[Point, gm.Point]]
1328
+ ) -> list[Union[Point, Line, Circle, HalfLine, HoleCircle]]:
1329
+ fun = globals()['sketch_' + name]
1330
+ args = [p.num if isinstance(p, gm.Point) else p for p in args]
1331
+ out = fun(args)
1332
+
1333
+ # out can be one or multiple {Point/Line/HalfLine}
1334
+ if isinstance(out, (tuple, list)):
1335
+ return list(out)
1336
+ return [out]
1337
+
1338
+
1339
+ def sketch_on_opline(args: tuple[gm.Point, ...]) -> HalfLine:
1340
+ a, b = args
1341
+ return HalfLine(a, a + a - b)
1342
+
1343
+
1344
+ def sketch_on_hline(args: tuple[gm.Point, ...]) -> HalfLine:
1345
+ a, b = args
1346
+ return HalfLine(a, b)
1347
+
1348
+
1349
+ def sketch_ieq_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1350
+ a = Point(0.0, 0.0)
1351
+ b = Point(1.0, 0.0)
1352
+
1353
+ c, _ = Circle(a, p1=b).intersect(Circle(b, p1=a))
1354
+ return a, b, c
1355
+
1356
+
1357
+ def sketch_incenter2(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1358
+ a, b, c = args
1359
+ l1 = sketch_bisect([b, a, c])
1360
+ l2 = sketch_bisect([a, b, c])
1361
+ i = line_line_intersection(l1, l2)
1362
+ x = i.foot(Line(b, c))
1363
+ y = i.foot(Line(c, a))
1364
+ z = i.foot(Line(a, b))
1365
+ return x, y, z, i
1366
+
1367
+
1368
+ def sketch_excenter2(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1369
+ a, b, c = args
1370
+ l1 = sketch_bisect([b, a, c])
1371
+ l2 = sketch_exbisect([a, b, c])
1372
+ i = line_line_intersection(l1, l2)
1373
+ x = i.foot(Line(b, c))
1374
+ y = i.foot(Line(c, a))
1375
+ z = i.foot(Line(a, b))
1376
+ return x, y, z, i
1377
+
1378
+
1379
+ def sketch_centroid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1380
+ a, b, c = args
1381
+ x = (b + c) * 0.5
1382
+ y = (c + a) * 0.5
1383
+ z = (a + b) * 0.5
1384
+ i = line_line_intersection(Line(a, x), Line(b, y))
1385
+ return x, y, z, i
1386
+
1387
+
1388
+ def sketch_ninepoints(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1389
+ a, b, c = args
1390
+ x = (b + c) * 0.5
1391
+ y = (c + a) * 0.5
1392
+ z = (a + b) * 0.5
1393
+ c = Circle(p1=x, p2=y, p3=z)
1394
+ return x, y, z, c.center
1395
+
1396
+
1397
+ def sketch_2l1c(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1398
+ """Sketch a circle touching two lines and another circle."""
1399
+ a, b, c, p = args
1400
+ bc, ac = Line(b, c), Line(a, c)
1401
+ circle = Circle(p, p1=a)
1402
+
1403
+ d, d_ = line_circle_intersection(p.perpendicular_line(bc), circle)
1404
+ if bc.diff_side(d_, a):
1405
+ d = d_
1406
+
1407
+ e, e_ = line_circle_intersection(p.perpendicular_line(ac), circle)
1408
+ if ac.diff_side(e_, b):
1409
+ e = e_
1410
+
1411
+ df = d.perpendicular_line(Line(p, d))
1412
+ ef = e.perpendicular_line(Line(p, e))
1413
+ f = line_line_intersection(df, ef)
1414
+
1415
+ g, g_ = line_circle_intersection(Line(c, f), circle)
1416
+ if bc.same_side(g_, a):
1417
+ g = g_
1418
+
1419
+ b_ = c + (b - c) / b.distance(c)
1420
+ a_ = c + (a - c) / a.distance(c)
1421
+ m = (a_ + b_) * 0.5
1422
+ x = line_line_intersection(Line(c, m), Line(p, g))
1423
+ return x.foot(ac), x.foot(bc), g, x
1424
+
1425
+
1426
+ def sketch_3peq(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1427
+ a, b, c = args
1428
+ ab, bc, ca = Line(a, b), Line(b, c), Line(c, a)
1429
+
1430
+ z = b + (c - b) * np.random.uniform(-0.5, 1.5)
1431
+
1432
+ z_ = z * 2 - c
1433
+ l = z_.parallel_line(ca)
1434
+ x = line_line_intersection(l, ab)
1435
+ y = z * 2 - x
1436
+ return x, y, z
1437
+
1438
+
1439
+ def try_to_sketch_intersect(
1440
+ name1: str,
1441
+ args1: list[Union[gm.Point, Point]],
1442
+ name2: str,
1443
+ args2: list[Union[gm.Point, Point]],
1444
+ existing_points: list[Point],
1445
+ ) -> Optional[Point]:
1446
+ """Try to sketch an intersection between two objects."""
1447
+ obj1 = sketch(name1, args1)[0]
1448
+ obj2 = sketch(name2, args2)[0]
1449
+
1450
+ if isinstance(obj1, Line) and isinstance(obj2, Line):
1451
+ fn = line_line_intersection
1452
+ elif isinstance(obj1, Circle) and isinstance(obj2, Circle):
1453
+ fn = circle_circle_intersection
1454
+ else:
1455
+ fn = line_circle_intersection
1456
+ if isinstance(obj2, Line) and isinstance(obj1, Circle):
1457
+ obj1, obj2 = obj2, obj1
1458
+
1459
+ try:
1460
+ x = fn(obj1, obj2)
1461
+ except: # pylint: disable=bare-except
1462
+ return None
1463
+
1464
+ if isinstance(x, Point):
1465
+ return x
1466
+
1467
+ x1, x2 = x
1468
+
1469
+ close1 = check_too_close([x1], existing_points)
1470
+ far1 = check_too_far([x1], existing_points)
1471
+ if not close1 and not far1:
1472
+ return x1
1473
+ close2 = check_too_close([x2], existing_points)
1474
+ far2 = check_too_far([x2], existing_points)
1475
+ if not close2 and not far2:
1476
+ return x2
1477
+
1478
+ return None
1479
+
1480
+
1481
+ def sketch_acircle(args: tuple[gm.Point, ...]) -> Circle:
1482
+ a, b, c, d, f = args
1483
+ de = sketch_aline([c, a, b, f, d])
1484
+ fe = sketch_aline([a, c, b, d, f])
1485
+ e = line_line_intersection(de, fe)
1486
+ return Circle(p1=d, p2=e, p3=f)
1487
+
1488
+
1489
+ def sketch_aline(args: tuple[gm.Point, ...]) -> HalfLine:
1490
+ """Sketch the construction aline."""
1491
+ A, B, C, D, E = args
1492
+ ab = A - B
1493
+ cb = C - B
1494
+ de = D - E
1495
+
1496
+ dab = A.distance(B)
1497
+ ang_ab = np.arctan2(ab.y / dab, ab.x / dab)
1498
+
1499
+ dcb = C.distance(B)
1500
+ ang_bc = np.arctan2(cb.y / dcb, cb.x / dcb)
1501
+
1502
+ dde = D.distance(E)
1503
+ ang_de = np.arctan2(de.y / dde, de.x / dde)
1504
+
1505
+ ang_ex = ang_de + ang_bc - ang_ab
1506
+ X = E + Point(np.cos(ang_ex), np.sin(ang_ex))
1507
+ return HalfLine(E, X)
1508
+
1509
+
1510
+ def sketch_amirror(args: tuple[gm.Point, ...]) -> HalfLine:
1511
+ """Sketch the angle mirror."""
1512
+ A, B, C = args # pylint: disable=invalid-name
1513
+ ab = A - B
1514
+ cb = C - B
1515
+
1516
+ dab = A.distance(B)
1517
+ ang_ab = np.arctan2(ab.y / dab, ab.x / dab)
1518
+ dcb = C.distance(B)
1519
+ ang_bc = np.arctan2(cb.y / dcb, cb.x / dcb)
1520
+
1521
+ ang_bx = 2 * ang_bc - ang_ab
1522
+ X = B + Point(np.cos(ang_bx), np.sin(ang_bx)) # pylint: disable=invalid-name
1523
+ return HalfLine(B, X)
1524
+
1525
+
1526
+ def sketch_bisect(args: tuple[gm.Point, ...]) -> Line:
1527
+ a, b, c = args
1528
+ ab = a.distance(b)
1529
+ bc = b.distance(c)
1530
+ x = b + (c - b) * (ab / bc)
1531
+ m = (a + x) * 0.5
1532
+ return Line(b, m)
1533
+
1534
+
1535
+ def sketch_exbisect(args: tuple[gm.Point, ...]) -> Line:
1536
+ a, b, c = args
1537
+ return sketch_bisect(args).perpendicular_line(b)
1538
+
1539
+
1540
+ def sketch_bline(args: tuple[gm.Point, ...]) -> Line:
1541
+ a, b = args
1542
+ m = (a + b) * 0.5
1543
+ return m.perpendicular_line(Line(a, b))
1544
+
1545
+
1546
+ def sketch_dia(args: tuple[gm.Point, ...]) -> Circle:
1547
+ a, b = args
1548
+ return Circle((a + b) * 0.5, p1=a)
1549
+
1550
+
1551
+ def sketch_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
1552
+ a, o, b = args
1553
+ dia = sketch_dia([a, o])
1554
+ return circle_circle_intersection(Circle(o, p1=b), dia)
1555
+
1556
+
1557
+ def sketch_circle(args: tuple[gm.Point, ...]) -> Circle:
1558
+ a, b, c = args
1559
+ return Circle(center=a, radius=b.distance(c))
1560
+
1561
+
1562
+ def sketch_cc_tangent(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1563
+ """Sketch tangents to two circles."""
1564
+ o, a, w, b = args
1565
+ ra, rb = o.distance(a), w.distance(b)
1566
+
1567
+ ow = Line(o, w)
1568
+ if close_enough(ra, rb):
1569
+ oo = ow.perpendicular_line(o)
1570
+ oa = Circle(o, ra)
1571
+ x, z = line_circle_intersection(oo, oa)
1572
+ y = x + w - o
1573
+ t = z + w - o
1574
+ return x, y, z, t
1575
+
1576
+ swap = rb > ra
1577
+ if swap:
1578
+ o, a, w, b = w, b, o, a
1579
+ ra, rb = rb, ra
1580
+
1581
+ oa = Circle(o, ra)
1582
+ q = o + (w - o) * ra / (ra - rb)
1583
+
1584
+ x, z = circle_circle_intersection(sketch_dia([o, q]), oa)
1585
+ y = w.foot(Line(x, q))
1586
+ t = w.foot(Line(z, q))
1587
+
1588
+ if swap:
1589
+ x, y, z, t = y, x, t, z
1590
+
1591
+ return x, y, z, t
1592
+
1593
+
1594
+ def sketch_hcircle(args: tuple[gm.Point, ...]) -> HoleCircle:
1595
+ a, b = args
1596
+ return HoleCircle(center=a, radius=a.distance(b), hole=b)
1597
+
1598
+
1599
+ def sketch_e5128(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
1600
+ a, b, c, d = args
1601
+ ad = Line(a, d)
1602
+
1603
+ g = (a + b) * 0.5
1604
+ de = Line(d, g)
1605
+
1606
+ e, f = line_circle_intersection(de, Circle(c, p1=b))
1607
+
1608
+ if e.distance(d) < f.distance(d):
1609
+ e = f
1610
+ return e, g
1611
+
1612
+
1613
+ def sketch_eq_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1614
+ """Sketch quadrangle with two equal opposite sides."""
1615
+ a = Point(0.0, 0.0)
1616
+ b = Point(1.0, 0.0)
1617
+
1618
+ length = np.random.uniform(0.5, 2.0)
1619
+ ang = np.random.uniform(np.pi / 3, np.pi * 2 / 3)
1620
+ d = head_from(a, ang, length)
1621
+
1622
+ ang = ang_of(b, d)
1623
+ ang = np.random.uniform(ang / 10, ang / 9)
1624
+ c = head_from(b, ang, length)
1625
+ a, b, c, d = random_rfss(a, b, c, d)
1626
+ return a, b, c, d
1627
+
1628
+
1629
+ def sketch_eq_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1630
+ a = Point(0.0, 0.0)
1631
+ b = Point(1.0, 0.0)
1632
+ l = unif(0.5, 2.0)
1633
+
1634
+ height = unif(0.5, 2.0)
1635
+ c = Point(0.5 + l / 2.0, height)
1636
+ d = Point(0.5 - l / 2.0, height)
1637
+
1638
+ a, b, c, d = random_rfss(a, b, c, d)
1639
+ return a, b, c, d
1640
+
1641
+
1642
+ def sketch_eqangle2(args: tuple[gm.Point, ...]) -> Point:
1643
+ """Sketch the def eqangle2."""
1644
+ a, b, c = args
1645
+
1646
+ d = c * 2 - b
1647
+
1648
+ ba = b.distance(a)
1649
+ bc = b.distance(c)
1650
+ l = ba * ba / bc
1651
+
1652
+ if unif(0.0, 1.0) < 0.5:
1653
+ be = min(l, bc)
1654
+ be = unif(be * 0.1, be * 0.9)
1655
+ else:
1656
+ be = max(l, bc)
1657
+ be = unif(be * 1.1, be * 1.5)
1658
+
1659
+ e = b + (c - b) * (be / bc)
1660
+ y = b + (a - b) * (be / l)
1661
+ return line_line_intersection(Line(c, y), Line(a, e))
1662
+
1663
+
1664
+ def sketch_eqangle3(args: tuple[gm.Point, ...]) -> Circle:
1665
+ a, b, d, e, f = args
1666
+ de = d.distance(e)
1667
+ ef = e.distance(f)
1668
+ ab = b.distance(a)
1669
+ ang_ax = ang_of(a, b) + ang_between(e, d, f)
1670
+ x = head_from(a, ang_ax, length=de / ef * ab)
1671
+ return Circle(p1=a, p2=b, p3=x)
1672
+
1673
+
1674
+ def sketch_eqdia_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1675
+ """Sketch quadrangle with two equal diagonals."""
1676
+ m = unif(0.3, 0.7)
1677
+ n = unif(0.3, 0.7)
1678
+ a = Point(-m, 0.0)
1679
+ c = Point(1 - m, 0.0)
1680
+ b = Point(0.0, -n)
1681
+ d = Point(0.0, 1 - n)
1682
+
1683
+ ang = unif(-0.25 * np.pi, 0.25 * np.pi)
1684
+ sin, cos = np.sin(ang), np.cos(ang)
1685
+ b = b.rotate(sin, cos)
1686
+ d = d.rotate(sin, cos)
1687
+ a, b, c, d = random_rfss(a, b, c, d)
1688
+ return a, b, c, d
1689
+
1690
+
1691
+ def sketch_free(args: tuple[gm.Point, ...]) -> Point:
1692
+ return random_points(1)[0]
1693
+
1694
+
1695
+ def sketch_isos(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1696
+ base = unif(0.5, 1.5)
1697
+ height = unif(0.5, 1.5)
1698
+
1699
+ b = Point(-base / 2, 0.0)
1700
+ c = Point(base / 2, 0.0)
1701
+ a = Point(0.0, height)
1702
+ a, b, c = random_rfss(a, b, c)
1703
+ return a, b, c
1704
+
1705
+
1706
+ def sketch_line(args: tuple[gm.Point, ...]) -> Line:
1707
+ a, b = args
1708
+ return Line(a, b)
1709
+
1710
+
1711
+ def sketch_cyclic(args: tuple[gm.Point, ...]) -> Circle:
1712
+ a, b, c = args
1713
+ return Circle(p1=a, p2=b, p3=c)
1714
+
1715
+
1716
+ def sketch_hline(args: tuple[gm.Point, ...]) -> HalfLine:
1717
+ a, b = args
1718
+ return HalfLine(a, b)
1719
+
1720
+
1721
+ def sketch_midp(args: tuple[gm.Point, ...]) -> Point:
1722
+ a, b = args
1723
+ return (a + b) * 0.5
1724
+
1725
+
1726
+ def sketch_pentagon(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1727
+ points = [Point(1.0, 0.0)]
1728
+ ang = 0.0
1729
+
1730
+ for i in range(4):
1731
+ ang += (2 * np.pi - ang) / (5 - i) * unif(0.5, 1.5)
1732
+ point = Point(np.cos(ang), np.sin(ang))
1733
+ points.append(point)
1734
+
1735
+ a, b, c, d, e = points # pylint: disable=unbalanced-tuple-unpacking
1736
+ a, b, c, d, e = random_rfss(a, b, c, d, e)
1737
+ return a, b, c, d, e
1738
+
1739
+
1740
+ def sketch_pline(args: tuple[gm.Point, ...]) -> Line:
1741
+ a, b, c = args
1742
+ return a.parallel_line(Line(b, c))
1743
+
1744
+
1745
+ def sketch_pmirror(args: tuple[gm.Point, ...]) -> Point:
1746
+ a, b = args
1747
+ return b * 2 - a
1748
+
1749
+
1750
+ def sketch_quadrangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1751
+ """Sketch a random quadrangle."""
1752
+ m = unif(0.3, 0.7)
1753
+ n = unif(0.3, 0.7)
1754
+
1755
+ a = Point(-m, 0.0)
1756
+ c = Point(1 - m, 0.0)
1757
+ b = Point(0.0, -unif(0.25, 0.75))
1758
+ d = Point(0.0, unif(0.25, 0.75))
1759
+
1760
+ ang = unif(-0.25 * np.pi, 0.25 * np.pi)
1761
+ sin, cos = np.sin(ang), np.cos(ang)
1762
+ b = b.rotate(sin, cos)
1763
+ d = d.rotate(sin, cos)
1764
+ a, b, c, d = random_rfss(a, b, c, d)
1765
+ return a, b, c, d
1766
+
1767
+
1768
+ def sketch_r_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1769
+ a = Point(0.0, 1.0)
1770
+ d = Point(0.0, 0.0)
1771
+ b = Point(unif(0.5, 1.5), 1.0)
1772
+ c = Point(unif(0.5, 1.5), 0.0)
1773
+ a, b, c, d = random_rfss(a, b, c, d)
1774
+ return a, b, c, d
1775
+
1776
+
1777
+ def sketch_r_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1778
+ a = Point(0.0, 0.0)
1779
+ b = Point(0.0, unif(0.5, 2.0))
1780
+ c = Point(unif(0.5, 2.0), 0.0)
1781
+ a, b, c = random_rfss(a, b, c)
1782
+ return a, b, c
1783
+
1784
+
1785
+ def sketch_rectangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1786
+ a = Point(0.0, 0.0)
1787
+ b = Point(0.0, 1.0)
1788
+ l = unif(0.5, 2.0)
1789
+ c = Point(l, 1.0)
1790
+ d = Point(l, 0.0)
1791
+ a, b, c, d = random_rfss(a, b, c, d)
1792
+ return a, b, c, d
1793
+
1794
+
1795
+ def sketch_reflect(args: tuple[gm.Point, ...]) -> Point:
1796
+ a, b, c = args
1797
+ m = a.foot(Line(b, c))
1798
+ return m * 2 - a
1799
+
1800
+
1801
+ def sketch_risos(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1802
+ a = Point(0.0, 0.0)
1803
+ b = Point(0.0, 1.0)
1804
+ c = Point(1.0, 0.0)
1805
+ a, b, c = random_rfss(a, b, c)
1806
+ return a, b, c
1807
+
1808
+
1809
+ def sketch_rotaten90(args: tuple[gm.Point, ...]) -> Point:
1810
+ a, b = args
1811
+ ang = -np.pi / 2
1812
+ return a + (b - a).rotate(np.sin(ang), np.cos(ang))
1813
+
1814
+
1815
+ def sketch_rotatep90(args: tuple[gm.Point, ...]) -> Point:
1816
+ a, b = args
1817
+ ang = np.pi / 2
1818
+ return a + (b - a).rotate(np.sin(ang), np.cos(ang))
1819
+
1820
+
1821
+ def sketch_s_angle(args: tuple[gm.Point, ...]) -> HalfLine:
1822
+ a, b, y = args
1823
+ ang = y / 180 * np.pi
1824
+ x = b + (a - b).rotatea(ang)
1825
+ return HalfLine(b, x)
1826
+
1827
+
1828
+ def sketch_segment(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
1829
+ a, b = random_points(2)
1830
+ return a, b
1831
+
1832
+
1833
+ def sketch_shift(args: tuple[gm.Point, ...]) -> Point:
1834
+ a, b, c = args
1835
+ return c + (b - a)
1836
+
1837
+
1838
+ def sketch_square(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
1839
+ a, b = args
1840
+ c = b + (a - b).rotatea(-np.pi / 2)
1841
+ d = a + (b - a).rotatea(np.pi / 2)
1842
+ return c, d
1843
+
1844
+
1845
+ def sketch_isquare(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1846
+ a = Point(0.0, 0.0)
1847
+ b = Point(1.0, 0.0)
1848
+ c = Point(1.0, 1.0)
1849
+ d = Point(0.0, 1.0)
1850
+ a, b, c, d = random_rfss(a, b, c, d)
1851
+ return a, b, c, d
1852
+
1853
+
1854
+ def sketch_tline(args: tuple[gm.Point, ...]) -> Line:
1855
+ a, b, c = args
1856
+ return a.perpendicular_line(Line(b, c))
1857
+
1858
+
1859
+ def sketch_trapezoid(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1860
+ d = Point(0.0, 0.0)
1861
+ c = Point(1.0, 0.0)
1862
+
1863
+ base = unif(0.5, 2.0)
1864
+ height = unif(0.5, 2.0)
1865
+ a = Point(unif(0.2, 0.5), height)
1866
+ b = Point(a.x + base, height)
1867
+ a, b, c, d = random_rfss(a, b, c, d)
1868
+ return a, b, c, d
1869
+
1870
+
1871
+ def sketch_triangle(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1872
+ a = Point(0.0, 0.0)
1873
+ b = Point(1.0, 0.0)
1874
+ ac = unif(0.5, 2.0)
1875
+ ang = unif(0.2, 0.8) * np.pi
1876
+ c = head_from(a, ang, ac)
1877
+ return a, b, c
1878
+
1879
+
1880
+ def sketch_triangle12(args: tuple[gm.Point, ...]) -> tuple[Point, ...]:
1881
+ b = Point(0.0, 0.0)
1882
+ c = Point(unif(1.5, 2.5), 0.0)
1883
+ a, _ = circle_circle_intersection(Circle(b, 1.0), Circle(c, 2.0))
1884
+ a, b, c = random_rfss(a, b, c)
1885
+ return a, b, c
1886
+
1887
+
1888
+ def sketch_trisect(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
1889
+ """Sketch two trisectors of an angle."""
1890
+ a, b, c = args
1891
+ ang1 = ang_of(b, a)
1892
+ ang2 = ang_of(b, c)
1893
+
1894
+ swap = 0
1895
+ if ang1 > ang2:
1896
+ ang1, ang2 = ang2, ang1
1897
+ swap += 1
1898
+
1899
+ if ang2 - ang1 > np.pi:
1900
+ ang1, ang2 = ang2, ang1 + 2 * np.pi
1901
+ swap += 1
1902
+
1903
+ angx = ang1 + (ang2 - ang1) / 3
1904
+ angy = ang2 - (ang2 - ang1) / 3
1905
+
1906
+ x = b + Point(np.cos(angx), np.sin(angx))
1907
+ y = b + Point(np.cos(angy), np.sin(angy))
1908
+
1909
+ ac = Line(a, c)
1910
+ x = line_line_intersection(Line(b, x), ac)
1911
+ y = line_line_intersection(Line(b, y), ac)
1912
+
1913
+ if swap == 1:
1914
+ return y, x
1915
+ return x, y
1916
+
1917
+
1918
+ def sketch_trisegment(args: tuple[gm.Point, ...]) -> tuple[Point, Point]:
1919
+ a, b = args
1920
+ x, y = a + (b - a) * (1.0 / 3), a + (b - a) * (2.0 / 3)
1921
+ return x, y
external/alphageometry/numericals_test.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit testing for the geometry numericals code."""
17
+
18
+ import unittest
19
+
20
+ from absl.testing import absltest
21
+ import numericals as nm
22
+
23
+ np = nm.np
24
+
25
+ unif = nm.unif
26
+ Point = nm.Point
27
+ Line = nm.Line
28
+ Circle = nm.Circle
29
+ HalfLine = nm.HalfLine
30
+
31
+ line_circle_intersection = nm.line_circle_intersection
32
+ line_line_intersection = nm.line_line_intersection
33
+
34
+ check_coll = nm.check_coll
35
+ check_eqangle = nm.check_eqangle
36
+
37
+ random_points = nm.random_points
38
+ ang_between = nm.ang_between
39
+ head_from = nm.head_from
40
+
41
+
42
+ class NumericalTest(unittest.TestCase):
43
+
44
+ def test_sketch_ieq_triangle(self):
45
+ a, b, c = nm.sketch_ieq_triangle([])
46
+ self.assertAlmostEqual(a.distance(b), b.distance(c))
47
+ self.assertAlmostEqual(c.distance(a), b.distance(c))
48
+
49
+ def test_sketch_2l1c(self):
50
+ p = nm.Point(0.0, 0.0)
51
+ pi = np.pi
52
+ anga = unif(-0.4 * pi, 0.4 * pi)
53
+ a = Point(np.cos(anga), np.sin(anga))
54
+ angb = unif(0.6 * pi, 1.4 * pi)
55
+ b = Point(np.cos(angb), np.sin(angb))
56
+
57
+ angc = unif(anga + 0.05 * pi, angb - 0.05 * pi)
58
+ c = Point(np.cos(angc), np.sin(angc)) * unif(0.2, 0.8)
59
+
60
+ x, y, z, i = nm.sketch_2l1c([a, b, c, p])
61
+ self.assertTrue(check_coll([x, c, a]))
62
+ self.assertTrue(check_coll([y, c, b]))
63
+ self.assertAlmostEqual(z.distance(p), 1.0)
64
+ self.assertTrue(check_coll([p, i, z]))
65
+ self.assertTrue(Line(i, x).is_perp(Line(c, a)))
66
+ self.assertTrue(Line(i, y).is_perp(Line(c, b)))
67
+ self.assertAlmostEqual(i.distance(x), i.distance(y))
68
+ self.assertAlmostEqual(i.distance(x), i.distance(z))
69
+
70
+ def test_sketch_3peq(self):
71
+ a, b, c = random_points(3)
72
+ x, y, z = nm.sketch_3peq([a, b, c])
73
+
74
+ self.assertTrue(check_coll([a, b, x]))
75
+ self.assertTrue(check_coll([a, c, y]))
76
+ self.assertTrue(check_coll([b, c, z]))
77
+ self.assertTrue(check_coll([x, y, z]))
78
+ self.assertAlmostEqual(z.distance(x), z.distance(y))
79
+
80
+ def test_sketch_aline(self):
81
+ a, b, c, d, e = random_points(5)
82
+ ex = nm.sketch_aline([a, b, c, d, e])
83
+ self.assertIsInstance(ex, HalfLine)
84
+ self.assertEqual(ex.tail, e)
85
+ x = ex.head
86
+ self.assertAlmostEqual(ang_between(b, a, c), ang_between(e, d, x))
87
+
88
+ def test_sketch_amirror(self):
89
+ a, b, c = random_points(3)
90
+ bx = nm.sketch_amirror([a, b, c])
91
+ self.assertIsInstance(bx, HalfLine)
92
+ assert bx.tail == b
93
+ x = bx.head
94
+
95
+ ang1 = ang_between(b, a, c)
96
+ ang2 = ang_between(b, c, x)
97
+ self.assertAlmostEqual(ang1, ang2)
98
+
99
+ def test_sketch_bisect(self):
100
+ a, b, c = random_points(3)
101
+ line = nm.sketch_bisect([a, b, c])
102
+ self.assertAlmostEqual(b.distance(line), 0.0)
103
+
104
+ l = a.perpendicular_line(line)
105
+ x = line_line_intersection(l, Line(b, c))
106
+ self.assertAlmostEqual(a.distance(line), x.distance(line))
107
+
108
+ d, _ = line_circle_intersection(line, Circle(b, radius=1))
109
+ ang1 = ang_between(b, a, d)
110
+ ang2 = ang_between(b, d, c)
111
+ self.assertAlmostEqual(ang1, ang2)
112
+
113
+ def test_sketch_bline(self):
114
+ a, b = random_points(2)
115
+ l = nm.sketch_bline([a, b])
116
+ self.assertTrue(Line(a, b).is_perp(l))
117
+ self.assertAlmostEqual(a.distance(l), b.distance(l))
118
+
119
+ def test_sketch_cc_tangent(self):
120
+ o = Point(0.0, 0.0)
121
+ w = Point(1.0, 0.0)
122
+
123
+ ra = unif(0.0, 0.6)
124
+ rb = unif(0.4, 1.0)
125
+
126
+ a = unif(0.0, np.pi)
127
+ b = unif(0.0, np.pi)
128
+
129
+ a = o + ra * Point(np.cos(a), np.sin(a))
130
+ b = w + rb * Point(np.sin(b), np.cos(b))
131
+
132
+ x, y, z, t = nm.sketch_cc_tangent([o, a, w, b])
133
+ xy = Line(x, y)
134
+ zt = Line(z, t)
135
+ self.assertAlmostEqual(o.distance(xy), o.distance(a))
136
+ self.assertAlmostEqual(o.distance(zt), o.distance(a))
137
+ self.assertAlmostEqual(w.distance(xy), w.distance(b))
138
+ self.assertAlmostEqual(w.distance(zt), w.distance(b))
139
+
140
+ def test_sketch_circle(self):
141
+ a, b, c = random_points(3)
142
+ circle = nm.sketch_circle([a, b, c])
143
+ self.assertAlmostEqual(circle.center.distance(a), 0.0)
144
+ self.assertAlmostEqual(circle.radius, b.distance(c))
145
+
146
+ def test_sketch_e5128(self):
147
+ b = Point(0.0, 0.0)
148
+ c = Point(0.0, 1.0)
149
+ ang = unif(-np.pi / 2, 3 * np.pi / 2)
150
+ d = head_from(c, ang, 1.0)
151
+ a = Point(unif(0.5, 2.0), 0.0)
152
+
153
+ e, g = nm.sketch_e5128([a, b, c, d])
154
+ ang1 = ang_between(a, b, d)
155
+ ang2 = ang_between(e, a, g)
156
+ self.assertAlmostEqual(ang1, ang2)
157
+
158
+ def test_sketch_eq_quadrangle(self):
159
+ a, b, c, d = nm.sketch_eq_quadrangle([])
160
+ self.assertAlmostEqual(a.distance(d), c.distance(b))
161
+ ac = Line(a, c)
162
+ assert ac.diff_side(b, d), (ac(b), ac(d))
163
+ bd = Line(b, d)
164
+ assert bd.diff_side(a, c), (bd(a), bd(c))
165
+
166
+ def test_sketch_eq_trapezoid(self):
167
+ a, b, c, d = nm.sketch_eq_trapezoid([])
168
+ assert Line(a, b).is_parallel(Line(c, d))
169
+ self.assertAlmostEqual(a.distance(d), b.distance(c))
170
+
171
+ def test_sketch_eqangle3(self):
172
+ points = random_points(5)
173
+ x = nm.sketch_eqangle3(points).sample_within(points)[0]
174
+ a, b, d, e, f = points
175
+ self.assertTrue(check_eqangle([x, a, x, b, d, e, d, f]))
176
+
177
+ def test_sketch_eqangle2(self):
178
+ a, b, c = random_points(3)
179
+ x = nm.sketch_eqangle2([a, b, c])
180
+ ang1 = ang_between(a, b, x)
181
+ ang2 = ang_between(c, x, b)
182
+ self.assertAlmostEqual(ang1, ang2)
183
+
184
+ def test_sketch_edia_quadrangle(self):
185
+ a, b, c, d = nm.sketch_eqdia_quadrangle([])
186
+ assert Line(a, c).diff_side(b, d)
187
+ assert Line(b, d).diff_side(a, c)
188
+ self.assertAlmostEqual(a.distance(c), b.distance(d))
189
+
190
+ def test_sketch_isos(self):
191
+ a, b, c = nm.sketch_isos([])
192
+ self.assertAlmostEqual(a.distance(b), a.distance(c))
193
+ self.assertAlmostEqual(ang_between(b, a, c), ang_between(c, b, a))
194
+
195
+ def test_sketch_quadrange(self):
196
+ a, b, c, d = nm.sketch_quadrangle([])
197
+ self.assertTrue(Line(a, c).diff_side(b, d))
198
+ self.assertTrue(Line(b, d).diff_side(a, c))
199
+
200
+ def test_sketch_r_trapezoid(self):
201
+ a, b, c, d = nm.sketch_r_trapezoid([])
202
+ self.assertTrue(Line(a, b).is_perp(Line(a, d)))
203
+ self.assertTrue(Line(a, b).is_parallel(Line(c, d)))
204
+ self.assertTrue(Line(a, c).diff_side(b, d))
205
+ self.assertTrue(Line(b, d).diff_side(a, c))
206
+
207
+ def test_sketch_r_triangle(self):
208
+ a, b, c = nm.sketch_r_triangle([])
209
+ self.assertTrue(Line(a, b).is_perp(Line(a, c)))
210
+
211
+ def test_sketch_rectangle(self):
212
+ a, b, c, d = nm.sketch_rectangle([])
213
+ self.assertTrue(Line(a, b).is_perp(Line(b, c)))
214
+ self.assertTrue(Line(b, c).is_perp(Line(c, d)))
215
+ self.assertTrue(Line(c, d).is_perp(Line(d, a)))
216
+
217
+ def test_sketch_reflect(self):
218
+ a, b, c = random_points(3)
219
+ x = nm.sketch_reflect([a, b, c])
220
+ self.assertTrue(Line(a, x).is_perp(Line(b, c)))
221
+ self.assertAlmostEqual(x.distance(Line(b, c)), a.distance(Line(b, c)))
222
+
223
+ def test_sketch_risos(self):
224
+ a, b, c = nm.sketch_risos([])
225
+ self.assertAlmostEqual(a.distance(b), a.distance(c))
226
+ self.assertTrue(Line(a, b).is_perp(Line(a, c)))
227
+
228
+ def test_sketch_rotaten90(self):
229
+ a, b = random_points(2)
230
+ x = nm.sketch_rotaten90([a, b])
231
+ self.assertAlmostEqual(a.distance(x), a.distance(b))
232
+ self.assertTrue(Line(a, x).is_perp(Line(a, b)))
233
+ d = Point(0.0, 0.0)
234
+ e = Point(0.0, 1.0)
235
+ f = Point(1.0, 0.0)
236
+ self.assertAlmostEqual(ang_between(d, e, f), ang_between(a, b, x))
237
+
238
+ def test_sketch_rotatep90(self):
239
+ a, b = random_points(2)
240
+ x = nm.sketch_rotatep90([a, b])
241
+ self.assertAlmostEqual(a.distance(x), a.distance(b))
242
+ self.assertTrue(Line(a, x).is_perp(Line(a, b)))
243
+ d = Point(0.0, 0.0)
244
+ e = Point(0.0, 1.0)
245
+ f = Point(1.0, 0.0)
246
+ self.assertAlmostEqual(ang_between(d, f, e), ang_between(a, b, x))
247
+
248
+ def test_sketch_s_angle(self):
249
+ a, b = random_points(2)
250
+ y = unif(0.0, np.pi)
251
+ bx = nm.sketch_s_angle([a, b, y / np.pi * 180])
252
+ self.assertIsInstance(bx, HalfLine)
253
+ self.assertEqual(bx.tail, b)
254
+ x = bx.head
255
+
256
+ d = Point(1.0, 0.0)
257
+ e = Point(0.0, 0.0)
258
+ f = Point(np.cos(y), np.sin(y))
259
+ self.assertAlmostEqual(ang_between(e, d, f), ang_between(b, a, x))
260
+
261
+ def test_sketch_shift(self):
262
+ a, b, c = random_points(3)
263
+ x = nm.sketch_shift([a, b, c])
264
+ self.assertTrue((b - a).close(x - c))
265
+
266
+ def test_sketch_square(self):
267
+ a, b = random_points(2)
268
+ c, d = nm.sketch_square([a, b])
269
+ self.assertTrue(Line(a, b).is_perp(Line(b, c)))
270
+ self.assertTrue(Line(b, c).is_perp(Line(c, d)))
271
+ self.assertTrue(Line(c, d).is_perp(Line(d, a)))
272
+ self.assertAlmostEqual(a.distance(b), b.distance(c))
273
+
274
+ def test_sketch_isquare(self):
275
+ a, b, c, d = nm.sketch_isquare([])
276
+ self.assertTrue(Line(a, b).is_perp(Line(b, c)))
277
+ self.assertTrue(Line(b, c).is_perp(Line(c, d)))
278
+ self.assertTrue(Line(c, d).is_perp(Line(d, a)))
279
+ self.assertAlmostEqual(a.distance(b), b.distance(c))
280
+
281
+ def test_sketch_trapezoid(self):
282
+ a, b, c, d = nm.sketch_trapezoid([])
283
+ self.assertTrue(Line(a, b).is_parallel(Line(c, d)))
284
+ self.assertTrue(Line(a, c).diff_side(b, d))
285
+ self.assertTrue(Line(b, d).diff_side(a, c))
286
+
287
+ def test_sketch_triangle(self):
288
+ a, b, c = nm.sketch_triangle([])
289
+ self.assertFalse(check_coll([a, b, c]))
290
+
291
+ def test_sketch_triangle12(self):
292
+ a, b, c = nm.sketch_triangle12([])
293
+ self.assertAlmostEqual(a.distance(b) * 2, a.distance(c))
294
+
295
+ def test_sketch_trisect(self):
296
+ a, b, c = random_points(3)
297
+ x, y = nm.sketch_trisect([a, b, c])
298
+ self.assertAlmostEqual(ang_between(b, a, x), ang_between(b, x, y))
299
+ self.assertAlmostEqual(ang_between(b, x, y), ang_between(b, y, c))
300
+ self.assertAlmostEqual(ang_between(b, a, x) * 3, ang_between(b, a, c))
301
+
302
+ def test_sketch_trisegment(self):
303
+ a, b = random_points(2)
304
+ x, y = nm.sketch_trisegment([a, b])
305
+ self.assertAlmostEqual(
306
+ a.distance(x) + x.distance(y) + y.distance(b), a.distance(b)
307
+ )
308
+ self.assertAlmostEqual(a.distance(x), x.distance(y))
309
+ self.assertAlmostEqual(x.distance(y), y.distance(b))
310
+
311
+
312
+ if __name__ == '__main__':
313
+ absltest.main()
external/alphageometry/pretty.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Utilities for string manipulation in the DSL."""
17
+
18
+ MAP_SYMBOL = {
19
+ 'T': 'perp',
20
+ 'P': 'para',
21
+ 'D': 'cong',
22
+ 'S': 'simtri',
23
+ 'I': 'circle',
24
+ 'M': 'midp',
25
+ 'O': 'cyclic',
26
+ 'C': 'coll',
27
+ '^': 'eqangle',
28
+ '/': 'eqratio',
29
+ '%': 'eqratio',
30
+ '=': 'contri',
31
+ 'X': 'collx',
32
+ 'A': 'acompute',
33
+ 'R': 'rcompute',
34
+ 'Q': 'fixc',
35
+ 'E': 'fixl',
36
+ 'V': 'fixb',
37
+ 'H': 'fixt',
38
+ 'Z': 'fixp',
39
+ 'Y': 'ind',
40
+ }
41
+
42
+
43
+ def map_symbol(c: str) -> str:
44
+ return MAP_SYMBOL[c]
45
+
46
+
47
+ def map_symbol_inv(c: str) -> str:
48
+ return {v: k for k, v in MAP_SYMBOL.items()}[c]
49
+
50
+
51
+ def _gcd(x: int, y: int) -> int:
52
+ while y:
53
+ x, y = y, x % y
54
+ return x
55
+
56
+
57
+ def simplify(n: int, d: int) -> tuple[int, int]:
58
+ g = _gcd(n, d)
59
+ return (n // g, d // g)
60
+
61
+
62
+ def pretty2r(a: str, b: str, c: str, d: str) -> str:
63
+ if b in (c, d):
64
+ a, b = b, a
65
+
66
+ if a == d:
67
+ c, d = d, c
68
+
69
+ return f'{a} {b} {c} {d}'
70
+
71
+
72
+ def pretty2a(a: str, b: str, c: str, d: str) -> str:
73
+ if b in (c, d):
74
+ a, b = b, a
75
+
76
+ if a == d:
77
+ c, d = d, c
78
+
79
+ return f'{a} {b} {c} {d}'
80
+
81
+
82
+ def pretty_angle(a: str, b: str, c: str, d: str) -> str:
83
+ if b in (c, d):
84
+ a, b = b, a
85
+ if a == d:
86
+ c, d = d, c
87
+
88
+ if a == c:
89
+ return f'\u2220{b}{a}{d}'
90
+ return f'\u2220({a}{b}-{c}{d})'
91
+
92
+
93
+ def pretty_nl(name: str, args: list[str]) -> str:
94
+ """Natural lang formatting a predicate."""
95
+ if name == 'aconst':
96
+ a, b, c, d, y = args
97
+ return f'{pretty_angle(a, b, c, d)} = {y}'
98
+ if name == 'rconst':
99
+ a, b, c, d, y = args
100
+ return f'{a}{b}:{c}{d} = {y}'
101
+ if name == 'acompute':
102
+ a, b, c, d = args
103
+ return f'{pretty_angle(a, b, c, d)}'
104
+ if name in ['coll', 'C']:
105
+ return '' + ','.join(args) + ' are collinear'
106
+ if name == 'collx':
107
+ return '' + ','.join(list(set(args))) + ' are collinear'
108
+ if name in ['cyclic', 'O']:
109
+ return '' + ','.join(args) + ' are concyclic'
110
+ if name in ['midp', 'midpoint', 'M']:
111
+ x, a, b = args
112
+ return f'{x} is midpoint of {a}{b}'
113
+ if name in ['eqangle', 'eqangle6', '^']:
114
+ a, b, c, d, e, f, g, h = args
115
+ return f'{pretty_angle(a, b, c, d)} = {pretty_angle(e, f, g, h)}'
116
+ if name in ['eqratio', 'eqratio6', '/']:
117
+ return '{}{}:{}{} = {}{}:{}{}'.format(*args)
118
+ if name == 'eqratio3':
119
+ a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
120
+ return f'S {o} {a} {b} {o} {c} {d}'
121
+ if name in ['cong', 'D']:
122
+ a, b, c, d = args
123
+ return f'{a}{b} = {c}{d}'
124
+ if name in ['perp', 'T']:
125
+ if len(args) == 2: # this is algebraic derivation.
126
+ ab, cd = args # ab = 'd( ... )'
127
+ return f'{ab} \u27c2 {cd}'
128
+ a, b, c, d = args
129
+ return f'{a}{b} \u27c2 {c}{d}'
130
+ if name in ['para', 'P']:
131
+ if len(args) == 2: # this is algebraic derivation.
132
+ ab, cd = args # ab = 'd( ... )'
133
+ return f'{ab} \u2225 {cd}'
134
+ a, b, c, d = args
135
+ return f'{a}{b} \u2225 {c}{d}'
136
+ if name in ['simtri2', 'simtri', 'simtri*']:
137
+ a, b, c, x, y, z = args
138
+ return f'\u0394{a}{b}{c} is similar to \u0394{x}{y}{z}'
139
+ if name in ['contri2', 'contri', 'contri*']:
140
+ a, b, c, x, y, z = args
141
+ return f'\u0394{a}{b}{c} is congruent to \u0394{x}{y}{z}'
142
+ if name in ['circle', 'I']:
143
+ o, a, b, c = args
144
+ return f'{o} is the circumcenter of \\Delta {a}{b}{c}'
145
+ if name == 'foot':
146
+ a, b, c, d = args
147
+ return f'{a} is the foot of {b} on {c}{d}'
148
+
149
+
150
+ def pretty(txt: str) -> str:
151
+ """Pretty formating a predicate string."""
152
+ if isinstance(txt, str):
153
+ txt = txt.split(' ')
154
+ name, *args = txt
155
+ if name == 'ind':
156
+ return 'Y ' + ' '.join(args)
157
+ if name in ['fixc', 'fixl', 'fixb', 'fixt', 'fixp']:
158
+ return map_symbol_inv(name) + ' ' + ' '.join(args)
159
+ if name == 'acompute':
160
+ a, b, c, d = args
161
+ return 'A ' + ' '.join(args)
162
+ if name == 'rcompute':
163
+ a, b, c, d = args
164
+ return 'R ' + ' '.join(args)
165
+ if name == 'aconst':
166
+ a, b, c, d, y = args
167
+ return f'^ {pretty2a(a, b, c, d)} {y}'
168
+ if name == 'rconst':
169
+ a, b, c, d, y = args
170
+ return f'/ {pretty2r(a, b, c, d)} {y}'
171
+ if name == 'coll':
172
+ return 'C ' + ' '.join(args)
173
+ if name == 'collx':
174
+ return 'X ' + ' '.join(args)
175
+ if name == 'cyclic':
176
+ return 'O ' + ' '.join(args)
177
+ if name in ['midp', 'midpoint']:
178
+ x, a, b = args
179
+ return f'M {x} {a} {b}'
180
+ if name == 'eqangle':
181
+ a, b, c, d, e, f, g, h = args
182
+ return f'^ {pretty2a(a, b, c, d)} {pretty2a(e, f, g, h)}'
183
+ if name == 'eqratio':
184
+ a, b, c, d, e, f, g, h = args
185
+ return f'/ {pretty2r(a, b, c, d)} {pretty2r(e, f, g, h)}'
186
+ if name == 'eqratio3':
187
+ a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
188
+ return f'S {o} {a} {b} {o} {c} {d}'
189
+ if name == 'cong':
190
+ a, b, c, d = args
191
+ return f'D {a} {b} {c} {d}'
192
+ if name == 'perp':
193
+ if len(args) == 2: # this is algebraic derivation.
194
+ ab, cd = args # ab = 'd( ... )'
195
+ return f'T {ab} {cd}'
196
+ a, b, c, d = args
197
+ return f'T {a} {b} {c} {d}'
198
+ if name == 'para':
199
+ if len(args) == 2: # this is algebraic derivation.
200
+ ab, cd = args # ab = 'd( ... )'
201
+ return f'P {ab} {cd}'
202
+ a, b, c, d = args
203
+ return f'P {a} {b} {c} {d}'
204
+ if name in ['simtri2', 'simtri', 'simtri*']:
205
+ a, b, c, x, y, z = args
206
+ return f'S {a} {b} {c} {x} {y} {z}'
207
+ if name in ['contri2', 'contri', 'contri*']:
208
+ a, b, c, x, y, z = args
209
+ return f'= {a} {b} {c} {x} {y} {z}'
210
+ if name == 'circle':
211
+ o, a, b, c = args
212
+ return f'I {o} {a} {b} {c}'
213
+ if name == 'foot':
214
+ a, b, c, d = args
215
+ return f'F {a} {b} {c} {d}'
216
+ return ' '.join(txt)
external/alphageometry/problem.py ADDED
@@ -0,0 +1,1133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Implements objects to represent problems, theorems, proofs, traceback."""
17
+
18
+ from __future__ import annotations
19
+
20
+ from collections import defaultdict # pylint: disable=g-importing-member
21
+ from typing import Any
22
+
23
+ import geometry as gm
24
+ import pretty as pt
25
+
26
+
27
+ # pylint: disable=protected-access
28
+ # pylint: disable=unused-variable
29
+ # pylint: disable=unused-argument
30
+ # pylint: disable=unused-assignment
31
+
32
+
33
+ def reshape(l: list[Any], n: int = 1) -> list[list[Any]]:
34
+ assert len(l) % n == 0
35
+ columns = [[] for i in range(n)]
36
+ for i, x in enumerate(l):
37
+ columns[i % n].append(x)
38
+ return zip(*columns)
39
+
40
+
41
+ def isint(x: str) -> bool:
42
+ try:
43
+ int(x)
44
+ return True
45
+ except: # pylint: disable=bare-except
46
+ return False
47
+
48
+
49
+ class Construction:
50
+ """One predicate."""
51
+
52
+ @classmethod
53
+ def from_txt(cls, data: str) -> Construction:
54
+ data = data.split(' ')
55
+ return Construction(data[0], data[1:])
56
+
57
+ def __init__(self, name: str, args: list[str]):
58
+ self.name = name
59
+ self.args = args
60
+
61
+ def translate(self, mapping: dict[str, str]) -> Construction:
62
+ args = [a if isint(a) else mapping[a] for a in self.args]
63
+ return Construction(self.name, args)
64
+
65
+ def txt(self) -> str:
66
+ return ' '.join([self.name] + list(self.args))
67
+
68
+
69
+ class Clause:
70
+ """One construction (>= 1 predicate)."""
71
+
72
+ @classmethod
73
+ def from_txt(cls, data: str) -> Clause:
74
+ if data == ' =':
75
+ return Clause([], [])
76
+ points, constructions = data.split(' = ')
77
+ return Clause(
78
+ points.split(' '),
79
+ [Construction.from_txt(c) for c in constructions.split(', ')],
80
+ )
81
+
82
+ def __init__(self, points: list[str], constructions: list[Construction]):
83
+ self.points = []
84
+ self.nums = []
85
+
86
+ for p in points:
87
+ num = None
88
+ if isinstance(p, str) and '@' in p:
89
+ p, num = p.split('@')
90
+ x, y = num.split('_')
91
+ num = float(x), float(y)
92
+ self.points.append(p)
93
+ self.nums.append(num)
94
+
95
+ self.constructions = constructions
96
+
97
+ def translate(self, mapping: dict[str, str]) -> Clause:
98
+ points0 = []
99
+ for p in self.points:
100
+ pcount = len(mapping) + 1
101
+ name = chr(96 + pcount)
102
+ if name > 'z': # pcount = 26 -> name = 'z'
103
+ name = chr(97 + (pcount - 1) % 26) + str((pcount - 1) // 26)
104
+
105
+ p0 = mapping.get(p, name)
106
+ mapping[p] = p0
107
+ points0.append(p0)
108
+ return Clause(points0, [c.translate(mapping) for c in self.constructions])
109
+
110
+ def add(self, name: str, args: list[str]) -> None:
111
+ self.constructions.append(Construction(name, args))
112
+
113
+ def txt(self) -> str:
114
+ return (
115
+ ' '.join(self.points)
116
+ + ' = '
117
+ + ', '.join(c.txt() for c in self.constructions)
118
+ )
119
+
120
+
121
+ def _gcd(x: int, y: int) -> int:
122
+ while y:
123
+ x, y = y, x % y
124
+ return x
125
+
126
+
127
+ def simplify(n: int, d: int) -> tuple[int, int]:
128
+ g = _gcd(n, d)
129
+ return (n // g, d // g)
130
+
131
+
132
+ def compare_fn(dep: Dependency) -> tuple[Dependency, str]:
133
+ return (dep, pt.pretty(dep))
134
+
135
+
136
+ def sort_deps(deps: list[Dependency]) -> list[Dependency]:
137
+ return sorted(deps, key=compare_fn)
138
+
139
+
140
+ class Problem:
141
+ """Describe one problem to solve."""
142
+
143
+ @classmethod
144
+ def from_txt_file(
145
+ cls, fname: str, to_dict: bool = False, translate: bool = True
146
+ ):
147
+ """Load a problem from a text file."""
148
+ with open(fname, 'r') as f:
149
+ lines = f.read().split('\n')
150
+
151
+ lines = [l for l in lines if l]
152
+ data = [
153
+ cls.from_txt(url + '\n' + problem, translate)
154
+ for (url, problem) in reshape(lines, 2)
155
+ ]
156
+ if to_dict:
157
+ return cls.to_dict(data)
158
+ return data
159
+
160
+ @classmethod
161
+ def from_txt(cls, data: str, translate: bool = True) -> Problem:
162
+ """Load a problem from a str object."""
163
+ url = ''
164
+ if '\n' in data:
165
+ url, data = data.split('\n')
166
+
167
+ if ' ? ' in data:
168
+ clauses, goal = data.split(' ? ')
169
+ goal = Construction.from_txt(goal)
170
+ else:
171
+ clauses, goal = data, None
172
+
173
+ clauses = clauses.split('; ')
174
+ problem = Problem(
175
+ url=url, clauses=[Clause.from_txt(c) for c in clauses], goal=goal
176
+ )
177
+ if translate:
178
+ return problem.translate()
179
+ return problem
180
+
181
+ @classmethod
182
+ def to_dict(cls, data: list[Problem]) -> dict[str, Problem]:
183
+ return {p.url: p for p in data}
184
+
185
+ def __init__(self, url: str, clauses: list[Clause], goal: Construction):
186
+ self.url = url
187
+ self.clauses = clauses
188
+ self.goal = goal
189
+
190
+ def copy(self) -> Problem:
191
+ return Problem(self.url, list(self.clauses), self.goal)
192
+
193
+ def translate(self) -> Problem: # to single-char point names
194
+ """Translate point names into alphabetical."""
195
+ mapping = {}
196
+ clauses = []
197
+
198
+ for clause in self.clauses:
199
+ clauses.append(clause.translate(mapping))
200
+
201
+ if self.goal:
202
+ goal = self.goal.translate(mapping)
203
+ else:
204
+ goal = self.goal
205
+
206
+ p = Problem(self.url, clauses, goal)
207
+ p.mapping = mapping
208
+ return p
209
+
210
+ def txt(self) -> str:
211
+ return (
212
+ '; '.join([c.txt() for c in self.clauses]) + ' ? ' + self.goal.txt()
213
+ if self.goal
214
+ else ''
215
+ )
216
+
217
+ def setup_str_from_problem(self, definitions: list[Definition]) -> str:
218
+ """Construct the <theorem_premises> string from Problem object."""
219
+ ref = 0
220
+
221
+ string = []
222
+ for clause in self.clauses:
223
+ group = {}
224
+ p2deps = defaultdict(list)
225
+ for c in clause.constructions:
226
+ cdef = definitions[c.name]
227
+
228
+ if len(c.args) != len(cdef.construction.args):
229
+ assert len(c.args) + len(clause.points) == len(cdef.construction.args)
230
+ c.args = clause.points + c.args
231
+
232
+ mapping = dict(zip(cdef.construction.args, c.args))
233
+ for points, bs in cdef.basics:
234
+ points = tuple([mapping[x] for x in points])
235
+ for p in points:
236
+ group[p] = points
237
+
238
+ for b in bs:
239
+ args = [mapping[a] for a in b.args]
240
+ name = b.name
241
+ if b.name in ['s_angle', 'aconst']:
242
+ x, y, z, v = args
243
+ name = 'aconst'
244
+ v = int(v)
245
+
246
+ if v < 0:
247
+ v = -v
248
+ x, z = z, x
249
+
250
+ m, n = simplify(int(v), 180)
251
+ args = [y, z, y, x, f'{m}pi/{n}']
252
+
253
+ p2deps[points].append(hashed_txt(name, args))
254
+
255
+ for k, v in p2deps.items():
256
+ p2deps[k] = sort_deps(v)
257
+
258
+ points = clause.points
259
+ while points:
260
+ p = points[0]
261
+ gr = group[p]
262
+ points = [x for x in points if x not in gr]
263
+
264
+ deps_str = []
265
+ for dep in p2deps[gr]:
266
+ ref_str = '{:02}'.format(ref)
267
+ dep_str = pt.pretty(dep)
268
+
269
+ if dep[0] == 'aconst':
270
+ m, n = map(int, dep[-1].split('pi/'))
271
+ mn = f'{m}. pi / {n}.'
272
+ dep_str = ' '.join(dep_str.split()[:-1] + [mn])
273
+
274
+ deps_str.append(dep_str + ' ' + ref_str)
275
+ ref += 1
276
+
277
+ string.append(' '.join(gr) + ' : ' + ' '.join(deps_str))
278
+
279
+ string = '{S} ' + ' ; '.join([s.strip() for s in string])
280
+ goal = self.goal
281
+ string += ' ? ' + pt.pretty([goal.name] + goal.args)
282
+ return string
283
+
284
+
285
+ def parse_rely(s: str) -> dict[str, str]:
286
+ result = {}
287
+ if not s:
288
+ return result
289
+ s = [x.strip() for x in s.split(',')]
290
+ for x in s:
291
+ a, b = x.split(':')
292
+ a, b = a.strip().split(), b.strip().split()
293
+ result.update({m: b for m in a})
294
+ return result
295
+
296
+
297
+ class Definition:
298
+ """Definitions of construction statements."""
299
+
300
+ @classmethod
301
+ def from_txt_file(cls, fname: str, to_dict: bool = False) -> Definition:
302
+ with open(fname, 'r') as f:
303
+ lines = f.read()
304
+ return cls.from_string(lines, to_dict)
305
+
306
+ @classmethod
307
+ def from_string(cls, string: str, to_dict: bool = False) -> Definition:
308
+ lines = string.split('\n')
309
+ data = [cls.from_txt('\n'.join(group)) for group in reshape(lines, 6)]
310
+ if to_dict:
311
+ return cls.to_dict(data)
312
+ return data
313
+
314
+ @classmethod
315
+ def to_dict(cls, data: list[Definition]) -> dict[str, Definition]:
316
+ return {d.construction.name: d for d in data}
317
+
318
+ @classmethod
319
+ def from_txt(cls, data: str) -> Definition:
320
+ """Load definitions from a str object."""
321
+ construction, rely, deps, basics, numerics, _ = data.split('\n')
322
+ basics = [] if not basics else [b.strip() for b in basics.split(';')]
323
+
324
+ levels = []
325
+ for bs in basics:
326
+ if ':' in bs:
327
+ points, bs = bs.split(':')
328
+ points = points.strip().split()
329
+ else:
330
+ points = []
331
+ if bs.strip():
332
+ bs = [Construction.from_txt(b.strip()) for b in bs.strip().split(',')]
333
+ else:
334
+ bs = []
335
+ levels.append((points, bs))
336
+
337
+ numerics = [] if not numerics else numerics.split(', ')
338
+
339
+ return Definition(
340
+ construction=Construction.from_txt(construction),
341
+ rely=parse_rely(rely),
342
+ deps=Clause.from_txt(deps),
343
+ basics=levels,
344
+ numerics=[Construction.from_txt(c) for c in numerics],
345
+ )
346
+
347
+ def __init__(
348
+ self,
349
+ construction: Construction,
350
+ rely: dict[str, str],
351
+ deps: Clause,
352
+ basics: list[tuple[list[str], list[Construction]]],
353
+ numerics: list[Construction],
354
+ ):
355
+ self.construction = construction
356
+ self.rely = rely
357
+ self.deps = deps
358
+ self.basics = basics
359
+ self.numerics = numerics
360
+
361
+ args = set()
362
+ for num in numerics:
363
+ args.update(num.args)
364
+
365
+ self.points = []
366
+ self.args = []
367
+ for p in self.construction.args:
368
+ if p in args:
369
+ self.args.append(p)
370
+ else:
371
+ self.points.append(p)
372
+
373
+
374
+ class Theorem:
375
+ """Deduction rule."""
376
+
377
+ @classmethod
378
+ def from_txt_file(cls, fname: str, to_dict: bool = False) -> Theorem:
379
+ with open(fname, 'r') as f:
380
+ theorems = f.read()
381
+ return cls.from_string(theorems, to_dict)
382
+
383
+ @classmethod
384
+ def from_string(cls, string: str, to_dict: bool = False) -> Theorem:
385
+ """Load deduction rule from a str object."""
386
+ theorems = string.split('\n')
387
+ theorems = [l for l in theorems if l and not l.startswith('#')]
388
+ theorems = [cls.from_txt(l) for l in theorems]
389
+
390
+ for i, th in enumerate(theorems):
391
+ th.rule_name = 'r{:02}'.format(i)
392
+
393
+ if to_dict:
394
+ result = {}
395
+ for t in theorems:
396
+ if t.name in result:
397
+ t.name += '_'
398
+ result[t.rule_name] = t
399
+
400
+ return result
401
+
402
+ return theorems
403
+
404
+ @classmethod
405
+ def from_txt(cls, data: str) -> Theorem:
406
+ premises, conclusion = data.split(' => ')
407
+ premises = premises.split(', ')
408
+ conclusion = conclusion.split(', ')
409
+ return Theorem(
410
+ premise=[Construction.from_txt(p) for p in premises],
411
+ conclusion=[Construction.from_txt(c) for c in conclusion],
412
+ )
413
+
414
+ def __init__(
415
+ self, premise: list[Construction], conclusion: list[Construction]
416
+ ):
417
+ if len(conclusion) != 1:
418
+ raise ValueError('Cannot have more than one conclusion')
419
+ self.name = '_'.join([p.name for p in premise + conclusion])
420
+ self.premise = premise
421
+ self.conclusion = conclusion
422
+ self.is_arg_reduce = False
423
+
424
+ assert len(self.conclusion) == 1
425
+ con = self.conclusion[0]
426
+
427
+ if con.name in [
428
+ 'eqratio3',
429
+ 'midp',
430
+ 'contri',
431
+ 'simtri',
432
+ 'contri2',
433
+ 'simtri2',
434
+ 'simtri*',
435
+ 'contri*',
436
+ ]:
437
+ return
438
+
439
+ prem_args = set(sum([p.args for p in self.premise], []))
440
+ con_args = set(con.args)
441
+ if len(prem_args) <= len(con_args):
442
+ self.is_arg_reduce = True
443
+
444
+ def txt(self) -> str:
445
+ premise_txt = ', '.join([clause.txt() for clause in self.premise])
446
+ conclusion_txt = ', '.join([clause.txt() for clause in self.conclusion])
447
+ return f'{premise_txt} => {conclusion_txt}'
448
+
449
+ def conclusion_name_args(
450
+ self, mapping: dict[str, gm.Point]
451
+ ) -> tuple[str, list[gm.Point]]:
452
+ mapping = {arg: p for arg, p in mapping.items() if isinstance(arg, str)}
453
+ c = self.conclusion[0]
454
+ args = [mapping[a] for a in c.args]
455
+ return c.name, args
456
+
457
+
458
+ def why_eqratio(
459
+ d1: gm.Direction,
460
+ d2: gm.Direction,
461
+ d3: gm.Direction,
462
+ d4: gm.Direction,
463
+ level: int,
464
+ ) -> list[Dependency]:
465
+ """Why two ratios are equal, returns a Dependency objects."""
466
+ all12 = list(gm.all_ratios(d1, d2, level))
467
+ all34 = list(gm.all_ratios(d3, d4, level))
468
+
469
+ min_why = None
470
+ for ang12, d1s, d2s in all12:
471
+ for ang34, d3s, d4s in all34:
472
+ why0 = gm.why_equal(ang12, ang34, level)
473
+ if why0 is None:
474
+ continue
475
+ d1_, d2_ = ang12._l
476
+ d3_, d4_ = ang34._l
477
+ why1 = gm.bfs_backtrack(d1, [d1_], d1s)
478
+ why2 = gm.bfs_backtrack(d2, [d2_], d2s)
479
+ why3 = gm.bfs_backtrack(d3, [d3_], d3s)
480
+ why4 = gm.bfs_backtrack(d4, [d4_], d4s)
481
+ why = why0 + why1 + why2 + why3 + why4
482
+ if min_why is None or len(why) < len(min_why[0]):
483
+ min_why = why, ang12, ang34, why0, why1, why2, why3, why4
484
+
485
+ if min_why is None:
486
+ return None
487
+
488
+ _, ang12, ang34, why0, why1, why2, why3, why4 = min_why
489
+ d1_, d2_ = ang12._l
490
+ d3_, d4_ = ang34._l
491
+
492
+ if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
493
+ return why0
494
+
495
+ (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
496
+ (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
497
+ deps = []
498
+ if why0:
499
+ dep = Dependency('eqratio', [a_, b_, c_, d_, e_, f_, g_, h_], '', level)
500
+ dep.why = why0
501
+ deps.append(dep)
502
+
503
+ (a, b), (c, d) = d1._obj.points, d2._obj.points
504
+ (e, f), (g, h) = d3._obj.points, d4._obj.points
505
+ for why, (x, y), (x_, y_) in zip(
506
+ [why1, why2, why3, why4],
507
+ [(a, b), (c, d), (e, f), (g, h)],
508
+ [(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
509
+ ):
510
+ if why:
511
+ dep = Dependency('cong', [x, y, x_, y_], '', level)
512
+ dep.why = why
513
+ deps.append(dep)
514
+
515
+ return deps
516
+
517
+
518
+ def why_eqangle(
519
+ d1: gm.Direction,
520
+ d2: gm.Direction,
521
+ d3: gm.Direction,
522
+ d4: gm.Direction,
523
+ level: int,
524
+ verbose: bool = False,
525
+ ) -> list[Dependency]:
526
+ """Why two angles are equal, returns a Dependency objects."""
527
+ all12 = list(gm.all_angles(d1, d2, level))
528
+ all34 = list(gm.all_angles(d3, d4, level))
529
+
530
+ min_why = None
531
+ for ang12, d1s, d2s in all12:
532
+ for ang34, d3s, d4s in all34:
533
+ why0 = gm.why_equal(ang12, ang34, level)
534
+ if why0 is None:
535
+ continue
536
+ d1_, d2_ = ang12._d
537
+ d3_, d4_ = ang34._d
538
+ why1 = gm.bfs_backtrack(d1, [d1_], d1s)
539
+ why2 = gm.bfs_backtrack(d2, [d2_], d2s)
540
+ why3 = gm.bfs_backtrack(d3, [d3_], d3s)
541
+ why4 = gm.bfs_backtrack(d4, [d4_], d4s)
542
+ why = why0 + why1 + why2 + why3 + why4
543
+ if min_why is None or len(why) < len(min_why[0]):
544
+ min_why = why, ang12, ang34, why0, why1, why2, why3, why4
545
+
546
+ if min_why is None:
547
+ return None
548
+
549
+ _, ang12, ang34, why0, why1, why2, why3, why4 = min_why
550
+ why0 = gm.why_equal(ang12, ang34, level)
551
+ d1_, d2_ = ang12._d
552
+ d3_, d4_ = ang34._d
553
+
554
+ if d1 == d1_ and d2 == d2_ and d3 == d3_ and d4 == d4_:
555
+ return (d1_, d2_, d3_, d4_), why0
556
+
557
+ (a_, b_), (c_, d_) = d1_._obj.points, d2_._obj.points
558
+ (e_, f_), (g_, h_) = d3_._obj.points, d4_._obj.points
559
+ deps = []
560
+ if why0:
561
+ dep = Dependency('eqangle', [a_, b_, c_, d_, e_, f_, g_, h_], '', None)
562
+ dep.why = why0
563
+ deps.append(dep)
564
+
565
+ (a, b), (c, d) = d1._obj.points, d2._obj.points
566
+ (e, f), (g, h) = d3._obj.points, d4._obj.points
567
+ for why, d_xy, (x, y), d_xy_, (x_, y_) in zip(
568
+ [why1, why2, why3, why4],
569
+ [d1, d2, d3, d4],
570
+ [(a, b), (c, d), (e, f), (g, h)],
571
+ [d1_, d2_, d3_, d4_],
572
+ [(a_, b_), (c_, d_), (e_, f_), (g_, h_)],
573
+ ):
574
+ xy, xy_ = d_xy._obj, d_xy_._obj
575
+ if why:
576
+ if xy == xy_:
577
+ name = 'collx'
578
+ else:
579
+ name = 'para'
580
+ dep = Dependency(name, [x_, y_, x, y], '', None)
581
+ dep.why = why
582
+ deps.append(dep)
583
+
584
+ return (d1_, d2_, d3_, d4_), deps
585
+
586
+
587
+ CONSTRUCTION_RULE = 'c0'
588
+
589
+
590
+ class EmptyDependency:
591
+ """Empty dependency predicate ready to get filled up."""
592
+
593
+ def __init__(self, level: int, rule_name: str):
594
+ self.level = level
595
+ self.rule_name = rule_name or ''
596
+ self.empty = True
597
+ self.why = []
598
+ self.trace = None
599
+
600
+ def populate(self, name: str, args: list[gm.Point]) -> Dependency:
601
+ dep = Dependency(name, args, self.rule_name, self.level)
602
+ dep.trace2 = self.trace
603
+ dep.why = list(self.why)
604
+ return dep
605
+
606
+ def copy(self) -> EmptyDependency:
607
+ other = EmptyDependency(self.level, self.rule_name)
608
+ other.why = list(self.why)
609
+ return other
610
+
611
+ def extend(
612
+ self,
613
+ g: Any,
614
+ name0: str,
615
+ args0: list[gm.Point],
616
+ name: str,
617
+ args: list[gm.Point],
618
+ ) -> EmptyDependency:
619
+ """Extend the dependency list by (name, args)."""
620
+ dep0 = self.populate(name0, args0)
621
+ deps = EmptyDependency(level=self.level, rule_name=None)
622
+ dep = Dependency(name, args, None, deps.level)
623
+ deps.why = [dep0, dep.why_me_or_cache(g, None)]
624
+ return deps
625
+
626
+ def extend_many(
627
+ self,
628
+ g: Any,
629
+ name0: str,
630
+ args0: list[gm.Point],
631
+ name_args: list[tuple[str, list[gm.Point]]],
632
+ ) -> EmptyDependency:
633
+ """Extend the dependency list by many name_args."""
634
+ if not name_args:
635
+ return self
636
+ dep0 = self.populate(name0, args0)
637
+ deps = EmptyDependency(level=self.level, rule_name=None)
638
+ deps.why = [dep0]
639
+ for name, args in name_args:
640
+ dep = Dependency(name, args, None, deps.level)
641
+ deps.why += [dep.why_me_or_cache(g, None)]
642
+ return deps
643
+
644
+
645
+ def maybe_make_equal_pairs(
646
+ a: gm.Point,
647
+ b: gm.Point,
648
+ c: gm.Point,
649
+ d: gm.Point,
650
+ m: gm.Point,
651
+ n: gm.Point,
652
+ p: gm.Point,
653
+ q: gm.Point,
654
+ ab: gm.Line,
655
+ mn: gm.Line,
656
+ g: Any,
657
+ level: int,
658
+ ) -> list[Dependency]:
659
+ """Make a-b:c-d==m-n:p-q in case a-b==m-n or c-d==p-q."""
660
+ if ab != mn:
661
+ return
662
+ why = []
663
+ eqname = 'para' if isinstance(ab, gm.Line) else 'cong'
664
+ colls = [a, b, m, n]
665
+ if len(set(colls)) > 2 and eqname == 'para':
666
+ dep = Dependency('collx', colls, None, level)
667
+ dep.why_me(g, level)
668
+ why += [dep]
669
+
670
+ dep = Dependency(eqname, [c, d, p, q], None, level)
671
+ dep.why_me(g, level)
672
+ why += [dep]
673
+ return why
674
+
675
+
676
+ class Dependency(Construction):
677
+ """Dependency is a predicate that other predicates depend on."""
678
+
679
+ def __init__(
680
+ self, name: str, args: list[gm.Point], rule_name: str, level: int
681
+ ):
682
+ super().__init__(name, args)
683
+ self.rule_name = rule_name or ''
684
+ self.level = level
685
+ self.why = []
686
+
687
+ self._stat = None
688
+ self.trace = None
689
+
690
+ def _find(self, dep_hashed: tuple[str, ...]) -> Dependency:
691
+ for w in self.why:
692
+ f = w._find(dep_hashed)
693
+ if f:
694
+ return f
695
+ if w.hashed() == dep_hashed:
696
+ return w
697
+
698
+ def remove_loop(self) -> Dependency:
699
+ f = self._find(self.hashed())
700
+ if f:
701
+ return f
702
+ return self
703
+
704
+ def copy(self) -> Dependency:
705
+ dep = Dependency(self.name, self.args, self.rule_name, self.level)
706
+ dep.trace = self.trace
707
+ dep.why = list(self.why)
708
+ return dep
709
+
710
+ def why_me_or_cache(self, g: Any, level: int) -> Dependency:
711
+ if self.hashed() in g.cache:
712
+ return g.cache[self.hashed()]
713
+ self.why_me(g, level)
714
+ return self
715
+
716
+ def populate(self, name: str, args: list[gm.Point]) -> Dependency:
717
+ assert self.rule_name == CONSTRUCTION_RULE, self.rule_name
718
+ dep = Dependency(self.name, self.args, self.rule_name, self.level)
719
+ dep.why = list(self.why)
720
+ return dep
721
+
722
+ def why_me(self, g: Any, level: int) -> None:
723
+ """Figure out the dependencies predicates of self."""
724
+ name, args = self.name, self.args
725
+
726
+ hashed_me = hashed(name, args)
727
+ if hashed_me in g.cache:
728
+ dep = g.cache[hashed_me]
729
+ self.why = dep.why
730
+ self.rule_name = dep.rule_name
731
+ return
732
+
733
+ if self.name == 'para':
734
+ a, b, c, d = self.args
735
+ if {a, b} == {c, d}:
736
+ self.why = []
737
+ return
738
+
739
+ ab = g._get_line(a, b)
740
+ cd = g._get_line(c, d)
741
+ if ab == cd:
742
+ if {a, b} == {c, d}:
743
+ self.why = []
744
+ self.rule_name = ''
745
+ return
746
+ dep = Dependency('coll', list({a, b, c, d}), 't??', None)
747
+ self.why = [dep.why_me_or_cache(g, level)]
748
+ return
749
+
750
+ for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
751
+ x_, y_ = xy.points
752
+ if {x, y} == {x_, y_}:
753
+ continue
754
+ d = Dependency('collx', [x, y, x_, y_], None, level)
755
+ self.why += [d.why_me_or_cache(g, level)]
756
+
757
+ whypara = g.why_equal(ab, cd, None)
758
+ self.why += whypara
759
+
760
+ elif self.name == 'midp':
761
+ m, a, b = self.args
762
+ ma = g._get_segment(m, a)
763
+ mb = g._get_segment(m, b)
764
+ dep = Dependency('coll', [m, a, b], None, None).why_me_or_cache(g, None)
765
+ self.why = [dep] + g.why_equal(ma, mb, level)
766
+
767
+ elif self.name == 'perp':
768
+ a, b, c, d = self.args
769
+ ab = g._get_line(a, b)
770
+ cd = g._get_line(c, d)
771
+ for (x, y), xy in zip([(a, b), (c, d)], [ab, cd]):
772
+ x_, y_ = xy.points
773
+ if {x, y} == {x_, y_}:
774
+ continue
775
+ d = Dependency('collx', [x, y, x_, y_], None, level)
776
+ self.why += [d.why_me_or_cache(g, level)]
777
+
778
+ _, why = why_eqangle(ab._val, cd._val, cd._val, ab._val, level)
779
+ a, b = ab.points
780
+ c, d = cd.points
781
+
782
+ if hashed(self.name, [a, b, c, d]) != self.hashed():
783
+ d = Dependency(self.name, [a, b, c, d], None, level)
784
+ d.why = why
785
+ why = [d]
786
+
787
+ self.why += why
788
+
789
+ elif self.name == 'cong':
790
+ a, b, c, d = self.args
791
+ ab = g._get_segment(a, b)
792
+ cd = g._get_segment(c, d)
793
+
794
+ self.why = g.why_equal(ab, cd, level)
795
+
796
+ elif self.name == 'coll':
797
+ _, why = gm.line_of_and_why(self.args, level)
798
+ self.why = why
799
+
800
+ elif self.name == 'collx':
801
+ if g.check_coll(self.args):
802
+ args = list(set(self.args))
803
+ hashed_me = hashed('coll', args)
804
+ if hashed_me in g.cache:
805
+ dep = g.cache[hashed_me]
806
+ self.why = [dep]
807
+ self.rule_name = ''
808
+ return
809
+ _, self.why = gm.line_of_and_why(args, level)
810
+ else:
811
+ self.name = 'para'
812
+ self.why_me(g, level)
813
+
814
+ elif self.name == 'cyclic':
815
+ _, why = gm.circle_of_and_why(self.args, level)
816
+ self.why = why
817
+
818
+ elif self.name == 'circle':
819
+ o, a, b, c = self.args
820
+ oa = g._get_segment(o, a)
821
+ ob = g._get_segment(o, b)
822
+ oc = g._get_segment(o, c)
823
+ self.why = g.why_equal(oa, ob, level) + g.why_equal(oa, oc, level)
824
+
825
+ elif self.name in ['eqangle', 'eqangle6']:
826
+ a, b, c, d, m, n, p, q = self.args
827
+
828
+ ab, why1 = g.get_line_thru_pair_why(a, b)
829
+ cd, why2 = g.get_line_thru_pair_why(c, d)
830
+ mn, why3 = g.get_line_thru_pair_why(m, n)
831
+ pq, why4 = g.get_line_thru_pair_why(p, q)
832
+
833
+ if ab is None or cd is None or mn is None or pq is None:
834
+ if {a, b} == {m, n}:
835
+ d = Dependency('para', [c, d, p, q], None, level)
836
+ self.why = [d.why_me_or_cache(g, level)]
837
+ if {a, b} == {c, d}:
838
+ d = Dependency('para', [p, q, m, n], None, level)
839
+ self.why = [d.why_me_or_cache(g, level)]
840
+ if {c, d} == {p, q}:
841
+ d = Dependency('para', [a, b, m, n], None, level)
842
+ self.why = [d.why_me_or_cache(g, level)]
843
+ if {p, q} == {m, n}:
844
+ d = Dependency('para', [a, b, c, d], None, level)
845
+ self.why = [d.why_me_or_cache(g, level)]
846
+ return
847
+
848
+ for (x, y), xy, whyxy in zip(
849
+ [(a, b), (c, d), (m, n), (p, q)],
850
+ [ab, cd, mn, pq],
851
+ [why1, why2, why3, why4],
852
+ ):
853
+ x_, y_ = xy.points
854
+ if {x, y} == {x_, y_}:
855
+ continue
856
+ d = Dependency('collx', [x, y, x_, y_], None, level)
857
+ d.why = whyxy
858
+ self.why += [d]
859
+
860
+ a, b = ab.points
861
+ c, d = cd.points
862
+ m, n = mn.points
863
+ p, q = pq.points
864
+ diff = hashed(self.name, [a, b, c, d, m, n, p, q]) != self.hashed()
865
+
866
+ whyeqangle = None
867
+ if ab._val and cd._val and mn._val and pq._val:
868
+ whyeqangle = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
869
+
870
+ if whyeqangle:
871
+ (dab, dcd, dmn, dpq), whyeqangle = whyeqangle
872
+ if diff:
873
+ d = Dependency('eqangle', [a, b, c, d, m, n, p, q], None, level)
874
+ d.why = whyeqangle
875
+ whyeqangle = [d]
876
+ self.why += whyeqangle
877
+
878
+ else:
879
+ if (ab == cd and mn == pq) or (ab == mn and cd == pq):
880
+ self.why += []
881
+ elif ab == mn:
882
+ self.why += maybe_make_equal_pairs(
883
+ a, b, c, d, m, n, p, q, ab, mn, g, level
884
+ )
885
+ elif cd == pq:
886
+ self.why += maybe_make_equal_pairs(
887
+ c, d, a, b, p, q, m, n, cd, pq, g, level
888
+ )
889
+ elif ab == cd:
890
+ self.why += maybe_make_equal_pairs(
891
+ a, b, m, n, c, d, p, q, ab, cd, g, level
892
+ )
893
+ elif mn == pq:
894
+ self.why += maybe_make_equal_pairs(
895
+ m, n, a, b, p, q, c, d, mn, pq, g, level
896
+ )
897
+ elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
898
+ dep1 = Dependency('para', [a, b, m, n], None, level)
899
+ dep1.why_me(g, level)
900
+ dep2 = Dependency('para', [c, d, p, q], None, level)
901
+ dep2.why_me(g, level)
902
+ self.why += [dep1, dep2]
903
+ elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
904
+ dep1 = Dependency('para', [a, b, c, d], None, level)
905
+ dep1.why_me(g, level)
906
+ dep2 = Dependency('para', [m, n, p, q], None, level)
907
+ dep2.why_me(g, level)
908
+ self.why += [dep1, dep2]
909
+ elif ab._val and cd._val and mn._val and pq._val:
910
+ self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
911
+
912
+ elif self.name in ['eqratio', 'eqratio6']:
913
+ a, b, c, d, m, n, p, q = self.args
914
+ ab = g._get_segment(a, b)
915
+ cd = g._get_segment(c, d)
916
+ mn = g._get_segment(m, n)
917
+ pq = g._get_segment(p, q)
918
+
919
+ if ab is None or cd is None or mn is None or pq is None:
920
+ if {a, b} == {m, n}:
921
+ d = Dependency('cong', [c, d, p, q], None, level)
922
+ self.why = [d.why_me_or_cache(g, level)]
923
+ if {a, b} == {c, d}:
924
+ d = Dependency('cong', [p, q, m, n], None, level)
925
+ self.why = [d.why_me_or_cache(g, level)]
926
+ if {c, d} == {p, q}:
927
+ d = Dependency('cong', [a, b, m, n], None, level)
928
+ self.why = [d.why_me_or_cache(g, level)]
929
+ if {p, q} == {m, n}:
930
+ d = Dependency('cong', [a, b, c, d], None, level)
931
+ self.why = [d.why_me_or_cache(g, level)]
932
+ return
933
+
934
+ if ab._val and cd._val and mn._val and pq._val:
935
+ self.why = why_eqratio(ab._val, cd._val, mn._val, pq._val, level)
936
+
937
+ if self.why is None:
938
+ self.why = []
939
+ if (ab == cd and mn == pq) or (ab == mn and cd == pq):
940
+ self.why = []
941
+ elif ab == mn:
942
+ self.why += maybe_make_equal_pairs(
943
+ a, b, c, d, m, n, p, q, ab, mn, g, level
944
+ )
945
+ elif cd == pq:
946
+ self.why += maybe_make_equal_pairs(
947
+ c, d, a, b, p, q, m, n, cd, pq, g, level
948
+ )
949
+ elif ab == cd:
950
+ self.why += maybe_make_equal_pairs(
951
+ a, b, m, n, c, d, p, q, ab, cd, g, level
952
+ )
953
+ elif mn == pq:
954
+ self.why += maybe_make_equal_pairs(
955
+ m, n, a, b, p, q, c, d, mn, pq, g, level
956
+ )
957
+ elif g.is_equal(ab, mn) or g.is_equal(cd, pq):
958
+ dep1 = Dependency('cong', [a, b, m, n], None, level)
959
+ dep1.why_me(g, level)
960
+ dep2 = Dependency('cong', [c, d, p, q], None, level)
961
+ dep2.why_me(g, level)
962
+ self.why += [dep1, dep2]
963
+ elif g.is_equal(ab, cd) or g.is_equal(mn, pq):
964
+ dep1 = Dependency('cong', [a, b, c, d], None, level)
965
+ dep1.why_me(g, level)
966
+ dep2 = Dependency('cong', [m, n, p, q], None, level)
967
+ dep2.why_me(g, level)
968
+ self.why += [dep1, dep2]
969
+ elif ab._val and cd._val and mn._val and pq._val:
970
+ self.why = why_eqangle(ab._val, cd._val, mn._val, pq._val, level)
971
+
972
+ elif self.name in ['diff', 'npara', 'nperp', 'ncoll', 'sameside']:
973
+ self.why = []
974
+
975
+ elif self.name == 'simtri':
976
+ a, b, c, x, y, z = self.args
977
+ dep1 = Dependency('eqangle', [a, b, a, c, x, y, x, z], '', level)
978
+ dep1.why_me(g, level)
979
+ dep2 = Dependency('eqangle', [b, a, b, c, y, x, y, z], '', level)
980
+ dep2.why_me(g, level)
981
+ self.rule_name = 'r34'
982
+ self.why = [dep1, dep2]
983
+
984
+ elif self.name == 'contri':
985
+ a, b, c, x, y, z = self.args
986
+ dep1 = Dependency('cong', [a, b, x, y], '', level)
987
+ dep1.why_me(g, level)
988
+ dep2 = Dependency('cong', [b, c, y, z], '', level)
989
+ dep2.why_me(g, level)
990
+ dep3 = Dependency('cong', [c, a, z, x], '', level)
991
+ dep3.why_me(g, level)
992
+ self.rule_name = 'r32'
993
+ self.why = [dep1, dep2, dep3]
994
+
995
+ elif self.name == 'ind':
996
+ pass
997
+
998
+ elif self.name == 'aconst':
999
+ a, b, c, d, ang0 = self.args
1000
+
1001
+ measure = ang0._val
1002
+
1003
+ for ang in measure.neighbors(gm.Angle):
1004
+ if ang == ang0:
1005
+ continue
1006
+ d1, d2 = ang._d
1007
+ l1, l2 = d1._obj, d2._obj
1008
+ (a1, b1), (c1, d1) = l1.points, l2.points
1009
+
1010
+ if not g.check_para_or_coll([a, b, a1, b1]) or not g.check_para_or_coll(
1011
+ [c, d, c1, d1]
1012
+ ):
1013
+ continue
1014
+
1015
+ self.why = []
1016
+ for args in [(a, b, a1, b1), (c, d, c1, d1)]:
1017
+ if g.check_coll(args):
1018
+ if len(set(args)) > 2:
1019
+ dep = Dependency('coll', args, None, None)
1020
+ self.why.append(dep.why_me_or_cache(g, level))
1021
+ else:
1022
+ dep = Dependency('para', args, None, None)
1023
+ self.why.append(dep.why_me_or_cache(g, level))
1024
+
1025
+ self.why += gm.why_equal(ang, ang0)
1026
+ break
1027
+
1028
+ elif self.name == 'rconst':
1029
+ a, b, c, d, rat0 = self.args
1030
+
1031
+ val = rat0._val
1032
+
1033
+ for rat in val.neighbors(gm.Ratio):
1034
+ if rat == rat0:
1035
+ continue
1036
+ l1, l2 = rat._l
1037
+ s1, s2 = l1._obj, l2._obj
1038
+ (a1, b1), (c1, d1) = list(s1.points), list(s2.points)
1039
+
1040
+ if not g.check_cong([a, b, a1, b1]) or not g.check_cong([c, d, c1, d1]):
1041
+ continue
1042
+
1043
+ self.why = []
1044
+ for args in [(a, b, a1, b1), (c, d, c1, d1)]:
1045
+ if len(set(args)) > 2:
1046
+ dep = Dependency('cong', args, None, None)
1047
+ self.why.append(dep.why_me_or_cache(g, level))
1048
+
1049
+ self.why += gm.why_equal(rat, rat0)
1050
+ break
1051
+
1052
+ else:
1053
+ raise ValueError('Not recognize', self.name)
1054
+
1055
+ def hashed(self, rename: bool = False) -> tuple[str, ...]:
1056
+ return hashed(self.name, self.args, rename=rename)
1057
+
1058
+
1059
+ def hashed(
1060
+ name: str, args: list[gm.Point], rename: bool = False
1061
+ ) -> tuple[str, ...]:
1062
+ if name == 's_angle':
1063
+ args = [p.name if not rename else p.new_name for p in args[:-1]] + [
1064
+ str(args[-1])
1065
+ ]
1066
+ else:
1067
+ args = [p.name if not rename else p.new_name for p in args]
1068
+ return hashed_txt(name, args)
1069
+
1070
+
1071
+ def hashed_txt(name: str, args: list[str]) -> tuple[str, ...]:
1072
+ """Return a tuple unique to name and args upto arg permutation equivariant."""
1073
+
1074
+ if name in ['const', 'aconst', 'rconst']:
1075
+ a, b, c, d, y = args
1076
+ a, b = sorted([a, b])
1077
+ c, d = sorted([c, d])
1078
+ return name, a, b, c, d, y
1079
+
1080
+ if name in ['npara', 'nperp', 'para', 'cong', 'perp', 'collx']:
1081
+ a, b, c, d = args
1082
+
1083
+ a, b = sorted([a, b])
1084
+ c, d = sorted([c, d])
1085
+ (a, b), (c, d) = sorted([(a, b), (c, d)])
1086
+
1087
+ return (name, a, b, c, d)
1088
+
1089
+ if name in ['midp', 'midpoint']:
1090
+ a, b, c = args
1091
+ b, c = sorted([b, c])
1092
+ return (name, a, b, c)
1093
+
1094
+ if name in ['coll', 'cyclic', 'ncoll', 'diff', 'triangle']:
1095
+ return (name,) + tuple(sorted(list(set(args))))
1096
+
1097
+ if name == 'circle':
1098
+ x, a, b, c = args
1099
+ return (name, x) + tuple(sorted([a, b, c]))
1100
+
1101
+ if name in ['eqangle', 'eqratio', 'eqangle6', 'eqratio6']:
1102
+ a, b, c, d, e, f, g, h = args
1103
+ a, b = sorted([a, b])
1104
+ c, d = sorted([c, d])
1105
+ e, f = sorted([e, f])
1106
+ g, h = sorted([g, h])
1107
+ if tuple(sorted([a, b, e, f])) > tuple(sorted([c, d, g, h])):
1108
+ a, b, e, f, c, d, g, h = c, d, g, h, a, b, e, f
1109
+ if (a, b, c, d) > (e, f, g, h):
1110
+ a, b, c, d, e, f, g, h = e, f, g, h, a, b, c, d
1111
+
1112
+ if name == 'eqangle6':
1113
+ name = 'eqangle'
1114
+ if name == 'eqratio6':
1115
+ name = 'eqratio'
1116
+ return (name,) + (a, b, c, d, e, f, g, h)
1117
+
1118
+ if name in ['contri', 'simtri', 'simtri2', 'contri2', 'contri*', 'simtri*']:
1119
+ a, b, c, x, y, z = args
1120
+ (a, x), (b, y), (c, z) = sorted([(a, x), (b, y), (c, z)], key=sorted)
1121
+ (a, b, c), (x, y, z) = sorted([(a, b, c), (x, y, z)], key=sorted)
1122
+ return (name, a, b, c, x, y, z)
1123
+
1124
+ if name in ['eqratio3']:
1125
+ a, b, c, d, o, o = args # pylint: disable=redeclared-assigned-name
1126
+ (a, c), (b, d) = sorted([(a, c), (b, d)], key=sorted)
1127
+ (a, b), (c, d) = sorted([(a, b), (c, d)], key=sorted)
1128
+ return (name, a, b, c, d, o, o)
1129
+
1130
+ if name in ['sameside', 's_angle']:
1131
+ return (name,) + tuple(args)
1132
+
1133
+ raise ValueError(f'Not recognize {name} to hash.')
external/alphageometry/problem_test.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 DeepMind Technologies Limited
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # ==============================================================================
15
+
16
+ """Unit tests for problem.py."""
17
+ import unittest
18
+
19
+ from absl.testing import absltest
20
+ import problem as pr
21
+
22
+
23
+ class ProblemTest(unittest.TestCase):
24
+
25
+ @classmethod
26
+ def setUpClass(cls):
27
+ super().setUpClass()
28
+ cls.defs = pr.Definition.from_txt_file('defs.txt', to_dict=True)
29
+
30
+ def test_orthocenter_no_translate(self):
31
+ txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long
32
+
33
+ # read the txt into pr.Problem object, do not change the name of points:
34
+ p = pr.Problem.from_txt(txt, translate=False)
35
+
36
+ # This is fed into the LM, translating from constructive to constrained:
37
+ setup_str = p.setup_str_from_problem(ProblemTest.defs)
38
+
39
+ self.assertEqual(
40
+ setup_str,
41
+ '{S} a : ; b : ; c : ; h : T a b c h 00 T a c b h 01 ? T a h b c',
42
+ )
43
+
44
+ def test_orthocenter_translate(self):
45
+ txt = 'a b c = triangle a b c; h = on_tline h b a c, on_tline h c a b ? perp a h b c' # pylint: disable=line-too-long
46
+
47
+ # Read the txt into pr.Problem object, change h -> d to match
48
+ # training data distribution.
49
+ p = pr.Problem.from_txt(txt, translate=True)
50
+
51
+ # This is fed into the LM, translating from constructive to constrained:
52
+ setup_str = p.setup_str_from_problem(ProblemTest.defs)
53
+
54
+ self.assertEqual(
55
+ setup_str,
56
+ '{S} a : ; b : ; c : ; d : T a b c d 00 T a c b d 01 ? T a d b c',
57
+ )
58
+
59
+
60
+ if __name__ == '__main__':
61
+ absltest.main()
external/alphageometry/requirements.in ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tensorflow==2.13.0
2
+ numpy==1.23.5
3
+ scipy==1.10.0
4
+ matplotlib==3.7.0
5
+ gdown==4.7.1
6
+ jax==0.4.6
7
+ jaxlib==0.4.6
8
+ flax==0.5.3
9
+ gin-config==0.5.0
10
+ gin==0.1.6
11
+ t5==0.9.4
12
+ sentencepiece==0.1.99
13
+ absl-py==1.4.0
14
+ clu==0.0.7
15
+ optax==0.1.7
16
+ seqio==0.0.18
17
+ tensorflow-datasets==4.9.3
external/alphageometry/requirements.txt ADDED
The diff for this file is too large to render. See raw diff