cl-ds / src /cl_macros /ext /defmacro_parser.py
j14i's picture
977 CL macro transformation examples: CL-native pipeline with SBCL verification
d69fc90 verified
"""Extract defmacro definitions from Common Lisp source files.
Handles:
- String literals with escaped quotes
- Character literals (#\", #\\, #\(, etc.)
- Line comments (;) and block comments (#| ... |#)
- Reader macros (#' , #(, #., #+, #-, etc.)
- Nested parens
- defmacro, define-compiler-macro, defmacro!, defmacro/g!
"""
from __future__ import annotations
import re
from dataclasses import dataclass, field
from pathlib import Path
@dataclass
class DefmacroDef:
macro_name: str
args: str
body: str
full_form: str
docstring: str | None = None
line_number: int = 0
source_file: Path | None = None
form_type: str = "defmacro"
MACRO_STARTERS = [
"defmacro",
"define-compiler-macro",
"defmacro!",
"defmacro/g!",
]
class DefmacroParser:
def __init__(self):
self.errors: list[str] = []
def extract_all(self, source_text: str) -> list[DefmacroDef]:
"""Find all macro definitions in source text."""
self.errors = []
results = []
i = 0
while i < len(source_text):
form_start = self._find_next_macro_form(source_text, i)
if form_start is None:
break
form = self._extract_balanced(source_text, form_start)
if form:
parsed = self._parse_defmacro(form, form_start)
if parsed:
results.append(parsed)
i = form_start + len(form)
else:
i = form_start + 1
return results
def _find_next_macro_form(self, text: str, start: int) -> int | None:
"""Find the opening paren of the next defmacro form."""
i = start
while i < len(text):
ch = text[i]
# Skip string literals
if ch == '"':
i = self._skip_string(text, i)
continue
# Skip character literals
if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
i = self._skip_char_literal(text, i)
continue
# Skip line comments
if ch == ';':
i = self._skip_line_comment(text, i)
continue
# Skip block comments
if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
i = self._skip_block_comment(text, i)
continue
# Found opening paren — check if it's a macro form
if ch == '(':
name_start = self._find_form_name(text, i + 1)
if name_start is not None:
name = text[name_start : name_start + 30]
for starter in MACRO_STARTERS:
if name.startswith(starter):
# Verify it's a full word boundary
after = name_start + len(starter)
if after >= len(text) or text[after] in ' \t\n\r(':
return i
i += 1
continue
i += 1
return None
def _find_form_name(self, text: str, pos: int) -> int | None:
"""Skip whitespace after opening paren, return position of first non-whitespace char."""
while pos < len(text) and text[pos] in ' \t\n\r':
pos += 1
return pos if pos < len(text) else None
def _extract_balanced(self, text: str, start: int) -> str | None:
"""Extract a balanced s-expression starting at 'start'.
Returns the full form including the outer parens, or None if unbalanced.
"""
if start >= len(text) or text[start] != '(':
return None
depth = 0
i = start
while i < len(text):
ch = text[i]
if ch == '"':
i = self._skip_string(text, i)
continue
if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
i = self._skip_char_literal(text, i)
continue
if ch == ';':
i = self._skip_line_comment(text, i)
continue
if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
i = self._skip_block_comment(text, i)
continue
# Handle #( vectors and #' function refs
if ch == '#' and i + 1 < len(text) and text[i + 1] in "'(.":
i += 2 # skip # and the next char
continue
# Handle #+ and #- reader conditionals
if ch == '#' and i + 1 < len(text) and text[i + 1] in '+-':
i += 2
continue
if ch == '(':
depth += 1
elif ch == ')':
depth -= 1
if depth == 0:
return text[start : i + 1]
i += 1
return None
def _skip_string(self, text: str, start: int) -> int:
"""Skip past a string literal starting at the opening quote."""
i = start + 1 # skip opening quote
while i < len(text):
if text[i] == '\\' and i + 1 < len(text):
i += 2 # skip escaped char
continue
if text[i] == '"':
return i + 1 # skip closing quote
i += 1
return i
def _skip_char_literal(self, text: str, start: int) -> int:
"""Skip past a character literal #\\X. Returns position after the literal."""
i = start + 2 # skip #\
# The char literal is the next character (or a name like #\Space, #\Newline)
if i >= len(text):
return i
if text[i] in ' \t\n\r()':
return i # bare #\ — edge case
i += 1
# If followed by more alpha chars, it's a named char like #\Space
while i < len(text) and text[i].isalpha():
i += 1
return i
def _skip_line_comment(self, text: str, start: int) -> int:
"""Skip past a semicolon comment to end of line."""
i = start + 1
while i < len(text) and text[i] != '\n':
i += 1
return i
def _skip_block_comment(self, text: str, start: int) -> int:
"""Skip past a #| ... |# block comment. Handles nesting."""
depth = 1
i = start + 2 # skip #|
while i < len(text) - 1:
if text[i] == '|' and text[i + 1] == '#':
depth -= 1
i += 2
if depth == 0:
return i
continue
if text[i] == '#' and text[i + 1] == '|':
depth += 1
i += 2
continue
i += 1
return i
def _parse_defmacro(self, form: str, pos: int) -> DefmacroDef | None:
"""Parse a defmacro form into its components."""
line_number = form[:pos].count('\n') + 1 if pos > 0 else 1
# Extract form type and macro name
inner = self._strip_outer_parens(form)
if inner is None:
return None
parts = self._split_top_level(inner, 3)
if len(parts) < 2:
return None
form_type = parts[0].strip()
if form_type not in MACRO_STARTERS:
return None
macro_name = parts[1].strip()
if len(parts) >= 3:
remaining = parts[2].strip()
# Split args from body
args_end = self._find_args_end(remaining)
args = remaining[:args_end].strip() if args_end >= 0 else remaining
body = remaining[args_end:].strip() if args_end >= 0 else ""
else:
args = "()"
body = ""
docstring = self._extract_docstring(body)
return DefmacroDef(
macro_name=macro_name,
args=args,
body=body,
full_form=form,
docstring=docstring,
line_number=line_number,
form_type=form_type,
)
def _strip_outer_parens(self, form: str) -> str | None:
"""Remove outer parens from a balanced form."""
form = form.strip()
if not form or form[0] != '(' or form[-1] != ')':
return None
return form[1:-1].strip()
def _split_top_level(self, text: str, n: int) -> list[str]:
"""Split text into at most n parts at top-level whitespace."""
parts = []
depth = 0
start = 0
i = 0
while i < len(text) and len(parts) < n - 1:
ch = text[i]
if ch == '"':
i = self._skip_string(text, i)
continue
if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
i = self._skip_char_literal(text, i)
continue
if ch == ';':
i = self._skip_line_comment(text, i)
continue
if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
i = self._skip_block_comment(text, i)
continue
if ch == '(':
depth += 1
elif ch == ')':
depth -= 1
elif depth == 0 and ch in ' \t\n\r':
part = text[start:i].strip()
if part:
parts.append(part)
start = i + 1
i += 1
if start < len(text) and len(parts) < n:
parts.append(text[start:].strip())
return parts
def _find_args_end(self, text: str) -> int:
"""Find the end of the lambda list (args). Returns index after closing paren."""
text = text.lstrip()
if not text or text[0] != '(':
return 0
depth = 0
i = 0
while i < len(text):
ch = text[i]
if ch == '"':
i = self._skip_string(text, i)
continue
if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
i = self._skip_char_literal(text, i)
continue
if ch == ';':
i = self._skip_line_comment(text, i)
continue
if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
i = self._skip_block_comment(text, i)
continue
if ch == '(':
depth += 1
elif ch == ')':
depth -= 1
if depth == 0:
return i + 1
i += 1
return -1
def _extract_docstring(self, body: str) -> str | None:
"""Extract docstring if first form in body is a string literal."""
body = body.strip()
if body and body[0] == '"':
end = body.index('"', 1)
while end > 0 and body[end - 1] == '\\':
end = body.index('"', end + 1)
if end > 0:
return body[1:end]
return None
def extract_file(self, filepath: Path) -> list[DefmacroDef]:
"""Extract all macro definitions from a .lisp file."""
try:
text = filepath.read_text(errors="replace")
except (OSError, UnicodeDecodeError) as e:
self.errors.append(f"Failed to read {filepath}: {e}")
return []
results = self.extract_all(text)
for r in results:
r.source_file = filepath
return results