File size: 11,359 Bytes
d69fc90
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
"""Extract defmacro definitions from Common Lisp source files.

Handles:
- String literals with escaped quotes
- Character literals (#\", #\\, #\(, etc.)
- Line comments (;) and block comments (#| ... |#)
- Reader macros (#' , #(, #., #+, #-, etc.)
- Nested parens
- defmacro, define-compiler-macro, defmacro!, defmacro/g!
"""

from __future__ import annotations

import re
from dataclasses import dataclass, field
from pathlib import Path


@dataclass
class DefmacroDef:
    macro_name: str
    args: str
    body: str
    full_form: str
    docstring: str | None = None
    line_number: int = 0
    source_file: Path | None = None
    form_type: str = "defmacro"


MACRO_STARTERS = [
    "defmacro",
    "define-compiler-macro",
    "defmacro!",
    "defmacro/g!",
]


class DefmacroParser:
    def __init__(self):
        self.errors: list[str] = []

    def extract_all(self, source_text: str) -> list[DefmacroDef]:
        """Find all macro definitions in source text."""
        self.errors = []
        results = []
        i = 0
        while i < len(source_text):
            form_start = self._find_next_macro_form(source_text, i)
            if form_start is None:
                break
            form = self._extract_balanced(source_text, form_start)
            if form:
                parsed = self._parse_defmacro(form, form_start)
                if parsed:
                    results.append(parsed)
                i = form_start + len(form)
            else:
                i = form_start + 1
        return results

    def _find_next_macro_form(self, text: str, start: int) -> int | None:
        """Find the opening paren of the next defmacro form."""
        i = start
        while i < len(text):
            ch = text[i]

            # Skip string literals
            if ch == '"':
                i = self._skip_string(text, i)
                continue

            # Skip character literals
            if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
                i = self._skip_char_literal(text, i)
                continue

            # Skip line comments
            if ch == ';':
                i = self._skip_line_comment(text, i)
                continue

            # Skip block comments
            if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
                i = self._skip_block_comment(text, i)
                continue

            # Found opening paren — check if it's a macro form
            if ch == '(':
                name_start = self._find_form_name(text, i + 1)
                if name_start is not None:
                    name = text[name_start : name_start + 30]
                    for starter in MACRO_STARTERS:
                        if name.startswith(starter):
                            # Verify it's a full word boundary
                            after = name_start + len(starter)
                            if after >= len(text) or text[after] in ' \t\n\r(':
                                return i
                i += 1
                continue

            i += 1
        return None

    def _find_form_name(self, text: str, pos: int) -> int | None:
        """Skip whitespace after opening paren, return position of first non-whitespace char."""
        while pos < len(text) and text[pos] in ' \t\n\r':
            pos += 1
        return pos if pos < len(text) else None

    def _extract_balanced(self, text: str, start: int) -> str | None:
        """Extract a balanced s-expression starting at 'start'.

        Returns the full form including the outer parens, or None if unbalanced.
        """
        if start >= len(text) or text[start] != '(':
            return None

        depth = 0
        i = start
        while i < len(text):
            ch = text[i]

            if ch == '"':
                i = self._skip_string(text, i)
                continue

            if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
                i = self._skip_char_literal(text, i)
                continue

            if ch == ';':
                i = self._skip_line_comment(text, i)
                continue

            if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
                i = self._skip_block_comment(text, i)
                continue

            # Handle #( vectors and #' function refs
            if ch == '#' and i + 1 < len(text) and text[i + 1] in "'(.":
                i += 2  # skip # and the next char
                continue

            # Handle #+ and #- reader conditionals
            if ch == '#' and i + 1 < len(text) and text[i + 1] in '+-':
                i += 2
                continue

            if ch == '(':
                depth += 1
            elif ch == ')':
                depth -= 1
                if depth == 0:
                    return text[start : i + 1]
            i += 1

        return None

    def _skip_string(self, text: str, start: int) -> int:
        """Skip past a string literal starting at the opening quote."""
        i = start + 1  # skip opening quote
        while i < len(text):
            if text[i] == '\\' and i + 1 < len(text):
                i += 2  # skip escaped char
                continue
            if text[i] == '"':
                return i + 1  # skip closing quote
            i += 1
        return i

    def _skip_char_literal(self, text: str, start: int) -> int:
        """Skip past a character literal #\\X. Returns position after the literal."""
        i = start + 2  # skip #\
        # The char literal is the next character (or a name like #\Space, #\Newline)
        if i >= len(text):
            return i
        if text[i] in ' \t\n\r()':
            return i  # bare #\ — edge case
        i += 1
        # If followed by more alpha chars, it's a named char like #\Space
        while i < len(text) and text[i].isalpha():
            i += 1
        return i

    def _skip_line_comment(self, text: str, start: int) -> int:
        """Skip past a semicolon comment to end of line."""
        i = start + 1
        while i < len(text) and text[i] != '\n':
            i += 1
        return i

    def _skip_block_comment(self, text: str, start: int) -> int:
        """Skip past a #| ... |# block comment. Handles nesting."""
        depth = 1
        i = start + 2  # skip #|
        while i < len(text) - 1:
            if text[i] == '|' and text[i + 1] == '#':
                depth -= 1
                i += 2
                if depth == 0:
                    return i
                continue
            if text[i] == '#' and text[i + 1] == '|':
                depth += 1
                i += 2
                continue
            i += 1
        return i

    def _parse_defmacro(self, form: str, pos: int) -> DefmacroDef | None:
        """Parse a defmacro form into its components."""
        line_number = form[:pos].count('\n') + 1 if pos > 0 else 1

        # Extract form type and macro name
        inner = self._strip_outer_parens(form)
        if inner is None:
            return None

        parts = self._split_top_level(inner, 3)
        if len(parts) < 2:
            return None

        form_type = parts[0].strip()
        if form_type not in MACRO_STARTERS:
            return None

        macro_name = parts[1].strip()
        if len(parts) >= 3:
            remaining = parts[2].strip()
            # Split args from body
            args_end = self._find_args_end(remaining)
            args = remaining[:args_end].strip() if args_end >= 0 else remaining
            body = remaining[args_end:].strip() if args_end >= 0 else ""
        else:
            args = "()"
            body = ""

        docstring = self._extract_docstring(body)

        return DefmacroDef(
            macro_name=macro_name,
            args=args,
            body=body,
            full_form=form,
            docstring=docstring,
            line_number=line_number,
            form_type=form_type,
        )

    def _strip_outer_parens(self, form: str) -> str | None:
        """Remove outer parens from a balanced form."""
        form = form.strip()
        if not form or form[0] != '(' or form[-1] != ')':
            return None
        return form[1:-1].strip()

    def _split_top_level(self, text: str, n: int) -> list[str]:
        """Split text into at most n parts at top-level whitespace."""
        parts = []
        depth = 0
        start = 0
        i = 0
        while i < len(text) and len(parts) < n - 1:
            ch = text[i]
            if ch == '"':
                i = self._skip_string(text, i)
                continue
            if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
                i = self._skip_char_literal(text, i)
                continue
            if ch == ';':
                i = self._skip_line_comment(text, i)
                continue
            if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
                i = self._skip_block_comment(text, i)
                continue
            if ch == '(':
                depth += 1
            elif ch == ')':
                depth -= 1
            elif depth == 0 and ch in ' \t\n\r':
                part = text[start:i].strip()
                if part:
                    parts.append(part)
                    start = i + 1
            i += 1

        if start < len(text) and len(parts) < n:
            parts.append(text[start:].strip())

        return parts

    def _find_args_end(self, text: str) -> int:
        """Find the end of the lambda list (args). Returns index after closing paren."""
        text = text.lstrip()
        if not text or text[0] != '(':
            return 0

        depth = 0
        i = 0
        while i < len(text):
            ch = text[i]
            if ch == '"':
                i = self._skip_string(text, i)
                continue
            if ch == '#' and i + 1 < len(text) and text[i + 1] == '\\':
                i = self._skip_char_literal(text, i)
                continue
            if ch == ';':
                i = self._skip_line_comment(text, i)
                continue
            if ch == '#' and i + 1 < len(text) and text[i + 1] == '|':
                i = self._skip_block_comment(text, i)
                continue
            if ch == '(':
                depth += 1
            elif ch == ')':
                depth -= 1
                if depth == 0:
                    return i + 1
            i += 1
        return -1

    def _extract_docstring(self, body: str) -> str | None:
        """Extract docstring if first form in body is a string literal."""
        body = body.strip()
        if body and body[0] == '"':
            end = body.index('"', 1)
            while end > 0 and body[end - 1] == '\\':
                end = body.index('"', end + 1)
            if end > 0:
                return body[1:end]
        return None

    def extract_file(self, filepath: Path) -> list[DefmacroDef]:
        """Extract all macro definitions from a .lisp file."""
        try:
            text = filepath.read_text(errors="replace")
        except (OSError, UnicodeDecodeError) as e:
            self.errors.append(f"Failed to read {filepath}: {e}")
            return []

        results = self.extract_all(text)
        for r in results:
            r.source_file = filepath
        return results