File size: 1,791 Bytes
87ef8e4
 
 
 
 
b73818c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46c2f22
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
---
base_model:
- Qwen/QwQ-32B
tags:
- code
---

# Model Summary
KernelCoder is trained on a curated dataset of reasoning traces and CUDA kernel pairs.

See details in [paper](https://lkongam.github.io/ConCuR/).

# Usage

```python
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import torch
import re
from typing import List, Tuple
from string import Template
PROMPT_TEMPLATE = Template('''
''')

class KernelCoder:

def __init__(self, model_name="lkongam/KernelCoder", tensor_parallel_size=1, gpu_memory_utilization=0.9):

self.model_name = model_name

self.llm = LLM(
model=model_name,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
trust_remote_code=True,
dtype="auto"
)

self.tokenizer = self.llm.get_tokenizer()
self.device = torch.device("cuda")

def generate_raw(self, prompt, temperature=1.0):
messages = [
{"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True
)
return text

def extract_last_code_block(text):
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL)
if code_blocks:
return code_blocks[-1].strip()
match = re.search(r"</think>(.*)", text, re.S)
after_think = match.group(1).strip() if match else text
if not after_think:
return None
import_match = re.search(r"\bimport\b", after_think)
if import_match:
return after_think[import_match.start():].strip()
return after_think.strip()

origin_code = """
"""

model = KernelCoder(model_name="lkongam/KernelCoder")

prompt = PROMPT_TEMPLATE.substitute(code=origin_code)
code_output = model.generate_raw(prompt)
code = extract_last_code_block(code_output)
print(code)
```

# Evaluation
![sta](./stas.png)

Left: Pass@1, Right: Pass@10.