|
|
|
|
|
""" |
|
|
Data preprocessing script. |
|
|
|
|
|
Convert the generated dataset into a format directly consumable by SFTTrainer. |
|
|
FunctionGemma expects a specific chat template structure. |
|
|
|
|
|
Usage: |
|
|
python -m src.prepare_dataset --input ./data/training_data.json --output ./data/prepared_dataset.json |
|
|
""" |
|
|
|
|
|
import json |
|
|
import argparse |
|
|
from pathlib import Path |
|
|
from typing import List, Dict, Any |
|
|
|
|
|
|
|
|
PROJECT_ROOT = Path(__file__).resolve().parent.parent |
|
|
DEFAULT_INPUT = PROJECT_ROOT / "data" / "training_data.json" |
|
|
DEFAULT_OUTPUT = PROJECT_ROOT / "data" / "prepared_dataset.json" |
|
|
|
|
|
|
|
|
def convert_tool_calls_to_text(tool_calls: List[Dict]) -> str: |
|
|
"""Convert tool_calls into plain text (FunctionGemma format).""" |
|
|
if not tool_calls: |
|
|
return "" |
|
|
|
|
|
result_parts = [] |
|
|
for tc in tool_calls: |
|
|
func = tc.get("function", {}) |
|
|
name = func.get("name", "") |
|
|
args = func.get("arguments", {}) |
|
|
|
|
|
|
|
|
args_str = json.dumps(args, ensure_ascii=False) |
|
|
result_parts.append(f"{name}({args_str})") |
|
|
|
|
|
return "\n".join(result_parts) |
|
|
|
|
|
|
|
|
def convert_messages_for_sft(messages: List[Dict], tools: List[Dict] = None) -> List[Dict]: |
|
|
""" |
|
|
Convert message format for SFTTrainer. |
|
|
|
|
|
Input: |
|
|
[ |
|
|
{"role": "developer", "content": "..."}, |
|
|
{"role": "user", "content": "..."}, |
|
|
{"role": "assistant", "tool_calls": [...]} or {"role": "assistant", "content": "..."} |
|
|
] |
|
|
|
|
|
Output: |
|
|
[ |
|
|
{"role": "system", "content": "..."}, # developer -> system |
|
|
{"role": "user", "content": "..."}, |
|
|
{"role": "assistant", "content": "..."} # tool_calls flattened to text |
|
|
] |
|
|
""" |
|
|
converted = [] |
|
|
|
|
|
|
|
|
tools_description = "" |
|
|
if tools: |
|
|
tools_desc_parts = [] |
|
|
for tool in tools: |
|
|
if tool.get("type") == "function": |
|
|
func = tool.get("function", {}) |
|
|
name = func.get("name", "") |
|
|
desc = func.get("description", "") |
|
|
params = func.get("parameters", {}) |
|
|
tools_desc_parts.append(f"- {name}: {desc}") |
|
|
if tools_desc_parts: |
|
|
tools_description = "\n\nAvailable tools:\n" + "\n".join(tools_desc_parts) |
|
|
|
|
|
for msg in messages: |
|
|
role = msg.get("role", "") |
|
|
|
|
|
if role == "developer": |
|
|
|
|
|
content = msg.get("content", "") |
|
|
if tools_description: |
|
|
content = content + tools_description |
|
|
converted.append({ |
|
|
"role": "system", |
|
|
"content": content |
|
|
}) |
|
|
|
|
|
elif role == "user": |
|
|
converted.append({ |
|
|
"role": "user", |
|
|
"content": msg.get("content", "") |
|
|
}) |
|
|
|
|
|
elif role == "assistant": |
|
|
if "tool_calls" in msg: |
|
|
|
|
|
tool_calls_text = convert_tool_calls_to_text(msg["tool_calls"]) |
|
|
converted.append({ |
|
|
"role": "assistant", |
|
|
"content": tool_calls_text |
|
|
}) |
|
|
else: |
|
|
converted.append({ |
|
|
"role": "assistant", |
|
|
"content": msg.get("content", "") |
|
|
}) |
|
|
|
|
|
elif role == "tool": |
|
|
|
|
|
converted.append({ |
|
|
"role": "tool", |
|
|
"content": msg.get("content", "") |
|
|
}) |
|
|
|
|
|
return converted |
|
|
|
|
|
|
|
|
def prepare_dataset(input_path: str, output_path: str, format_type: str = "messages"): |
|
|
""" |
|
|
Prepare dataset. |
|
|
|
|
|
format_type: |
|
|
- "messages": output {"messages": [...]} |
|
|
- "text": output {"text": "..."} (flattened text) |
|
|
""" |
|
|
print(f"Loading dataset: {input_path}") |
|
|
|
|
|
with open(input_path, 'r', encoding='utf-8') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
print(f"Raw samples: {len(data)}") |
|
|
|
|
|
prepared_data = [] |
|
|
|
|
|
for i, item in enumerate(data): |
|
|
messages = item.get("messages", []) |
|
|
tools = item.get("tools", []) |
|
|
|
|
|
|
|
|
converted_messages = convert_messages_for_sft(messages, tools) |
|
|
|
|
|
if format_type == "messages": |
|
|
prepared_data.append({ |
|
|
"messages": converted_messages |
|
|
}) |
|
|
elif format_type == "text": |
|
|
|
|
|
text_parts = [] |
|
|
for msg in converted_messages: |
|
|
role = msg["role"] |
|
|
content = msg["content"] |
|
|
if role == "system": |
|
|
text_parts.append(f"<start_of_turn>system\n{content}<end_of_turn>") |
|
|
elif role == "user": |
|
|
text_parts.append(f"<start_of_turn>user\n{content}<end_of_turn>") |
|
|
elif role == "assistant": |
|
|
text_parts.append(f"<start_of_turn>model\n{content}<end_of_turn>") |
|
|
|
|
|
prepared_data.append({ |
|
|
"text": "\n".join(text_parts) |
|
|
}) |
|
|
|
|
|
print(f"Processed samples: {len(prepared_data)}") |
|
|
|
|
|
|
|
|
with open(output_path, 'w', encoding='utf-8') as f: |
|
|
json.dump(prepared_data, f, ensure_ascii=False, indent=2) |
|
|
|
|
|
print(f"Saved to: {output_path}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Example:") |
|
|
print("=" * 60) |
|
|
|
|
|
if format_type == "messages": |
|
|
example = prepared_data[0] |
|
|
for msg in example["messages"]: |
|
|
print(f"\n[{msg['role']}]") |
|
|
print(msg["content"][:200] + "..." if len(msg["content"]) > 200 else msg["content"]) |
|
|
else: |
|
|
print(prepared_data[0]["text"][:500] + "...") |
|
|
|
|
|
return prepared_data |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="Dataset preparation") |
|
|
parser.add_argument("--input", type=str, default=str(DEFAULT_INPUT), help="Input file path") |
|
|
parser.add_argument("--output", type=str, default=str(DEFAULT_OUTPUT), help="Output file path") |
|
|
parser.add_argument("--format", type=str, choices=["messages", "text"], default="messages", help="Output format") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
prepare_dataset(args.input, args.output, args.format) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|