File size: 7,141 Bytes
a9bd396 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 | import os
import libcst as cst
# Files from external libraries that should not be tracked
# E.g. for habana, we don't want to track the dependencies from `modeling_all_models.py` as it is not part of the transformers library
EXCLUDED_EXTERNAL_FILES = {
"habana": [{"name": "modeling_all_models", "type": "modeling"}],
}
def convert_relative_import_to_absolute(
import_node: cst.ImportFrom,
file_path: str,
package_name: str | None = "transformers",
) -> cst.ImportFrom:
"""
Convert a relative libcst.ImportFrom node into an absolute one,
using the file path and package name.
Args:
import_node: A relative import node (e.g. `from ..utils import helper`)
file_path: Path to the file containing the import (can be absolute or relative)
package_name: The top-level package name (e.g. 'myproject')
Returns:
A new ImportFrom node with the absolute import path
"""
if not (import_node.relative and len(import_node.relative) > 0):
return import_node # Already absolute
file_path = os.path.abspath(file_path)
rel_level = len(import_node.relative)
# Strip file extension and split into parts
file_path_no_ext = file_path.removesuffix(".py")
file_parts = file_path_no_ext.split(os.path.sep)
# Ensure the file path includes the package name
if package_name not in file_parts:
raise ValueError(f"Package name '{package_name}' not found in file path '{file_path}'")
# Slice file_parts starting from the package name
pkg_index = file_parts.index(package_name)
module_parts = file_parts[pkg_index + 1 :] # e.g. ['module', 'submodule', 'foo']
if len(module_parts) < rel_level:
raise ValueError(f"Relative import level ({rel_level}) goes beyond package root.")
base_parts = module_parts[:-rel_level]
# Flatten the module being imported (if any)
def flatten_module(module: cst.BaseExpression | None) -> list[str]:
if not module:
return []
if isinstance(module, cst.Name):
return [module.value]
elif isinstance(module, cst.Attribute):
parts = []
while isinstance(module, cst.Attribute):
parts.insert(0, module.attr.value)
module = module.value
if isinstance(module, cst.Name):
parts.insert(0, module.value)
return parts
return []
import_parts = flatten_module(import_node.module)
# Combine to get the full absolute import path
full_parts = [package_name] + base_parts + import_parts
# Handle special case where the import comes from a namespace package (e.g. optimum with `optimum.habana`, `optimum.intel` instead of `src.optimum`)
if package_name != "transformers" and file_parts[pkg_index - 1] != "src":
full_parts = [file_parts[pkg_index - 1]] + full_parts
# Build the dotted module path
dotted_module: cst.BaseExpression | None = None
for part in full_parts:
name = cst.Name(part)
dotted_module = name if dotted_module is None else cst.Attribute(value=dotted_module, attr=name)
# Return a new ImportFrom node with absolute import
return import_node.with_changes(module=dotted_module, relative=[])
def convert_to_relative_import(import_node: cst.ImportFrom, file_path: str, package_name: str) -> cst.ImportFrom:
"""
Convert an absolute import to a relative one if it belongs to `package_name`.
Parameters:
- node: The ImportFrom node to possibly transform.
- file_path: Absolute path to the file containing the import (e.g., '/path/to/mypackage/foo/bar.py').
- package_name: The top-level package name (e.g., 'mypackage').
Returns:
- A possibly modified ImportFrom node.
"""
if import_node.relative:
return import_node # Already relative import
# Extract module name string from ImportFrom
def get_module_name(module):
if isinstance(module, cst.Name):
return module.value, [module.value]
elif isinstance(module, cst.Attribute):
parts = []
while isinstance(module, cst.Attribute):
parts.append(module.attr.value)
module = module.value
if isinstance(module, cst.Name):
parts.append(module.value)
parts.reverse()
return ".".join(parts), parts
return "", None
module_name, submodule_list = get_module_name(import_node.module)
# Check if it's from the target package
if (
not (module_name.startswith(package_name + ".") or module_name.startswith("optimum." + package_name + "."))
and module_name != package_name
):
return import_node # Not from target package
# Locate the package root inside the file path
norm_file_path = os.path.normpath(file_path)
parts = norm_file_path.split(os.sep)
try:
pkg_index = parts.index(package_name)
except ValueError:
# Package name not found in path — assume we can't resolve relative depth
return import_node
# Depth is how many directories after the package name before the current file
depth = len(parts) - pkg_index - 1 # exclude the .py file itself
for i, submodule in enumerate(parts[pkg_index + 1 :]):
if submodule == submodule_list[2 + i]:
depth -= 1
else:
break
# Create the correct number of dots
relative = [cst.Dot()] * depth if depth > 0 else [cst.Dot()]
# Strip package prefix from import module path
if module_name.startswith("optimum." + package_name + "."):
stripped_name = module_name[len("optimum." + package_name) :].lstrip(".")
else:
stripped_name = module_name[len(package_name) :].lstrip(".")
# Build new module node
if stripped_name == "":
new_module = None
else:
name_parts = stripped_name.split(".")[i:]
new_module = cst.Name(name_parts[0])
for part in name_parts[1:]:
new_module = cst.Attribute(value=new_module, attr=cst.Name(part))
return import_node.with_changes(module=new_module, relative=relative)
class AbsoluteImportTransformer(cst.CSTTransformer):
def __init__(self, relative_path: str, source_library: str):
super().__init__()
self.relative_path = relative_path
self.source_library = source_library
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
return convert_relative_import_to_absolute(
import_node=updated_node, file_path=self.relative_path, package_name=self.source_library
)
class RelativeImportTransformer(cst.CSTTransformer):
def __init__(self, relative_path: str, source_library: str):
super().__init__()
self.relative_path = relative_path
self.source_library = source_library
def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom) -> cst.ImportFrom:
return convert_to_relative_import(updated_node, self.relative_path, self.source_library)
|