nbaldwin's picture
first version FunSearch
97e363b
raw
history blame
4.08 kB
import dataclasses
from typing import Optional
from .artifacts import AbstractArtifact
import ast
from .artifacts import FunctionArtifact
@dataclasses.dataclass(frozen=True)
class Program:
"""A parsed Python program."""
# `preface` is everything from the beginning of the code till the first
# function is found.
preface: str
artifacts: list[AbstractArtifact]
def __str__(self) -> str:
program = f'{self.preface}\n' if self.preface else ''
program += '\n'.join([str(f) for f in self.artifacts])
return program
def find_artifact_index(self, artifact_name: str) -> int:
"""Returns the index of input function name."""
artifact_names = [a.name for a in self.artifacts]
count = artifact_names.count(artifact_name)
if count == 0:
raise ValueError(
f'artifact {artifact_name} does not exist in program:\n{str(self)}'
)
if count > 1:
raise ValueError(
f'artifact {artifact_name} exists more than once in program:\n'
f'{str(self)}'
)
index = artifact_names.index(artifact_name)
return index
def get_artifact(self, artifact_name: str) -> AbstractArtifact:
index = self.find_artifact_index(artifact_name)
return self.artifacts[index]
# TODO: Do this for various types of artifacts (only for functions rn)
class ProgramVisitor(ast.NodeVisitor):
"""Parses code to collect all required information to produce a `Program`.
Note that we do not store function decorators.
"""
def __init__(self, sourcecode: str):
self._codelines: list[str] = sourcecode.splitlines()
self._preface: str = ''
self._artifacts: list[AbstractArtifact] = []
self._current_artifact: Optional[str] = None
def visit_FunctionDef(self,
node: ast.FunctionDef):
"""Collects all information about the function being parsed."""
if node.col_offset == 0: # We only care about first level functions.
self._current_function = node.name
if not self._artifacts:
self._preface = '\n'.join(self._codelines[:node.lineno - 1])
function_end_line = node.end_lineno
body_start_line = node.body[0].lineno - 1
# Extract the docstring.
docstring = None
if isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value,ast.Str):
docstring = f' """{ast.literal_eval(ast.unparse(node.body[0]))}"""'
if len(node.body) > 1:
body_start_line = node.body[1].lineno - 1
else:
body_start_line = function_end_line
self._artifacts.append(FunctionArtifact(
name=node.name,
args=ast.unparse(node.args),
return_type=ast.unparse(node.returns) if node.returns else None,
docstring=docstring,
body='\n'.join(self._codelines[body_start_line:function_end_line]),
))
self.generic_visit(node)
def return_program(self) -> Program:
return Program(preface=self._preface, artifacts=self._artifacts)
def text_to_program(text: str) -> Program:
"""Returns Program object by parsing input text using Python AST."""
# We assume that the program is composed of some preface (e.g. imports,
# classes, assignments, ...) followed by a sequence of functions.
#Often happens that it returns a codeblock so remove it
if text.startswith("```python"):
text = text[9:]
if text.endswith("```"):
text = text[:-3]
tree = ast.parse(text)
visitor = ProgramVisitor(text)
visitor.visit(tree)
return visitor.return_program()
def text_to_artifact(text: str) -> AbstractArtifact:
"""Returns Function object by parsing input text using Python AST."""
program = text_to_program(text)
if len(program.artifacts) != 1:
raise ValueError(f'Only one artifact expected, got {len(program.artifacts)}'
f':\n{program.functions}')
return program.artifacts[0]