File size: 6,269 Bytes
0dd6c2f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import json
import logging
from argparse import ArgumentParser
from typing import Any

from datasets import Dataset, DatasetDict, DownloadMode, load_dataset

from linalg_zero.shared.lib import get_tools
from linalg_zero.shared.system_prompts import get_sft_system_prompt
from linalg_zero.shared.utils import (
    get_logger,
    get_representative_examples_indices,
    normalize_text,
    setup_logging,
)

# Log both to file and console
setup_logging(level=logging.INFO, include_timestamp=True)
logger = get_logger(__name__)


def load_datasets(src_train: str, src_test: str) -> DatasetDict:
    """Load datasets"""
    # Load
    logger.info(f"Loading train dataset from https://huggingface.co/datasets/{src_train}")
    train_dataset = load_dataset(src_train, split="train", download_mode=DownloadMode.FORCE_REDOWNLOAD)

    logger.info(f"Loading validation dataset from https://huggingface.co/datasets/{src_test}")
    test_dataset = load_dataset(src_test, split="validation")

    # Prepare results
    assert isinstance(train_dataset, Dataset)
    assert isinstance(test_dataset, Dataset)

    return DatasetDict({"train": train_dataset, "validation": test_dataset})


def process_dataset(dataset: DatasetDict, normalize_unicode: bool, per_category: int, seed: int) -> DatasetDict:
    """Load and process dataset for SFT training."""

    # The necessary columns for SFT
    keep_columns = [
        "tools",
        "messages",
        "ground_truth",
        "stepwise_ground_truths",
    ]

    def _normalize_messages(example: dict[str, Any]) -> dict[str, Any]:
        if not normalize_unicode:
            return example
        msgs = example.get("messages", [])
        for m in msgs:
            if isinstance(m, dict) and "content" in m:
                m["content"] = normalize_text(m["content"], normalize_unicode)
        example["messages"] = msgs
        return example

    # Add missing columns (messages & tools)
    def ensure_messages(example: dict[str, Any]) -> dict[str, Any]:
        example["messages"] = [
            {"role": "system", "content": normalize_text(get_sft_system_prompt(), normalize_unicode)},
            {"role": "user", "content": normalize_text(example["query"], normalize_unicode)},
        ]
        return _normalize_messages(example)

    def ensure_tools(example: dict[str, Any]) -> dict[str, Any]:
        if "tools" not in example or example["tools"] is None:
            example["tools"] = get_tools()
        return example

    def parse_messages(example: dict[str, Any]) -> dict[str, Any]:
        """Convert messages from JSON string to array and replace system prompt"""
        example["messages"] = json.loads(example["messages"])

        # Replace the system prompt with the SFT system prompt
        if example["messages"] and example["messages"][0]["role"] == "system":
            example["messages"][0]["content"] = normalize_text(get_sft_system_prompt(), normalize_unicode)

        return _normalize_messages(example)

    train_dataset = dataset["train"]
    train_dataset = train_dataset.shuffle(seed=seed)
    train_dataset = train_dataset.map(parse_messages)

    test_dataset = dataset["validation"]
    test_dataset = test_dataset.shuffle(seed=seed)
    test_dataset = test_dataset.map(ensure_messages)
    test_dataset = test_dataset.map(ensure_tools)
    indices = get_representative_examples_indices(test_dataset, per_category=per_category, include_remaining=False)
    test_dataset = test_dataset.select(indices)

    # Ensure only relevant columns are preserved
    strip_cols = set(train_dataset.column_names) - set(keep_columns)
    train_dataset = train_dataset.remove_columns(strip_cols)

    strip_cols = set(test_dataset.column_names) - set(keep_columns)
    test_dataset = test_dataset.remove_columns(strip_cols)

    # Ensure the two schemas align (in tools field)
    test_dataset = test_dataset.cast(train_dataset.features)

    # Prepare results
    assert isinstance(train_dataset, Dataset)
    assert isinstance(test_dataset, Dataset)

    return DatasetDict({"train": train_dataset, "test": test_dataset})


def prepare_debug(train: Dataset, validation: Dataset, dataset_size: int) -> DatasetDict:
    train = train.select(range(dataset_size))
    validation = validation.select(range(dataset_size))
    return DatasetDict({"train": train, "validation": validation})


def main(output_repo: str, push_to_hub: bool, normalize_unicode: bool, per_category: int, seed: int) -> None:
    """Main processing function."""
    # Load
    train_repo = "atomwalk12/linalgzero-distilled-clean"
    test_repo = "atomwalk12/linalgzero"

    logger.info("*** Loading datasets ***")
    dataset = load_datasets(train_repo, test_repo)

    # Process
    logger.info("*** Processing dataset ***")
    dataset = process_dataset(dataset, normalize_unicode=normalize_unicode, per_category=per_category, seed=seed)

    # Push to hub
    if push_to_hub:
        logger.info("*** Pushing to Hub ***")
        try:
            dataset.push_to_hub(output_repo)
            logger.info(f"Successfully pushed dataset to https://huggingface.co/datasets/{output_repo}")
        except Exception:
            logger.exception("Failed to push to hub")


if __name__ == "__main__":
    """Script entry point for SFT training."""
    parser = ArgumentParser()
    parser.add_argument("--output_repo", default="atomwalk12/linalgzero-sft", type=str, help="Output repository name")
    parser.add_argument(
        "--push_to_hub", default=False, action="store_true", help="Whether to push the dataset to HuggingFace"
    )
    parser.add_argument(
        "--no_normalize_unicode",
        default=False,
        action="store_true",
        help="Disable Unicode NFKC normalization and minus-sign replacement during dataset prep",
    )
    parser.add_argument("--per_category", default=40, type=int, help="Number of representative examples per category")
    parser.add_argument("--seed", default=42, type=int, help="Random seed for dataset shuffling")
    args = parser.parse_args()

    main(
        output_repo=args.output_repo,
        push_to_hub=args.push_to_hub,
        normalize_unicode=(not args.no_normalize_unicode),
        per_category=args.per_category,
        seed=args.seed,
    )