File size: 5,851 Bytes
6379283
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#!/usr/bin/env python3
"""
Quality validation for Stack 2.9 training dataset.
Checks: message structure, tool format, schema compliance.
"""

import json
from pathlib import Path
from typing import Dict, List, Any
import argparse
from collections import Counter

def load_tool_catalog(path: str) -> Dict[str, Any]:
    with open(path, 'r') as f:
        return {tool["tool"]: tool for tool in json.load(f)}

def validate_example(example: Dict[str, Any], tool_catalog: Dict[str, Any]) -> List[str]:
    """Validate a single example. Returns list of errors (empty if valid)."""
    errors = []

    if "messages" not in example:
        errors.append("Missing 'messages' field")
        return errors

    messages = example["messages"]
    if not isinstance(messages, list) or len(messages) < 2:
        errors.append("Invalid messages: must be list with at least 2 messages")
        return errors

    # Check roles sequence
    roles = [msg.get("role") for msg in messages]
    valid_roles = {"system", "user", "assistant"}
    if not all(r in valid_roles for r in roles):
        errors.append(f"Invalid roles: {roles}")

    # Tool use validation
    for msg in messages:
        if msg.get("role") == "assistant" and "tool_use" in msg:
            tool_use = msg["tool_use"]
            if "name" not in tool_use:
                errors.append("Tool use missing 'name'")
            else:
                tool_name = tool_use["name"]
                if tool_name not in tool_catalog:
                    errors.append(f"Unknown tool: {tool_name}")
            if "input" not in tool_use:
                errors.append(f"Tool use missing 'input' for {tool_name}")

        if msg.get("role") == "user" and "tool_result" in msg:
            tool_result = msg["tool_result"]
            if "tool_use_id" not in tool_result:
                errors.append("Tool result missing 'tool_use_id'")
            if "content" not in tool_result:
                errors.append("Tool result missing 'content'")

    # Check message content is non-empty (except user with tool_result can be empty)
    for i, msg in enumerate(messages):
        role = msg.get("role")
        content = msg.get("content")
        if role == "user" and "tool_result" in msg:
            continue  # Tool result user message can have empty content
        if content is not None and not isinstance(content, str):
            errors.append(f"Message content must be string, got {type(content)}")
        if content is not None and len(content.strip()) == 0:
            errors.append(f"Empty content in {role} message")

    return errors

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", type=str, default="training-data/final/train.jsonl")
    parser.add_argument("--catalog", type=str, default="training-data/tools/catalog.json")
    parser.add_argument("--output-report", type=str, default="training-data/final/quality_report.json")
    args = parser.parse_args()

    input_path = Path(args.input)
    catalog_path = Path(args.catalog)

    if not input_path.exists():
        print(f"❌ Input not found: {input_path}")
        return

    if not catalog_path.exists():
        print(f"⚠️  Catalog not found: {catalog_path}, skipping tool validation")
        tool_catalog = {}
    else:
        tool_catalog = load_tool_catalog(catalog_path)
        print(f"✅ Loaded tool catalog with {len(tool_catalog)} tools")

    print(f"🔍 Validating {input_path}...")

    total_examples = 0
    valid_examples = 0
    error_distribution = Counter()
    tool_usage = Counter()

    with open(input_path, 'r') as f:
        for line in f:
            total_examples += 1
            try:
                example = json.loads(line)
                errors = validate_example(example, tool_catalog)

                if errors:
                    for err in errors:
                        error_distribution[err] += 1
                else:
                    valid_examples += 1

                # Track tool usage regardless of validation
                for msg in example.get("messages", []):
                    if "tool_use" in msg:
                        tool_name = msg["tool_use"]["name"]
                        tool_usage[tool_name] += 1

            except json.JSONDecodeError:
                error_distribution["JSON decode error"] += 1

    print(f"\n📊 Validation Results:")
    print(f"   Total examples: {total_examples}")
    print(f"   Valid: {valid_examples} ({valid_examples/total_examples*100:.1f}%)")
    print(f"   Invalid: {total_examples - valid_examples}")

    if error_distribution:
        print("\n   Error breakdown:")
        for err, count in error_distribution.most_common(10):
            print(f"     - {err}: {count}")

    print("\n   Tool usage (top 10):")
    for tool, count in tool_usage.most_common(10):
        print(f"     - {tool}: {count}")

    # Write report
    report = {
        "total_examples": total_examples,
        "valid_examples": valid_examples,
        "invalid_examples": total_examples - valid_examples,
        "validity_rate": valid_examples / total_examples if total_examples > 0 else 0,
        "error_distribution": dict(error_distribution),
        "tool_usage": dict(tool_usage),
        "generated_at": datetime.datetime.now().isoformat()
    }

    output_report = Path(args.output_report)
    output_report.parent.mkdir(parents=True, exist_ok=True)
    with open(output_report, 'w') as f:
        json.dump(report, f, indent=2)

    print(f"\n✅ Report saved: {output_report}")

    if valid_examples / total_examples < 0.9:
        print("\n⚠️  Quality below 90%. Consider filtering invalid examples before training.")
    else:
        print("\n✅ Dataset quality looks good for training!")

if __name__ == "__main__":
    import json, datetime
    main()