File size: 7,141 Bytes
a9bd396
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
import os

import libcst as cst


# Files from external libraries that should not be tracked
# E.g. for habana, we don't want to track the dependencies from `modeling_all_models.py` as it is not part of the transformers library
EXCLUDED_EXTERNAL_FILES = {
    "habana": [{"name": "modeling_all_models", "type": "modeling"}],
}


def convert_relative_import_to_absolute(
    import_node: cst.ImportFrom,
    file_path: str,
    package_name: str | None = "transformers",
) -> cst.ImportFrom:
    """
    Convert a relative libcst.ImportFrom node into an absolute one,
    using the file path and package name.

    Args:
        import_node: A relative import node (e.g. `from ..utils import helper`)
        file_path: Path to the file containing the import (can be absolute or relative)
        package_name: The top-level package name (e.g. 'myproject')

    Returns:
        A new ImportFrom node with the absolute import path
    """
    if not (import_node.relative and len(import_node.relative) > 0):
        return import_node  # Already absolute

    file_path = os.path.abspath(file_path)
    rel_level = len(import_node.relative)

    # Strip file extension and split into parts
    file_path_no_ext = file_path.removesuffix(".py")
    file_parts = file_path_no_ext.split(os.path.sep)

    # Ensure the file path includes the package name
    if package_name not in file_parts:
        raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'")

    # Slice file_parts starting from the package name
    pkg_index = file_parts.index(package_name)
    module_parts = file_parts[pkg_index + 1 :]  # e.g. ['module', 'submodule', 'foo']
    if len(module_parts) < rel_level:
        raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.")

    base_parts = module_parts[:-rel_level]

    # Flatten the module being imported (if any)
    def flatten_module(module: cst.BaseExpression | None) -> list[str]:
        if not module:
            return []
        if isinstance(module, cst.Name):
            return [module.value]
        elif isinstance(module, cst.Attribute):
            parts = []
            while isinstance(module, cst.Attribute):
                parts.insert(0, module.attr.value)
                module = module.value
            if isinstance(module, cst.Name):
                parts.insert(0, module.value)
            return parts
        return []

    import_parts = flatten_module(import_node.module)

    # Combine to get the full absolute import path
    full_parts = [package_name] + base_parts + import_parts

    # Handle special case where the import comes from a namespace package (e.g. optimum with `optimum.habana`, `optimum.intel` instead of `src.optimum`)
    if package_name != "transformers" and file_parts[pkg_index - 1] != "src":
        full_parts = [file_parts[pkg_index - 1]] + full_parts

    # Build the dotted module path
    dotted_module: cst.BaseExpression | None = None
    for part in full_parts:
        name = cst.Name(part)
        dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)

    # Return a new ImportFrom node with absolute import
    return import_node.with_changes(module=dotted_module, relative=[])


def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom:
    """
    Convert an absolute import to a relative one if it belongs to `package_name`.

    Parameters:
    - node: The ImportFrom node to possibly transform.
    - file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py').
    - package_name: The top-level package name (e.g., 'mypackage').

    Returns:
    - A possibly modified ImportFrom node.
    """
    if import_node.relative:
        return import_node  # Already relative import

    # Extract module name string from ImportFrom
    def get_module_name(module):
        if isinstance(module, cst.Name):
            return module.value, [module.value]
        elif isinstance(module, cst.Attribute):
            parts = []
            while isinstance(module, cst.Attribute):
                parts.append(module.attr.value)
                module = module.value
            if isinstance(module, cst.Name):
                parts.append(module.value)
            parts.reverse()
            return ".".join(parts), parts
        return "", None

    module_name, submodule_list = get_module_name(import_node.module)

    # Check if it's from the target package
    if (
        not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + "."))
        and module_name != package_name
    ):
        return import_node  # Not from target package

    # Locate the package root inside the file path
    norm_file_path = os.path.normpath(file_path)
    parts = norm_file_path.split(os.sep)

    try:
        pkg_index = parts.index(package_name)
    except ValueError:
        # Package name not found in path — assume we can't resolve relative depth
        return import_node

    # Depth is how many directories after the package name before the current file
    depth = len(parts) - pkg_index - 1  # exclude the .py file itself
    for i, submodule in enumerate(parts[pkg_index + 1 :]):
        if submodule == submodule_list[2 + i]:
            depth -= 1
        else:
            break

    # Create the correct number of dots
    relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()]

    # Strip package prefix from import module path
    if module_name.startswith("optimum." + package_name + "."):
        stripped_name = module_name[len("optimum." + package_name) :].lstrip(".")
    else:
        stripped_name = module_name[len(package_name) :].lstrip(".")

    # Build new module node
    if stripped_name == "":
        new_module = None
    else:
        name_parts = stripped_name.split(".")[i:]
        new_module = cst.Name(name_parts[0])
        for part in name_parts[1:]:
            new_module = cst.Attribute(value=new_module, attr=cst.Name(part))

    return import_node.with_changes(module=new_module, relative=relative)


class AbsoluteImportTransformer(cst.CSTTransformer):
    def __init__(self, relative_path: str, source_library: str):
        super().__init__()
        self.relative_path = relative_path
        self.source_library = source_library

    def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
        return convert_relative_import_to_absolute(
            import_node=updated_node, file_path=self.relative_path, package_name=self.source_library
        )


class RelativeImportTransformer(cst.CSTTransformer):
    def __init__(self, relative_path: str, source_library: str):
        super().__init__()
        self.relative_path = relative_path
        self.source_library = source_library

    def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
        return convert_to_relative_import(updated_node, self.relative_path, self.source_library)