codellama-fine-tuning / reformat_dataset_for_codellama.py
Prithvik-1's picture
Upload reformat_dataset_for_codellama.py with huggingface_hub
0361c24 verified
#!/usr/bin/env python3
"""
Reformat dataset to use CodeLlama chat template format
CodeLlama-Instruct expects: <s>[INST] <<SYS>>...<</SYS>> User [/INST] Response </s>
"""
import json
import sys
from pathlib import Path
from transformers import AutoTokenizer
def extract_system_and_user(instruction: str):
"""Extract system prompt and user message from instruction"""
# The instruction contains: "System prompt...\n\nTask description"
parts = instruction.split("\n\n", 1)
if len(parts) == 2:
system_msg = parts[0].strip()
user_msg = parts[1].strip()
# Check if system message contains the role description
if "Elinnos RTL Code Generator" in system_msg or "specialized Verilog" in system_msg:
return system_msg, user_msg
# Default: extract system prompt
if "You are" in instruction and "\n\n" in instruction:
parts = instruction.split("\n\n", 1)
system_msg = parts[0]
user_msg = parts[1] if len(parts) > 1 else ""
return system_msg, user_msg
# Fallback: use default system prompt
system_msg = "You are Elinnos RTL Code Generator v1.0, a specialized Verilog/SystemVerilog code generation agent. Your role: Generate clean, synthesizable RTL code for hardware design tasks. Output ONLY functional RTL code with no $display, assertions, comments, or debug statements."
user_msg = instruction
return system_msg, user_msg
def reformat_dataset(input_file: str, output_file: str):
"""Reformat dataset to use CodeLlama chat template format"""
print("=" * 80)
print("๐Ÿ”„ REFORMATTING DATASET FOR CODELLAMA CHAT TEMPLATE")
print("=" * 80)
# Load tokenizer
tokenizer_path = "models/base-models/CodeLlama-7B-Instruct"
print(f"\n๐Ÿ“ฆ Loading tokenizer from: {tokenizer_path}")
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
# Read input dataset
samples = []
with open(input_file, 'r', encoding='utf-8') as f:
for line in f:
if line.strip():
try:
samples.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"โš ๏ธ Skipping invalid JSON: {e}")
continue
print(f"โœ… Loaded {len(samples)} samples from {input_file}")
# Reformat each sample
reformatted_samples = []
for i, sample in enumerate(samples, 1):
instruction = sample.get("instruction", "").strip()
response = sample.get("response", "").strip()
if not instruction or not response:
print(f"โš ๏ธ Skipping sample {i}: missing instruction or response")
continue
# Extract system and user messages
system_message, user_message = extract_system_and_user(instruction)
# Create messages for CodeLlama chat template
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": user_message}
]
# Apply chat template to get the prompt part
# This will create: <s>[INST] <<SYS>>...<</SYS>> User [/INST]
formatted_prompt = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True # Adds [/INST] at the end, ready for generation
)
# For training, we need:
# formatted_prompt + response + EOS
# The formatted_prompt already ends with [/INST]
# We append response + EOS
reformatted_samples.append({
"instruction": formatted_prompt, # The prompt part (ends with [/INST])
"response": response # What model should generate
})
if i % 10 == 0:
print(f" Processed {i}/{len(samples)} samples...")
# Save reformatted dataset
output_path = Path(output_file)
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_file, 'w', encoding='utf-8') as f:
for sample in reformatted_samples:
f.write(json.dumps(sample, ensure_ascii=False) + '\n')
print(f"\nโœ… Reformatted {len(reformatted_samples)} samples")
print(f"๐Ÿ’พ Saved to: {output_file}")
# Show example
if reformatted_samples:
print("\n๐Ÿ“ Example reformatted sample:")
print("-" * 80)
example = reformatted_samples[0]
print(f"Instruction (first 400 chars):")
print(example["instruction"][:400] + "...")
print(f"\nResponse (first 200 chars):")
print(example["response"][:200] + "...")
print("=" * 80)
return len(reformatted_samples)
if __name__ == "__main__":
script_dir = Path(__file__).parent
input_file = script_dir / "datasets" / "processed" / "elinnos_fifo_codellama_v1.jsonl"
output_file = script_dir / "datasets" / "processed" / "elinnos_fifo_codellama_chat_format.jsonl"
if not input_file.exists():
print(f"โŒ Error: Input file not found: {input_file}")
sys.exit(1)
count = reformat_dataset(str(input_file), str(output_file))
print(f"\nโœ… Successfully reformatted {count} samples!")
print(f"\nNext steps:")
print(f"1. Split the reformatted dataset: python3 scripts/dataset_split.py --input {output_file}")
print(f"2. Update training script to use chat template format")
print(f"3. Retrain with new format")