linalg-zero / linalg_zero /sft /scripts /prepare_dataset.py
atomwalk12's picture
initial commit
0dd6c2f
import json
import logging
from argparse import ArgumentParser
from typing import Any
from datasets import Dataset, DatasetDict, DownloadMode, load_dataset
from linalg_zero.shared.lib import get_tools
from linalg_zero.shared.system_prompts import get_sft_system_prompt
from linalg_zero.shared.utils import (
get_logger,
get_representative_examples_indices,
normalize_text,
setup_logging,
)
# Log both to file and console
setup_logging(level=logging.INFO, include_timestamp=True)
logger = get_logger(__name__)
def load_datasets(src_train: str, src_test: str) -> DatasetDict:
"""Load datasets"""
# Load
logger.info(f"Loading train dataset from https://huggingface.co/datasets/{src_train}")
train_dataset = load_dataset(src_train, split="train", download_mode=DownloadMode.FORCE_REDOWNLOAD)
logger.info(f"Loading validation dataset from https://huggingface.co/datasets/{src_test}")
test_dataset = load_dataset(src_test, split="validation")
# Prepare results
assert isinstance(train_dataset, Dataset)
assert isinstance(test_dataset, Dataset)
return DatasetDict({"train": train_dataset, "validation": test_dataset})
def process_dataset(dataset: DatasetDict, normalize_unicode: bool, per_category: int, seed: int) -> DatasetDict:
"""Load and process dataset for SFT training."""
# The necessary columns for SFT
keep_columns = [
"tools",
"messages",
"ground_truth",
"stepwise_ground_truths",
]
def _normalize_messages(example: dict[str, Any]) -> dict[str, Any]:
if not normalize_unicode:
return example
msgs = example.get("messages", [])
for m in msgs:
if isinstance(m, dict) and "content" in m:
m["content"] = normalize_text(m["content"], normalize_unicode)
example["messages"] = msgs
return example
# Add missing columns (messages & tools)
def ensure_messages(example: dict[str, Any]) -> dict[str, Any]:
example["messages"] = [
{"role": "system", "content": normalize_text(get_sft_system_prompt(), normalize_unicode)},
{"role": "user", "content": normalize_text(example["query"], normalize_unicode)},
]
return _normalize_messages(example)
def ensure_tools(example: dict[str, Any]) -> dict[str, Any]:
if "tools" not in example or example["tools"] is None:
example["tools"] = get_tools()
return example
def parse_messages(example: dict[str, Any]) -> dict[str, Any]:
"""Convert messages from JSON string to array and replace system prompt"""
example["messages"] = json.loads(example["messages"])
# Replace the system prompt with the SFT system prompt
if example["messages"] and example["messages"][0]["role"] == "system":
example["messages"][0]["content"] = normalize_text(get_sft_system_prompt(), normalize_unicode)
return _normalize_messages(example)
train_dataset = dataset["train"]
train_dataset = train_dataset.shuffle(seed=seed)
train_dataset = train_dataset.map(parse_messages)
test_dataset = dataset["validation"]
test_dataset = test_dataset.shuffle(seed=seed)
test_dataset = test_dataset.map(ensure_messages)
test_dataset = test_dataset.map(ensure_tools)
indices = get_representative_examples_indices(test_dataset, per_category=per_category, include_remaining=False)
test_dataset = test_dataset.select(indices)
# Ensure only relevant columns are preserved
strip_cols = set(train_dataset.column_names) - set(keep_columns)
train_dataset = train_dataset.remove_columns(strip_cols)
strip_cols = set(test_dataset.column_names) - set(keep_columns)
test_dataset = test_dataset.remove_columns(strip_cols)
# Ensure the two schemas align (in tools field)
test_dataset = test_dataset.cast(train_dataset.features)
# Prepare results
assert isinstance(train_dataset, Dataset)
assert isinstance(test_dataset, Dataset)
return DatasetDict({"train": train_dataset, "test": test_dataset})
def prepare_debug(train: Dataset, validation: Dataset, dataset_size: int) -> DatasetDict:
train = train.select(range(dataset_size))
validation = validation.select(range(dataset_size))
return DatasetDict({"train": train, "validation": validation})
def main(output_repo: str, push_to_hub: bool, normalize_unicode: bool, per_category: int, seed: int) -> None:
"""Main processing function."""
# Load
train_repo = "atomwalk12/linalgzero-distilled-clean"
test_repo = "atomwalk12/linalgzero"
logger.info("*** Loading datasets ***")
dataset = load_datasets(train_repo, test_repo)
# Process
logger.info("*** Processing dataset ***")
dataset = process_dataset(dataset, normalize_unicode=normalize_unicode, per_category=per_category, seed=seed)
# Push to hub
if push_to_hub:
logger.info("*** Pushing to Hub ***")
try:
dataset.push_to_hub(output_repo)
logger.info(f"Successfully pushed dataset to https://huggingface.co/datasets/{output_repo}")
except Exception:
logger.exception("Failed to push to hub")
if __name__ == "__main__":
"""Script entry point for SFT training."""
parser = ArgumentParser()
parser.add_argument("--output_repo", default="atomwalk12/linalgzero-sft", type=str, help="Output repository name")
parser.add_argument(
"--push_to_hub", default=False, action="store_true", help="Whether to push the dataset to HuggingFace"
)
parser.add_argument(
"--no_normalize_unicode",
default=False,
action="store_true",
help="Disable Unicode NFKC normalization and minus-sign replacement during dataset prep",
)
parser.add_argument("--per_category", default=40, type=int, help="Number of representative examples per category")
parser.add_argument("--seed", default=42, type=int, help="Random seed for dataset shuffling")
args = parser.parse_args()
main(
output_repo=args.output_repo,
push_to_hub=args.push_to_hub,
normalize_unicode=(not args.no_normalize_unicode),
per_category=args.per_category,
seed=args.seed,
)