|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, TypedDict |
|
|
|
|
|
from typing_extensions import NotRequired |
|
|
|
|
|
from ...extras.types import Sample, SFTSample |
|
|
|
|
|
|
|
|
class AlpacaSample(TypedDict, total=False): |
|
|
system: NotRequired[str] |
|
|
instruction: NotRequired[str] |
|
|
input: NotRequired[str] |
|
|
output: NotRequired[str] |
|
|
|
|
|
|
|
|
def alpaca_converter(raw_sample: AlpacaSample) -> SFTSample: |
|
|
"""Convert Alpaca sample to SFT sample. |
|
|
|
|
|
Args: |
|
|
raw_sample (AlpacaSample): Alpaca sample. |
|
|
|
|
|
Returns: |
|
|
SFTSample: SFT sample. |
|
|
""" |
|
|
messages = [] |
|
|
if "system" in raw_sample: |
|
|
messages.append( |
|
|
{"role": "system", "content": [{"type": "text", "value": raw_sample["system"]}], "loss_weight": 0.0} |
|
|
) |
|
|
|
|
|
if "instruction" in raw_sample or "input" in raw_sample: |
|
|
messages.append( |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{"type": "text", "value": raw_sample.get("instruction", "") + raw_sample.get("input", "")} |
|
|
], |
|
|
"loss_weight": 0.0, |
|
|
} |
|
|
) |
|
|
|
|
|
if "output" in raw_sample: |
|
|
messages.append( |
|
|
{"role": "assistant", "content": [{"type": "text", "value": raw_sample["output"]}], "loss_weight": 1.0} |
|
|
) |
|
|
|
|
|
return {"messages": messages} |
|
|
|
|
|
|
|
|
CONVERTERS = { |
|
|
"alpaca": alpaca_converter, |
|
|
} |
|
|
|
|
|
|
|
|
def get_converter(converter_name: str) -> Callable[[dict], Sample]: |
|
|
if converter_name not in CONVERTERS: |
|
|
raise ValueError(f"Converter {converter_name} not found.") |
|
|
|
|
|
return CONVERTERS[converter_name] |
|
|
|