File size: 7,924 Bytes
f613e51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
"""
Dataset Format Analyzer
SFTデータセットのフォーマット分布を分析するスクリプト

指定されたHuggingFaceデータセットをダウンロードし、
各サンプルのターゲット出力がどのフォーマット(JSON/YAML/TOML/XML/CSV)
であるかを判定・集計します。
"""

import json
import re
import sys
from collections import Counter, defaultdict

def detect_format(text):
    """テキストのフォーマットを推定する"""
    text = text.strip()
    
    # マークダウンブロック除去
    cleaned = re.sub(r"```\w*\n?", "", text).strip()
    if not cleaned:
        return "EMPTY"
    
    # JSON: { or [ で始まる
    if cleaned.startswith("{") or cleaned.startswith("["):
        try:
            json.loads(cleaned)
            return "JSON"
        except:
            return "JSON"  # JSONっぽいが壊れている
    
    # XML: < で始まる(<?xml or <tag)
    if cleaned.startswith("<"):
        return "XML"
    
    # CSV: カンマ区切りの複数行
    lines = cleaned.split("\n")
    if len(lines) >= 2:
        comma_counts = [line.count(",") for line in lines[:5] if line.strip()]
        if comma_counts and all(c == comma_counts[0] and c > 0 for c in comma_counts):
            return "CSV"
    
    # TOML: [section] パターンまたは key = value パターン
    if re.match(r"^\[[\w\.\-]+\]", cleaned) or re.match(r'^[\w\.\-]+\s*=\s*', cleaned):
        return "TOML"
    
    # YAML: key: value パターン(インデント構造)
    if re.match(r'^[\w\-]+:\s', cleaned) or cleaned.startswith("---") or cleaned.startswith("- "):
        return "YAML"
    
    return "OTHER"


def detect_format_from_prompt(prompt_text):
    """プロンプト(query)からターゲットフォーマットを推定"""
    prompt_lower = prompt_text.lower()
    
    # 明示的な指示を検索
    patterns = {
        "JSON": [r"output\s+json", r"to\s+json", r"in\s+json", r"json\s+code", r"json\s+format"],
        "YAML": [r"output\s+yaml", r"to\s+yaml", r"in\s+yaml", r"yaml\s+code", r"yaml\s+format"],
        "TOML": [r"output\s+toml", r"to\s+toml", r"in\s+toml", r"toml\s+code", r"toml\s+format"],
        "XML":  [r"output\s+xml",  r"to\s+xml",  r"in\s+xml",  r"xml\s+code",  r"xml\s+format"],
        "CSV":  [r"output\s+csv",  r"to\s+csv",  r"in\s+csv",  r"csv\s+code",  r"csv\s+format"],
    }
    
    for fmt, pats in patterns.items():
        for pat in pats:
            if re.search(pat, prompt_lower):
                return fmt
    
    # タスク名パターン (e.g., "Text to JSON", "CSV to YAML")
    task_pattern = r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)"
    match = re.search(task_pattern, prompt_lower)
    if match:
        return match.group(2).upper()
    
    return None


def analyze_dataset(dataset_id):
    """HuggingFaceデータセットを分析"""
    from datasets import load_dataset
    
    print(f"📥 データセットをダウンロード中: {dataset_id}")
    ds = load_dataset(dataset_id, split="train")
    print(f"✅ ダウンロード完了: {len(ds)} 件\n")
    
    # messages構造を解析
    format_from_output = Counter()
    format_from_prompt = Counter()
    task_types = Counter()
    cot_count = 0
    samples_by_format = defaultdict(list)
    
    for i, row in enumerate(ds):
        messages = row.get("messages", [])
        
        # messagesからuser/assistantを抽出
        user_msg = ""
        assistant_msg = ""
        has_cot = False
        
        for msg in messages:
            role = msg.get("role", "")
            content = msg.get("content", "")
            if role == "user":
                user_msg = content
            elif role == "assistant":
                assistant_msg = content
                if "<think>" in content or "</think>" in content:
                    has_cot = True
        
        if has_cot:
            cot_count += 1
        
        # CoT部分を除去してアシスタントの最終出力を取得
        final_output = assistant_msg
        think_match = re.search(r"</think>\s*(.*)", assistant_msg, re.DOTALL)
        if think_match:
            final_output = think_match.group(1).strip()
        
        # 出力フォーマットを判定(2つの方法)
        fmt_output = detect_format(final_output)
        fmt_prompt = detect_format_from_prompt(user_msg)
        
        format_from_output[fmt_output] += 1
        if fmt_prompt:
            format_from_prompt[fmt_prompt] += 1
        else:
            format_from_prompt["UNKNOWN"] += 1
        
        # タスクタイプ推定
        task_match = re.search(r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)", user_msg.lower())
        if task_match:
            task_type = f"{task_match.group(1).upper()} to {task_match.group(2).upper()}"
        elif "please output" in user_msg.lower():
            task_type = f"Text to {fmt_prompt or fmt_output}"
        else:
            task_type = "OTHER"
        task_types[task_type] += 1
        
        # サンプル保存(各フォーマット最大2件)
        fmt_key = fmt_prompt or fmt_output
        if len(samples_by_format[fmt_key]) < 2:
            samples_by_format[fmt_key].append({
                "index": i,
                "prompt_preview": user_msg[:150],
                "output_preview": final_output[:150],
            })
    
    # --- 結果出力 ---
    total = len(ds)
    
    print("=" * 70)
    print(f"📊 データセット分析結果: {dataset_id}")
    print(f"   総サンプル数: {total}")
    print(f"   CoTあり: {cot_count} ({cot_count/total*100:.1f}%)")
    print("=" * 70)
    
    print(f"\n📋 ターゲットフォーマット分布(プロンプトから判定):")
    print(f"{'Format':<12} {'Count':>6} {'Percent':>8}")
    print("-" * 30)
    for fmt in ["JSON", "YAML", "TOML", "XML", "CSV", "UNKNOWN"]:
        count = format_from_prompt.get(fmt, 0)
        pct = f"{count/total*100:.1f}%"
        bar = "█" * int(count/total*50)
        print(f"{fmt:<12} {count:>6} {pct:>8}  {bar}")
    
    print(f"\n📋 出力フォーマット分布(出力内容から判定):")
    print(f"{'Format':<12} {'Count':>6} {'Percent':>8}")
    print("-" * 30)
    for fmt, count in format_from_output.most_common():
        pct = f"{count/total*100:.1f}%"
        bar = "█" * int(count/total*50)
        print(f"{fmt:<12} {count:>6} {pct:>8}  {bar}")
    
    print(f"\n📋 タスクタイプ分布:")
    print(f"{'Task Type':<25} {'Count':>6} {'Percent':>8}")
    print("-" * 45)
    for task, count in task_types.most_common(20):
        pct = f"{count/total*100:.1f}%"
        print(f"{task:<25} {count:>6} {pct:>8}")
    
    # public_150との比較
    print(f"\n📋 public_150.json との比較(参考):")
    print(f"{'Format':<8} {'public_150':>12} {'dataset':>12} {'充足度':>10}")
    print("-" * 45)
    public_counts = {"JSON": 50, "YAML": 35, "TOML": 25, "XML": 20, "CSV": 20}
    for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]:
        pub = public_counts[fmt]
        ds_count = format_from_prompt.get(fmt, 0)
        ratio = f"{ds_count/pub:.1f}x" if pub > 0 else "N/A"
        print(f"{fmt:<8} {pub:>12} {ds_count:>12} {ratio:>10}")
    
    print(f"\n📋 各フォーマットのサンプル:")
    for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]:
        samples = samples_by_format.get(fmt, [])
        print(f"\n--- {fmt} サンプル ({len(samples)}件) ---")
        for s in samples:
            print(f"  [#{s['index']}] prompt: {s['prompt_preview'][:100]}")
            print(f"         output: {s['output_preview'][:100]}")


if __name__ == "__main__":
    if len(sys.argv) > 1:
        dataset_id = sys.argv[1]
    else:
        dataset_id = "u-10bei/structured_data_with_cot_dataset_512_v4"
    
    analyze_dataset(dataset_id)