File size: 20,594 Bytes
be7647c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 | """Functions related to Black's formatting by line ranges feature."""
import difflib
from collections.abc import Collection, Iterator, Sequence
from dataclasses import dataclass
from black.nodes import (
LN,
STANDALONE_COMMENT,
Leaf,
Node,
Visitor,
first_leaf,
furthest_ancestor_with_last_leaf,
last_leaf,
syms,
)
from blib2to3.pgen2.token import ASYNC, NEWLINE
def parse_line_ranges(line_ranges: Sequence[str]) -> list[tuple[int, int]]:
lines: list[tuple[int, int]] = []
for lines_str in line_ranges:
parts = lines_str.split("-")
if len(parts) != 2:
raise ValueError(
"Incorrect --line-ranges format, expect 'START-END', found"
f" {lines_str!r}"
)
try:
start = int(parts[0])
end = int(parts[1])
except ValueError:
raise ValueError(
"Incorrect --line-ranges value, expect integer ranges, found"
f" {lines_str!r}"
) from None
else:
lines.append((start, end))
return lines
def is_valid_line_range(lines: tuple[int, int]) -> bool:
"""Returns whether the line range is valid."""
return not lines or lines[0] <= lines[1]
def sanitized_lines(
lines: Collection[tuple[int, int]], src_contents: str
) -> Collection[tuple[int, int]]:
"""Returns the valid line ranges for the given source.
This removes ranges that are entirely outside the valid lines.
Other ranges are normalized so that the start values are at least 1 and the
end values are at most the (1-based) index of the last source line.
"""
if not src_contents:
return []
good_lines = []
src_line_count = src_contents.count("\n")
if not src_contents.endswith("\n"):
src_line_count += 1
for start, end in lines:
if start > src_line_count:
continue
# line-ranges are 1-based
start = max(start, 1)
if end < start:
continue
end = min(end, src_line_count)
good_lines.append((start, end))
return good_lines
def adjusted_lines(
lines: Collection[tuple[int, int]],
original_source: str,
modified_source: str,
) -> list[tuple[int, int]]:
"""Returns the adjusted line ranges based on edits from the original code.
This computes the new line ranges by diffing original_source and
modified_source, and adjust each range based on how the range overlaps with
the diffs.
Note the diff can contain lines outside of the original line ranges. This can
happen when the formatting has to be done in adjacent to maintain consistent
local results. For example:
1. def my_func(arg1, arg2,
2. arg3,):
3. pass
If it restricts to line 2-2, it can't simply reformat line 2, it also has
to reformat line 1:
1. def my_func(
2. arg1,
3. arg2,
4. arg3,
5. ):
6. pass
In this case, we will expand the line ranges to also include the whole diff
block.
Args:
lines: a collection of line ranges.
original_source: the original source.
modified_source: the modified source.
"""
lines_mappings = _calculate_lines_mappings(original_source, modified_source)
new_lines = []
# Keep an index of the current search. Since the lines and lines_mappings are
# sorted, this makes the search complexity linear.
current_mapping_index = 0
for start, end in sorted(lines):
start_mapping_index = _find_lines_mapping_index(
start,
lines_mappings,
current_mapping_index,
)
end_mapping_index = _find_lines_mapping_index(
end,
lines_mappings,
start_mapping_index,
)
current_mapping_index = start_mapping_index
if start_mapping_index >= len(lines_mappings) or end_mapping_index >= len(
lines_mappings
):
# Protect against invalid inputs.
continue
start_mapping = lines_mappings[start_mapping_index]
end_mapping = lines_mappings[end_mapping_index]
if start_mapping.is_changed_block:
# When the line falls into a changed block, expands to the whole block.
new_start = start_mapping.modified_start
else:
new_start = (
start - start_mapping.original_start + start_mapping.modified_start
)
if end_mapping.is_changed_block:
# When the line falls into a changed block, expands to the whole block.
new_end = end_mapping.modified_end
else:
new_end = end - end_mapping.original_start + end_mapping.modified_start
new_range = (new_start, new_end)
if is_valid_line_range(new_range):
new_lines.append(new_range)
return new_lines
def convert_unchanged_lines(src_node: Node, lines: Collection[tuple[int, int]]) -> None:
r"""Converts unchanged lines to STANDALONE_COMMENT.
The idea is similar to how `# fmt: on/off` is implemented. It also converts the
nodes between those markers as a single `STANDALONE_COMMENT` leaf node with
the unformatted code as its value. `STANDALONE_COMMENT` is a "fake" token
that will be formatted as-is with its prefix normalized.
Here we perform two passes:
1. Visit the top-level statements, and convert them to a single
`STANDALONE_COMMENT` when unchanged. This speeds up formatting when some
of the top-level statements aren't changed.
2. Convert unchanged "unwrapped lines" to `STANDALONE_COMMENT` nodes line by
line. "unwrapped lines" are divided by the `NEWLINE` token. e.g. a
multi-line statement is *one* "unwrapped line" that ends with `NEWLINE`,
even though this statement itself can span multiple lines, and the
tokenizer only sees the last '\n' as the `NEWLINE` token.
NOTE: During pass (2), comment prefixes and indentations are ALWAYS
normalized even when the lines aren't changed. This is fixable by moving
more formatting to pass (1). However, it's hard to get it correct when
incorrect indentations are used. So we defer this to future optimizations.
"""
lines_set: set[int] = set()
for start, end in lines:
lines_set.update(range(start, end + 1))
visitor = _TopLevelStatementsVisitor(lines_set)
_ = list(visitor.visit(src_node)) # Consume all results.
_convert_unchanged_line_by_line(src_node, lines_set)
def _contains_standalone_comment(node: LN) -> bool:
if isinstance(node, Leaf):
return node.type == STANDALONE_COMMENT
else:
for child in node.children:
if _contains_standalone_comment(child):
return True
return False
class _TopLevelStatementsVisitor(Visitor[None]):
"""
A node visitor that converts unchanged top-level statements to
STANDALONE_COMMENT.
This is used in addition to _convert_unchanged_line_by_line, to
speed up formatting when there are unchanged top-level
classes/functions/statements.
"""
def __init__(self, lines_set: set[int]):
self._lines_set = lines_set
def visit_simple_stmt(self, node: Node) -> Iterator[None]:
# This is only called for top-level statements, since `visit_suite`
# won't visit its children nodes.
yield from []
newline_leaf = last_leaf(node)
if not newline_leaf:
return
assert (
newline_leaf.type == NEWLINE
), f"Unexpectedly found leaf.type={newline_leaf.type}"
# We need to find the furthest ancestor with the NEWLINE as the last
# leaf, since a `suite` can simply be a `simple_stmt` when it puts
# its body on the same line. Example: `if cond: pass`.
ancestor = furthest_ancestor_with_last_leaf(newline_leaf)
if not _get_line_range(ancestor).intersection(self._lines_set):
_convert_node_to_standalone_comment(ancestor)
def visit_suite(self, node: Node) -> Iterator[None]:
yield from []
# If there is a STANDALONE_COMMENT node, it means parts of the node tree
# have fmt on/off/skip markers. Those STANDALONE_COMMENT nodes can't
# be simply converted by calling str(node). So we just don't convert
# here.
if _contains_standalone_comment(node):
return
# Find the semantic parent of this suite. For `async_stmt` and
# `async_funcdef`, the ASYNC token is defined on a separate level by the
# grammar.
semantic_parent = node.parent
if semantic_parent is not None:
if (
semantic_parent.prev_sibling is not None
and semantic_parent.prev_sibling.type == ASYNC
):
semantic_parent = semantic_parent.parent
if semantic_parent is not None and not _get_line_range(
semantic_parent
).intersection(self._lines_set):
_convert_node_to_standalone_comment(semantic_parent)
def _convert_unchanged_line_by_line(node: Node, lines_set: set[int]) -> None:
"""Converts unchanged to STANDALONE_COMMENT line by line."""
for leaf in node.leaves():
if leaf.type != NEWLINE:
# We only consider "unwrapped lines", which are divided by the NEWLINE
# token.
continue
if leaf.parent and leaf.parent.type == syms.match_stmt:
# The `suite` node is defined as:
# match_stmt: "match" subject_expr ':' NEWLINE INDENT case_block+ DEDENT
# Here we need to check `subject_expr`. The `case_block+` will be
# checked by their own NEWLINEs.
nodes_to_ignore: list[LN] = []
prev_sibling = leaf.prev_sibling
while prev_sibling:
nodes_to_ignore.insert(0, prev_sibling)
prev_sibling = prev_sibling.prev_sibling
if not _get_line_range(nodes_to_ignore).intersection(lines_set):
_convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)
elif leaf.parent and leaf.parent.type == syms.suite:
# The `suite` node is defined as:
# suite: simple_stmt | NEWLINE INDENT stmt+ DEDENT
# We will check `simple_stmt` and `stmt+` separately against the lines set
parent_sibling = leaf.parent.prev_sibling
nodes_to_ignore = []
while parent_sibling and parent_sibling.type != syms.suite:
# NOTE: Multiple suite nodes can exist as siblings in e.g. `if_stmt`.
nodes_to_ignore.insert(0, parent_sibling)
parent_sibling = parent_sibling.prev_sibling
# Special case for `async_stmt` and `async_funcdef` where the ASYNC
# token is on the grandparent node.
grandparent = leaf.parent.parent
if (
grandparent is not None
and grandparent.prev_sibling is not None
and grandparent.prev_sibling.type == ASYNC
):
nodes_to_ignore.insert(0, grandparent.prev_sibling)
if not _get_line_range(nodes_to_ignore).intersection(lines_set):
_convert_nodes_to_standalone_comment(nodes_to_ignore, newline=leaf)
else:
ancestor = furthest_ancestor_with_last_leaf(leaf)
# Consider multiple decorators as a whole block, as their
# newlines have different behaviors than the rest of the grammar.
if (
ancestor.type == syms.decorator
and ancestor.parent
and ancestor.parent.type == syms.decorators
):
ancestor = ancestor.parent
if not _get_line_range(ancestor).intersection(lines_set):
_convert_node_to_standalone_comment(ancestor)
def _convert_node_to_standalone_comment(node: LN) -> None:
"""Convert node to STANDALONE_COMMENT by modifying the tree inline."""
parent = node.parent
if not parent:
return
first = first_leaf(node)
last = last_leaf(node)
if not first or not last:
return
if first is last:
# This can happen on the following edge cases:
# 1. A block of `# fmt: off/on` code except the `# fmt: on` is placed
# on the end of the last line instead of on a new line.
# 2. A single backslash on its own line followed by a comment line.
# Ideally we don't want to format them when not requested, but fixing
# isn't easy. These cases are also badly formatted code, so it isn't
# too bad we reformat them.
return
# The prefix contains comments and indentation whitespaces. They are
# reformatted accordingly to the correct indentation level.
# This also means the indentation will be changed on the unchanged lines, and
# this is actually required to not break incremental reformatting.
prefix = first.prefix
first.prefix = ""
index = node.remove()
if index is not None:
# Because of the special handling of multiple decorators, if the decorated
# item is a single line then there will be a missing newline between the
# decorator and item, so add it back. This doesn't affect any other case
# since a decorated item with a newline would hit the earlier suite case
# in _convert_unchanged_line_by_line that correctly handles the newlines.
if node.type == syms.decorated:
# A leaf of type decorated wouldn't make sense, since it should always
# have at least the decorator + the decorated item, so if this assert
# hits that means there's a problem in the parser.
assert isinstance(node, Node)
# 1 will always be the correct index since before this function is
# called all the decorators are collapsed into a single leaf
node.insert_child(1, Leaf(NEWLINE, "\n"))
# Remove the '\n', as STANDALONE_COMMENT will have '\n' appended when
# generating the formatted code.
value = str(node)[:-1]
parent.insert_child(
index,
Leaf(
STANDALONE_COMMENT,
value,
prefix=prefix,
fmt_pass_converted_first_leaf=first,
),
)
def _convert_nodes_to_standalone_comment(nodes: Sequence[LN], *, newline: Leaf) -> None:
"""Convert nodes to STANDALONE_COMMENT by modifying the tree inline."""
if not nodes:
return
parent = nodes[0].parent
first = first_leaf(nodes[0])
if not parent or not first:
return
prefix = first.prefix
first.prefix = ""
value = "".join(str(node) for node in nodes)
# The prefix comment on the NEWLINE leaf is the trailing comment of the statement.
if newline.prefix:
value += newline.prefix
newline.prefix = ""
index = nodes[0].remove()
for node in nodes[1:]:
node.remove()
if index is not None:
parent.insert_child(
index,
Leaf(
STANDALONE_COMMENT,
value,
prefix=prefix,
fmt_pass_converted_first_leaf=first,
),
)
def _leaf_line_end(leaf: Leaf) -> int:
"""Returns the line number of the leaf node's last line."""
if leaf.type == NEWLINE:
return leaf.lineno
else:
# Leaf nodes like multiline strings can occupy multiple lines.
return leaf.lineno + str(leaf).count("\n")
def _get_line_range(node_or_nodes: LN | list[LN]) -> set[int]:
"""Returns the line range of this node or list of nodes."""
if isinstance(node_or_nodes, list):
nodes = node_or_nodes
if not nodes:
return set()
first = first_leaf(nodes[0])
last = last_leaf(nodes[-1])
if first and last:
line_start = first.lineno
line_end = _leaf_line_end(last)
return set(range(line_start, line_end + 1))
else:
return set()
else:
node = node_or_nodes
if isinstance(node, Leaf):
return set(range(node.lineno, _leaf_line_end(node) + 1))
else:
first = first_leaf(node)
last = last_leaf(node)
if first and last:
return set(range(first.lineno, _leaf_line_end(last) + 1))
else:
return set()
@dataclass
class _LinesMapping:
"""1-based lines mapping from original source to modified source.
Lines [original_start, original_end] from original source
are mapped to [modified_start, modified_end].
The ranges are inclusive on both ends.
"""
original_start: int
original_end: int
modified_start: int
modified_end: int
# Whether this range corresponds to a changed block, or an unchanged block.
is_changed_block: bool
def _calculate_lines_mappings(
original_source: str,
modified_source: str,
) -> Sequence[_LinesMapping]:
"""Returns a sequence of _LinesMapping by diffing the sources.
For example, given the following diff:
import re
- def func(arg1,
- arg2, arg3):
+ def func(arg1, arg2, arg3):
pass
It returns the following mappings:
original -> modified
(1, 1) -> (1, 1), is_changed_block=False (the "import re" line)
(2, 3) -> (2, 2), is_changed_block=True (the diff)
(4, 4) -> (3, 3), is_changed_block=False (the "pass" line)
You can think of this visually as if it brings up a side-by-side diff, and tries
to map the line ranges from the left side to the right side:
(1, 1)->(1, 1) 1. import re 1. import re
(2, 3)->(2, 2) 2. def func(arg1, 2. def func(arg1, arg2, arg3):
3. arg2, arg3):
(4, 4)->(3, 3) 4. pass 3. pass
Args:
original_source: the original source.
modified_source: the modified source.
"""
matcher = difflib.SequenceMatcher(
None,
original_source.splitlines(keepends=True),
modified_source.splitlines(keepends=True),
)
matching_blocks = matcher.get_matching_blocks()
lines_mappings: list[_LinesMapping] = []
# matching_blocks is a sequence of "same block of code ranges", see
# https://docs.python.org/3/library/difflib.html#difflib.SequenceMatcher.get_matching_blocks
# Each block corresponds to a _LinesMapping with is_changed_block=False,
# and the ranges between two blocks corresponds to a _LinesMapping with
# is_changed_block=True,
# NOTE: matching_blocks is 0-based, but _LinesMapping is 1-based.
for i, block in enumerate(matching_blocks):
if i == 0:
if block.a != 0 or block.b != 0:
lines_mappings.append(
_LinesMapping(
original_start=1,
original_end=block.a,
modified_start=1,
modified_end=block.b,
is_changed_block=False,
)
)
else:
previous_block = matching_blocks[i - 1]
lines_mappings.append(
_LinesMapping(
original_start=previous_block.a + previous_block.size + 1,
original_end=block.a,
modified_start=previous_block.b + previous_block.size + 1,
modified_end=block.b,
is_changed_block=True,
)
)
if i < len(matching_blocks) - 1:
lines_mappings.append(
_LinesMapping(
original_start=block.a + 1,
original_end=block.a + block.size,
modified_start=block.b + 1,
modified_end=block.b + block.size,
is_changed_block=False,
)
)
return lines_mappings
def _find_lines_mapping_index(
original_line: int,
lines_mappings: Sequence[_LinesMapping],
start_index: int,
) -> int:
"""Returns the original index of the lines mappings for the original line."""
index = start_index
while index < len(lines_mappings):
mapping = lines_mappings[index]
if mapping.original_start <= original_line <= mapping.original_end:
return index
index += 1
return index
|