Spaces:
Sleeping
Sleeping
File size: 6,269 Bytes
0dd6c2f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | import json
import logging
from argparse import ArgumentParser
from typing import Any
from datasets import Dataset, DatasetDict, DownloadMode, load_dataset
from linalg_zero.shared.lib import get_tools
from linalg_zero.shared.system_prompts import get_sft_system_prompt
from linalg_zero.shared.utils import (
get_logger,
get_representative_examples_indices,
normalize_text,
setup_logging,
)
# Log both to file and console
setup_logging(level=logging.INFO, include_timestamp=True)
logger = get_logger(__name__)
def load_datasets(src_train: str, src_test: str) -> DatasetDict:
"""Load datasets"""
# Load
logger.info(f"Loading train dataset from https://huggingface.co/datasets/{src_train}")
train_dataset = load_dataset(src_train, split="train", download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loading validation dataset from https://huggingface.co/datasets/{src_test}")
test_dataset = load_dataset(src_test, split="validation")
# Prepare results
assert isinstance(train_dataset, Dataset)
assert isinstance(test_dataset, Dataset)
return DatasetDict({"train": train_dataset, "validation": test_dataset})
def process_dataset(dataset: DatasetDict, normalize_unicode: bool, per_category: int, seed: int) -> DatasetDict:
"""Load and process dataset for SFT training."""
# The necessary columns for SFT
keep_columns = [
"tools",
"messages",
"ground_truth",
"stepwise_ground_truths",
]
def _normalize_messages(example: dict[str, Any]) -> dict[str, Any]:
if not normalize_unicode:
return example
msgs = example.get("messages", [])
for m in msgs:
if isinstance(m, dict) and "content" in m:
m["content"] = normalize_text(m["content"], normalize_unicode)
example["messages"] = msgs
return example
# Add missing columns (messages & tools)
def ensure_messages(example: dict[str, Any]) -> dict[str, Any]:
example["messages"] = [
{"role": "system", "content": normalize_text(get_sft_system_prompt(), normalize_unicode)},
{"role": "user", "content": normalize_text(example["query"], normalize_unicode)},
]
return _normalize_messages(example)
def ensure_tools(example: dict[str, Any]) -> dict[str, Any]:
if "tools" not in example or example["tools"] is None:
example["tools"] = get_tools()
return example
def parse_messages(example: dict[str, Any]) -> dict[str, Any]:
"""Convert messages from JSON string to array and replace system prompt"""
example["messages"] = json.loads(example["messages"])
# Replace the system prompt with the SFT system prompt
if example["messages"] and example["messages"][0]["role"] == "system":
example["messages"][0]["content"] = normalize_text(get_sft_system_prompt(), normalize_unicode)
return _normalize_messages(example)
train_dataset = dataset["train"]
train_dataset = train_dataset.shuffle(seed=seed)
train_dataset = train_dataset.map(parse_messages)
test_dataset = dataset["validation"]
test_dataset = test_dataset.shuffle(seed=seed)
test_dataset = test_dataset.map(ensure_messages)
test_dataset = test_dataset.map(ensure_tools)
indices = get_representative_examples_indices(test_dataset, per_category=per_category, include_remaining=False)
test_dataset = test_dataset.select(indices)
# Ensure only relevant columns are preserved
strip_cols = set(train_dataset.column_names) - set(keep_columns)
train_dataset = train_dataset.remove_columns(strip_cols)
strip_cols = set(test_dataset.column_names) - set(keep_columns)
test_dataset = test_dataset.remove_columns(strip_cols)
# Ensure the two schemas align (in tools field)
test_dataset = test_dataset.cast(train_dataset.features)
# Prepare results
assert isinstance(train_dataset, Dataset)
assert isinstance(test_dataset, Dataset)
return DatasetDict({"train": train_dataset, "test": test_dataset})
def prepare_debug(train: Dataset, validation: Dataset, dataset_size: int) -> DatasetDict:
train = train.select(range(dataset_size))
validation = validation.select(range(dataset_size))
return DatasetDict({"train": train, "validation": validation})
def main(output_repo: str, push_to_hub: bool, normalize_unicode: bool, per_category: int, seed: int) -> None:
"""Main processing function."""
# Load
train_repo = "atomwalk12/linalgzero-distilled-clean"
test_repo = "atomwalk12/linalgzero"
logger.info("*** Loading datasets ***")
dataset = load_datasets(train_repo, test_repo)
# Process
logger.info("*** Processing dataset ***")
dataset = process_dataset(dataset, normalize_unicode=normalize_unicode, per_category=per_category, seed=seed)
# Push to hub
if push_to_hub:
logger.info("*** Pushing to Hub ***")
try:
dataset.push_to_hub(output_repo)
logger.info(f"Successfully pushed dataset to https://huggingface.co/datasets/{output_repo}")
except Exception:
logger.exception("Failed to push to hub")
if __name__ == "__main__":
"""Script entry point for SFT training."""
parser = ArgumentParser()
parser.add_argument("--output_repo", default="atomwalk12/linalgzero-sft", type=str, help="Output repository name")
parser.add_argument(
"--push_to_hub", default=False, action="store_true", help="Whether to push the dataset to HuggingFace"
)
parser.add_argument(
"--no_normalize_unicode",
default=False,
action="store_true",
help="Disable Unicode NFKC normalization and minus-sign replacement during dataset prep",
)
parser.add_argument("--per_category", default=40, type=int, help="Number of representative examples per category")
parser.add_argument("--seed", default=42, type=int, help="Random seed for dataset shuffling")
args = parser.parse_args()
main(
output_repo=args.output_repo,
push_to_hub=args.push_to_hub,
normalize_unicode=(not args.no_normalize_unicode),
per_category=args.per_category,
seed=args.seed,
)
|