"""Extract defmacro definitions from Common Lisp source files. Handles: - String literals with escaped quotes - Character literals (#\", #\\, #\(, etc.) - Line comments (;) and block comments (#| ... |#) - Reader macros (#' , #(, #., #+, #-, etc.) - Nested parens - defmacro, define-compiler-macro, defmacro!, defmacro/g! """ from __future__ import annotations import re from dataclasses import dataclass, field from pathlib import Path @dataclass class DefmacroDef: macro_name: str args: str body: str full_form: str docstring: str | None = None line_number: int = 0 source_file: Path | None = None form_type: str = "defmacro" MACRO_STARTERS = [ "defmacro", "define-compiler-macro", "defmacro!", "defmacro/g!", ] class DefmacroParser: def __init__(self): self.errors: list[str] = [] def extract_all(self, source_text: str) -> list[DefmacroDef]: """Find all macro definitions in source text.""" self.errors = [] results = [] i = 0 while i < len(source_text): form_start = self._find_next_macro_form(source_text, i) if form_start is None: break form = self._extract_balanced(source_text, form_start) if form: parsed = self._parse_defmacro(form, form_start) if parsed: results.append(parsed) i = form_start + len(form) else: i = form_start + 1 return results def _find_next_macro_form(self, text: str, start: int) -> int | None: """Find the opening paren of the next defmacro form.""" i = start while i < len(text): ch = text[i] # Skip string literals if ch == '"': i = self._skip_string(text, i) continue # Skip character literals if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\': i = self._skip_char_literal(text, i) continue # Skip line comments if ch == ';': i = self._skip_line_comment(text, i) continue # Skip block comments if ch == '#' and i + 1 < len(text) and text[i + 1] == '|': i = self._skip_block_comment(text, i) continue # Found opening paren — check if it's a macro form if ch == '(': name_start = self._find_form_name(text, i + 1) if name_start is not None: name = text[name_start : name_start + 30] for starter in MACRO_STARTERS: if name.startswith(starter): # Verify it's a full word boundary after = name_start + len(starter) if after >= len(text) or text[after] in ' \t\n\r(': return i i += 1 continue i += 1 return None def _find_form_name(self, text: str, pos: int) -> int | None: """Skip whitespace after opening paren, return position of first non-whitespace char.""" while pos < len(text) and text[pos] in ' \t\n\r': pos += 1 return pos if pos < len(text) else None def _extract_balanced(self, text: str, start: int) -> str | None: """Extract a balanced s-expression starting at 'start'. Returns the full form including the outer parens, or None if unbalanced. """ if start >= len(text) or text[start] != '(': return None depth = 0 i = start while i < len(text): ch = text[i] if ch == '"': i = self._skip_string(text, i) continue if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\': i = self._skip_char_literal(text, i) continue if ch == ';': i = self._skip_line_comment(text, i) continue if ch == '#' and i + 1 < len(text) and text[i + 1] == '|': i = self._skip_block_comment(text, i) continue # Handle #( vectors and #' function refs if ch == '#' and i + 1 < len(text) and text[i + 1] in "'(.": i += 2 # skip # and the next char continue # Handle #+ and #- reader conditionals if ch == '#' and i + 1 < len(text) and text[i + 1] in '+-': i += 2 continue if ch == '(': depth += 1 elif ch == ')': depth -= 1 if depth == 0: return text[start : i + 1] i += 1 return None def _skip_string(self, text: str, start: int) -> int: """Skip past a string literal starting at the opening quote.""" i = start + 1 # skip opening quote while i < len(text): if text[i] == '\\' and i + 1 < len(text): i += 2 # skip escaped char continue if text[i] == '"': return i + 1 # skip closing quote i += 1 return i def _skip_char_literal(self, text: str, start: int) -> int: """Skip past a character literal #\\X. Returns position after the literal.""" i = start + 2 # skip #\ # The char literal is the next character (or a name like #\Space, #\Newline) if i >= len(text): return i if text[i] in ' \t\n\r()': return i # bare #\ — edge case i += 1 # If followed by more alpha chars, it's a named char like #\Space while i < len(text) and text[i].isalpha(): i += 1 return i def _skip_line_comment(self, text: str, start: int) -> int: """Skip past a semicolon comment to end of line.""" i = start + 1 while i < len(text) and text[i] != '\n': i += 1 return i def _skip_block_comment(self, text: str, start: int) -> int: """Skip past a #| ... |# block comment. Handles nesting.""" depth = 1 i = start + 2 # skip #| while i < len(text) - 1: if text[i] == '|' and text[i + 1] == '#': depth -= 1 i += 2 if depth == 0: return i continue if text[i] == '#' and text[i + 1] == '|': depth += 1 i += 2 continue i += 1 return i def _parse_defmacro(self, form: str, pos: int) -> DefmacroDef | None: """Parse a defmacro form into its components.""" line_number = form[:pos].count('\n') + 1 if pos > 0 else 1 # Extract form type and macro name inner = self._strip_outer_parens(form) if inner is None: return None parts = self._split_top_level(inner, 3) if len(parts) < 2: return None form_type = parts[0].strip() if form_type not in MACRO_STARTERS: return None macro_name = parts[1].strip() if len(parts) >= 3: remaining = parts[2].strip() # Split args from body args_end = self._find_args_end(remaining) args = remaining[:args_end].strip() if args_end >= 0 else remaining body = remaining[args_end:].strip() if args_end >= 0 else "" else: args = "()" body = "" docstring = self._extract_docstring(body) return DefmacroDef( macro_name=macro_name, args=args, body=body, full_form=form, docstring=docstring, line_number=line_number, form_type=form_type, ) def _strip_outer_parens(self, form: str) -> str | None: """Remove outer parens from a balanced form.""" form = form.strip() if not form or form[0] != '(' or form[-1] != ')': return None return form[1:-1].strip() def _split_top_level(self, text: str, n: int) -> list[str]: """Split text into at most n parts at top-level whitespace.""" parts = [] depth = 0 start = 0 i = 0 while i < len(text) and len(parts) < n - 1: ch = text[i] if ch == '"': i = self._skip_string(text, i) continue if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\': i = self._skip_char_literal(text, i) continue if ch == ';': i = self._skip_line_comment(text, i) continue if ch == '#' and i + 1 < len(text) and text[i + 1] == '|': i = self._skip_block_comment(text, i) continue if ch == '(': depth += 1 elif ch == ')': depth -= 1 elif depth == 0 and ch in ' \t\n\r': part = text[start:i].strip() if part: parts.append(part) start = i + 1 i += 1 if start < len(text) and len(parts) < n: parts.append(text[start:].strip()) return parts def _find_args_end(self, text: str) -> int: """Find the end of the lambda list (args). Returns index after closing paren.""" text = text.lstrip() if not text or text[0] != '(': return 0 depth = 0 i = 0 while i < len(text): ch = text[i] if ch == '"': i = self._skip_string(text, i) continue if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\': i = self._skip_char_literal(text, i) continue if ch == ';': i = self._skip_line_comment(text, i) continue if ch == '#' and i + 1 < len(text) and text[i + 1] == '|': i = self._skip_block_comment(text, i) continue if ch == '(': depth += 1 elif ch == ')': depth -= 1 if depth == 0: return i + 1 i += 1 return -1 def _extract_docstring(self, body: str) -> str | None: """Extract docstring if first form in body is a string literal.""" body = body.strip() if body and body[0] == '"': end = body.index('"', 1) while end > 0 and body[end - 1] == '\\': end = body.index('"', end + 1) if end > 0: return body[1:end] return None def extract_file(self, filepath: Path) -> list[DefmacroDef]: """Extract all macro definitions from a .lisp file.""" try: text = filepath.read_text(errors="replace") except (OSError, UnicodeDecodeError) as e: self.errors.append(f"Failed to read {filepath}: {e}") return [] results = self.extract_all(text) for r in results: r.source_file = filepath return results