Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
ag4masses/alphageometry/alphageometry.py
CHANGED
|
@@ -33,6 +33,8 @@ import problem as pr
|
|
| 33 |
#=============
|
| 34 |
import sys, os, math, re
|
| 35 |
import multiprocessing
|
|
|
|
|
|
|
| 36 |
model = None # global variable used in multi-processing workers
|
| 37 |
|
| 38 |
_GIN_SEARCH_PATHS = flags.DEFINE_list(
|
|
@@ -199,12 +201,11 @@ def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
|
|
| 199 |
rule_name = r2name.get(con.rule_name, '')
|
| 200 |
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
|
| 201 |
solution += '{:03}. '.format(i + 1) + nl + '\n'
|
| 202 |
-
logging.info(solution)
|
| 203 |
if out_file:
|
| 204 |
with open(out_file, 'w') as f:
|
| 205 |
f.write(solution)
|
| 206 |
-
logging.info('Solution written to %s.', out_file)
|
| 207 |
-
|
| 208 |
|
| 209 |
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
|
| 210 |
lm.parse_gin_configuration(
|
|
@@ -233,7 +234,6 @@ def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
|
|
| 233 |
return False
|
| 234 |
|
| 235 |
write_solution(g, p, out_file)
|
| 236 |
-
|
| 237 |
gh.nm.draw(
|
| 238 |
g.type2nodes[gh.Point],
|
| 239 |
g.type2nodes[gh.Line],
|
|
@@ -598,7 +598,7 @@ def bqsearch(i_nd, srch_inputs, out_file) -> tuple[int, bool, list]: # ( iNode,
|
|
| 598 |
return (i_nd, False, ret)
|
| 599 |
|
| 600 |
def run_alphageometry(
|
| 601 |
-
#
|
| 602 |
p: pr.Problem,
|
| 603 |
search_depth: int,
|
| 604 |
beam_size: int,
|
|
@@ -739,9 +739,9 @@ def main(_):
|
|
| 739 |
run_ddar(g, this_problem, _OUT_FILE.value)
|
| 740 |
|
| 741 |
elif _MODE.value == 'alphageometry':
|
| 742 |
-
|
| 743 |
run_alphageometry(
|
| 744 |
-
|
| 745 |
this_problem,
|
| 746 |
_SEARCH_DEPTH.value,
|
| 747 |
_BEAM_SIZE.value,
|
|
|
|
| 33 |
#=============
|
| 34 |
import sys, os, math, re
|
| 35 |
import multiprocessing
|
| 36 |
+
import warnings
|
| 37 |
+
warnings.filterwarnings("ignore")
|
| 38 |
model = None # global variable used in multi-processing workers
|
| 39 |
|
| 40 |
_GIN_SEARCH_PATHS = flags.DEFINE_list(
|
|
|
|
| 201 |
rule_name = r2name.get(con.rule_name, '')
|
| 202 |
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
|
| 203 |
solution += '{:03}. '.format(i + 1) + nl + '\n'
|
| 204 |
+
# logging.info(solution)
|
| 205 |
if out_file:
|
| 206 |
with open(out_file, 'w') as f:
|
| 207 |
f.write(solution)
|
| 208 |
+
# logging.info('Solution written to %s.', out_file)
|
|
|
|
| 209 |
|
| 210 |
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
|
| 211 |
lm.parse_gin_configuration(
|
|
|
|
| 234 |
return False
|
| 235 |
|
| 236 |
write_solution(g, p, out_file)
|
|
|
|
| 237 |
gh.nm.draw(
|
| 238 |
g.type2nodes[gh.Point],
|
| 239 |
g.type2nodes[gh.Line],
|
|
|
|
| 598 |
return (i_nd, False, ret)
|
| 599 |
|
| 600 |
def run_alphageometry(
|
| 601 |
+
# model: lm.LanguageModelInference,
|
| 602 |
p: pr.Problem,
|
| 603 |
search_depth: int,
|
| 604 |
beam_size: int,
|
|
|
|
| 739 |
run_ddar(g, this_problem, _OUT_FILE.value)
|
| 740 |
|
| 741 |
elif _MODE.value == 'alphageometry':
|
| 742 |
+
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
|
| 743 |
run_alphageometry(
|
| 744 |
+
model,
|
| 745 |
this_problem,
|
| 746 |
_SEARCH_DEPTH.value,
|
| 747 |
_BEAM_SIZE.value,
|
ag4masses/alphageometry/numericals.py
CHANGED
|
@@ -29,6 +29,10 @@ from numpy.random import uniform as unif # pylint: disable=g-importing-member
|
|
| 29 |
import graph as gh
|
| 30 |
from collections import defaultdict
|
| 31 |
from itertools import combinations
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
matplotlib.use('TkAgg')
|
| 34 |
|
|
@@ -1060,7 +1064,7 @@ def _draw_line(
|
|
| 1060 |
|
| 1061 |
|
| 1062 |
def draw_line(
|
| 1063 |
-
ax: matplotlib.axes.Axes, line: Line, color: Any = 'white'
|
| 1064 |
) -> tuple[Point, Point]:
|
| 1065 |
"""Draw a line."""
|
| 1066 |
points = line.neighbors(gm.Point)
|
|
@@ -1080,8 +1084,11 @@ def draw_line(
|
|
| 1080 |
pmax = p, v
|
| 1081 |
|
| 1082 |
p1, p2 = pmin[0], pmax[0]
|
| 1083 |
-
|
| 1084 |
-
|
|
|
|
|
|
|
|
|
|
| 1085 |
|
| 1086 |
|
| 1087 |
def _draw_circle(
|
|
@@ -1315,10 +1322,6 @@ def highlight(
|
|
| 1315 |
) -> None:
|
| 1316 |
"""Draw highlights."""
|
| 1317 |
args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
|
| 1318 |
-
|
| 1319 |
-
if name == 'cyclic':
|
| 1320 |
-
a, b, c, d = args
|
| 1321 |
-
_draw_circle(ax, Circle(p1=a, p2=b, p3=c), color=color1, lw=2.0)
|
| 1322 |
if name == 'coll':
|
| 1323 |
a, b, c = args
|
| 1324 |
a, b = max(a, b, c), min(a, b, c)
|
|
@@ -1407,6 +1410,90 @@ def find_pairs_with_same_distance(line_lengths):
|
|
| 1407 |
|
| 1408 |
return result
|
| 1409 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1410 |
def _draw(
|
| 1411 |
ax: matplotlib.axes.Axes,
|
| 1412 |
points: list[gm.Point],
|
|
@@ -1471,6 +1558,7 @@ def _draw(
|
|
| 1471 |
|
| 1472 |
# Call the highlight function with the determined color
|
| 1473 |
highlight(ax, 'cong', [p1, p2, p3, p4], lcolor, color, color)
|
|
|
|
| 1474 |
if equals:
|
| 1475 |
for i, segs in enumerate(equals['segments']):
|
| 1476 |
color = colors[i % len(colors)]
|
|
|
|
| 29 |
import graph as gh
|
| 30 |
from collections import defaultdict
|
| 31 |
from itertools import combinations
|
| 32 |
+
import numpy as np
|
| 33 |
+
import matplotlib.patches
|
| 34 |
+
import matplotlib.pyplot as plt
|
| 35 |
+
from itertools import combinations
|
| 36 |
|
| 37 |
matplotlib.use('TkAgg')
|
| 38 |
|
|
|
|
| 1064 |
|
| 1065 |
|
| 1066 |
def draw_line(
|
| 1067 |
+
ax: matplotlib.axes.Axes, line: Line, color: Any = 'white', draw: bool = True
|
| 1068 |
) -> tuple[Point, Point]:
|
| 1069 |
"""Draw a line."""
|
| 1070 |
points = line.neighbors(gm.Point)
|
|
|
|
| 1084 |
pmax = p, v
|
| 1085 |
|
| 1086 |
p1, p2 = pmin[0], pmax[0]
|
| 1087 |
+
if draw:
|
| 1088 |
+
_draw_line(ax, p1, p2, color=color)
|
| 1089 |
+
return p1, p2
|
| 1090 |
+
else:
|
| 1091 |
+
return p1, p2
|
| 1092 |
|
| 1093 |
|
| 1094 |
def _draw_circle(
|
|
|
|
| 1322 |
) -> None:
|
| 1323 |
"""Draw highlights."""
|
| 1324 |
args = list(map(lambda x: x.num if isinstance(x, gm.Point) else x, args))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1325 |
if name == 'coll':
|
| 1326 |
a, b, c = args
|
| 1327 |
a, b = max(a, b, c), min(a, b, c)
|
|
|
|
| 1410 |
|
| 1411 |
return result
|
| 1412 |
|
| 1413 |
+
def calculate_angle(p1, p2, p3, p4):
|
| 1414 |
+
"""Calculates the angle between two lines formed by points (p1, p2) and (p3, p4) in degrees."""
|
| 1415 |
+
# Determine the common point
|
| 1416 |
+
if p2 == p3:
|
| 1417 |
+
common = p2
|
| 1418 |
+
other_points = (p1, p4)
|
| 1419 |
+
v1 = np.array([p1.x - p2.x, p1.y - p2.y])
|
| 1420 |
+
v2 = np.array([p4.x - p2.x, p4.y - p2.y])
|
| 1421 |
+
elif p1 == p4:
|
| 1422 |
+
common = p1
|
| 1423 |
+
other_points = (p2, p3)
|
| 1424 |
+
v1 = np.array([p2.x - p1.x, p2.y - p1.y])
|
| 1425 |
+
v2 = np.array([p3.x - p1.x, p3.y - p1.y])
|
| 1426 |
+
elif p1 == p3:
|
| 1427 |
+
common = p1
|
| 1428 |
+
other_points = (p2, p4)
|
| 1429 |
+
v1 = np.array([p2.x - p1.x, p2.y - p1.y])
|
| 1430 |
+
v2 = np.array([p4.x - p1.x, p4.y - p1.y])
|
| 1431 |
+
elif p2 == p4:
|
| 1432 |
+
common = p2
|
| 1433 |
+
other_points = (p1, p3)
|
| 1434 |
+
v1 = np.array([p1.x - p2.x, p1.y - p2.y])
|
| 1435 |
+
v2 = np.array([p3.x - p2.x, p3.y - p2.y])
|
| 1436 |
+
else:
|
| 1437 |
+
return None, None, None # No shared point, angle cannot be calculated
|
| 1438 |
+
|
| 1439 |
+
# Calculate the angle
|
| 1440 |
+
cos_angle = np.dot(v1, v2) / (np.linalg.norm(v1) * np.linalg.norm(v2))
|
| 1441 |
+
cos_angle = np.clip(cos_angle, -1, 1) # Ensure valid range for acos
|
| 1442 |
+
angle_rad = np.arccos(cos_angle)
|
| 1443 |
+
return common, other_points, round(np.degrees(angle_rad), 5)
|
| 1444 |
+
|
| 1445 |
+
def highlight_angle2(ax, origin, p1, p2, radius, color):
|
| 1446 |
+
"""Highlights the angle formed by two vectors meeting at 'origin'."""
|
| 1447 |
+
# Calculate angles of vectors
|
| 1448 |
+
angle1 = np.arctan2(p1.y - origin.y, p1.x - origin.x)
|
| 1449 |
+
angle2 = np.arctan2(p2.y - origin.y, p2.x - origin.x)
|
| 1450 |
+
|
| 1451 |
+
# Convert to degrees and ensure the smaller angle is first
|
| 1452 |
+
angle1_deg, angle2_deg = sorted(np.degrees([angle1, angle2]))
|
| 1453 |
+
if angle2_deg - angle1_deg > 180:
|
| 1454 |
+
angle1_deg, angle2_deg = angle2_deg, angle1_deg
|
| 1455 |
+
|
| 1456 |
+
# Draw the wedge
|
| 1457 |
+
wedge = matplotlib.patches.Wedge(
|
| 1458 |
+
center=(origin.x, origin.y),
|
| 1459 |
+
r=radius,
|
| 1460 |
+
theta1=angle1_deg,
|
| 1461 |
+
theta2=angle2_deg,
|
| 1462 |
+
color=color,
|
| 1463 |
+
alpha=0.5
|
| 1464 |
+
)
|
| 1465 |
+
ax.add_patch(wedge)
|
| 1466 |
+
# print("Angle highlighted with color:", color)
|
| 1467 |
+
|
| 1468 |
+
def search_in_dict(num, my_dict):
|
| 1469 |
+
for key in my_dict.keys():
|
| 1470 |
+
if round(key, 3) == round(num, 3):
|
| 1471 |
+
return True
|
| 1472 |
+
return False
|
| 1473 |
+
|
| 1474 |
+
def highlight_same_angle(ax, lines, color_list):
|
| 1475 |
+
"""Highlights angles formed at shared points by pairs of lines."""
|
| 1476 |
+
lines_list = [(draw_line(ax, l, draw=False)) for l in lines] # Extract points for all lines
|
| 1477 |
+
angle_color_radius = {}
|
| 1478 |
+
|
| 1479 |
+
for line1, line2 in combinations(lines_list, 2):
|
| 1480 |
+
# Calculate the angle
|
| 1481 |
+
common_point, other_points, angle = calculate_angle(*line1, *line2)
|
| 1482 |
+
if angle is None or angle > 90:
|
| 1483 |
+
continue # Skip invalid or small angles
|
| 1484 |
+
|
| 1485 |
+
if search_in_dict(angle, angle_color_radius) == False:
|
| 1486 |
+
# Assign color and radius for this unique angle
|
| 1487 |
+
color = color_list[len(angle_color_radius) % len(color_list)]
|
| 1488 |
+
radius = 0.1 + len(angle_color_radius) * 0.05
|
| 1489 |
+
angle_color_radius[round(angle,3)] = (color, radius)
|
| 1490 |
+
# print(type(angle))
|
| 1491 |
+
else:
|
| 1492 |
+
color, radius = angle_color_radius[round(angle, 3)]
|
| 1493 |
+
# print(type(angle))
|
| 1494 |
+
|
| 1495 |
+
# Highlight the angle
|
| 1496 |
+
highlight_angle2(ax, common_point, *other_points, radius, color)
|
| 1497 |
def _draw(
|
| 1498 |
ax: matplotlib.axes.Axes,
|
| 1499 |
points: list[gm.Point],
|
|
|
|
| 1558 |
|
| 1559 |
# Call the highlight function with the determined color
|
| 1560 |
highlight(ax, 'cong', [p1, p2, p3, p4], lcolor, color, color)
|
| 1561 |
+
highlight_same_angle(ax, lines, color_list=colors_highlight)
|
| 1562 |
if equals:
|
| 1563 |
for i, segs in enumerate(equals['segments']):
|
| 1564 |
color = colors[i % len(colors)]
|