|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Run DD+AR or AlphaGeometry solver.
|
|
|
|
|
|
Please refer to README.md for detailed instructions.
|
|
|
"""
|
|
|
|
|
|
import traceback
|
|
|
|
|
|
from absl import app
|
|
|
from absl import flags
|
|
|
from absl import logging
|
|
|
import ddar
|
|
|
import graph as gh
|
|
|
import lm_inference as lm
|
|
|
import pretty as pt
|
|
|
import problem as pr
|
|
|
|
|
|
|
|
|
_GIN_SEARCH_PATHS = flags.DEFINE_list(
|
|
|
'gin_search_paths',
|
|
|
['third_party/py/meliad/transformer/configs'],
|
|
|
'List of paths where the Gin config files are located.',
|
|
|
)
|
|
|
_GIN_FILE = flags.DEFINE_multi_string(
|
|
|
'gin_file', ['base_htrans.gin'], 'List of Gin config files.'
|
|
|
)
|
|
|
_GIN_PARAM = flags.DEFINE_multi_string(
|
|
|
'gin_param', None, 'Newline separated list of Gin parameter bindings.'
|
|
|
)
|
|
|
|
|
|
_PROBLEMS_FILE = flags.DEFINE_string(
|
|
|
'problems_file',
|
|
|
'imo_ag_30.txt',
|
|
|
'text file contains the problem strings. See imo_ag_30.txt for example.',
|
|
|
)
|
|
|
_PROBLEM_NAME = flags.DEFINE_string(
|
|
|
'problem_name',
|
|
|
'imo_2000_p1',
|
|
|
'name of the problem to solve, must be in the problem_file.',
|
|
|
)
|
|
|
_MODE = flags.DEFINE_string(
|
|
|
'mode', 'ddar', 'either `ddar` (DD+AR) or `alphageometry`')
|
|
|
_DEFS_FILE = flags.DEFINE_string(
|
|
|
'defs_file',
|
|
|
'defs.txt',
|
|
|
'definitions of available constructions to state a problem.',
|
|
|
)
|
|
|
_RULES_FILE = flags.DEFINE_string(
|
|
|
'rules_file', 'rules.txt', 'list of deduction rules used by DD.'
|
|
|
)
|
|
|
_CKPT_PATH = flags.DEFINE_string('ckpt_path', '', 'checkpoint of the LM model.')
|
|
|
_VOCAB_PATH = flags.DEFINE_string(
|
|
|
'vocab_path', '', 'path to the LM vocab file.'
|
|
|
)
|
|
|
_OUT_FILE = flags.DEFINE_string(
|
|
|
'out_file', '', 'path to the solution output file.'
|
|
|
)
|
|
|
_BEAM_SIZE = flags.DEFINE_integer(
|
|
|
'beam_size', 1, 'beam size of the proof search.'
|
|
|
)
|
|
|
_SEARCH_DEPTH = flags.DEFINE_integer(
|
|
|
'search_depth', 1, 'search depth of the proof search.'
|
|
|
)
|
|
|
|
|
|
DEFINITIONS = None
|
|
|
RULES = None
|
|
|
|
|
|
|
|
|
def natural_language_statement(logical_statement: pr.Dependency) -> str:
|
|
|
"""Convert logical_statement to natural language.
|
|
|
|
|
|
Args:
|
|
|
logical_statement: pr.Dependency with .name and .args
|
|
|
|
|
|
Returns:
|
|
|
a string of (pseudo) natural language of the predicate for human reader.
|
|
|
"""
|
|
|
names = [a.name.upper() for a in logical_statement.args]
|
|
|
names = [(n[0] + '_' + n[1:]) if len(n) > 1 else n for n in names]
|
|
|
return pt.pretty_nl(logical_statement.name, names)
|
|
|
|
|
|
|
|
|
def proof_step_string(
|
|
|
proof_step: pr.Dependency, refs: dict[tuple[str, ...], int], last_step: bool
|
|
|
) -> str:
|
|
|
"""Translate proof to natural language.
|
|
|
|
|
|
Args:
|
|
|
proof_step: pr.Dependency with .name and .args
|
|
|
refs: dict(hash: int) to keep track of derived predicates
|
|
|
last_step: boolean to keep track whether this is the last step.
|
|
|
|
|
|
Returns:
|
|
|
a string of (pseudo) natural language of the proof step for human reader.
|
|
|
"""
|
|
|
premises, [conclusion] = proof_step
|
|
|
|
|
|
premises_nl = ' & '.join(
|
|
|
[
|
|
|
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
|
|
|
for p in premises
|
|
|
]
|
|
|
)
|
|
|
|
|
|
if not premises:
|
|
|
premises_nl = 'similarly'
|
|
|
|
|
|
refs[conclusion.hashed()] = len(refs)
|
|
|
|
|
|
conclusion_nl = natural_language_statement(conclusion)
|
|
|
if not last_step:
|
|
|
conclusion_nl += ' [{:02}]'.format(refs[conclusion.hashed()])
|
|
|
|
|
|
return f'{premises_nl} \u21d2 {conclusion_nl}'
|
|
|
|
|
|
|
|
|
def write_solution(g: gh.Graph, p: pr.Problem, out_file: str) -> None:
|
|
|
"""Output the solution to out_file.
|
|
|
|
|
|
Args:
|
|
|
g: gh.Graph object, containing the proof state.
|
|
|
p: pr.Problem object, containing the theorem.
|
|
|
out_file: file to write to, empty string to skip writing to file.
|
|
|
"""
|
|
|
setup, aux, proof_steps, refs = ddar.get_proof_steps(
|
|
|
g, p.goal, merge_trivials=False
|
|
|
)
|
|
|
|
|
|
solution = '\n=========================='
|
|
|
solution += '\n * From theorem premises:\n'
|
|
|
premises_nl = []
|
|
|
for premises, [points] in setup:
|
|
|
solution += ' '.join([p.name.upper() for p in points]) + ' '
|
|
|
if not premises:
|
|
|
continue
|
|
|
premises_nl += [
|
|
|
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
|
|
|
for p in premises
|
|
|
]
|
|
|
solution += ': Points\n' + '\n'.join(premises_nl)
|
|
|
|
|
|
solution += '\n\n * Auxiliary Constructions:\n'
|
|
|
aux_premises_nl = []
|
|
|
for premises, [points] in aux:
|
|
|
solution += ' '.join([p.name.upper() for p in points]) + ' '
|
|
|
aux_premises_nl += [
|
|
|
natural_language_statement(p) + ' [{:02}]'.format(refs[p.hashed()])
|
|
|
for p in premises
|
|
|
]
|
|
|
solution += ': Points\n' + '\n'.join(aux_premises_nl)
|
|
|
|
|
|
|
|
|
r2name = {
|
|
|
'r32': '(SSS)',
|
|
|
'r33': '(SAS)',
|
|
|
'r34': '(Similar Triangles)',
|
|
|
'r35': '(Similar Triangles)',
|
|
|
'r36': '(ASA)',
|
|
|
'r37': '(ASA)',
|
|
|
'r38': '(Similar Triangles)',
|
|
|
'r39': '(Similar Triangles)',
|
|
|
'r40': '(Congruent Triangles)',
|
|
|
'a00': '(Distance chase)',
|
|
|
'a01': '(Ratio chase)',
|
|
|
'a02': '(Angle chase)',
|
|
|
}
|
|
|
|
|
|
solution += '\n\n * Proof steps:\n'
|
|
|
for i, step in enumerate(proof_steps):
|
|
|
_, [con] = step
|
|
|
nl = proof_step_string(step, refs, last_step=i == len(proof_steps) - 1)
|
|
|
rule_name = r2name.get(con.rule_name, '')
|
|
|
nl = nl.replace('\u21d2', f'{rule_name}\u21d2 ')
|
|
|
solution += '{:03}. '.format(i + 1) + nl + '\n'
|
|
|
|
|
|
solution += '==========================\n'
|
|
|
logging.info(solution)
|
|
|
if out_file:
|
|
|
with open(out_file, 'w') as f:
|
|
|
f.write(solution)
|
|
|
logging.info('Solution written to %s.', out_file)
|
|
|
|
|
|
|
|
|
def get_lm(ckpt_init: str, vocab_path: str) -> lm.LanguageModelInference:
|
|
|
lm.parse_gin_configuration(
|
|
|
_GIN_FILE.value, _GIN_PARAM.value, gin_paths=_GIN_SEARCH_PATHS.value
|
|
|
)
|
|
|
|
|
|
return lm.LanguageModelInference(vocab_path, ckpt_init, mode='beam_search')
|
|
|
|
|
|
|
|
|
def run_ddar(g: gh.Graph, p: pr.Problem, out_file: str) -> bool:
|
|
|
"""Run DD+AR.
|
|
|
|
|
|
Args:
|
|
|
g: gh.Graph object, containing the proof state.
|
|
|
p: pr.Problem object, containing the problem statement.
|
|
|
out_file: path to output file if solution is found.
|
|
|
|
|
|
Returns:
|
|
|
Boolean, whether DD+AR finishes successfully.
|
|
|
"""
|
|
|
ddar.solve(g, RULES, p, max_level=1000)
|
|
|
|
|
|
goal_args = g.names2nodes(p.goal.args)
|
|
|
if not g.check(p.goal.name, goal_args):
|
|
|
logging.info('DD+AR failed to solve the problem.')
|
|
|
return False
|
|
|
|
|
|
write_solution(g, p, out_file)
|
|
|
|
|
|
gh.nm.draw(
|
|
|
g.type2nodes[gh.Point],
|
|
|
g.type2nodes[gh.Line],
|
|
|
g.type2nodes[gh.Circle],
|
|
|
g.type2nodes[gh.Segment])
|
|
|
return True
|
|
|
|
|
|
|
|
|
def translate_constrained_to_constructive(
|
|
|
point: str, name: str, args: list[str]
|
|
|
) -> tuple[str, list[str]]:
|
|
|
"""Translate a predicate from constraint-based to construction-based.
|
|
|
|
|
|
Args:
|
|
|
point: str: name of the new point
|
|
|
name: str: name of the predicate, e.g., perp, para, etc.
|
|
|
args: list[str]: list of predicate args.
|
|
|
|
|
|
Returns:
|
|
|
(name, args): translated to constructive predicate.
|
|
|
"""
|
|
|
if name in ['T', 'perp']:
|
|
|
a, b, c, d = args
|
|
|
if point in [c, d]:
|
|
|
a, b, c, d = c, d, a, b
|
|
|
if point == b:
|
|
|
a, b = b, a
|
|
|
if point == d:
|
|
|
c, d = d, c
|
|
|
if a == c and a == point:
|
|
|
return 'on_dia', [a, b, d]
|
|
|
return 'on_tline', [a, b, c, d]
|
|
|
|
|
|
elif name in ['P', 'para']:
|
|
|
a, b, c, d = args
|
|
|
if point in [c, d]:
|
|
|
a, b, c, d = c, d, a, b
|
|
|
if point == b:
|
|
|
a, b = b, a
|
|
|
return 'on_pline', [a, b, c, d]
|
|
|
|
|
|
elif name in ['D', 'cong']:
|
|
|
a, b, c, d = args
|
|
|
if point in [c, d]:
|
|
|
a, b, c, d = c, d, a, b
|
|
|
if point == b:
|
|
|
a, b = b, a
|
|
|
if point == d:
|
|
|
c, d = d, c
|
|
|
if a == c and a == point:
|
|
|
return 'on_bline', [a, b, d]
|
|
|
if b in [c, d]:
|
|
|
if b == d:
|
|
|
c, d = d, c
|
|
|
return 'on_circle', [a, b, d]
|
|
|
return 'eqdistance', [a, b, c, d]
|
|
|
|
|
|
elif name in ['C', 'coll']:
|
|
|
a, b, c = args
|
|
|
if point == b:
|
|
|
a, b = b, a
|
|
|
if point == c:
|
|
|
a, b, c = c, a, b
|
|
|
return 'on_line', [a, b, c]
|
|
|
|
|
|
elif name in ['^', 'eqangle']:
|
|
|
a, b, c, d, e, f = args
|
|
|
|
|
|
if point in [d, e, f]:
|
|
|
a, b, c, d, e, f = d, e, f, a, b, c
|
|
|
|
|
|
x, b, y, c, d = b, c, e, d, f
|
|
|
if point == b:
|
|
|
a, b, c, d = b, a, d, c
|
|
|
|
|
|
if point == d and x == y:
|
|
|
return 'angle_bisector', [point, b, x, c]
|
|
|
|
|
|
if point == x:
|
|
|
return 'eqangle3', [x, a, b, y, c, d]
|
|
|
|
|
|
return 'on_aline', [a, x, b, c, y, d]
|
|
|
|
|
|
elif name in ['cyclic', 'O']:
|
|
|
a, b, c = [x for x in args if x != point]
|
|
|
return 'on_circum', [point, a, b, c]
|
|
|
|
|
|
return name, args
|
|
|
|
|
|
|
|
|
def check_valid_args(name: str, args: list[str]) -> bool:
|
|
|
"""Check whether a predicate is grammarically correct.
|
|
|
|
|
|
Args:
|
|
|
name: str: name of the predicate
|
|
|
args: list[str]: args of the predicate
|
|
|
|
|
|
Returns:
|
|
|
bool: whether the predicate arg count is valid.
|
|
|
"""
|
|
|
if name == 'perp':
|
|
|
if len(args) != 4:
|
|
|
return False
|
|
|
a, b, c, d = args
|
|
|
if len({a, b}) < 2:
|
|
|
return False
|
|
|
if len({c, d}) < 2:
|
|
|
return False
|
|
|
elif name == 'para':
|
|
|
if len(args) != 4:
|
|
|
return False
|
|
|
a, b, c, d = args
|
|
|
if len({a, b, c, d}) < 4:
|
|
|
return False
|
|
|
elif name == 'cong':
|
|
|
if len(args) != 4:
|
|
|
return False
|
|
|
a, b, c, d = args
|
|
|
if len({a, b}) < 2:
|
|
|
return False
|
|
|
if len({c, d}) < 2:
|
|
|
return False
|
|
|
elif name == 'coll':
|
|
|
if len(args) != 3:
|
|
|
return False
|
|
|
a, b, c = args
|
|
|
if len({a, b, c}) < 3:
|
|
|
return False
|
|
|
elif name == 'cyclic':
|
|
|
if len(args) != 4:
|
|
|
return False
|
|
|
a, b, c, d = args
|
|
|
if len({a, b, c, d}) < 4:
|
|
|
return False
|
|
|
elif name == 'eqangle':
|
|
|
if len(args) != 8:
|
|
|
return False
|
|
|
a, b, c, d, e, f, g, h = args
|
|
|
if len({a, b, c, d}) < 3:
|
|
|
return False
|
|
|
if len({e, f, g, h}) < 3:
|
|
|
return False
|
|
|
return True
|
|
|
|
|
|
|
|
|
def try_translate_constrained_to_construct(string: str, g: gh.Graph) -> str:
|
|
|
"""Whether a string of aux construction can be constructed.
|
|
|
|
|
|
Args:
|
|
|
string: str: the string describing aux construction.
|
|
|
g: gh.Graph: the current proof state.
|
|
|
|
|
|
Returns:
|
|
|
str: whether this construction is valid. If not, starts with "ERROR:".
|
|
|
"""
|
|
|
if string[-1] != ';':
|
|
|
return 'ERROR: must end with ;'
|
|
|
|
|
|
head, prem_str = string.split(' : ')
|
|
|
point = head.strip()
|
|
|
|
|
|
if len(point) != 1 or point == ' ':
|
|
|
return f'ERROR: invalid point name {point}'
|
|
|
|
|
|
existing_points = [p.name for p in g.all_points()]
|
|
|
if point in existing_points:
|
|
|
return f'ERROR: point {point} already exists.'
|
|
|
|
|
|
prem_toks = prem_str.split()[:-1]
|
|
|
prems = [[]]
|
|
|
|
|
|
for i, tok in enumerate(prem_toks):
|
|
|
if tok.isdigit():
|
|
|
if i < len(prem_toks) - 1:
|
|
|
prems.append([])
|
|
|
else:
|
|
|
prems[-1].append(tok)
|
|
|
|
|
|
if len(prems) > 2:
|
|
|
return 'ERROR: there cannot be more than two predicates.'
|
|
|
|
|
|
clause_txt = point + ' = '
|
|
|
constructions = []
|
|
|
|
|
|
for prem in prems:
|
|
|
name, *args = prem
|
|
|
|
|
|
if point not in args:
|
|
|
return f'ERROR: {point} not found in predicate args.'
|
|
|
|
|
|
if not check_valid_args(pt.map_symbol(name), args):
|
|
|
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
|
|
|
|
|
|
for a in args:
|
|
|
if a != point and a not in existing_points:
|
|
|
return f'ERROR: point {a} does not exist.'
|
|
|
|
|
|
try:
|
|
|
name, args = translate_constrained_to_constructive(point, name, args)
|
|
|
except:
|
|
|
return 'ERROR: Invalid predicate ' + name + ' ' + ' '.join(args)
|
|
|
|
|
|
if name == 'on_aline':
|
|
|
if args.count(point) > 1:
|
|
|
return f'ERROR: on_aline involves twice {point}'
|
|
|
|
|
|
constructions += [name + ' ' + ' '.join(args)]
|
|
|
|
|
|
clause_txt += ', '.join(constructions)
|
|
|
clause = pr.Clause.from_txt(clause_txt)
|
|
|
|
|
|
try:
|
|
|
g.copy().add_clause(clause, 0, DEFINITIONS)
|
|
|
except:
|
|
|
return 'ERROR: ' + traceback.format_exc()
|
|
|
|
|
|
return clause_txt
|
|
|
|
|
|
|
|
|
def insert_aux_to_premise(pstring: str, auxstring: str) -> str:
|
|
|
"""Insert auxiliary constructs from proof to premise.
|
|
|
|
|
|
Args:
|
|
|
pstring: str: describing the problem to solve.
|
|
|
auxstring: str: describing the auxiliar construction.
|
|
|
|
|
|
Returns:
|
|
|
str: new pstring with auxstring inserted before the conclusion.
|
|
|
"""
|
|
|
setup, goal = pstring.split(' ? ')
|
|
|
return setup + '; ' + auxstring + ' ? ' + goal
|
|
|
|
|
|
|
|
|
class BeamQueue:
|
|
|
"""Keep only the top k objects according to their values."""
|
|
|
|
|
|
def __init__(self, max_size: int = 512):
|
|
|
self.queue = []
|
|
|
self.max_size = max_size
|
|
|
|
|
|
def add(self, node: object, val: float) -> None:
|
|
|
"""Add a new node to this queue."""
|
|
|
|
|
|
if len(self.queue) < self.max_size:
|
|
|
self.queue.append((val, node))
|
|
|
return
|
|
|
|
|
|
|
|
|
min_idx, (min_val, _) = min(enumerate(self.queue), key=lambda x: x[1])
|
|
|
|
|
|
|
|
|
if val > min_val:
|
|
|
self.queue[min_idx] = (val, node)
|
|
|
|
|
|
def __iter__(self):
|
|
|
for val, node in self.queue:
|
|
|
yield val, node
|
|
|
|
|
|
def __len__(self) -> int:
|
|
|
return len(self.queue)
|
|
|
|
|
|
|
|
|
def run_alphageometry(
|
|
|
model: lm.LanguageModelInference,
|
|
|
p: pr.Problem,
|
|
|
search_depth: int,
|
|
|
beam_size: int,
|
|
|
out_file: str,
|
|
|
) -> bool:
|
|
|
"""Simplified code to run AlphaGeometry proof search.
|
|
|
|
|
|
We removed all optimizations that are infrastructure-dependent, e.g.
|
|
|
parallelized model inference on multi GPUs,
|
|
|
parallelized DD+AR on multiple CPUs,
|
|
|
parallel execution of LM and DD+AR,
|
|
|
shared pool of CPU workers across different problems, etc.
|
|
|
|
|
|
Many other speed optimizations and abstractions are also removed to
|
|
|
better present the core structure of the proof search.
|
|
|
|
|
|
Args:
|
|
|
model: Interface with inference-related endpoints to JAX's model.
|
|
|
p: pr.Problem object describing the problem to solve.
|
|
|
search_depth: max proof search depth.
|
|
|
beam_size: beam size of the proof search.
|
|
|
out_file: path to output file if solution is found.
|
|
|
|
|
|
Returns:
|
|
|
boolean of whether this is solved.
|
|
|
"""
|
|
|
|
|
|
string = p.setup_str_from_problem(DEFINITIONS)
|
|
|
|
|
|
string += ' {F1} x00'
|
|
|
|
|
|
g, _ = gh.Graph.build_problem(p, DEFINITIONS)
|
|
|
|
|
|
|
|
|
if run_ddar(g, p, out_file):
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beam_queue = BeamQueue(max_size=beam_size)
|
|
|
|
|
|
beam_queue.add(
|
|
|
node=(g, string, p.txt()), val=0.0
|
|
|
)
|
|
|
|
|
|
for depth in range(search_depth):
|
|
|
logging.info(
|
|
|
'Depth %s. There are %i nodes to expand:', depth, len(beam_queue)
|
|
|
)
|
|
|
for _, (_, string, _) in beam_queue:
|
|
|
logging.info(string)
|
|
|
|
|
|
new_queue = BeamQueue(max_size=beam_size)
|
|
|
|
|
|
for prev_score, (g, string, pstring) in beam_queue:
|
|
|
logging.info('Decoding from %s', string)
|
|
|
outputs = model.beam_decode(string, eos_tokens=[';'])
|
|
|
|
|
|
|
|
|
|
|
|
translations = [
|
|
|
try_translate_constrained_to_construct(o, g)
|
|
|
for o in outputs['seqs_str']
|
|
|
]
|
|
|
|
|
|
|
|
|
candidates = zip(outputs['seqs_str'], translations, outputs['scores'])
|
|
|
|
|
|
|
|
|
candidates = reversed(list(candidates))
|
|
|
|
|
|
for lm_out, translation, score in candidates:
|
|
|
logging.info('LM output (score=%f): "%s"', score, lm_out)
|
|
|
logging.info('Translation: "%s"\n', translation)
|
|
|
|
|
|
if translation.startswith('ERROR:'):
|
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
candidate_pstring = insert_aux_to_premise(pstring, translation)
|
|
|
|
|
|
logging.info('Solving: "%s"', candidate_pstring)
|
|
|
p_new = pr.Problem.from_txt(candidate_pstring)
|
|
|
|
|
|
|
|
|
g_new, _ = gh.Graph.build_problem(p_new, DEFINITIONS)
|
|
|
if run_ddar(g_new, p_new, out_file):
|
|
|
logging.info('Solved.')
|
|
|
return True
|
|
|
|
|
|
|
|
|
new_queue.add(
|
|
|
|
|
|
|
|
|
node=(g_new, string + ' ' + lm_out + ' x00', candidate_pstring),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
val=prev_score + score,
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
beam_queue = new_queue
|
|
|
|
|
|
return False
|
|
|
|
|
|
|
|
|
def main(_):
|
|
|
global DEFINITIONS
|
|
|
global RULES
|
|
|
|
|
|
|
|
|
DEFINITIONS = pr.Definition.from_txt_file(_DEFS_FILE.value, to_dict=True)
|
|
|
|
|
|
RULES = pr.Theorem.from_txt_file(_RULES_FILE.value, to_dict=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
need_rename = _MODE.value != 'ddar'
|
|
|
|
|
|
|
|
|
problems = pr.Problem.from_txt_file(
|
|
|
_PROBLEMS_FILE.value, to_dict=True, translate=need_rename
|
|
|
)
|
|
|
|
|
|
if _PROBLEM_NAME.value not in problems:
|
|
|
raise ValueError(
|
|
|
f'Problem name `{_PROBLEM_NAME.value}` '
|
|
|
+ f'not found in `{_PROBLEMS_FILE.value}`'
|
|
|
)
|
|
|
|
|
|
this_problem = problems[_PROBLEM_NAME.value]
|
|
|
|
|
|
if _MODE.value == 'ddar':
|
|
|
g, _ = gh.Graph.build_problem(this_problem, DEFINITIONS)
|
|
|
run_ddar(g, this_problem, _OUT_FILE.value)
|
|
|
|
|
|
elif _MODE.value == 'alphageometry':
|
|
|
model = get_lm(_CKPT_PATH.value, _VOCAB_PATH.value)
|
|
|
run_alphageometry(
|
|
|
model,
|
|
|
this_problem,
|
|
|
_SEARCH_DEPTH.value,
|
|
|
_BEAM_SIZE.value,
|
|
|
_OUT_FILE.value,
|
|
|
)
|
|
|
|
|
|
else:
|
|
|
raise ValueError(f'Unknown FLAGS.mode: {_MODE.value}')
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
app.run(main)
|
|
|
|