File size: 8,912 Bytes
984e3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7a1a69
 
 
 
 
 
 
 
 
984e3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7a1a69
984e3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
e7a1a69
 
 
 
 
 
 
 
 
984e3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
"""
static_analyzer.py — Pure-Python wavefront correctness scanner.

Runs BEFORE the LLM sees any code. Zero external dependencies. Typical run time < 5ms.

Detects the six most common categories of CUDA→AMD correctness hazards caused by the
NVIDIA warpSize=32 vs AMD wavefront=64 mismatch. Results are fed as structured pre-analysis
context into the LLM analyzer prompt, making the LLM's job more targeted and auditable.
"""

import re
import time
from typing import List

from ..models import RiskItem, StaticRiskReport


# ---------------------------------------------------------------------------
# Risk pattern definitions
# Each entry: (pattern_name, regex, risk_level, description, amd_fix_hint)
# ---------------------------------------------------------------------------

_PATTERNS: List[tuple] = [
    (
        "warp_size_hardcoded_32_conditional",
        re.compile(r'\btid\s*<\s*32\b|\bthreadIdx\.x\s*<\s*32\b|\bi\s*<\s*32\b', re.MULTILINE),
        "CRITICAL",
        "Hardcoded '<32' in thread conditional — assumes NVIDIA warpSize=32. "
        "On AMD wavefront=64 this silently skips lanes 32–63 in final reduction stages, "
        "producing incorrect results.",
        "Expand final stage: check 'tid < 64' first, then 'tid < 32'. "
        "See AMD wavefront reduction pattern in docs/JUDGE_MODE.md."
    ),
    (
        "warp_size_define_32",
        re.compile(r'#\s*define\s+WARP_SIZE\s+32\b', re.MULTILINE),
        "CRITICAL",
        "#define WARP_SIZE 32 — this constant will produce wrong kernel geometry on AMD. "
        "Wavefront size is 64 on all GCN/CDNA architectures including MI300X.",
        "Change to #define WARP_SIZE 64 or use the runtime constant wavefrontSize "
        "from hipDeviceGetAttribute(HIP_DEVICE_ATTRIBUTE_WAVEFRONT_SIZE)."
    ),
    (
        "shfl_sync_warp_primitive",
        re.compile(r'\b__shfl_sync\b|\b__shfl_up_sync\b|\b__shfl_down_sync\b|\b__shfl_xor_sync\b', re.MULTILINE),
        "CRITICAL",
        "__shfl_sync family requires the 0xffffffff mask to be reinterpreted for 64-lane wavefronts. "
        "hipify replaces the function name but not the mask — lanes 32–63 are excluded.",
        "Replace with __shfl, __shfl_up, __shfl_down, __shfl_xor (no mask arg in HIP). "
        "Verify lane shuffle ranges cover the full 64-lane wavefront."
    ),
    (
        "ballot_sync_mask",
        re.compile(r'\b__ballot_sync\s*\(\s*0x[Ff]+\s*,', re.MULTILINE),
        "CRITICAL",
        "__ballot_sync(0xffffffff, ...) uses a 32-bit full mask. On AMD this is __ballot() "
        "with no mask argument — the 32-bit mask is semantically wrong for a 64-lane wavefront.",
        "Replace __ballot_sync(0xffffffff, cond) with __ballot(cond). "
        "The return type changes from uint32_t to uint64_t — update downstream bitmask logic."
    ),
    (
        "shfl_wavefront_offset_16",
        re.compile(r'\b__shfl(?:_down|_up|_xor)?\s*\([^;]*,\s*16\s*(?:,|\))', re.MULTILINE),
        "HIGH",
        "__shfl* with offset=16 often encodes a 32-lane warp reduction tail. "
        "On AMD wavefront=64 the reduction should include an offset=32 step first.",
        "Audit the shuffle reduction and add a wavefront-64 step, e.g. offset=32 "
        "before offset=16 where the algorithm reduces a full wavefront."
    ),
    (
        "activemask_warp",
        re.compile(r'\b__activemask\s*\(\s*\)', re.MULTILINE),
        "HIGH",
        "__activemask() returns a 32-bit value on NVIDIA. On AMD __activemask() "
        "or __ballot(1) returns a 64-bit value. Storing in uint32_t will truncate lanes 32–63.",
        "Declare the result as uint64_t. Audit all bitmask operations for 64-bit correctness."
    ),
    (
        "threadidx_modulo_warpsize",
        re.compile(r'threadIdx\.x\s*%\s*(?:32|warpSize)\b', re.MULTILINE),
        "HIGH",
        "threadIdx.x % 32 assumes 32-lane warps. On AMD wavefront=64 the lane ID "
        "within a wavefront requires modulo 64.",
        "Use threadIdx.x % 64 or threadIdx.x & 63 for the lane ID within a wavefront."
    ),
    (
        "reduction_loop_stops_at_32",
        re.compile(r'for\s*\([^)]*\bs\s*>\s*32\b', re.MULTILINE),
        "HIGH",
        "Reduction loop terminates at s>32 before manually unrolling the final 32 lanes. "
        "On AMD the loop should terminate at s>64 to correctly handle the 64-lane warp tail.",
        "Change loop bound from s>32 to s>64. Expand the manual unroll below the loop "
        "to cover tid<64 before the tid<32 block."
    ),
    (
        "inline_ptx_block",
        re.compile(r'asm\s+volatile\s*\(', re.MULTILINE),
        "CRITICAL",
        "Inline PTX assembly is NVIDIA-specific ISA. hipify cannot translate PTX semantics. "
        "The kernel may compile under hipcc but will have undefined or incorrect behaviour.",
        "Replace inline PTX with portable HIP intrinsics or CDNA ISA equivalents. "
        "Common cases: lane_id → __lane_id(), __clz → __clz() (same name in HIP)."
    ),
    (
        "cuda_runtime_include",
        re.compile(r'#\s*include\s*[<\"]cuda_runtime(?:_api)?\.h[>\"]', re.MULTILINE),
        "MEDIUM",
        "cuda_runtime.h / cuda_runtime_api.h must be replaced with hip/hip_runtime.h. "
        "hipify handles this mechanically but the check confirms it was applied.",
        "Replace with #include <hip/hip_runtime.h>. "
        "hipify-clang does this automatically in its first pass."
    ),
    (
        "cuda_library_dependency",
        re.compile(r'#\s*include\s*[<"][^>"]*(?:cub|thrust|cudnn)[^>"]*[>"]|\b(?:cub|thrust|cudnn)::', re.MULTILINE),
        "HIGH",
        "CUDA library dependency detected. hipify can rename some CUB/Thrust/cuDNN symbols, "
        "but API coverage and performance behavior are not guaranteed to match rocPRIM/hipCUB/MIOpen.",
        "Manually review the translated library call, compare against rocPRIM/hipCUB/MIOpen, "
        "and add correctness/performance tests for the specific primitive."
    ),
    (
        "shared_memory_no_padding",
        re.compile(r'__shared__\s+\w+\s+\w+\s*\[\s*\d+\s*\]', re.MULTILINE),
        "MEDIUM",
        "Fixed-size shared memory array detected without padding. AMD LDS has 32 banks of 4B. "
        "Arrays whose inner dimension is a power-of-2 may cause systematic bank conflicts.",
        "Add +1 padding to the inner dimension, e.g., __shared__ float tile[32][33]. "
        "This staggers accesses across banks and eliminates the conflict."
    ),
]


def _find_line_number(code: str, match_start: int) -> int:
    """Convert a character offset into a 1-indexed line number."""
    return code[:match_start].count('\n') + 1


def scan(cuda_code: str) -> StaticRiskReport:
    """
    Scan CUDA source for AMD compatibility hazards.

    Returns a StaticRiskReport with structured RiskItems, counts by severity,
    and the wall-clock scan duration for transparency.
    """
    t0 = time.perf_counter()
    items: List[RiskItem] = []

    for pattern_name, regex, risk_level, description, amd_fix_hint in _PATTERNS:
        for match in regex.finditer(cuda_code):
            line_num = _find_line_number(cuda_code, match.start())
            items.append(RiskItem(
                line=line_num,
                pattern=pattern_name,
                risk_level=risk_level,
                description=description,
                amd_fix_hint=amd_fix_hint,
            ))

    elapsed_ms = (time.perf_counter() - t0) * 1000.0

    critical = sum(1 for i in items if i.risk_level == "CRITICAL")
    high = sum(1 for i in items if i.risk_level == "HIGH")
    medium = sum(1 for i in items if i.risk_level == "MEDIUM")

    return StaticRiskReport(
        items=items,
        critical_count=critical,
        high_count=high,
        medium_count=medium,
        scan_duration_ms=round(elapsed_ms, 3),
    )


def format_for_llm_prompt(report: StaticRiskReport) -> str:
    """
    Render the static report as a compact context block to inject into LLM prompts.
    Keeps token usage low while giving the LLM grounded, actionable pre-analysis.
    """
    if not report.items:
        return "Static pre-scan: No known AMD compatibility hazards detected."

    lines = [
        f"=== STATIC PRE-SCAN ({report.critical_count} CRITICAL, "
        f"{report.high_count} HIGH, {report.medium_count} MEDIUM) ===",
        "The following hazards were detected by deterministic pattern matching BEFORE LLM analysis.",
        "Confirm and expand on these findings — do NOT contradict them without strong evidence.",
        "",
    ]
    for item in report.items:
        loc = f"line {item.line}" if item.line else "location unknown"
        lines.append(f"[{item.risk_level}] {item.pattern} @ {loc}")
        lines.append(f"  Issue: {item.description}")
        lines.append(f"  Fix:   {item.amd_fix_hint}")
        lines.append("")

    return "\n".join(lines)