File size: 1,701 Bytes
08157a5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
# -------------------------------------------------------------------------

"""Source code information used for diagnostic messages."""

from __future__ import annotations

import ast
from typing import Callable, Optional


class SourceInfo:
    """Information about onnxscript source fragment, used for diagnostic messages."""

    def __init__(
        self,
        ast_node: ast.AST,
        code: Optional[str] = None,
        function_name: Optional[str] = None,
    ):
        self.ast_node = ast_node
        self.code = code
        self.function_name = function_name

    @property
    def lineno(self):
        return self.ast_node.lineno

    def msg(self, error_message: str) -> str:
        lineno = self.lineno
        if self.function_name:
            source_loc = f"Function '{self.function_name}', line {lineno}"
        else:
            source_loc = f"Line {lineno}"

        if self.code:
            lines = self.code.split("\n")
            line = lines[lineno - 1]
            marker_prefix = " " * (self.ast_node.col_offset)
            source_line = f"{line}\n{marker_prefix}^\n"
        else:
            source_line = ""

        return f"ERROR: {error_message}\nat: {source_loc}\n{source_line}"

    def __str__(self) -> str:
        raise ValueError("Cannot happen!")


Formatter = Callable[[ast.AST, str], str]


def formatter(source_code: Optional[str]) -> Formatter:
    def format(node: ast.AST, message: str) -> str:
        return SourceInfo(node, source_code).msg(message)

    return format