File size: 6,404 Bytes
49ffe54
 
 
12c2955
49ffe54
 
 
 
 
 
2c32ffd
49ffe54
 
12c2955
49ffe54
 
 
12c2955
49ffe54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12c2955
49ffe54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12c2955
49ffe54
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
#!/usr/bin/env python3
"""
Create a minimal training dataset for rapid prototyping.
Samples N examples from the full data/final/train.jsonl ensuring tool diversity.
"""

import argparse
import json
import random
from pathlib import Path
from typing import List, Dict
from collections import defaultdict, Counter

def load_full_dataset(train_path: str = "data/final/train.jsonl") -> List[Dict]:
    """Load the full dataset."""
    path = Path(train_path)
    if not path.exists():
        raise FileNotFoundError(f"Training data not found at {path}. Please ensure data/final/train.jsonl exists.")

    data = []
    with open(path, 'r') as f:
        for line in f:
            data.append(json.loads(line))
    return data

def extract_tool_calls(example: Dict) -> List[str]:
    """Extract tool names used in an example."""
    tools = []
    messages = example.get("messages", [])
    for msg in messages:
        if msg.get("role") == "assistant" and msg.get("tool_calls"):
            for tc in msg["tool_calls"]:
                func = tc.get("function", {})
                name = func.get("name", "")
                if name:
                    tools.append(name)
    return tools

def create_mini_dataset(
    output_path: str,
    n_samples: int = 5000,
    train_source: str = "data/final/train.jsonl",
    seed: int = 42
):
    """Create a stratified mini dataset."""
    random.seed(seed)

    print(f"Loading full dataset from {train_source}...")
    full_data = load_full_dataset(train_source)
    print(f"Loaded {len(full_data)} total examples")

    # Group by tool usage
    tool_groups = defaultdict(list)
    unknown_tools = []

    for ex in full_data:
        tools = extract_tool_calls(ex)
        if tools:
            # Use first tool as primary category
            primary_tool = tools[0]
            tool_groups[primary_tool].append(ex)
        else:
            unknown_tools.append(ex)

    print(f"\nTool distribution in full dataset:")
    total_tool_examples = sum(len(v) for v in tool_groups.values())
    for tool, examples in sorted(tool_groups.items(), key=lambda x: len(x[1]), reverse=True)[:15]:
        pct = len(examples) / len(full_data) * 100
        print(f"  {tool}: {len(examples)} examples ({pct:.1f}%)")

    print(f"  No-tool examples: {len(unknown_tools)} ({len(unknown_tools)/len(full_data)*100:.1f}%)")

    # Determine sampling strategy
    # Allocate samples proportionally, but ensure minimum 3 examples per tool
    samples_per_tool = {}
    min_per_tool = 3
    remaining = n_samples

    # First pass: assign minimum to all tools that have enough
    for tool, examples in tool_groups.items():
        if len(examples) >= min_per_tool:
            samples_per_tool[tool] = min_per_tool
            remaining -= min_per_tool

    # Second pass: distribute remaining proportionally
    if remaining > 0:
        total_weight = sum(len(v) for v in tool_groups.values() if len(v) >= min_per_tool)
        for tool, examples in tool_groups.items():
            if len(examples) >= min_per_tool:
                weight = len(examples) / total_weight
                extra = int(remaining * weight)
                samples_per_tool[tool] += extra
                remaining -= extra

    # Fill any leftover with no-tool examples
    if remaining > 0 and unknown_tools:
        samples_per_tool["__notool__"] = min(remaining, len(unknown_tools))
        remaining -= min(remaining, len(unknown_tools))

    # If we still have remaining, just take from the largest tool groups
    if remaining > 0:
        sorted_tools = sorted(tool_groups.items(), key=lambda x: len(x[1]), reverse=True)
        for tool, examples in sorted_tools:
            if remaining <= 0:
                break
            can_take = min(remaining, len(examples) - samples_per_tool.get(tool, 0))
            if can_take > 0:
                samples_per_tool[tool] = samples_per_tool.get(tool, 0) + can_take
                remaining -= can_take

    print(f"\nSampling plan (target {n_samples}):")
    total_sampled = 0
    for tool, n in sorted(samples_per_tool.items(), key=lambda x: x[1], reverse=True):
        if n > 0:
            available = len(tool_groups.get(tool, [])) if tool != "__notool__" else len(unknown_tools)
            pct = n / n_samples * 100
            print(f"  {tool}: {n} examples ({pct:.1f}%) from {available} available")
            total_sampled += n

    # Perform sampling
    mini_dataset = []
    for tool, n_to_sample in samples_per_tool.items():
        if n_to_sample <= 0:
            continue

        source_pool = tool_groups[tool] if tool != "__notool__" else unknown_tools
        if len(source_pool) < n_to_sample:
            n_to_sample = len(source_pool)

        sampled = random.sample(source_pool, n_to_sample)
        mini_dataset.extend(sampled)

    # Shuffle the final dataset
    random.shuffle(mini_dataset)

    # Write output
    output_path = Path(output_path)
    output_path.parent.mkdir(parents=True, exist_ok=True)

    with open(output_path, 'w') as f:
        for ex in mini_dataset:
            f.write(json.dumps(ex) + '\n')

    print(f"\n✅ Mini dataset created: {len(mini_dataset)} examples")
    print(f"   Saved to: {output_path}")

    # Stats
    tool_counts = Counter()
    for ex in mini_dataset:
        tools = extract_tool_calls(ex)
        if tools:
            tool_counts[tools[0]] += 1
        else:
            tool_counts["__notool__"] += 1

    print(f"\nFinal tool distribution:")
    for tool, count in tool_counts.most_common(15):
        pct = count / len(mini_dataset) * 100
        print(f"  {tool}: {count} ({pct:.1f}%)")

    return mini_dataset

def main():
    parser = argparse.ArgumentParser(description="Create mini dataset for fast prototyping")
    parser.add_argument("--size", type=int, default=5000, help="Number of examples in mini dataset")
    parser.add_argument("--output", type=str, default="./data_mini/train_mini.jsonl", help="Output file path")
    parser.add_argument("--source", type=str, default="data/final/train.jsonl", help="Source full dataset")
    parser.add_argument("--seed", type=int, default=42, help="Random seed for sampling")

    args = parser.parse_args()

    create_mini_dataset(
        output_path=args.output,
        n_samples=args.size,
        train_source=args.source,
        seed=args.seed
    )

if __name__ == "__main__":
    main()