File size: 4,677 Bytes
a80f6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
"""This script checks documentation for broken import statements."""

import importlib
import json
import logging
import os
import re
import warnings
from pathlib import Path
from typing import List, Tuple

logger = logging.getLogger(__name__)

DOCS_DIR = Path(os.path.abspath(__file__)).parents[1] / "docs"
import_pattern = re.compile(
    r"import\s+(\w+)|from\s+([\w\.]+)\s+import\s+((?:\w+(?:,\s*)?)+|\(.*?\))", re.DOTALL
)


def _get_imports_from_code_cell(code_lines: str) -> List[Tuple[str, str]]:
    """Get (module, import) statements from a single code cell."""
    import_statements = []
    for line in code_lines:
        line = line.strip()
        if line.startswith("#") or not line:
            continue
        # Join lines that end with a backslash
        if line.endswith("\\"):
            line = line[:-1].rstrip() + " "
            continue
        matches = import_pattern.findall(line)
        for match in matches:
            if match[0]:  # simple import statement
                import_statements.append((match[0], ""))
            else:  # from ___ import statement
                module, items = match[1], match[2]
                items_list = items.replace(" ", "").split(",")
                for item in items_list:
                    import_statements.append((module, item))
    return import_statements


def _extract_import_statements(notebook_path: str) -> List[Tuple[str, str]]:
    """Get (module, import) statements from a Jupyter notebook."""
    with open(notebook_path, "r", encoding="utf-8") as file:
        notebook = json.load(file)
    code_cells = [cell for cell in notebook["cells"] if cell["cell_type"] == "code"]
    import_statements = []
    for cell in code_cells:
        code_lines = cell["source"]
        import_statements.extend(_get_imports_from_code_cell(code_lines))
    return import_statements


def _get_bad_imports(import_statements: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
    """Collect offending import statements."""
    offending_imports = []
    for module, item in import_statements:
        try:
            if item:
                try:
                    # submodule
                    full_module_name = f"{module}.{item}"
                    importlib.import_module(full_module_name)
                except ModuleNotFoundError:
                    # attribute
                    try:
                        imported_module = importlib.import_module(module)
                        getattr(imported_module, item)
                    except AttributeError:
                        offending_imports.append((module, item))
                except Exception:
                    offending_imports.append((module, item))
            else:
                importlib.import_module(module)
        except Exception:
            offending_imports.append((module, item))

    return offending_imports


def _is_relevant_import(module: str) -> bool:
    """Check if module is recognized."""
    # Ignore things like langchain_{bla}, where bla is unrecognized.
    recognized_packages = [
        "langchain",
        "langchain_core",
        "langchain_community",
        "langchain_experimental",
        "langchain_text_splitters",
    ]
    return module.split(".")[0] in recognized_packages


def _serialize_bad_imports(bad_files: list) -> str:
    """Serialize bad imports to a string."""
    bad_imports_str = ""
    for file, bad_imports in bad_files:
        bad_imports_str += f"File: {file}\n"
        for module, item in bad_imports:
            bad_imports_str += f"    {module}.{item}\n"
    return bad_imports_str


def check_notebooks(directory: str) -> list:
    """Check notebooks for broken import statements."""
    bad_files = []
    for root, _, files in os.walk(directory):
        for file in files:
            if file.endswith(".ipynb") and not file.endswith("-checkpoint.ipynb"):
                notebook_path = os.path.join(root, file)
                import_statements = [
                    (module, item)
                    for module, item in _extract_import_statements(notebook_path)
                    if _is_relevant_import(module)
                ]
                bad_imports = _get_bad_imports(import_statements)
                if bad_imports:
                    bad_files.append(
                        (
                            os.path.join(root, file),
                            bad_imports,
                        )
                    )
    return bad_files


if __name__ == "__main__":
    bad_files = check_notebooks(DOCS_DIR)
    if bad_files:
        raise ImportError("Found bad imports:\n" f"{_serialize_bad_imports(bad_files)}")