KernelCoder / README.md
lkongam's picture
Update README.md
46c2f22 verified
---
base_model:
- Qwen/QwQ-32B
tags:
- code
---
# Model Summary
KernelCoder is trained on a curated dataset of reasoning traces and CUDA kernel pairs.
See details in [paper](https://lkongam.github.io/ConCuR/).
# Usage
```python
from vllm import LLM, SamplingParams
from transformers import AutoTokenizer
import torch
import re
from typing import List, Tuple
from string import Template
PROMPT_TEMPLATE = Template('''
''')
class KernelCoder:
def __init__(self, model_name="lkongam/KernelCoder", tensor_parallel_size=1, gpu_memory_utilization=0.9):
self.model_name = model_name
self.llm = LLM(
model=model_name,
tensor_parallel_size=tensor_parallel_size,
gpu_memory_utilization=gpu_memory_utilization,
trust_remote_code=True,
dtype="auto"
)
self.tokenizer = self.llm.get_tokenizer()
self.device = torch.device("cuda")
def generate_raw(self, prompt, temperature=1.0):
messages = [
{"role": "user", "content": prompt}
]
text = self.tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=True
)
return text
def extract_last_code_block(text):
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL)
if code_blocks:
return code_blocks[-1].strip()
match = re.search(r"</think>(.*)", text, re.S)
after_think = match.group(1).strip() if match else text
if not after_think:
return None
import_match = re.search(r"\bimport\b", after_think)
if import_match:
return after_think[import_match.start():].strip()
return after_think.strip()
origin_code = """
"""
model = KernelCoder(model_name="lkongam/KernelCoder")
prompt = PROMPT_TEMPLATE.substitute(code=origin_code)
code_output = model.generate_raw(prompt)
code = extract_last_code_block(code_output)
print(code)
```
# Evaluation
![sta](./stas.png)
Left: Pass@1, Right: Pass@10.