Spaces:
Running on Zero
Running on Zero
| import json | |
| import logging | |
| from argparse import ArgumentParser | |
| from typing import Any | |
| from datasets import Dataset, DatasetDict, DownloadMode, load_dataset | |
| from linalg_zero.shared.lib import get_tools | |
| from linalg_zero.shared.system_prompts import get_sft_system_prompt | |
| from linalg_zero.shared.utils import ( | |
| get_logger, | |
| get_representative_examples_indices, | |
| normalize_text, | |
| setup_logging, | |
| ) | |
| # Log both to file and console | |
| setup_logging(level=logging.INFO, include_timestamp=True) | |
| logger = get_logger(__name__) | |
| def load_datasets(src_train: str, src_test: str) -> DatasetDict: | |
| """Load datasets""" | |
| # Load | |
| logger.info(f"Loading train dataset from https://huggingface.co/datasets/{src_train}") | |
| train_dataset = load_dataset(src_train, split="train", download_mode=DownloadMode.FORCE_REDOWNLOAD) | |
| logger.info(f"Loading validation dataset from https://huggingface.co/datasets/{src_test}") | |
| test_dataset = load_dataset(src_test, split="validation") | |
| # Prepare results | |
| assert isinstance(train_dataset, Dataset) | |
| assert isinstance(test_dataset, Dataset) | |
| return DatasetDict({"train": train_dataset, "validation": test_dataset}) | |
| def process_dataset(dataset: DatasetDict, normalize_unicode: bool, per_category: int, seed: int) -> DatasetDict: | |
| """Load and process dataset for SFT training.""" | |
| # The necessary columns for SFT | |
| keep_columns = [ | |
| "tools", | |
| "messages", | |
| "ground_truth", | |
| "stepwise_ground_truths", | |
| ] | |
| def _normalize_messages(example: dict[str, Any]) -> dict[str, Any]: | |
| if not normalize_unicode: | |
| return example | |
| msgs = example.get("messages", []) | |
| for m in msgs: | |
| if isinstance(m, dict) and "content" in m: | |
| m["content"] = normalize_text(m["content"], normalize_unicode) | |
| example["messages"] = msgs | |
| return example | |
| # Add missing columns (messages & tools) | |
| def ensure_messages(example: dict[str, Any]) -> dict[str, Any]: | |
| example["messages"] = [ | |
| {"role": "system", "content": normalize_text(get_sft_system_prompt(), normalize_unicode)}, | |
| {"role": "user", "content": normalize_text(example["query"], normalize_unicode)}, | |
| ] | |
| return _normalize_messages(example) | |
| def ensure_tools(example: dict[str, Any]) -> dict[str, Any]: | |
| if "tools" not in example or example["tools"] is None: | |
| example["tools"] = get_tools() | |
| return example | |
| def parse_messages(example: dict[str, Any]) -> dict[str, Any]: | |
| """Convert messages from JSON string to array and replace system prompt""" | |
| example["messages"] = json.loads(example["messages"]) | |
| # Replace the system prompt with the SFT system prompt | |
| if example["messages"] and example["messages"][0]["role"] == "system": | |
| example["messages"][0]["content"] = normalize_text(get_sft_system_prompt(), normalize_unicode) | |
| return _normalize_messages(example) | |
| train_dataset = dataset["train"] | |
| train_dataset = train_dataset.shuffle(seed=seed) | |
| train_dataset = train_dataset.map(parse_messages) | |
| test_dataset = dataset["validation"] | |
| test_dataset = test_dataset.shuffle(seed=seed) | |
| test_dataset = test_dataset.map(ensure_messages) | |
| test_dataset = test_dataset.map(ensure_tools) | |
| indices = get_representative_examples_indices(test_dataset, per_category=per_category, include_remaining=False) | |
| test_dataset = test_dataset.select(indices) | |
| # Ensure only relevant columns are preserved | |
| strip_cols = set(train_dataset.column_names) - set(keep_columns) | |
| train_dataset = train_dataset.remove_columns(strip_cols) | |
| strip_cols = set(test_dataset.column_names) - set(keep_columns) | |
| test_dataset = test_dataset.remove_columns(strip_cols) | |
| # Ensure the two schemas align (in tools field) | |
| test_dataset = test_dataset.cast(train_dataset.features) | |
| # Prepare results | |
| assert isinstance(train_dataset, Dataset) | |
| assert isinstance(test_dataset, Dataset) | |
| return DatasetDict({"train": train_dataset, "test": test_dataset}) | |
| def prepare_debug(train: Dataset, validation: Dataset, dataset_size: int) -> DatasetDict: | |
| train = train.select(range(dataset_size)) | |
| validation = validation.select(range(dataset_size)) | |
| return DatasetDict({"train": train, "validation": validation}) | |
| def main(output_repo: str, push_to_hub: bool, normalize_unicode: bool, per_category: int, seed: int) -> None: | |
| """Main processing function.""" | |
| # Load | |
| train_repo = "atomwalk12/linalgzero-distilled-clean" | |
| test_repo = "atomwalk12/linalgzero" | |
| logger.info("*** Loading datasets ***") | |
| dataset = load_datasets(train_repo, test_repo) | |
| # Process | |
| logger.info("*** Processing dataset ***") | |
| dataset = process_dataset(dataset, normalize_unicode=normalize_unicode, per_category=per_category, seed=seed) | |
| # Push to hub | |
| if push_to_hub: | |
| logger.info("*** Pushing to Hub ***") | |
| try: | |
| dataset.push_to_hub(output_repo) | |
| logger.info(f"Successfully pushed dataset to https://huggingface.co/datasets/{output_repo}") | |
| except Exception: | |
| logger.exception("Failed to push to hub") | |
| if __name__ == "__main__": | |
| """Script entry point for SFT training.""" | |
| parser = ArgumentParser() | |
| parser.add_argument("--output_repo", default="atomwalk12/linalgzero-sft", type=str, help="Output repository name") | |
| parser.add_argument( | |
| "--push_to_hub", default=False, action="store_true", help="Whether to push the dataset to HuggingFace" | |
| ) | |
| parser.add_argument( | |
| "--no_normalize_unicode", | |
| default=False, | |
| action="store_true", | |
| help="Disable Unicode NFKC normalization and minus-sign replacement during dataset prep", | |
| ) | |
| parser.add_argument("--per_category", default=40, type=int, help="Number of representative examples per category") | |
| parser.add_argument("--seed", default=42, type=int, help="Random seed for dataset shuffling") | |
| args = parser.parse_args() | |
| main( | |
| output_repo=args.output_repo, | |
| push_to_hub=args.push_to_hub, | |
| normalize_unicode=(not args.no_normalize_unicode), | |
| per_category=args.per_category, | |
| seed=args.seed, | |
| ) | |