banhmi-gemma4-e4b / scripts /prepare_data.py
bradduy's picture
Add Unsloth training pipeline (train, evaluate, export, prepare_data, training_logger)
4942b80 verified
#!/usr/bin/env python3
"""
Dataset preparation for Gemma 4 fine-tuning.
Converts raw datasets into Gemma 4 chat format and saves to data/processed/.
Usage:
python scripts/prepare_data.py --dataset <name> --output data/processed/train.jsonl
"""
import argparse
import json
import os
from datasets import load_dataset
def convert_to_gemma4_chat(example, system_prompt=None):
"""Convert a single example to Gemma 4 chat format.
Gemma 4 uses "model" (not "assistant") as the role name.
"""
messages = []
if system_prompt:
messages.append({"role": "system", "content": system_prompt})
# Handle different dataset formats
if "conversations" in example:
for turn in example["conversations"]:
role = turn.get("role", turn.get("from", ""))
content = turn.get("content", turn.get("value", ""))
# Normalize roles
if role in ("assistant", "gpt", "bot"):
role = "model"
elif role in ("human", "user"):
role = "user"
messages.append({"role": role, "content": content})
elif "messages" in example:
for msg in example["messages"]:
role = msg["role"]
if role == "assistant":
role = "model"
messages.append({"role": role, "content": msg["content"]})
elif "instruction" in example:
user_content = example["instruction"]
if example.get("input"):
user_content += f"\n\nInput: {example['input']}"
messages.append({"role": "user", "content": user_content})
messages.append({"role": "model", "content": example["output"]})
elif "question" in example and "answer" in example:
messages.append({"role": "user", "content": example["question"]})
messages.append({"role": "model", "content": example["answer"]})
elif "prompt" in example and "response" in example:
messages.append({"role": "user", "content": example["prompt"]})
messages.append({"role": "model", "content": example["response"]})
else:
raise ValueError(f"Unknown dataset format. Keys: {list(example.keys())}")
return {"messages": messages}
def load_and_convert(dataset_name, split="train", system_prompt=None, max_samples=None):
"""Load a HuggingFace dataset and convert to Gemma 4 format."""
print(f"Loading dataset: {dataset_name} (split={split})")
if max_samples:
dataset = load_dataset(dataset_name, split=f"{split}[:{max_samples}]")
else:
dataset = load_dataset(dataset_name, split=split)
print(f"Loaded {len(dataset)} examples")
converted = []
errors = 0
for i, example in enumerate(dataset):
try:
converted.append(convert_to_gemma4_chat(example, system_prompt))
except ValueError as e:
if errors == 0:
print(f" Warning: {e}")
errors += 1
if errors:
print(f" Skipped {errors} examples due to format errors")
print(f"Converted {len(converted)} examples to Gemma 4 chat format")
return converted
def save_jsonl(data, output_path):
"""Save data as JSONL file."""
os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True)
with open(output_path, "w") as f:
for item in data:
f.write(json.dumps(item) + "\n")
print(f"Saved {len(data)} examples to {output_path}")
def main():
parser = argparse.ArgumentParser(description="Prepare dataset for Gemma 4 fine-tuning")
parser.add_argument("--dataset", type=str, required=True,
help="HuggingFace dataset name (e.g., 'mlabonne/FineTome-100k')")
parser.add_argument("--split", type=str, default="train",
help="Dataset split to use")
parser.add_argument("--output", type=str, default="data/processed/train.jsonl",
help="Output JSONL file path")
parser.add_argument("--system-prompt", type=str, default=None,
help="System prompt to prepend to every conversation")
parser.add_argument("--max-samples", type=int, default=None,
help="Maximum number of samples to use")
parser.add_argument("--eval-split", type=float, default=0.05,
help="Fraction of data to hold out for evaluation (0 to disable)")
args = parser.parse_args()
data = load_and_convert(
args.dataset,
split=args.split,
system_prompt=args.system_prompt,
max_samples=args.max_samples,
)
if args.eval_split > 0 and len(data) > 20:
eval_size = max(1, int(len(data) * args.eval_split))
train_data = data[:-eval_size]
eval_data = data[-eval_size:]
save_jsonl(train_data, args.output)
eval_path = args.output.replace(".jsonl", "_eval.jsonl")
save_jsonl(eval_data, eval_path)
else:
save_jsonl(data, args.output)
if __name__ == "__main__":
main()