structeval-analyz / dataset_analyzer.py
小形克宏
Add dataset format distribution analyzer
f613e51
"""
Dataset Format Analyzer
SFTデータセットのフォーマット分布を分析するスクリプト
指定されたHuggingFaceデータセットをダウンロードし、
各サンプルのターゲット出力がどのフォーマット(JSON/YAML/TOML/XML/CSV)
であるかを判定・集計します。
"""
import json
import re
import sys
from collections import Counter, defaultdict
def detect_format(text):
"""テキストのフォーマットを推定する"""
text = text.strip()
# マークダウンブロック除去
cleaned = re.sub(r"```\w*\n?", "", text).strip()
if not cleaned:
return "EMPTY"
# JSON: { or [ で始まる
if cleaned.startswith("{") or cleaned.startswith("["):
try:
json.loads(cleaned)
return "JSON"
except:
return "JSON" # JSONっぽいが壊れている
# XML: < で始まる(<?xml or <tag)
if cleaned.startswith("<"):
return "XML"
# CSV: カンマ区切りの複数行
lines = cleaned.split("\n")
if len(lines) >= 2:
comma_counts = [line.count(",") for line in lines[:5] if line.strip()]
if comma_counts and all(c == comma_counts[0] and c > 0 for c in comma_counts):
return "CSV"
# TOML: [section] パターンまたは key = value パターン
if re.match(r"^\[[\w\.\-]+\]", cleaned) or re.match(r'^[\w\.\-]+\s*=\s*', cleaned):
return "TOML"
# YAML: key: value パターン(インデント構造)
if re.match(r'^[\w\-]+:\s', cleaned) or cleaned.startswith("---") or cleaned.startswith("- "):
return "YAML"
return "OTHER"
def detect_format_from_prompt(prompt_text):
"""プロンプト(query)からターゲットフォーマットを推定"""
prompt_lower = prompt_text.lower()
# 明示的な指示を検索
patterns = {
"JSON": [r"output\s+json", r"to\s+json", r"in\s+json", r"json\s+code", r"json\s+format"],
"YAML": [r"output\s+yaml", r"to\s+yaml", r"in\s+yaml", r"yaml\s+code", r"yaml\s+format"],
"TOML": [r"output\s+toml", r"to\s+toml", r"in\s+toml", r"toml\s+code", r"toml\s+format"],
"XML": [r"output\s+xml", r"to\s+xml", r"in\s+xml", r"xml\s+code", r"xml\s+format"],
"CSV": [r"output\s+csv", r"to\s+csv", r"in\s+csv", r"csv\s+code", r"csv\s+format"],
}
for fmt, pats in patterns.items():
for pat in pats:
if re.search(pat, prompt_lower):
return fmt
# タスク名パターン (e.g., "Text to JSON", "CSV to YAML")
task_pattern = r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)"
match = re.search(task_pattern, prompt_lower)
if match:
return match.group(2).upper()
return None
def analyze_dataset(dataset_id):
"""HuggingFaceデータセットを分析"""
from datasets import load_dataset
print(f"📥 データセットをダウンロード中: {dataset_id}")
ds = load_dataset(dataset_id, split="train")
print(f"✅ ダウンロード完了: {len(ds)} 件\n")
# messages構造を解析
format_from_output = Counter()
format_from_prompt = Counter()
task_types = Counter()
cot_count = 0
samples_by_format = defaultdict(list)
for i, row in enumerate(ds):
messages = row.get("messages", [])
# messagesからuser/assistantを抽出
user_msg = ""
assistant_msg = ""
has_cot = False
for msg in messages:
role = msg.get("role", "")
content = msg.get("content", "")
if role == "user":
user_msg = content
elif role == "assistant":
assistant_msg = content
if "<think>" in content or "</think>" in content:
has_cot = True
if has_cot:
cot_count += 1
# CoT部分を除去してアシスタントの最終出力を取得
final_output = assistant_msg
think_match = re.search(r"</think>\s*(.*)", assistant_msg, re.DOTALL)
if think_match:
final_output = think_match.group(1).strip()
# 出力フォーマットを判定(2つの方法)
fmt_output = detect_format(final_output)
fmt_prompt = detect_format_from_prompt(user_msg)
format_from_output[fmt_output] += 1
if fmt_prompt:
format_from_prompt[fmt_prompt] += 1
else:
format_from_prompt["UNKNOWN"] += 1
# タスクタイプ推定
task_match = re.search(r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)", user_msg.lower())
if task_match:
task_type = f"{task_match.group(1).upper()} to {task_match.group(2).upper()}"
elif "please output" in user_msg.lower():
task_type = f"Text to {fmt_prompt or fmt_output}"
else:
task_type = "OTHER"
task_types[task_type] += 1
# サンプル保存(各フォーマット最大2件)
fmt_key = fmt_prompt or fmt_output
if len(samples_by_format[fmt_key]) < 2:
samples_by_format[fmt_key].append({
"index": i,
"prompt_preview": user_msg[:150],
"output_preview": final_output[:150],
})
# --- 結果出力 ---
total = len(ds)
print("=" * 70)
print(f"📊 データセット分析結果: {dataset_id}")
print(f" 総サンプル数: {total}")
print(f" CoTあり: {cot_count} ({cot_count/total*100:.1f}%)")
print("=" * 70)
print(f"\n📋 ターゲットフォーマット分布(プロンプトから判定):")
print(f"{'Format':<12} {'Count':>6} {'Percent':>8}")
print("-" * 30)
for fmt in ["JSON", "YAML", "TOML", "XML", "CSV", "UNKNOWN"]:
count = format_from_prompt.get(fmt, 0)
pct = f"{count/total*100:.1f}%"
bar = "█" * int(count/total*50)
print(f"{fmt:<12} {count:>6} {pct:>8} {bar}")
print(f"\n📋 出力フォーマット分布(出力内容から判定):")
print(f"{'Format':<12} {'Count':>6} {'Percent':>8}")
print("-" * 30)
for fmt, count in format_from_output.most_common():
pct = f"{count/total*100:.1f}%"
bar = "█" * int(count/total*50)
print(f"{fmt:<12} {count:>6} {pct:>8} {bar}")
print(f"\n📋 タスクタイプ分布:")
print(f"{'Task Type':<25} {'Count':>6} {'Percent':>8}")
print("-" * 45)
for task, count in task_types.most_common(20):
pct = f"{count/total*100:.1f}%"
print(f"{task:<25} {count:>6} {pct:>8}")
# public_150との比較
print(f"\n📋 public_150.json との比較(参考):")
print(f"{'Format':<8} {'public_150':>12} {'dataset':>12} {'充足度':>10}")
print("-" * 45)
public_counts = {"JSON": 50, "YAML": 35, "TOML": 25, "XML": 20, "CSV": 20}
for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]:
pub = public_counts[fmt]
ds_count = format_from_prompt.get(fmt, 0)
ratio = f"{ds_count/pub:.1f}x" if pub > 0 else "N/A"
print(f"{fmt:<8} {pub:>12} {ds_count:>12} {ratio:>10}")
print(f"\n📋 各フォーマットのサンプル:")
for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]:
samples = samples_by_format.get(fmt, [])
print(f"\n--- {fmt} サンプル ({len(samples)}件) ---")
for s in samples:
print(f" [#{s['index']}] prompt: {s['prompt_preview'][:100]}")
print(f" output: {s['output_preview'][:100]}")
if __name__ == "__main__":
if len(sys.argv) > 1:
dataset_id = sys.argv[1]
else:
dataset_id = "u-10bei/structured_data_with_cot_dataset_512_v4"
analyze_dataset(dataset_id)