File size: 4,966 Bytes
6379283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
#!/usr/bin/env python3
"""
Combine all training data sources into final dataset.
Applies deduplication and quality filtering.
"""

import json
import hashlib
from pathlib import Path
import argparse
from datetime import datetime

def hash_messages(messages: list) -> str:
    """Create a hash of messages to detect duplicates."""
    m = hashlib.md5()
    m.update(json.dumps(messages, sort_keys=True).encode())
    return m.hexdigest()

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--output", type=str, default="training-data/final/dataset.jsonl")
    parser.add_argument("--train-size", type=float, default=0.8)
    parser.add_argument("--val-size", type=float, default=0.1)
    parser.add_argument("--max-dataset", type=int, default=50000, help="Max examples to include")
    args = parser.parse_args()

    output_path = Path(args.output)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    # List all source files
    sources = [
        ("training-data/synthetic/examples.jsonl", "original_synthetic"),
        ("training-data/advanced-patterns/examples.jsonl", "advanced_patterns"),
        ("training-data/code-pairs/pairs.json", "code_pairs"),
        ("training-data/code-pairs/extended_pairs.json", "code_pairs_extended"),
        ("training-data/scaled/synthetic_final.jsonl", "synthetic_augmented"),
        ("training-data/scaled/random_10k.jsonl", "random_10k"),
        ("training-data/scaled/random_5_5k.jsonl", "random_5k"),
    ]

    all_examples = []
    seen_hashes = set()
    duplicates_removed = 0

    print("📦 Combining datasets...")
    for file_path, source in sources:
        path = Path(file_path)
        if not path.exists():
            print(f"   ⚠️  Not found: {path}")
            continue

        print(f"   Loading {source}...")
        count = 0
        with open(path, 'r') as f:
            for line in f:
                try:
                    ex = json.loads(line)

                    # Convert code-pair format if needed
                    if "code" in ex and "comment" in ex:
                        # Convert code-pair to message format
                        ex = {
                            "messages": [
                                {"role": "user", "content": ex["comment"]},
                                {"role": "assistant", "content": f"Here's the code:\n{ex['code']}"}
                            ],
                            "source": source,
                            "type": "code_pair"
                        }

                    # Deduplication
                    msg_hash = hash_messages(ex["messages"])
                    if msg_hash in seen_hashes:
                        duplicates_removed += 1
                        continue
                    seen_hashes.add(msg_hash)

                    # Add metadata
                    ex["source_original"] = source
                    all_examples.append(ex)
                    count += 1

                    if len(all_examples) >= args.max_dataset:
                        break

                except json.JSONDecodeError:
                    continue

        print(f"      ✅ Added {count} examples")

    print(f"\n✨ Total collected: {len(all_examples)} examples")
    print(f"   Duplicates removed: {duplicates_removed}")

    # Shuffle
    random.seed(42)
    random.shuffle(all_examples)

    # Split
    n_total = len(all_examples)
    n_train = int(n_total * args.train_size)
    n_val = int(n_total * args.val_size)
    n_test = n_total - n_train - n_val

    train_set = all_examples[:n_train]
    val_set = all_examples[n_train:n_train+n_val]
    test_set = all_examples[n_train+n_val:]

    print(f"\n📊 Split:")
    print(f"   Train: {len(train_set)}")
    print(f"   Val: {len(val_set)}")
    print(f"   Test: {len(test_set)}")

    # Save splits
    for split_name, split_data in [("train", train_set), ("val", val_set), ("test", test_set)]:
        split_path = output_path.parent / f"{split_name}.jsonl"
        with open(split_path, 'w') as f:
            for ex in split_data:
                f.write(json.dumps(ex) + "\n")
        print(f"   Saved {split_name} to {split_path}")

    # Create manifest
    manifest = {
        "dataset": "Stack 2.9 Training Data",
        "version": "1.0",
        "created": datetime.now().isoformat(),
        "total_examples": n_total,
        "splits": {
            "train": len(train_set),
            "val": len(val_set),
            "test": len(test_set)
        },
        "sources": {src: sum(1 for ex in all_examples if ex.get("source_original") == src) for src in set(ex.get("source_original") for ex in all_examples)}
    }

    manifest_path = output_path.parent / "manifest.json"
    with open(manifest_path, 'w') as f:
        json.dump(manifest, f, indent=2)
    print(f"\n📄 Manifest: {manifest_path}")

    print("\n✅ Dataset complete!")

if __name__ == "__main__":
    import random
    main()