File size: 3,714 Bytes
aaef24a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python
"""Check that marimo notebooks in the same lesson directory agree on package versions.

It is acceptable for different notebooks in a directory to specify different packages,
but if two or more notebooks specify the same package, their version constraints must
be identical.
"""

import argparse
import re
import sys
from collections import defaultdict
from pathlib import Path


# Regex to extract the inline script metadata block (PEP 723)
SCRIPT_BLOCK_RE = re.compile(r"^# /// script\s*\n((?:#[^\n]*\n)*?)# ///", re.MULTILINE)
DEPENDENCY_LINE_RE = re.compile(r'^#\s+"([^"]+)",?\s*$')


def parse_script_header(text: str) -> list[str]:
    """Return the list of dependency strings from a PEP 723 script header, or []."""
    match = SCRIPT_BLOCK_RE.search(text)
    if not match:
        return []
    block = match.group(1)
    deps: list[str] = []
    in_deps = False
    for raw_line in block.splitlines():
        line = raw_line.lstrip("#").strip()
        if line.startswith("dependencies"):
            in_deps = True
            continue
        if in_deps:
            if line.startswith("]"):
                break
            # strip surrounding quotes and comma: e.g. '    "polars==1.0",' -> 'polars==1.0'
            stripped = line.strip().strip('"\'').rstrip(",").strip('"\'')
            if stripped:
                deps.append(stripped)
    return deps


def package_name(dep: str) -> str:
    """Extract the bare package name from a PEP 508 dependency string.

    Examples:
        "polars==1.22.0"  -> "polars"
        "pandas>=2.0,<3"  -> "pandas"
        "marimo"          -> "marimo"
    """
    return re.split(r"[><=!;\s\[]", dep, maxsplit=1)[0].lower()


def check_directory(lesson_dir: Path, only: set[str]) -> list[str]:
    """Return a list of error messages for version inconsistencies among *only* in lesson_dir."""
    # Map package name -> {version_spec: [notebook_path, ...]}
    seen: dict[str, dict[str, list[str]]] = defaultdict(lambda: defaultdict(list))

    for nb in sorted(lesson_dir.glob("*.py")):
        if nb.name not in only:
            continue
        try:
            text = nb.read_text(encoding="utf-8")
        except IOError:
            continue
        if "marimo.App" not in text:
            continue
        for dep in parse_script_header(text):
            name = package_name(dep)
            seen[name][dep].append(nb.name)

    errors: list[str] = []
    for name, specs in sorted(seen.items()):
        if len(specs) > 1:
            errors.append(f"  Package '{name}' has conflicting specifications:")
            for spec, files in sorted(specs.items()):
                errors.append(f"    {spec!r} in: {', '.join(files)}")
    return errors


def main() -> None:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("notebooks", nargs="+", metavar="NOTEBOOK",
                        help="notebook files to check (grouped by directory)")
    args = parser.parse_args()

    dir_filter: dict[Path, set[str]] = defaultdict(set)
    for nb_path in (Path(p) for p in args.notebooks):
        dir_filter[nb_path.parent].add(nb_path.name)

    total_errors = 0
    for lesson_dir, only in sorted(dir_filter.items()):
        errors = check_directory(lesson_dir, only=only)
        if errors:
            print(f"\n{lesson_dir}/")
            for msg in errors:
                print(msg)
            total_errors += len(errors)

    if total_errors:
        print(f"\nFound package version inconsistencies in {total_errors} package(s).")
        sys.exit(1)
    else:
        print("All package version specifications are consistent.")
        sys.exit(0)


if __name__ == "__main__":
    main()