File size: 5,487 Bytes
a80f6e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import ast
import inspect
import os
import pathlib
from pathlib import Path
from typing import Any, List, Optional, Tuple, Type

HERE = Path(__file__).parent
# Should bring us to [root]/src
PKGS_ROOT = HERE.parent.parent.parent.parent.parent

LANGCHAIN_PKG = PKGS_ROOT / "langchain"
COMMUNITY_PKG = PKGS_ROOT / "community"
PARTNER_PKGS = PKGS_ROOT / "partners"


class ImportExtractor(ast.NodeVisitor):
    def __init__(self, *, from_package: Optional[str] = None) -> None:
        """Extract all imports from the given code, optionally filtering by package."""
        self.imports: list = []
        self.package = from_package

    def visit_ImportFrom(self, node):
        if node.module and (
            self.package is None or str(node.module).startswith(self.package)
        ):
            for alias in node.names:
                self.imports.append((node.module, alias.name))
        self.generic_visit(node)


def _get_class_names(code: str) -> List[str]:
    """Extract class names from a code string."""
    # Parse the content of the file into an AST
    tree = ast.parse(code)

    # Initialize a list to hold all class names
    class_names = []

    # Define a node visitor class to collect class names
    class ClassVisitor(ast.NodeVisitor):
        def visit_ClassDef(self, node):
            class_names.append(node.name)
            self.generic_visit(node)

    # Create an instance of the visitor and visit the AST
    visitor = ClassVisitor()
    visitor.visit(tree)
    return class_names


def is_subclass(class_obj: Any, classes_: List[Type]) -> bool:
    """Check if the given class object is a subclass of any class in list classes."""
    return any(
        issubclass(class_obj, kls)
        for kls in classes_
        if inspect.isclass(class_obj) and inspect.isclass(kls)
    )


def find_subclasses_in_module(module, classes_: List[Type]) -> List[str]:
    """Find all classes in the module that inherit from one of the classes."""
    subclasses = []
    # Iterate over all attributes of the module that are classes
    for name, obj in inspect.getmembers(module, inspect.isclass):
        if is_subclass(obj, classes_):
            subclasses.append(obj.__name__)
    return subclasses


def _get_all_classnames_from_file(file: Path, pkg: str) -> List[Tuple[str, str]]:
    """Extract all class names from a file."""
    with open(file, encoding="utf-8") as f:
        code = f.read()
    module_name = _get_current_module(file, pkg)
    class_names = _get_class_names(code)

    return [(module_name, class_name) for class_name in class_names]


def identify_all_imports_in_file(
    file: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
    """Let's also identify all the imports in the given file."""
    with open(file, encoding="utf-8") as f:
        code = f.read()
    return find_imports_from_package(code, from_package=from_package)


def identify_pkg_source(pkg_root: str) -> pathlib.Path:
    """Identify the source of the package.

    Args:
        pkg_root: the root of the package. This contains source + tests, and other
            things like pyproject.toml, lock files etc

    Returns:
        Returns the path to the source code for the package.
    """
    dirs = [d for d in Path(pkg_root).iterdir() if d.is_dir()]
    matching_dirs = [d for d in dirs if d.name.startswith("langchain_")]
    assert len(matching_dirs) == 1, "There should be only one langchain package."
    return matching_dirs[0]


def list_classes_by_package(pkg_root: str) -> List[Tuple[str, str]]:
    """List all classes in a package."""
    module_classes = []
    pkg_source = identify_pkg_source(pkg_root)
    files = list(pkg_source.rglob("*.py"))

    for file in files:
        rel_path = os.path.relpath(file, pkg_root)
        if rel_path.startswith("tests"):
            continue
        module_classes.extend(_get_all_classnames_from_file(file, pkg_root))
    return module_classes


def list_init_imports_by_package(pkg_root: str) -> List[Tuple[str, str]]:
    """List all the things that are being imported in a package by module."""
    imports = []
    pkg_source = identify_pkg_source(pkg_root)
    # Scan all the files in the package
    files = list(Path(pkg_source).rglob("*.py"))

    for file in files:
        if not file.name == "__init__.py":
            continue
        import_in_file = identify_all_imports_in_file(str(file))
        module_name = _get_current_module(file, pkg_root)
        imports.extend([(module_name, item) for _, item in import_in_file])
    return imports


def find_imports_from_package(
    code: str, *, from_package: Optional[str] = None
) -> List[Tuple[str, str]]:
    # Parse the code into an AST
    tree = ast.parse(code)
    # Create an instance of the visitor
    extractor = ImportExtractor(from_package=from_package)
    # Use the visitor to update the imports list
    extractor.visit(tree)
    return extractor.imports


def _get_current_module(path: Path, pkg_root: str) -> str:
    """Convert a path to a module name."""
    path_as_pathlib = pathlib.Path(os.path.abspath(path))
    relative_path = path_as_pathlib.relative_to(pkg_root).with_suffix("")
    posix_path = relative_path.as_posix()
    norm_path = os.path.normpath(str(posix_path))
    fully_qualified_module = norm_path.replace("/", ".")
    # Strip __init__ if present
    if fully_qualified_module.endswith(".__init__"):
        return fully_qualified_module[:-9]
    return fully_qualified_module