File size: 6,709 Bytes
a5be23e
0b5416e
a5be23e
 
 
 
984e3c2
1a6672d
 
 
a5be23e
28263c0
1a6672d
28263c0
1a6672d
a5be23e
1a6672d
 
a5be23e
1a6672d
984e3c2
 
 
 
1a6672d
984e3c2
 
 
 
1a6672d
 
 
a5be23e
984e3c2
 
1a6672d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
984e3c2
 
 
 
1a6672d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a5be23e
984e3c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28263c0
 
 
984e3c2
28263c0
 
 
 
 
 
 
0b5416e
 
984e3c2
28263c0
 
 
984e3c2
 
 
 
 
 
 
28263c0
 
 
984e3c2
28263c0
 
 
a5be23e
27c4e2c
 
 
 
1a6672d
 
 
 
 
 
 
 
 
 
 
 
 
984e3c2
 
1a6672d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# pylint: disable=broad-exception-caught
import logging

from ..models import AnalyzerResult, WorkloadType
from ..tools.llm_client import LLMClient
from ..tools.json_utils import safe_json_loads
from ..tools import static_analyzer

llm_client = LLMClient()


def chat_complete(messages: list, temperature: float = 0.7, max_tokens: int = 4000) -> str:
    """Wrapper for LLM client chat completion"""
    return llm_client.chat_completion(messages, temperature=temperature, max_tokens=max_tokens)


def generate_prediction(workload_type: WorkloadType, line_count: int) -> str:
    """Generate performance prediction based on workload analysis"""
    size_hint = "large" if line_count and line_count > 200 else "small/medium"
    if workload_type == WorkloadType.MEMORY_BOUND:
        return (
            f"🧠 Prediction: This {size_hint} kernel is memory-bound → "
            "HIGH potential gain on MI300X (5.3 TB/s vs H100 3.35 TB/s bandwidth)"
        )
    elif workload_type == WorkloadType.COMPUTE_BOUND:
        return (
            f"🧠 Prediction: This {size_hint} kernel is compute-bound → "
            "MODERATE gain on MI300X (wavefront efficiency improvements)"
        )
    else:
        return "🧠 Prediction: Unknown workload type → LIMITED gain prediction without further analysis"


# Base system prompt — static-scan context is injected at call time
_BASE_SYSTEM_PROMPT = """You are an expert CUDA and GPU architecture engineer analyzing CUDA code before porting it to AMD ROCm/HIP.

Your job is to deeply analyze CUDA code and output a structured JSON analysis. Be specific and technical.

CRITICAL things to detect:
1. All CUDA kernel functions (__global__ functions)
2. All CUDA API calls (cudaMalloc, cudaMemcpy, cudaFree, etc.)
3. Warp size assumptions - NVIDIA warp = 32, AMD wavefront = 64. This causes SILENT BUGS.
   Look for: warpSize, __shfl_*, __ballot_sync, hardcoded 32 in thread calculations, WARP_SIZE defines
4. Workload type classification:
   - memory-bound: lots of global memory reads/writes, low arithmetic intensity
   - compute-bound: lots of math operations, high reuse of loaded data
5. Multi-GPU sharding code (written for NVIDIA's 80GB limit - unnecessary on MI300X 192GB)
6. Porting difficulty
7. Code complexity estimation (line count, nested loops, memory access patterns)

A static pre-scan has already run and its findings are included below your instructions.
You MUST confirm those findings and MAY add additional findings.
Do NOT contradict the static scan without strong evidence from the code.

Respond ONLY with this exact JSON structure, no markdown, no extra text:
{
  "kernels_found": ["kernel1", "kernel2"],
  "cuda_apis": ["cudaMalloc", "cudaMemcpy"],
  "warp_size_issue": true,
  "warp_size_detail": "Line 23: hardcoded warpSize=32 in block reduction. AMD wavefront=64 -- this will produce incorrect results.",
  "workload_type": "memory-bound",
  "sharding_detected": false,
  "difficulty": "Medium",
  "difficulty_reason": "Warp-level primitives require manual rewriting beyond hipify scope",
  "line_count": 150,
  "complexity_score": 7
}"""


def run(cuda_code: str) -> AnalyzerResult:
    # Count lines for complexity estimation
    line_count = len([line for line in cuda_code.split('\n') if line.strip()])

    # -----------------------------------------------------------------------
    # Step 1: Pure-Python static scan — runs before the LLM, zero cost, <5ms
    # -----------------------------------------------------------------------
    risk_report = static_analyzer.scan(cuda_code)
    static_context = static_analyzer.format_for_llm_prompt(risk_report)

    # -----------------------------------------------------------------------
    # Step 2: Build grounded system prompt with static findings pre-injected
    # -----------------------------------------------------------------------
    system_prompt = _BASE_SYSTEM_PROMPT + "\n\n" + static_context

    # Force warp_size_issue=true in JSON if static scan caught CRITICAL items
    # This prevents the LLM from missing bugs the static pass already confirmed
    force_warp_hint = ""
    if risk_report.critical_count > 0:
        critical_patterns = [
            item.pattern for item in risk_report.items if item.risk_level == "CRITICAL"
        ]
        force_warp_hint = (
            f"\n\nIMPORTANT: The static scan found {risk_report.critical_count} CRITICAL "
            f"warp-size issue(s): {', '.join(critical_patterns)}. "
            "You MUST set warp_size_issue=true in your JSON response."
        )

    try:
        raw = chat_complete(
            messages=[
                {"role": "system", "content": system_prompt + force_warp_hint},
                {"role": "user", "content": f"Analyze this CUDA code:\n\n```cuda\n{cuda_code}\n```"}
            ],
            temperature=0.1,
            max_tokens=1024,
        )
        data = safe_json_loads(raw)
    except Exception:
        logging.exception(
            "Analyzer LLM call failed; falling back to static-scan defaults")
        # Fallback to static-scan-informed defaults on LLM/parse failure
        data = {
            "kernels_found": ["unknown_kernel"],
            "cuda_apis": [],
            # If static scan found critical warp issues, preserve that signal in fallback
            "warp_size_issue": risk_report.critical_count > 0,
            "warp_size_detail": (
                risk_report.items[0].description
                if risk_report.critical_count > 0
                else None
            ),
            "workload_type": "memory-bound",
            "sharding_detected": False,
            "difficulty": "Medium",
            "difficulty_reason": "LLM analysis failed; static scan findings preserved",
            "line_count": line_count,
            "complexity_score": 5
        }

    try:
        workload_type = WorkloadType(data.get("workload_type", "unknown"))
    except ValueError:
        workload_type = WorkloadType.UNKNOWN
    prediction = generate_prediction(workload_type, line_count)

    return AnalyzerResult(
        kernels_found=data.get("kernels_found", []),
        cuda_apis=data.get("cuda_apis", []),
        warp_size_issue=data.get("warp_size_issue", False),
        warp_size_detail=data.get("warp_size_detail"),
        workload_type=workload_type,
        sharding_detected=data.get("sharding_detected", False),
        difficulty=data.get("difficulty", "Medium"),
        difficulty_reason=data.get("difficulty_reason", ""),
        prediction=prediction,
        line_count=data.get("line_count", line_count),
        complexity_score=data.get("complexity_score", 5),
        static_risk_report=risk_report,
    )