|
|
--- |
|
|
base_model: |
|
|
- Qwen/QwQ-32B |
|
|
tags: |
|
|
- code |
|
|
--- |
|
|
|
|
|
# Model Summary |
|
|
KernelCoder is trained on a curated dataset of reasoning traces and CUDA kernel pairs. |
|
|
|
|
|
See details in [paper](https://lkongam.github.io/ConCuR/). |
|
|
|
|
|
# Usage |
|
|
|
|
|
```python |
|
|
from vllm import LLM, SamplingParams |
|
|
from transformers import AutoTokenizer |
|
|
import torch |
|
|
import re |
|
|
from typing import List, Tuple |
|
|
from string import Template |
|
|
PROMPT_TEMPLATE = Template(''' |
|
|
''') |
|
|
|
|
|
class KernelCoder: |
|
|
|
|
|
def __init__(self, model_name="lkongam/KernelCoder", tensor_parallel_size=1, gpu_memory_utilization=0.9): |
|
|
|
|
|
self.model_name = model_name |
|
|
|
|
|
self.llm = LLM( |
|
|
model=model_name, |
|
|
tensor_parallel_size=tensor_parallel_size, |
|
|
gpu_memory_utilization=gpu_memory_utilization, |
|
|
trust_remote_code=True, |
|
|
dtype="auto" |
|
|
) |
|
|
|
|
|
self.tokenizer = self.llm.get_tokenizer() |
|
|
self.device = torch.device("cuda") |
|
|
|
|
|
def generate_raw(self, prompt, temperature=1.0): |
|
|
messages = [ |
|
|
{"role": "user", "content": prompt} |
|
|
] |
|
|
text = self.tokenizer.apply_chat_template( |
|
|
messages, |
|
|
tokenize=False, |
|
|
add_generation_prompt=True, |
|
|
enable_thinking=True |
|
|
) |
|
|
return text |
|
|
|
|
|
def extract_last_code_block(text): |
|
|
code_blocks = re.findall(r"```(?:python)?\n(.*?)```", text, re.DOTALL) |
|
|
if code_blocks: |
|
|
return code_blocks[-1].strip() |
|
|
match = re.search(r"</think>(.*)", text, re.S) |
|
|
after_think = match.group(1).strip() if match else text |
|
|
if not after_think: |
|
|
return None |
|
|
import_match = re.search(r"\bimport\b", after_think) |
|
|
if import_match: |
|
|
return after_think[import_match.start():].strip() |
|
|
return after_think.strip() |
|
|
|
|
|
origin_code = """ |
|
|
""" |
|
|
|
|
|
model = KernelCoder(model_name="lkongam/KernelCoder") |
|
|
|
|
|
prompt = PROMPT_TEMPLATE.substitute(code=origin_code) |
|
|
code_output = model.generate_raw(prompt) |
|
|
code = extract_last_code_block(code_output) |
|
|
print(code) |
|
|
``` |
|
|
|
|
|
# Evaluation |
|
|
 |
|
|
|
|
|
Left: Pass@1, Right: Pass@10. |