小形克宏 commited on
Commit
f613e51
·
1 Parent(s): 92b319a

Add dataset format distribution analyzer

Browse files
Files changed (1) hide show
  1. dataset_analyzer.py +211 -0
dataset_analyzer.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Dataset Format Analyzer
3
+ SFTデータセットのフォーマット分布を分析するスクリプト
4
+
5
+ 指定されたHuggingFaceデータセットをダウンロードし、
6
+ 各サンプルのターゲット出力がどのフォーマット(JSON/YAML/TOML/XML/CSV)
7
+ であるかを判定・集計します。
8
+ """
9
+
10
+ import json
11
+ import re
12
+ import sys
13
+ from collections import Counter, defaultdict
14
+
15
+ def detect_format(text):
16
+ """テキストのフォーマットを推定する"""
17
+ text = text.strip()
18
+
19
+ # マークダウンブロック除去
20
+ cleaned = re.sub(r"```\w*\n?", "", text).strip()
21
+ if not cleaned:
22
+ return "EMPTY"
23
+
24
+ # JSON: { or [ で始まる
25
+ if cleaned.startswith("{") or cleaned.startswith("["):
26
+ try:
27
+ json.loads(cleaned)
28
+ return "JSON"
29
+ except:
30
+ return "JSON" # JSONっぽいが壊れている
31
+
32
+ # XML: < で始まる(<?xml or <tag)
33
+ if cleaned.startswith("<"):
34
+ return "XML"
35
+
36
+ # CSV: カンマ区切りの複数行
37
+ lines = cleaned.split("\n")
38
+ if len(lines) >= 2:
39
+ comma_counts = [line.count(",") for line in lines[:5] if line.strip()]
40
+ if comma_counts and all(c == comma_counts[0] and c > 0 for c in comma_counts):
41
+ return "CSV"
42
+
43
+ # TOML: [section] パターンまたは key = value パターン
44
+ if re.match(r"^\[[\w\.\-]+\]", cleaned) or re.match(r'^[\w\.\-]+\s*=\s*', cleaned):
45
+ return "TOML"
46
+
47
+ # YAML: key: value パターン(インデント構造)
48
+ if re.match(r'^[\w\-]+:\s', cleaned) or cleaned.startswith("---") or cleaned.startswith("- "):
49
+ return "YAML"
50
+
51
+ return "OTHER"
52
+
53
+
54
+ def detect_format_from_prompt(prompt_text):
55
+ """プロンプト(query)からターゲットフォーマットを推定"""
56
+ prompt_lower = prompt_text.lower()
57
+
58
+ # 明示的な指示を検索
59
+ patterns = {
60
+ "JSON": [r"output\s+json", r"to\s+json", r"in\s+json", r"json\s+code", r"json\s+format"],
61
+ "YAML": [r"output\s+yaml", r"to\s+yaml", r"in\s+yaml", r"yaml\s+code", r"yaml\s+format"],
62
+ "TOML": [r"output\s+toml", r"to\s+toml", r"in\s+toml", r"toml\s+code", r"toml\s+format"],
63
+ "XML": [r"output\s+xml", r"to\s+xml", r"in\s+xml", r"xml\s+code", r"xml\s+format"],
64
+ "CSV": [r"output\s+csv", r"to\s+csv", r"in\s+csv", r"csv\s+code", r"csv\s+format"],
65
+ }
66
+
67
+ for fmt, pats in patterns.items():
68
+ for pat in pats:
69
+ if re.search(pat, prompt_lower):
70
+ return fmt
71
+
72
+ # タスク名パターン (e.g., "Text to JSON", "CSV to YAML")
73
+ task_pattern = r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)"
74
+ match = re.search(task_pattern, prompt_lower)
75
+ if match:
76
+ return match.group(2).upper()
77
+
78
+ return None
79
+
80
+
81
+ def analyze_dataset(dataset_id):
82
+ """HuggingFaceデータセットを分析"""
83
+ from datasets import load_dataset
84
+
85
+ print(f"📥 データセットをダウンロード中: {dataset_id}")
86
+ ds = load_dataset(dataset_id, split="train")
87
+ print(f"✅ ダウンロード完了: {len(ds)} 件\n")
88
+
89
+ # messages構造を解析
90
+ format_from_output = Counter()
91
+ format_from_prompt = Counter()
92
+ task_types = Counter()
93
+ cot_count = 0
94
+ samples_by_format = defaultdict(list)
95
+
96
+ for i, row in enumerate(ds):
97
+ messages = row.get("messages", [])
98
+
99
+ # messagesからuser/assistantを抽出
100
+ user_msg = ""
101
+ assistant_msg = ""
102
+ has_cot = False
103
+
104
+ for msg in messages:
105
+ role = msg.get("role", "")
106
+ content = msg.get("content", "")
107
+ if role == "user":
108
+ user_msg = content
109
+ elif role == "assistant":
110
+ assistant_msg = content
111
+ if "<think>" in content or "</think>" in content:
112
+ has_cot = True
113
+
114
+ if has_cot:
115
+ cot_count += 1
116
+
117
+ # CoT部分を除去してアシスタントの最終出力を取得
118
+ final_output = assistant_msg
119
+ think_match = re.search(r"</think>\s*(.*)", assistant_msg, re.DOTALL)
120
+ if think_match:
121
+ final_output = think_match.group(1).strip()
122
+
123
+ # 出力フォーマットを判定(2つの方法)
124
+ fmt_output = detect_format(final_output)
125
+ fmt_prompt = detect_format_from_prompt(user_msg)
126
+
127
+ format_from_output[fmt_output] += 1
128
+ if fmt_prompt:
129
+ format_from_prompt[fmt_prompt] += 1
130
+ else:
131
+ format_from_prompt["UNKNOWN"] += 1
132
+
133
+ # タスクタイプ推定
134
+ task_match = re.search(r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)", user_msg.lower())
135
+ if task_match:
136
+ task_type = f"{task_match.group(1).upper()} to {task_match.group(2).upper()}"
137
+ elif "please output" in user_msg.lower():
138
+ task_type = f"Text to {fmt_prompt or fmt_output}"
139
+ else:
140
+ task_type = "OTHER"
141
+ task_types[task_type] += 1
142
+
143
+ # サンプル保存(各フォーマット最大2件)
144
+ fmt_key = fmt_prompt or fmt_output
145
+ if len(samples_by_format[fmt_key]) < 2:
146
+ samples_by_format[fmt_key].append({
147
+ "index": i,
148
+ "prompt_preview": user_msg[:150],
149
+ "output_preview": final_output[:150],
150
+ })
151
+
152
+ # --- 結果出力 ---
153
+ total = len(ds)
154
+
155
+ print("=" * 70)
156
+ print(f"📊 データセット分析結果: {dataset_id}")
157
+ print(f" 総サンプル数: {total}")
158
+ print(f" CoTあり: {cot_count} ({cot_count/total*100:.1f}%)")
159
+ print("=" * 70)
160
+
161
+ print(f"\n📋 ターゲットフォーマット分布(プロンプトから判定):")
162
+ print(f"{'Format':<12} {'Count':>6} {'Percent':>8}")
163
+ print("-" * 30)
164
+ for fmt in ["JSON", "YAML", "TOML", "XML", "CSV", "UNKNOWN"]:
165
+ count = format_from_prompt.get(fmt, 0)
166
+ pct = f"{count/total*100:.1f}%"
167
+ bar = "█" * int(count/total*50)
168
+ print(f"{fmt:<12} {count:>6} {pct:>8} {bar}")
169
+
170
+ print(f"\n📋 出力フォーマット分布(出力内容から判定):")
171
+ print(f"{'Format':<12} {'Count':>6} {'Percent':>8}")
172
+ print("-" * 30)
173
+ for fmt, count in format_from_output.most_common():
174
+ pct = f"{count/total*100:.1f}%"
175
+ bar = "█" * int(count/total*50)
176
+ print(f"{fmt:<12} {count:>6} {pct:>8} {bar}")
177
+
178
+ print(f"\n📋 タスクタイプ分布:")
179
+ print(f"{'Task Type':<25} {'Count':>6} {'Percent':>8}")
180
+ print("-" * 45)
181
+ for task, count in task_types.most_common(20):
182
+ pct = f"{count/total*100:.1f}%"
183
+ print(f"{task:<25} {count:>6} {pct:>8}")
184
+
185
+ # public_150との比較
186
+ print(f"\n📋 public_150.json との比較(参考):")
187
+ print(f"{'Format':<8} {'public_150':>12} {'dataset':>12} {'充足度':>10}")
188
+ print("-" * 45)
189
+ public_counts = {"JSON": 50, "YAML": 35, "TOML": 25, "XML": 20, "CSV": 20}
190
+ for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]:
191
+ pub = public_counts[fmt]
192
+ ds_count = format_from_prompt.get(fmt, 0)
193
+ ratio = f"{ds_count/pub:.1f}x" if pub > 0 else "N/A"
194
+ print(f"{fmt:<8} {pub:>12} {ds_count:>12} {ratio:>10}")
195
+
196
+ print(f"\n📋 各フォーマットのサンプル:")
197
+ for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]:
198
+ samples = samples_by_format.get(fmt, [])
199
+ print(f"\n--- {fmt} サンプル ({len(samples)}件) ---")
200
+ for s in samples:
201
+ print(f" [#{s['index']}] prompt: {s['prompt_preview'][:100]}")
202
+ print(f" output: {s['output_preview'][:100]}")
203
+
204
+
205
+ if __name__ == "__main__":
206
+ if len(sys.argv) > 1:
207
+ dataset_id = sys.argv[1]
208
+ else:
209
+ dataset_id = "u-10bei/structured_data_with_cot_dataset_512_v4"
210
+
211
+ analyze_dataset(dataset_id)