#!/usr/bin/env python3 """ Dataset preparation for Gemma 4 fine-tuning. Converts raw datasets into Gemma 4 chat format and saves to data/processed/. Usage: python scripts/prepare_data.py --dataset --output data/processed/train.jsonl """ import argparse import json import os from datasets import load_dataset def convert_to_gemma4_chat(example, system_prompt=None): """Convert a single example to Gemma 4 chat format. Gemma 4 uses "model" (not "assistant") as the role name. """ messages = [] if system_prompt: messages.append({"role": "system", "content": system_prompt}) # Handle different dataset formats if "conversations" in example: for turn in example["conversations"]: role = turn.get("role", turn.get("from", "")) content = turn.get("content", turn.get("value", "")) # Normalize roles if role in ("assistant", "gpt", "bot"): role = "model" elif role in ("human", "user"): role = "user" messages.append({"role": role, "content": content}) elif "messages" in example: for msg in example["messages"]: role = msg["role"] if role == "assistant": role = "model" messages.append({"role": role, "content": msg["content"]}) elif "instruction" in example: user_content = example["instruction"] if example.get("input"): user_content += f"\n\nInput: {example['input']}" messages.append({"role": "user", "content": user_content}) messages.append({"role": "model", "content": example["output"]}) elif "question" in example and "answer" in example: messages.append({"role": "user", "content": example["question"]}) messages.append({"role": "model", "content": example["answer"]}) elif "prompt" in example and "response" in example: messages.append({"role": "user", "content": example["prompt"]}) messages.append({"role": "model", "content": example["response"]}) else: raise ValueError(f"Unknown dataset format. Keys: {list(example.keys())}") return {"messages": messages} def load_and_convert(dataset_name, split="train", system_prompt=None, max_samples=None): """Load a HuggingFace dataset and convert to Gemma 4 format.""" print(f"Loading dataset: {dataset_name} (split={split})") if max_samples: dataset = load_dataset(dataset_name, split=f"{split}[:{max_samples}]") else: dataset = load_dataset(dataset_name, split=split) print(f"Loaded {len(dataset)} examples") converted = [] errors = 0 for i, example in enumerate(dataset): try: converted.append(convert_to_gemma4_chat(example, system_prompt)) except ValueError as e: if errors == 0: print(f" Warning: {e}") errors += 1 if errors: print(f" Skipped {errors} examples due to format errors") print(f"Converted {len(converted)} examples to Gemma 4 chat format") return converted def save_jsonl(data, output_path): """Save data as JSONL file.""" os.makedirs(os.path.dirname(output_path) or ".", exist_ok=True) with open(output_path, "w") as f: for item in data: f.write(json.dumps(item) + "\n") print(f"Saved {len(data)} examples to {output_path}") def main(): parser = argparse.ArgumentParser(description="Prepare dataset for Gemma 4 fine-tuning") parser.add_argument("--dataset", type=str, required=True, help="HuggingFace dataset name (e.g., 'mlabonne/FineTome-100k')") parser.add_argument("--split", type=str, default="train", help="Dataset split to use") parser.add_argument("--output", type=str, default="data/processed/train.jsonl", help="Output JSONL file path") parser.add_argument("--system-prompt", type=str, default=None, help="System prompt to prepend to every conversation") parser.add_argument("--max-samples", type=int, default=None, help="Maximum number of samples to use") parser.add_argument("--eval-split", type=float, default=0.05, help="Fraction of data to hold out for evaluation (0 to disable)") args = parser.parse_args() data = load_and_convert( args.dataset, split=args.split, system_prompt=args.system_prompt, max_samples=args.max_samples, ) if args.eval_split > 0 and len(data) > 20: eval_size = max(1, int(len(data) * args.eval_split)) train_data = data[:-eval_size] eval_data = data[-eval_size:] save_jsonl(train_data, args.output) eval_path = args.output.replace(".jsonl", "_eval.jsonl") save_jsonl(eval_data, eval_path) else: save_jsonl(data, args.output) if __name__ == "__main__": main()