File size: 4,426 Bytes
40eb9bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
"""
Python port of query_normalize.c from pg_plan_cache.

Replicates the exact normalization and SHA-256 hashing logic so the agent
can compute cache keys without needing PostgreSQL.
"""

import hashlib
from enum import Enum, auto


class _State(Enum):
    DEFAULT = auto()
    IN_SINGLE_QUOTE = auto()
    IN_DOUBLE_QUOTE = auto()
    IN_NUMBER = auto()
    IN_LINE_COMMENT = auto()
    IN_BLOCK_COMMENT = auto()


def normalize_query(query: str) -> str:
    """
    Normalize a SQL query string:
      - Replace string literals with positional parameters ($N)
      - Replace numeric literals with positional parameters ($N)
      - Collapse whitespace to single spaces
      - Lowercase non-quoted identifiers
      - Strip comments

    This mirrors the C implementation in query_normalize.c.
    """
    if not query:
        return ""

    buf: list[str] = []
    param_num = 0
    i = 0
    n = len(query)
    state = _State.DEFAULT
    last_was_space = False
    block_comment_depth = 0

    while i < n:
        c = query[i]

        if state == _State.DEFAULT:
            # Line comment --
            if c == '-' and i + 1 < n and query[i + 1] == '-':
                state = _State.IN_LINE_COMMENT
                i += 2
                continue

            # Block comment /*
            if c == '/' and i + 1 < n and query[i + 1] == '*':
                state = _State.IN_BLOCK_COMMENT
                block_comment_depth = 1
                i += 2
                continue

            # String literal
            if c == "'":
                param_num += 1
                buf.append(f"${param_num}")
                state = _State.IN_SINGLE_QUOTE
                i += 1
                last_was_space = False
                continue

            # Quoted identifier
            if c == '"':
                state = _State.IN_DOUBLE_QUOTE
                buf.append(c)
                i += 1
                last_was_space = False
                continue

            # Numeric literal
            if c.isdigit() or (c == '.' and i + 1 < n and query[i + 1].isdigit()):
                param_num += 1
                buf.append(f"${param_num}")
                state = _State.IN_NUMBER
                i += 1
                last_was_space = False
                continue

            # Whitespace
            if c.isspace():
                if not last_was_space and buf:
                    buf.append(' ')
                    last_was_space = True
                i += 1
                continue

            # Regular character
            if c.isalpha():
                buf.append(c.lower())
            else:
                buf.append(c)
            last_was_space = False
            i += 1

        elif state == _State.IN_SINGLE_QUOTE:
            if c == "'" and i + 1 < n and query[i + 1] == "'":
                i += 2  # escaped quote
            elif c == "'":
                state = _State.DEFAULT
                i += 1
            else:
                i += 1

        elif state == _State.IN_DOUBLE_QUOTE:
            buf.append(c)
            if c == '"' and not (i + 1 < n and query[i + 1] == '"'):
                state = _State.DEFAULT
            elif c == '"':
                i += 1
                if i < n:
                    buf.append(query[i])
            i += 1

        elif state == _State.IN_NUMBER:
            if c.isdigit() or c in '.eE+-':
                i += 1
            else:
                state = _State.DEFAULT

        elif state == _State.IN_LINE_COMMENT:
            if c == '\n':
                state = _State.DEFAULT
            i += 1

        elif state == _State.IN_BLOCK_COMMENT:
            if c == '/' and i + 1 < n and query[i + 1] == '*':
                block_comment_depth += 1
                i += 2
            elif c == '*' and i + 1 < n and query[i + 1] == '/':
                block_comment_depth -= 1
                i += 2
                if block_comment_depth == 0:
                    state = _State.DEFAULT
            else:
                i += 1

    result = ''.join(buf)
    return result.rstrip(' ')


def compute_query_hash(normalized_query: str) -> str:
    """
    Compute SHA-256 hash of the normalized query.
    Returns a 64-character hex string, matching compute_query_hash() in C.
    """
    if not normalized_query:
        return "0" * 64
    return hashlib.sha256(normalized_query.encode("utf-8")).hexdigest()