| | import os |
| | import argparse |
| |
|
| | from lats import run_lats |
| |
|
| |
|
| | def get_args(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--run_name", type=str, help="The name of the run") |
| | parser.add_argument("--root_dir", type=str, |
| | help="The root logging directory", default="root") |
| | parser.add_argument("--dataset_path", type=str, |
| | help="The path to the benchmark dataset", default="root") |
| | parser.add_argument("--strategy", type=str, |
| | help="Strategy: `simple`, `reflexion`") |
| | parser.add_argument("--language", type=str, help="Strategy: `py` or `rs`") |
| | parser.add_argument( |
| | "--model", type=str, help="OpenAI models only for now. For best results, use GPT-4") |
| | parser.add_argument("--pass_at_k", type=int, |
| | help="Pass@k metric", default=1) |
| | parser.add_argument("--max_iters", type=int, |
| | help="The maximum number of self-improvement iterations", default=10) |
| | parser.add_argument("--expansion_factor", type=int, |
| | help="The expansion factor for the reflexion UCS and A* strategy", default=3) |
| | parser.add_argument("--verbose", action='store_true', |
| | help="To print live logs") |
| | parser.add_argument("--instruction", type=str, |
| | help="text string", default="") |
| | parser.add_argument("--n_samples", type=int, |
| | help="The number of nodes added during expansion", default=3) |
| | parser.add_argument("--depth", type=int, |
| | help="Tree depth", default=5) |
| |
|
| | |
| | |
| | |
| | args = parser.parse_args() |
| | return args |
| |
|
| |
|
| | def strategy_factory(strategy: str): |
| | def kwargs_wrapper_gen(func, delete_keys=[]): |
| | def kwargs_wrapper(**kwargs): |
| | for key in delete_keys: |
| | del kwargs[key] |
| | return func(**kwargs) |
| | return kwargs_wrapper |
| |
|
| | return kwargs_wrapper_gen(run_lats, delete_keys=[]) |
| | |
| |
|
| | def lats_main(args): |
| |
|
| | |
| | run_strategy = strategy_factory(args.strategy) |
| |
|
| | |
| | |
| | x = run_strategy( |
| | model_name=args.model, |
| | language=args.language, |
| | max_iters=args.max_iters, |
| | verbose=args.verbose, |
| | instruction=args.instruction, |
| | n_samples=args.n_samples, |
| | depth=args.depth |
| | ) |
| |
|
| | return x |
| |
|
| |
|
| |
|
| | def main(args): |
| | lats_main(args) |
| |
|
| | if __name__ == "__main__": |
| | args = get_args() |
| | main(args) |
| |
|