DMind-3-nano / src /prepare_dataset.py
yuzhe's picture
Upload 13 files
6f09d40 verified
#!/usr/bin/env python3
"""
Data preprocessing script.
Convert the generated dataset into a format directly consumable by SFTTrainer.
FunctionGemma expects a specific chat template structure.
Usage:
python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json
"""
import json
import argparse
from pathlib import Path
from typing import List, Dict, Any
PROJECT_ROOT = Path(__file__).resolve().parent.parent
DEFAULT_INPUT = PROJECT_ROOT / "data" / "training_data.json"
DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "prepared_dataset.json"
def convert_tool_calls_to_text(tool_calls: List[Dict]) -> str:
"""Convert tool_calls into plain text (FunctionGemma format)."""
if not tool_calls:
return ""
result_parts = []
for tc in tool_calls:
func = tc.get("function", {})
name = func.get("name", "")
args = func.get("arguments", {})
# FunctionGemma format: functionName(arguments)
args_str = json.dumps(args, ensure_ascii=False)
result_parts.append(f"{name}({args_str})")
return "\n".join(result_parts)
def convert_messages_for_sft(messages: List[Dict], tools: List[Dict] = None) -> List[Dict]:
"""
Convert message format for SFTTrainer.
Input:
[
{"role": "developer", "content": "..."},
{"role": "user", "content": "..."},
{"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."}
]
Output:
[
{"role": "system", "content": "..."}, # developer -> system
{"role": "user", "content": "..."},
{"role": "assistant", "content": "..."} # tool_calls flattened to text
]
"""
converted = []
# Build tools description
tools_description = ""
if tools:
tools_desc_parts = []
for tool in tools:
if tool.get("type") == "function":
func = tool.get("function", {})
name = func.get("name", "")
desc = func.get("description", "")
params = func.get("parameters", {})
tools_desc_parts.append(f"- {name}: {desc}")
if tools_desc_parts:
tools_description = "\n\nAvailable tools:\n" + "\n".join(tools_desc_parts)
for msg in messages:
role = msg.get("role", "")
if role == "developer":
# developer -> system
content = msg.get("content", "")
if tools_description:
content = content + tools_description
converted.append({
"role": "system",
"content": content
})
elif role == "user":
converted.append({
"role": "user",
"content": msg.get("content", "")
})
elif role == "assistant":
if "tool_calls" in msg:
# Convert tool_calls to text
tool_calls_text = convert_tool_calls_to_text(msg["tool_calls"])
converted.append({
"role": "assistant",
"content": tool_calls_text
})
else:
converted.append({
"role": "assistant",
"content": msg.get("content", "")
})
elif role == "tool":
# Tool response
converted.append({
"role": "tool",
"content": msg.get("content", "")
})
return converted
def prepare_dataset(input_path: str, output_path: str, format_type: str = "messages"):
"""
Prepare dataset.
format_type:
- "messages": output {"messages": [...]}
- "text": output {"text": "..."} (flattened text)
"""
print(f"Loading dataset: {input_path}")
with open(input_path, 'r', encoding='utf-8') as f:
data = json.load(f)
print(f"Raw samples: {len(data)}")
prepared_data = []
for i, item in enumerate(data):
messages = item.get("messages", [])
tools = item.get("tools", [])
# Convert messages
converted_messages = convert_messages_for_sft(messages, tools)
if format_type == "messages":
prepared_data.append({
"messages": converted_messages
})
elif format_type == "text":
# Convert to plain text
text_parts = []
for msg in converted_messages:
role = msg["role"]
content = msg["content"]
if role == "system":
text_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>")
elif role == "user":
text_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>")
elif role == "assistant":
text_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>")
prepared_data.append({
"text": "\n".join(text_parts)
})
print(f"Processed samples: {len(prepared_data)}")
# Save
with open(output_path, 'w', encoding='utf-8') as f:
json.dump(prepared_data, f, ensure_ascii=False, indent=2)
print(f"Saved to: {output_path}")
# Show example
print("\n" + "=" * 60)
print("Example:")
print("=" * 60)
if format_type == "messages":
example = prepared_data[0]
for msg in example["messages"]:
print(f"\n[{msg['role']}]")
print(msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"])
else:
print(prepared_data[0]["text"][:500] + "...")
return prepared_data
def main():
parser = argparse.ArgumentParser(description="Dataset preparation")
parser.add_argument("--input", type=str, default=str(DEFAULT_INPUT), help="Input file path")
parser.add_argument("--output", type=str, default=str(DEFAULT_OUTPUT), help="Output file path")
parser.add_argument("--format", type=str, choices=["messages", "text"], default="messages", help="Output format")
args = parser.parse_args()
prepare_dataset(args.input, args.output, args.format)
if __name__ == "__main__":
main()