File size: 5,490 Bytes
0361c24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
"""
Reformat dataset to use CodeLlama chat template format
CodeLlama-Instruct expects: <s>[INST] <<SYS>>...<</SYS>> User [/INST] Response </s>
"""

import json
import sys
from pathlib import Path
from transformers import AutoTokenizer

def extract_system_and_user(instruction: str):
    """Extract system prompt and user message from instruction"""
    # The instruction contains: "System prompt...\n\nTask description"
    parts = instruction.split("\n\n", 1)
    
    if len(parts) == 2:
        system_msg = parts[0].strip()
        user_msg = parts[1].strip()
        
        # Check if system message contains the role description
        if "Elinnos RTL Code Generator" in system_msg or "specialized Verilog" in system_msg:
            return system_msg, user_msg
    
    # Default: extract system prompt
    if "You are" in instruction and "\n\n" in instruction:
        parts = instruction.split("\n\n", 1)
        system_msg = parts[0]
        user_msg = parts[1] if len(parts) > 1 else ""
        return system_msg, user_msg
    
    # Fallback: use default system prompt
    system_msg = "You are Elinnos RTL Code Generator v1.0, a specialized Verilog/SystemVerilog code generation agent. Your role: Generate clean, synthesizable RTL code for hardware design tasks. Output ONLY functional RTL code with no $display, assertions, comments, or debug statements."
    user_msg = instruction
    
    return system_msg, user_msg

def reformat_dataset(input_file: str, output_file: str):
    """Reformat dataset to use CodeLlama chat template format"""
    
    print("=" * 80)
    print("🔄 REFORMATTING DATASET FOR CODELLAMA CHAT TEMPLATE")
    print("=" * 80)
    
    # Load tokenizer
    tokenizer_path = "models/base-models/CodeLlama-7B-Instruct"
    print(f"\n📦 Loading tokenizer from: {tokenizer_path}")
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
    
    # Read input dataset
    samples = []
    with open(input_file, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                try:
                    samples.append(json.loads(line))
                except json.JSONDecodeError as e:
                    print(f"⚠️  Skipping invalid JSON: {e}")
                    continue
    
    print(f"✅ Loaded {len(samples)} samples from {input_file}")
    
    # Reformat each sample
    reformatted_samples = []
    
    for i, sample in enumerate(samples, 1):
        instruction = sample.get("instruction", "").strip()
        response = sample.get("response", "").strip()
        
        if not instruction or not response:
            print(f"⚠️  Skipping sample {i}: missing instruction or response")
            continue
        
        # Extract system and user messages
        system_message, user_message = extract_system_and_user(instruction)
        
        # Create messages for CodeLlama chat template
        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message}
        ]
        
        # Apply chat template to get the prompt part
        # This will create: <s>[INST] <<SYS>>...<</SYS>> User [/INST]
        formatted_prompt = tokenizer.apply_chat_template(
            messages,
            tokenize=False,
            add_generation_prompt=True  # Adds [/INST] at the end, ready for generation
        )
        
        # For training, we need:
        # formatted_prompt + response + EOS
        # The formatted_prompt already ends with [/INST]
        # We append response + EOS
        
        reformatted_samples.append({
            "instruction": formatted_prompt,  # The prompt part (ends with [/INST])
            "response": response  # What model should generate
        })
        
        if i % 10 == 0:
            print(f"   Processed {i}/{len(samples)} samples...")
    
    # Save reformatted dataset
    output_path = Path(output_file)
    output_path.parent.mkdir(parents=True, exist_ok=True)
    
    with open(output_file, 'w', encoding='utf-8') as f:
        for sample in reformatted_samples:
            f.write(json.dumps(sample, ensure_ascii=False) + '\n')
    
    print(f"\n✅ Reformatted {len(reformatted_samples)} samples")
    print(f"💾 Saved to: {output_file}")
    
    # Show example
    if reformatted_samples:
        print("\n📝 Example reformatted sample:")
        print("-" * 80)
        example = reformatted_samples[0]
        print(f"Instruction (first 400 chars):")
        print(example["instruction"][:400] + "...")
        print(f"\nResponse (first 200 chars):")
        print(example["response"][:200] + "...")
        print("=" * 80)
    
    return len(reformatted_samples)

if __name__ == "__main__":
    script_dir = Path(__file__).parent
    
    input_file = script_dir / "datasets" / "processed" / "elinnos_fifo_codellama_v1.jsonl"
    output_file = script_dir / "datasets" / "processed" / "elinnos_fifo_codellama_chat_format.jsonl"
    
    if not input_file.exists():
        print(f"❌ Error: Input file not found: {input_file}")
        sys.exit(1)
    
    count = reformat_dataset(str(input_file), str(output_file))
    print(f"\n✅ Successfully reformatted {count} samples!")
    print(f"\nNext steps:")
    print(f"1. Split the reformatted dataset: python3 scripts/dataset_split.py --input {output_file}")
    print(f"2. Update training script to use chat template format")
    print(f"3. Retrain with new format")