#!/usr/bin/env python3 """ Data preprocessing script. Convert the generated dataset into a format directly consumable by SFTTrainer. FunctionGemma expects a specific chat template structure. Usage: python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json """ import json import argparse from pathlib import Path from typing import List, Dict, Any PROJECT_ROOT = Path(__file__).resolve().parent.parent DEFAULT_INPUT = PROJECT_ROOT / "data" / "training_data.json" DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "prepared_dataset.json" def convert_tool_calls_to_text(tool_calls: List[Dict]) -> str: """Convert tool_calls into plain text (FunctionGemma format).""" if not tool_calls: return "" result_parts = [] for tc in tool_calls: func = tc.get("function", {}) name = func.get("name", "") args = func.get("arguments", {}) # FunctionGemma format: functionName(arguments) args_str = json.dumps(args, ensure_ascii=False) result_parts.append(f"{name}({args_str})") return "\n".join(result_parts) def convert_messages_for_sft(messages: List[Dict], tools: List[Dict] = None) -> List[Dict]: """ Convert message format for SFTTrainer. Input: [ {"role": "developer", "content": "..."}, {"role": "user", "content": "..."}, {"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} ] Output: [ {"role": "system", "content": "..."}, # developer -> system {"role": "user", "content": "..."}, {"role": "assistant", "content": "..."} # tool_calls flattened to text ] """ converted = [] # Build tools description tools_description = "" if tools: tools_desc_parts = [] for tool in tools: if tool.get("type") == "function": func = tool.get("function", {}) name = func.get("name", "") desc = func.get("description", "") params = func.get("parameters", {}) tools_desc_parts.append(f"- {name}: {desc}") if tools_desc_parts: tools_description = "\n\nAvailable tools:\n" + "\n".join(tools_desc_parts) for msg in messages: role = msg.get("role", "") if role == "developer": # developer -> system content = msg.get("content", "") if tools_description: content = content + tools_description converted.append({ "role": "system", "content": content }) elif role == "user": converted.append({ "role": "user", "content": msg.get("content", "") }) elif role == "assistant": if "tool_calls" in msg: # Convert tool_calls to text tool_calls_text = convert_tool_calls_to_text(msg["tool_calls"]) converted.append({ "role": "assistant", "content": tool_calls_text }) else: converted.append({ "role": "assistant", "content": msg.get("content", "") }) elif role == "tool": # Tool response converted.append({ "role": "tool", "content": msg.get("content", "") }) return converted def prepare_dataset(input_path: str, output_path: str, format_type: str = "messages"): """ Prepare dataset. format_type: - "messages": output {"messages": [...]} - "text": output {"text": "..."} (flattened text) """ print(f"Loading dataset: {input_path}") with open(input_path, 'r', encoding='utf-8') as f: data = json.load(f) print(f"Raw samples: {len(data)}") prepared_data = [] for i, item in enumerate(data): messages = item.get("messages", []) tools = item.get("tools", []) # Convert messages converted_messages = convert_messages_for_sft(messages, tools) if format_type == "messages": prepared_data.append({ "messages": converted_messages }) elif format_type == "text": # Convert to plain text text_parts = [] for msg in converted_messages: role = msg["role"] content = msg["content"] if role == "system": text_parts.append(f"system\n{content}") elif role == "user": text_parts.append(f"user\n{content}") elif role == "assistant": text_parts.append(f"model\n{content}") prepared_data.append({ "text": "\n".join(text_parts) }) print(f"Processed samples: {len(prepared_data)}") # Save with open(output_path, 'w', encoding='utf-8') as f: json.dump(prepared_data, f, ensure_ascii=False, indent=2) print(f"Saved to: {output_path}") # Show example print("\n" + "=" * 60) print("Example:") print("=" * 60) if format_type == "messages": example = prepared_data[0] for msg in example["messages"]: print(f"\n[{msg['role']}]") print(msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"]) else: print(prepared_data[0]["text"][:500] + "...") return prepared_data def main(): parser = argparse.ArgumentParser(description="Dataset preparation") parser.add_argument("--input", type=str, default=str(DEFAULT_INPUT), help="Input file path") parser.add_argument("--output", type=str, default=str(DEFAULT_OUTPUT), help="Output file path") parser.add_argument("--format", type=str, choices=["messages", "text"], default="messages", help="Output format") args = parser.parse_args() prepare_dataset(args.input, args.output, args.format) if __name__ == "__main__": main()