File size: 4,984 Bytes
4942b80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
#!/usr/bin/env python3
"""
Dataset preparation for Gemma 4 fine-tuning.
Converts raw datasets into Gemma 4 chat format and saves to data/processed/.

Usage:
    python scripts/prepare_data.py --dataset <name> --output data/processed/train.jsonl
"""

import argparse
import json
import os
from datasets import load_dataset


def convert_to_gemma4_chat(example, system_prompt=None):
    """Convert a single example to Gemma 4 chat format.

    Gemma 4 uses "model" (not "assistant") as the role name.
    """
    messages = []

    if system_prompt:
        messages.append({"role": "system", "content": system_prompt})

    # Handle different dataset formats
    if "conversations" in example:
        for turn in example["conversations"]:
            role = turn.get("role", turn.get("from", ""))
            content = turn.get("content", turn.get("value", ""))
            # Normalize roles
            if role in ("assistant", "gpt", "bot"):
                role = "model"
            elif role in ("human", "user"):
                role = "user"
            messages.append({"role": role, "content": content})

    elif "messages" in example:
        for msg in example["messages"]:
            role = msg["role"]
            if role == "assistant":
                role = "model"
            messages.append({"role": role, "content": msg["content"]})

    elif "instruction" in example:
        user_content = example["instruction"]
        if example.get("input"):
            user_content += f"\n\nInput: {example['input']}"
        messages.append({"role": "user", "content": user_content})
        messages.append({"role": "model", "content": example["output"]})

    elif "question" in example and "answer" in example:
        messages.append({"role": "user", "content": example["question"]})
        messages.append({"role": "model", "content": example["answer"]})

    elif "prompt" in example and "response" in example:
        messages.append({"role": "user", "content": example["prompt"]})
        messages.append({"role": "model", "content": example["response"]})

    else:
        raise ValueError(f"Unknown dataset format. Keys: {list(example.keys())}")

    return {"messages": messages}


def load_and_convert(dataset_name, split="train", system_prompt=None, max_samples=None):
    """Load a HuggingFace dataset and convert to Gemma 4 format."""
    print(f"Loading dataset: {dataset_name} (split={split})")

    if max_samples:
        dataset = load_dataset(dataset_name, split=f"{split}[:{max_samples}]")
    else:
        dataset = load_dataset(dataset_name, split=split)

    print(f"Loaded {len(dataset)} examples")

    converted = []
    errors = 0
    for i, example in enumerate(dataset):
        try:
            converted.append(convert_to_gemma4_chat(example, system_prompt))
        except ValueError as e:
            if errors == 0:
                print(f"  Warning: {e}")
            errors += 1

    if errors:
        print(f"  Skipped {errors} examples due to format errors")

    print(f"Converted {len(converted)} examples to Gemma 4 chat format")
    return converted


def save_jsonl(data, output_path):
    """Save data as JSONL file."""
    os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
    with open(output_path, "w") as f:
        for item in data:
            f.write(json.dumps(item) + "\n")
    print(f"Saved {len(data)} examples to {output_path}")


def main():
    parser = argparse.ArgumentParser(description="Prepare dataset for Gemma 4 fine-tuning")
    parser.add_argument("--dataset", type=str, required=True,
                        help="HuggingFace dataset name (e.g., 'mlabonne/FineTome-100k')")
    parser.add_argument("--split", type=str, default="train",
                        help="Dataset split to use")
    parser.add_argument("--output", type=str, default="data/processed/train.jsonl",
                        help="Output JSONL file path")
    parser.add_argument("--system-prompt", type=str, default=None,
                        help="System prompt to prepend to every conversation")
    parser.add_argument("--max-samples", type=int, default=None,
                        help="Maximum number of samples to use")
    parser.add_argument("--eval-split", type=float, default=0.05,
                        help="Fraction of data to hold out for evaluation (0 to disable)")
    args = parser.parse_args()

    data = load_and_convert(
        args.dataset,
        split=args.split,
        system_prompt=args.system_prompt,
        max_samples=args.max_samples,
    )

    if args.eval_split > 0 and len(data) > 20:
        eval_size = max(1, int(len(data) * args.eval_split))
        train_data = data[:-eval_size]
        eval_data = data[-eval_size:]

        save_jsonl(train_data, args.output)
        eval_path = args.output.replace(".jsonl", "_eval.jsonl")
        save_jsonl(eval_data, eval_path)
    else:
        save_jsonl(data, args.output)


if __name__ == "__main__":
    main()