Spaces:
Building
Building
| #!/usr/bin/env python3 | |
| import argparse | |
| import copy | |
| import json | |
| import re | |
| import subprocess | |
| from enum import Enum as PyEnum | |
| from pathlib import Path | |
| from typing import Callable | |
| from urllib import request | |
| VoidFn = Callable[[], None] | |
| CHEATCODES_JSON_URL = "https://raw.githubusercontent.com/foundry-rs/foundry/master/crates/cheatcodes/assets/cheatcodes.json" | |
| OUT_PATH = "src/Vm.sol" | |
| VM_SAFE_DOC = """\ | |
| /// The `VmSafe` interface does not allow manipulation of the EVM state or other actions that may | |
| /// result in Script simulations differing from on-chain execution. It is recommended to only use | |
| /// these cheats in scripts. | |
| """ | |
| VM_DOC = """\ | |
| /// The `Vm` interface does allow manipulation of the EVM state. These are all intended to be used | |
| /// in tests, but it is not recommended to use these cheats in scripts. | |
| """ | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Generate Vm.sol based on the cheatcodes json created by Foundry") | |
| parser.add_argument( | |
| "--from", | |
| metavar="PATH", | |
| dest="path", | |
| required=False, | |
| help="path to a json file containing the Vm interface, as generated by Foundry") | |
| args = parser.parse_args() | |
| json_str = request.urlopen(CHEATCODES_JSON_URL).read().decode("utf-8") if args.path is None else Path(args.path).read_text() | |
| contract = Cheatcodes.from_json(json_str) | |
| ccs = contract.cheatcodes | |
| ccs = list(filter(lambda cc: cc.status not in ["experimental", "internal"], ccs)) | |
| ccs.sort(key=lambda cc: cc.func.id) | |
| safe = list(filter(lambda cc: cc.safety == "safe", ccs)) | |
| safe.sort(key=CmpCheatcode) | |
| unsafe = list(filter(lambda cc: cc.safety == "unsafe", ccs)) | |
| unsafe.sort(key=CmpCheatcode) | |
| assert len(safe) + len(unsafe) == len(ccs) | |
| prefix_with_group_headers(safe) | |
| prefix_with_group_headers(unsafe) | |
| out = "" | |
| out += "// Automatically @generated by scripts/vm.py. Do not modify manually.\n\n" | |
| pp = CheatcodesPrinter( | |
| spdx_identifier="MIT OR Apache-2.0", | |
| solidity_requirement=">=0.8.13 <0.9.0", | |
| ) | |
| pp.p_prelude() | |
| pp.prelude = False | |
| out += pp.finish() | |
| out += "\n\n" | |
| out += VM_SAFE_DOC | |
| vm_safe = Cheatcodes( | |
| # TODO: Custom errors were introduced in 0.8.4 | |
| errors=[], # contract.errors | |
| events=contract.events, | |
| enums=contract.enums, | |
| structs=contract.structs, | |
| cheatcodes=safe, | |
| ) | |
| pp.p_contract(vm_safe, "VmSafe") | |
| out += pp.finish() | |
| out += "\n\n" | |
| out += VM_DOC | |
| vm_unsafe = Cheatcodes( | |
| errors=[], | |
| events=[], | |
| enums=[], | |
| structs=[], | |
| cheatcodes=unsafe, | |
| ) | |
| pp.p_contract(vm_unsafe, "Vm", "VmSafe") | |
| out += pp.finish() | |
| # Compatibility with <0.8.0 | |
| def memory_to_calldata(m: re.Match) -> str: | |
| return " calldata " + m.group(1) | |
| out = re.sub(r" memory (.*returns)", memory_to_calldata, out) | |
| with open(OUT_PATH, "w") as f: | |
| f.write(out) | |
| forge_fmt = ["forge", "fmt", OUT_PATH] | |
| res = subprocess.run(forge_fmt) | |
| assert res.returncode == 0, f"command failed: {forge_fmt}" | |
| print(f"Wrote to {OUT_PATH}") | |
| class CmpCheatcode: | |
| cheatcode: "Cheatcode" | |
| def __init__(self, cheatcode: "Cheatcode"): | |
| self.cheatcode = cheatcode | |
| def __lt__(self, other: "CmpCheatcode") -> bool: | |
| return cmp_cheatcode(self.cheatcode, other.cheatcode) < 0 | |
| def __eq__(self, other: "CmpCheatcode") -> bool: | |
| return cmp_cheatcode(self.cheatcode, other.cheatcode) == 0 | |
| def __gt__(self, other: "CmpCheatcode") -> bool: | |
| return cmp_cheatcode(self.cheatcode, other.cheatcode) > 0 | |
| def cmp_cheatcode(a: "Cheatcode", b: "Cheatcode") -> int: | |
| if a.group != b.group: | |
| return -1 if a.group < b.group else 1 | |
| if a.status != b.status: | |
| return -1 if a.status < b.status else 1 | |
| if a.safety != b.safety: | |
| return -1 if a.safety < b.safety else 1 | |
| if a.func.id != b.func.id: | |
| return -1 if a.func.id < b.func.id else 1 | |
| return 0 | |
| # HACK: A way to add group header comments without having to modify printer code | |
| def prefix_with_group_headers(cheats: list["Cheatcode"]): | |
| s = set() | |
| for i, cheat in enumerate(cheats): | |
| if cheat.group in s: | |
| continue | |
| s.add(cheat.group) | |
| c = copy.deepcopy(cheat) | |
| c.func.description = "" | |
| c.func.declaration = f"// ======== {group(c.group)} ========" | |
| cheats.insert(i, c) | |
| return cheats | |
| def group(s: str) -> str: | |
| if s == "evm": | |
| return "EVM" | |
| if s == "json": | |
| return "JSON" | |
| return s[0].upper() + s[1:] | |
| class Visibility(PyEnum): | |
| EXTERNAL: str = "external" | |
| PUBLIC: str = "public" | |
| INTERNAL: str = "internal" | |
| PRIVATE: str = "private" | |
| def __str__(self): | |
| return self.value | |
| class Mutability(PyEnum): | |
| PURE: str = "pure" | |
| VIEW: str = "view" | |
| NONE: str = "" | |
| def __str__(self): | |
| return self.value | |
| class Function: | |
| id: str | |
| description: str | |
| declaration: str | |
| visibility: Visibility | |
| mutability: Mutability | |
| signature: str | |
| selector: str | |
| selector_bytes: bytes | |
| def __init__( | |
| self, | |
| id: str, | |
| description: str, | |
| declaration: str, | |
| visibility: Visibility, | |
| mutability: Mutability, | |
| signature: str, | |
| selector: str, | |
| selector_bytes: bytes, | |
| ): | |
| self.id = id | |
| self.description = description | |
| self.declaration = declaration | |
| self.visibility = visibility | |
| self.mutability = mutability | |
| self.signature = signature | |
| self.selector = selector | |
| self.selector_bytes = selector_bytes | |
| def from_dict(d: dict) -> "Function": | |
| return Function( | |
| d["id"], | |
| d["description"], | |
| d["declaration"], | |
| Visibility(d["visibility"]), | |
| Mutability(d["mutability"]), | |
| d["signature"], | |
| d["selector"], | |
| bytes(d["selectorBytes"]), | |
| ) | |
| class Cheatcode: | |
| func: Function | |
| group: str | |
| status: str | |
| safety: str | |
| def __init__(self, func: Function, group: str, status: str, safety: str): | |
| self.func = func | |
| self.group = group | |
| self.status = status | |
| self.safety = safety | |
| def from_dict(d: dict) -> "Cheatcode": | |
| return Cheatcode( | |
| Function.from_dict(d["func"]), | |
| str(d["group"]), | |
| str(d["status"]), | |
| str(d["safety"]), | |
| ) | |
| class Error: | |
| name: str | |
| description: str | |
| declaration: str | |
| def __init__(self, name: str, description: str, declaration: str): | |
| self.name = name | |
| self.description = description | |
| self.declaration = declaration | |
| def from_dict(d: dict) -> "Error": | |
| return Error(**d) | |
| class Event: | |
| name: str | |
| description: str | |
| declaration: str | |
| def __init__(self, name: str, description: str, declaration: str): | |
| self.name = name | |
| self.description = description | |
| self.declaration = declaration | |
| def from_dict(d: dict) -> "Event": | |
| return Event(**d) | |
| class EnumVariant: | |
| name: str | |
| description: str | |
| def __init__(self, name: str, description: str): | |
| self.name = name | |
| self.description = description | |
| class Enum: | |
| name: str | |
| description: str | |
| variants: list[EnumVariant] | |
| def __init__(self, name: str, description: str, variants: list[EnumVariant]): | |
| self.name = name | |
| self.description = description | |
| self.variants = variants | |
| def from_dict(d: dict) -> "Enum": | |
| return Enum( | |
| d["name"], | |
| d["description"], | |
| list(map(lambda v: EnumVariant(**v), d["variants"])), | |
| ) | |
| class StructField: | |
| name: str | |
| ty: str | |
| description: str | |
| def __init__(self, name: str, ty: str, description: str): | |
| self.name = name | |
| self.ty = ty | |
| self.description = description | |
| class Struct: | |
| name: str | |
| description: str | |
| fields: list[StructField] | |
| def __init__(self, name: str, description: str, fields: list[StructField]): | |
| self.name = name | |
| self.description = description | |
| self.fields = fields | |
| def from_dict(d: dict) -> "Struct": | |
| return Struct( | |
| d["name"], | |
| d["description"], | |
| list(map(lambda f: StructField(**f), d["fields"])), | |
| ) | |
| class Cheatcodes: | |
| errors: list[Error] | |
| events: list[Event] | |
| enums: list[Enum] | |
| structs: list[Struct] | |
| cheatcodes: list[Cheatcode] | |
| def __init__( | |
| self, | |
| errors: list[Error], | |
| events: list[Event], | |
| enums: list[Enum], | |
| structs: list[Struct], | |
| cheatcodes: list[Cheatcode], | |
| ): | |
| self.errors = errors | |
| self.events = events | |
| self.enums = enums | |
| self.structs = structs | |
| self.cheatcodes = cheatcodes | |
| def from_dict(d: dict) -> "Cheatcodes": | |
| return Cheatcodes( | |
| errors=[Error.from_dict(e) for e in d["errors"]], | |
| events=[Event.from_dict(e) for e in d["events"]], | |
| enums=[Enum.from_dict(e) for e in d["enums"]], | |
| structs=[Struct.from_dict(e) for e in d["structs"]], | |
| cheatcodes=[Cheatcode.from_dict(e) for e in d["cheatcodes"]], | |
| ) | |
| def from_json(s) -> "Cheatcodes": | |
| return Cheatcodes.from_dict(json.loads(s)) | |
| def from_json_file(file_path: str) -> "Cheatcodes": | |
| with open(file_path, "r") as f: | |
| return Cheatcodes.from_dict(json.load(f)) | |
| class Item(PyEnum): | |
| ERROR: str = "error" | |
| EVENT: str = "event" | |
| ENUM: str = "enum" | |
| STRUCT: str = "struct" | |
| FUNCTION: str = "function" | |
| class ItemOrder: | |
| _list: list[Item] | |
| def __init__(self, list: list[Item]) -> None: | |
| assert len(list) <= len(Item), "list must not contain more items than Item" | |
| assert len(list) == len(set(list)), "list must not contain duplicates" | |
| self._list = list | |
| pass | |
| def get_list(self) -> list[Item]: | |
| return self._list | |
| def default() -> "ItemOrder": | |
| return ItemOrder( | |
| [ | |
| Item.ERROR, | |
| Item.EVENT, | |
| Item.ENUM, | |
| Item.STRUCT, | |
| Item.FUNCTION, | |
| ] | |
| ) | |
| class CheatcodesPrinter: | |
| buffer: str | |
| prelude: bool | |
| spdx_identifier: str | |
| solidity_requirement: str | |
| block_doc_style: bool | |
| indent_level: int | |
| _indent_str: str | |
| nl_str: str | |
| items_order: ItemOrder | |
| def __init__( | |
| self, | |
| buffer: str = "", | |
| prelude: bool = True, | |
| spdx_identifier: str = "UNLICENSED", | |
| solidity_requirement: str = "", | |
| block_doc_style: bool = False, | |
| indent_level: int = 0, | |
| indent_with: int | str = 4, | |
| nl_str: str = "\n", | |
| items_order: ItemOrder = ItemOrder.default(), | |
| ): | |
| self.prelude = prelude | |
| self.spdx_identifier = spdx_identifier | |
| self.solidity_requirement = solidity_requirement | |
| self.block_doc_style = block_doc_style | |
| self.buffer = buffer | |
| self.indent_level = indent_level | |
| self.nl_str = nl_str | |
| if isinstance(indent_with, int): | |
| assert indent_with >= 0 | |
| self._indent_str = " " * indent_with | |
| elif isinstance(indent_with, str): | |
| self._indent_str = indent_with | |
| else: | |
| assert False, "indent_with must be int or str" | |
| self.items_order = items_order | |
| def finish(self) -> str: | |
| ret = self.buffer.rstrip() | |
| self.buffer = "" | |
| return ret | |
| def p_contract(self, contract: Cheatcodes, name: str, inherits: str = ""): | |
| if self.prelude: | |
| self.p_prelude(contract) | |
| self._p_str("interface ") | |
| name = name.strip() | |
| if name != "": | |
| self._p_str(name) | |
| self._p_str(" ") | |
| if inherits != "": | |
| self._p_str("is ") | |
| self._p_str(inherits) | |
| self._p_str(" ") | |
| self._p_str("{") | |
| self._p_nl() | |
| self._with_indent(lambda: self._p_items(contract)) | |
| self._p_str("}") | |
| self._p_nl() | |
| def _p_items(self, contract: Cheatcodes): | |
| for item in self.items_order.get_list(): | |
| if item == Item.ERROR: | |
| self.p_errors(contract.errors) | |
| elif item == Item.EVENT: | |
| self.p_events(contract.events) | |
| elif item == Item.ENUM: | |
| self.p_enums(contract.enums) | |
| elif item == Item.STRUCT: | |
| self.p_structs(contract.structs) | |
| elif item == Item.FUNCTION: | |
| self.p_functions(contract.cheatcodes) | |
| else: | |
| assert False, f"unknown item {item}" | |
| def p_prelude(self, contract: Cheatcodes | None = None): | |
| self._p_str(f"// SPDX-License-Identifier: {self.spdx_identifier}") | |
| self._p_nl() | |
| if self.solidity_requirement != "": | |
| req = self.solidity_requirement | |
| else: | |
| req = ">=0.8.13 <0.9.0" | |
| self._p_str(f"pragma solidity {req};") | |
| self._p_nl() | |
| self._p_nl() | |
| def p_errors(self, errors: list[Error]): | |
| for error in errors: | |
| self._p_line(lambda: self.p_error(error)) | |
| def p_error(self, error: Error): | |
| self._p_comment(error.description, doc=True) | |
| self._p_line(lambda: self._p_str(error.declaration)) | |
| def p_events(self, events: list[Event]): | |
| for event in events: | |
| self._p_line(lambda: self.p_event(event)) | |
| def p_event(self, event: Event): | |
| self._p_comment(event.description, doc=True) | |
| self._p_line(lambda: self._p_str(event.declaration)) | |
| def p_enums(self, enums: list[Enum]): | |
| for enum in enums: | |
| self._p_line(lambda: self.p_enum(enum)) | |
| def p_enum(self, enum: Enum): | |
| self._p_comment(enum.description, doc=True) | |
| self._p_line(lambda: self._p_str(f"enum {enum.name} {{")) | |
| self._with_indent(lambda: self.p_enum_variants(enum.variants)) | |
| self._p_line(lambda: self._p_str("}")) | |
| def p_enum_variants(self, variants: list[EnumVariant]): | |
| for i, variant in enumerate(variants): | |
| self._p_indent() | |
| self._p_comment(variant.description) | |
| self._p_indent() | |
| self._p_str(variant.name) | |
| if i < len(variants) - 1: | |
| self._p_str(",") | |
| self._p_nl() | |
| def p_structs(self, structs: list[Struct]): | |
| for struct in structs: | |
| self._p_line(lambda: self.p_struct(struct)) | |
| def p_struct(self, struct: Struct): | |
| self._p_comment(struct.description, doc=True) | |
| self._p_line(lambda: self._p_str(f"struct {struct.name} {{")) | |
| self._with_indent(lambda: self.p_struct_fields(struct.fields)) | |
| self._p_line(lambda: self._p_str("}")) | |
| def p_struct_fields(self, fields: list[StructField]): | |
| for field in fields: | |
| self._p_line(lambda: self.p_struct_field(field)) | |
| def p_struct_field(self, field: StructField): | |
| self._p_comment(field.description) | |
| self._p_indented(lambda: self._p_str(f"{field.ty} {field.name};")) | |
| def p_functions(self, cheatcodes: list[Cheatcode]): | |
| for cheatcode in cheatcodes: | |
| self._p_line(lambda: self.p_function(cheatcode.func)) | |
| def p_function(self, func: Function): | |
| self._p_comment(func.description, doc=True) | |
| self._p_line(lambda: self._p_str(func.declaration)) | |
| def _p_comment(self, s: str, doc: bool = False): | |
| s = s.strip() | |
| if s == "": | |
| return | |
| s = map(lambda line: line.lstrip(), s.split("\n")) | |
| if self.block_doc_style: | |
| self._p_str("/*") | |
| if doc: | |
| self._p_str("*") | |
| self._p_nl() | |
| for line in s: | |
| self._p_indent() | |
| self._p_str(" ") | |
| if doc: | |
| self._p_str("* ") | |
| self._p_str(line) | |
| self._p_nl() | |
| self._p_indent() | |
| self._p_str(" */") | |
| self._p_nl() | |
| else: | |
| first_line = True | |
| for line in s: | |
| if not first_line: | |
| self._p_indent() | |
| first_line = False | |
| if doc: | |
| self._p_str("/// ") | |
| else: | |
| self._p_str("// ") | |
| self._p_str(line) | |
| self._p_nl() | |
| def _with_indent(self, f: VoidFn): | |
| self._inc_indent() | |
| f() | |
| self._dec_indent() | |
| def _p_line(self, f: VoidFn): | |
| self._p_indent() | |
| f() | |
| self._p_nl() | |
| def _p_indented(self, f: VoidFn): | |
| self._p_indent() | |
| f() | |
| def _p_indent(self): | |
| for _ in range(self.indent_level): | |
| self._p_str(self._indent_str) | |
| def _p_nl(self): | |
| self._p_str(self.nl_str) | |
| def _p_str(self, txt: str): | |
| self.buffer += txt | |
| def _inc_indent(self): | |
| self.indent_level += 1 | |
| def _dec_indent(self): | |
| self.indent_level -= 1 | |
| if __name__ == "__main__": | |
| main() | |