| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import json |
| | from typing import Any, Literal, NotRequired, TypedDict |
| |
|
| | from ...utils import logging |
| | from ...utils.plugin import BasePlugin |
| | from ...utils.types import DPOSample, Sample, SFTSample, ToolCall |
| |
|
| |
|
| | logger = logging.get_logger(__name__) |
| |
|
| |
|
| | class AlpacaSample(TypedDict, total=False): |
| | system: NotRequired[str] |
| | instruction: str |
| | input: NotRequired[str] |
| | output: str |
| |
|
| |
|
| | SharegptMessage = TypedDict( |
| | "SharegptMessage", |
| | {"from": Literal["human", "gpt", "system", "function_call", "observation"], "value": str}, |
| | ) |
| |
|
| |
|
| | class SharegptSample(TypedDict, total=False): |
| | conversations: list[SharegptMessage] |
| | tools: NotRequired[str] |
| |
|
| |
|
| | class OpenaiMessage(TypedDict, total=False): |
| | role: Literal["user", "assistant", "tool"] |
| | content: str |
| |
|
| |
|
| | class OpenaiSample(TypedDict, total=False): |
| | messages: list[OpenaiMessage] |
| |
|
| |
|
| | class PairSample(TypedDict, total=False): |
| | chosen: list[OpenaiMessage] |
| | rejected: list[OpenaiMessage] |
| |
|
| |
|
| | class DataConverterPlugin(BasePlugin): |
| | """Plugin for data converters.""" |
| |
|
| | def __call__(self, raw_sample: dict[str, Any]) -> Sample: |
| | return super().__call__(raw_sample) |
| |
|
| |
|
| | @DataConverterPlugin("alpaca").register() |
| | def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: |
| | """Convert Alpaca sample to SFT sample. |
| | |
| | See raw example at: https://huggingface.co/datasets/llamafactory/alpaca_gpt4_en |
| | |
| | Args: |
| | raw_sample (AlpacaSample): Alpaca sample. |
| | |
| | Returns: |
| | SFTSample: SFT sample. |
| | """ |
| | messages = [] |
| | if "system" in raw_sample: |
| | messages.append( |
| | {"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0} |
| | ) |
| |
|
| | if "instruction" in raw_sample or "input" in raw_sample: |
| | messages.append( |
| | { |
| | "role": "user", |
| | "content": [ |
| | {"type": "text", "value": raw_sample.get("instruction", "") + raw_sample.get("input", "")} |
| | ], |
| | "loss_weight": 0.0, |
| | } |
| | ) |
| |
|
| | if "output" in raw_sample: |
| | messages.append( |
| | {"role": "assistant", "content": [{"type": "text", "value": raw_sample["output"]}], "loss_weight": 1.0} |
| | ) |
| |
|
| | return {"messages": messages} |
| |
|
| |
|
| | @DataConverterPlugin("sharegpt").register() |
| | def sharegpt_converter(raw_sample: SharegptSample) -> SFTSample: |
| | """Convert ShareGPT sample to SFT sample. |
| | |
| | See raw example at: https://huggingface.co/datasets/llamafactory/glaive_toolcall_en |
| | |
| | Args: |
| | raw_sample (SharegptSample): ShareGPT sample. |
| | |
| | Returns: |
| | SFTSample: SFT sample. |
| | """ |
| | tag_mapping = { |
| | "system": "system", |
| | "human": "user", |
| | "gpt": "assistant", |
| | "observation": "tool", |
| | "function_call": "assistant", |
| | } |
| | sample = {} |
| | messages = [] |
| | for message in raw_sample.get("conversations", []): |
| | tag = message["from"] |
| | if tag not in tag_mapping: |
| | logger.warning_rank0(f"Unsupported role tag {tag} in message: {message}") |
| | elif tag == "function_call": |
| | try: |
| | tool_calls: ToolCall | list[ToolCall] = json.loads(message["value"]) |
| | except json.JSONDecodeError: |
| | logger.warning_rank0(f"Invalid tool call format: {str(message['value'])}") |
| | continue |
| |
|
| | if not isinstance(tool_calls, list): |
| | tool_calls = [tool_calls] |
| |
|
| | messages.append( |
| | { |
| | "role": "assistant", |
| | "content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls], |
| | "loss_weight": 1.0, |
| | } |
| | ) |
| | else: |
| | messages.append( |
| | { |
| | "role": tag_mapping[tag], |
| | "content": [{"type": "text", "value": message["value"]}], |
| | "loss_weight": 1.0 if tag == "gpt" else 0.0, |
| | } |
| | ) |
| |
|
| | sample["messages"] = messages |
| |
|
| | tools = raw_sample.get("tools") |
| | if tools: |
| | try: |
| | tools: list[dict[str, Any]] = json.loads(tools) |
| | sample["tools"] = json.dumps(tools) |
| | except json.JSONDecodeError: |
| | logger.warning_rank0(f"Invalid tools format: {str(tools)}") |
| |
|
| | return sample |
| |
|
| |
|
| | @DataConverterPlugin("pair").register() |
| | def pair_converter(raw_sample: PairSample) -> DPOSample: |
| | """Convert Pair sample to DPO sample. |
| | |
| | See raw example at: https://huggingface.co/datasets/HuggingFaceH4/orca_dpo_pairs |
| | |
| | Args: |
| | raw_sample (PairSample): pair sample with chosen, rejected fields. |
| | |
| | Returns: |
| | DPOSample: DPO sample with chosen_messages and rejected_messages. |
| | """ |
| |
|
| | def process_message(raw_messages: list[OpenaiMessage]): |
| | messages = [] |
| | for message in raw_messages: |
| | if message["role"] == "tool": |
| | try: |
| | tool_calls: ToolCall | list[ToolCall] = json.loads(message["content"]) |
| | except json.JSONDecodeError: |
| | logger.warning_rank0(f"Invalid tool call format: {str(message['content'])}") |
| | continue |
| |
|
| | if not isinstance(tool_calls, list): |
| | tool_calls = [tool_calls] |
| |
|
| | messages.append( |
| | { |
| | "role": message["role"], |
| | "content": [{"type": "tool_call", "value": json.dumps(tool_call)} for tool_call in tool_calls], |
| | "loss_weight": 1.0 if message["role"] == "assistant" else 0.0, |
| | } |
| | ) |
| | else: |
| | messages.append( |
| | { |
| | "role": message["role"], |
| | "content": [{"type": "text", "value": message["content"]}], |
| | "loss_weight": 1.0 if message["role"] == "assistant" else 0.0, |
| | } |
| | ) |
| |
|
| | return messages |
| |
|
| | sample = {} |
| | sample["chosen_messages"] = process_message(raw_sample.get("chosen", [])) |
| | sample["rejected_messages"] = process_message(raw_sample.get("rejected", [])) |
| |
|
| | tools = raw_sample.get("tools") |
| | if tools: |
| | try: |
| | tools: list[dict[str, Any]] = json.loads(tools) |
| | sample["tools"] = json.dumps(tools) |
| | except json.JSONDecodeError: |
| | logger.warning_rank0(f"Invalid tools format: {str(tools)}") |
| |
|
| | return sample |
| |
|