|
|
import dataclasses |
|
|
from typing import Optional |
|
|
from .artifacts import AbstractArtifact |
|
|
import ast |
|
|
from .artifacts import FunctionArtifact |
|
|
|
|
|
@dataclasses.dataclass(frozen=True) |
|
|
class Program: |
|
|
"""A parsed Python program.""" |
|
|
|
|
|
|
|
|
|
|
|
preface: str |
|
|
artifacts: list[AbstractArtifact] |
|
|
|
|
|
def __str__(self) -> str: |
|
|
program = f'{self.preface}\n' if self.preface else '' |
|
|
program += '\n'.join([str(f) for f in self.artifacts]) |
|
|
return program |
|
|
|
|
|
def find_artifact_index(self, artifact_name: str) -> int: |
|
|
"""Returns the index of input function name.""" |
|
|
|
|
|
artifact_names = [a.name for a in self.artifacts] |
|
|
count = artifact_names.count(artifact_name) |
|
|
if count == 0: |
|
|
raise ValueError( |
|
|
f'artifact {artifact_name} does not exist in program:\n{str(self)}' |
|
|
) |
|
|
if count > 1: |
|
|
raise ValueError( |
|
|
f'artifact {artifact_name} exists more than once in program:\n' |
|
|
f'{str(self)}' |
|
|
) |
|
|
index = artifact_names.index(artifact_name) |
|
|
return index |
|
|
|
|
|
def get_artifact(self, artifact_name: str) -> AbstractArtifact: |
|
|
index = self.find_artifact_index(artifact_name) |
|
|
return self.artifacts[index] |
|
|
|
|
|
|
|
|
class ProgramVisitor(ast.NodeVisitor): |
|
|
"""Parses code to collect all required information to produce a `Program`. |
|
|
|
|
|
Note that we do not store function decorators. |
|
|
""" |
|
|
|
|
|
def __init__(self, sourcecode: str): |
|
|
self._codelines: list[str] = sourcecode.splitlines() |
|
|
|
|
|
self._preface: str = '' |
|
|
self._artifacts: list[AbstractArtifact] = [] |
|
|
self._current_artifact: Optional[str] = None |
|
|
|
|
|
def visit_FunctionDef(self, |
|
|
node: ast.FunctionDef): |
|
|
"""Collects all information about the function being parsed.""" |
|
|
if node.col_offset == 0: |
|
|
self._current_function = node.name |
|
|
if not self._artifacts: |
|
|
self._preface = '\n'.join(self._codelines[:node.lineno - 1]) |
|
|
|
|
|
function_end_line = node.end_lineno |
|
|
body_start_line = node.body[0].lineno - 1 |
|
|
|
|
|
docstring = None |
|
|
if isinstance(node.body[0], ast.Expr) and isinstance(node.body[0].value,ast.Str): |
|
|
docstring = f' """{ast.literal_eval(ast.unparse(node.body[0]))}"""' |
|
|
if len(node.body) > 1: |
|
|
body_start_line = node.body[1].lineno - 1 |
|
|
else: |
|
|
body_start_line = function_end_line |
|
|
|
|
|
self._artifacts.append(FunctionArtifact( |
|
|
name=node.name, |
|
|
args=ast.unparse(node.args), |
|
|
return_type=ast.unparse(node.returns) if node.returns else None, |
|
|
docstring=docstring, |
|
|
body='\n'.join(self._codelines[body_start_line:function_end_line]), |
|
|
)) |
|
|
self.generic_visit(node) |
|
|
|
|
|
def return_program(self) -> Program: |
|
|
return Program(preface=self._preface, artifacts=self._artifacts) |
|
|
|
|
|
|
|
|
def text_to_program(text: str) -> Program: |
|
|
"""Returns Program object by parsing input text using Python AST.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if text.startswith("```python"): |
|
|
text = text[9:] |
|
|
if text.endswith("```"): |
|
|
text = text[:-3] |
|
|
|
|
|
tree = ast.parse(text) |
|
|
visitor = ProgramVisitor(text) |
|
|
visitor.visit(tree) |
|
|
return visitor.return_program() |
|
|
|
|
|
def text_to_artifact(text: str) -> AbstractArtifact: |
|
|
"""Returns Function object by parsing input text using Python AST.""" |
|
|
program = text_to_program(text) |
|
|
if len(program.artifacts) != 1: |
|
|
raise ValueError(f'Only one artifact expected, got {len(program.artifacts)}' |
|
|
f':\n{program.functions}') |
|
|
return program.artifacts[0] |