File size: 6,829 Bytes
acf77ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
from __future__ import annotations

import ast
import importlib
import importlib.util
import logging
from typing import Literal

from pydantic import BaseModel, ConfigDict

_log = logging.getLogger(__name__)


# ---------------------------------------------------------------------------
# Models
# ---------------------------------------------------------------------------


class Symbol(BaseModel):
    """A single symbol extracted from source code by AST walking."""

    model_config = ConfigDict(frozen=True)
    module: str
    attr: str | None
    kind: Literal["import", "attribute"]
    resolved: bool
    line: int


class GroundingReport(BaseModel):
    """Result of grounding analysis on source code."""

    model_config = ConfigDict(frozen=True)
    total_symbols: int
    grounded: tuple[Symbol, ...]
    ungrounded: tuple[Symbol, ...]
    groundedness: float


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------


def _module_spec(name: str) -> bool:
    """Return True if the module can be found by the import system."""
    try:
        return importlib.util.find_spec(name) is not None
    except (ImportError, ValueError, ModuleNotFoundError):
        return False


def _has_attr(module_name: str, attr: str) -> bool:
    """Check if *module_name* exposes *attr*.

    Uses the FULL module path (e.g. ``os.path``) β€” not just
    the top-level package.  This is the fix for SYSTEM_DESIGN Β§4.8.3
    bug #3.
    """
    try:
        mod = importlib.import_module(module_name)
    except Exception:
        return False
    return hasattr(mod, attr)


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


def ground(
    source: str,
    *,
    local_modules: frozenset[str] = frozenset(),
) -> GroundingReport:
    """AST-parse *source*, check every import and attribute access resolves.

    Three fixes baked in from day one (SYSTEM_DESIGN Β§4.8.3):
    1. SyntaxError β†’ groundedness=0.0  (was 1.0)
    2. Zero symbols β†’ groundedness=0.5 (was 1.0)
    3. Attribute resolution against full module path (was top-level only)

    *local_modules*: set of module names (e.g. ``{"core", "main"}``) that are
    local to the agent's project and should be treated as grounded even though
    ``importlib.util.find_spec`` cannot resolve them from the grader process.
    """
    # ----- parse --------------------------------------------------------
    try:
        tree = ast.parse(source)
    except SyntaxError:
        # FIX 1: unparseable code β†’ 0.0, not 1.0
        return GroundingReport(
            total_symbols=0,
            grounded=(),
            ungrounded=(),
            groundedness=0.0,
        )

    symbols: list[Symbol] = []
    import_to_module: dict[str, str] = {}

    # ----- walk imports -------------------------------------------------
    for node in ast.walk(tree):
        if isinstance(node, ast.Import):
            for alias in node.names:
                pkg = alias.name.split(".")[0]
                # Local modules are always treated as grounded
                resolved = (
                    pkg in local_modules or _module_spec(alias.name)
                )
                symbols.append(
                    Symbol(
                        module=alias.name,
                        attr=None,
                        kind="import",
                        resolved=resolved,
                        line=node.lineno,
                    )
                )
                import_to_module[alias.asname or pkg] = alias.name

        elif isinstance(node, ast.ImportFrom):
            if node.level != 0 or node.module is None:
                continue
            mod_top = node.module.split(".")[0]
            is_local = mod_top in local_modules
            resolved_mod = is_local or _module_spec(node.module)
            for alias in (node.names or []):
                attr_resolved = resolved_mod if is_local else (
                    resolved_mod and _has_attr(node.module, alias.name)
                )
                symbols.append(
                    Symbol(
                        module=node.module,
                        attr=alias.name,
                        kind="import",
                        resolved=attr_resolved,
                        line=node.lineno,
                    )
                )

    # ----- walk attribute accesses --------------------------------------
    for node in ast.walk(tree):
        if not isinstance(node, ast.Attribute):
            continue

        # Resolve the chain: e.g. os.path.join β†’ base="os", chain=["path"], attr="join"
        chain: list[str] = []
        cursor: ast.expr = node.value
        while isinstance(cursor, ast.Attribute):
            chain.append(cursor.attr)
            cursor = cursor.value
        if not isinstance(cursor, ast.Name):
            continue

        base = cursor.id
        mod_name = import_to_module.get(base)
        if mod_name is None:
            continue

        # Build the full module path for chained access:
        # import os.path β†’ import_to_module["os"] = "os.path"
        # os.path.join β†’ chain=["path"], we need to resolve "join" against "os.path"
        # The chain intermediates are sub-module parts already covered by mod_name.
        # We check the final attr against the deepest resolvable module.
        if chain:
            # chain was built bottom-up, reverse to get top-down order
            chain.reverse()
            # Build candidate module: mod_name + chain parts
            full_mod = mod_name + "." + ".".join(chain)
            # Try the full module first; fall back to mod_name if it doesn't exist
            check_mod = full_mod if _module_spec(full_mod) else mod_name
        else:
            check_mod = mod_name

        # FIX 3: resolve against full module path, not just top-level
        resolved = _has_attr(check_mod, node.attr)
        symbols.append(
            Symbol(
                module=check_mod,
                attr=node.attr,
                kind="attribute",
                resolved=resolved,
                line=node.lineno,
            )
        )

    # ----- compute groundedness -----------------------------------------
    grounded = tuple(s for s in symbols if s.resolved)
    ungrounded = tuple(s for s in symbols if not s.resolved)
    total = len(symbols)

    # FIX 2: zero symbols β†’ 0.5 (neutral), not 1.0
    groundedness = 0.5 if total == 0 else len(grounded) / total

    return GroundingReport(
        total_symbols=total,
        grounded=grounded,
        ungrounded=ungrounded,
        groundedness=groundedness,
    )