""" Dataset Format Analyzer SFTデータセットのフォーマット分布を分析するスクリプト 指定されたHuggingFaceデータセットをダウンロードし、 各サンプルのターゲット出力がどのフォーマット(JSON/YAML/TOML/XML/CSV) であるかを判定・集計します。 """ import json import re import sys from collections import Counter, defaultdict def detect_format(text): """テキストのフォーマットを推定する""" text = text.strip() # マークダウンブロック除去 cleaned = re.sub(r"```\w*\n?", "", text).strip() if not cleaned: return "EMPTY" # JSON: { or [ で始まる if cleaned.startswith("{") or cleaned.startswith("["): try: json.loads(cleaned) return "JSON" except: return "JSON" # JSONっぽいが壊れている # XML: < で始まる(= 2: comma_counts = [line.count(",") for line in lines[:5] if line.strip()] if comma_counts and all(c == comma_counts[0] and c > 0 for c in comma_counts): return "CSV" # TOML: [section] パターンまたは key = value パターン if re.match(r"^\[[\w\.\-]+\]", cleaned) or re.match(r'^[\w\.\-]+\s*=\s*', cleaned): return "TOML" # YAML: key: value パターン(インデント構造) if re.match(r'^[\w\-]+:\s', cleaned) or cleaned.startswith("---") or cleaned.startswith("- "): return "YAML" return "OTHER" def detect_format_from_prompt(prompt_text): """プロンプト(query)からターゲットフォーマットを推定""" prompt_lower = prompt_text.lower() # 明示的な指示を検索 patterns = { "JSON": [r"output\s+json", r"to\s+json", r"in\s+json", r"json\s+code", r"json\s+format"], "YAML": [r"output\s+yaml", r"to\s+yaml", r"in\s+yaml", r"yaml\s+code", r"yaml\s+format"], "TOML": [r"output\s+toml", r"to\s+toml", r"in\s+toml", r"toml\s+code", r"toml\s+format"], "XML": [r"output\s+xml", r"to\s+xml", r"in\s+xml", r"xml\s+code", r"xml\s+format"], "CSV": [r"output\s+csv", r"to\s+csv", r"in\s+csv", r"csv\s+code", r"csv\s+format"], } for fmt, pats in patterns.items(): for pat in pats: if re.search(pat, prompt_lower): return fmt # タスク名パターン (e.g., "Text to JSON", "CSV to YAML") task_pattern = r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)" match = re.search(task_pattern, prompt_lower) if match: return match.group(2).upper() return None def analyze_dataset(dataset_id): """HuggingFaceデータセットを分析""" from datasets import load_dataset print(f"📥 データセットをダウンロード中: {dataset_id}") ds = load_dataset(dataset_id, split="train") print(f"✅ ダウンロード完了: {len(ds)} 件\n") # messages構造を解析 format_from_output = Counter() format_from_prompt = Counter() task_types = Counter() cot_count = 0 samples_by_format = defaultdict(list) for i, row in enumerate(ds): messages = row.get("messages", []) # messagesからuser/assistantを抽出 user_msg = "" assistant_msg = "" has_cot = False for msg in messages: role = msg.get("role", "") content = msg.get("content", "") if role == "user": user_msg = content elif role == "assistant": assistant_msg = content if "" in content or "" in content: has_cot = True if has_cot: cot_count += 1 # CoT部分を除去してアシスタントの最終出力を取得 final_output = assistant_msg think_match = re.search(r"\s*(.*)", assistant_msg, re.DOTALL) if think_match: final_output = think_match.group(1).strip() # 出力フォーマットを判定(2つの方法) fmt_output = detect_format(final_output) fmt_prompt = detect_format_from_prompt(user_msg) format_from_output[fmt_output] += 1 if fmt_prompt: format_from_prompt[fmt_prompt] += 1 else: format_from_prompt["UNKNOWN"] += 1 # タスクタイプ推定 task_match = re.search(r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)", user_msg.lower()) if task_match: task_type = f"{task_match.group(1).upper()} to {task_match.group(2).upper()}" elif "please output" in user_msg.lower(): task_type = f"Text to {fmt_prompt or fmt_output}" else: task_type = "OTHER" task_types[task_type] += 1 # サンプル保存(各フォーマット最大2件) fmt_key = fmt_prompt or fmt_output if len(samples_by_format[fmt_key]) < 2: samples_by_format[fmt_key].append({ "index": i, "prompt_preview": user_msg[:150], "output_preview": final_output[:150], }) # --- 結果出力 --- total = len(ds) print("=" * 70) print(f"📊 データセット分析結果: {dataset_id}") print(f" 総サンプル数: {total}") print(f" CoTあり: {cot_count} ({cot_count/total*100:.1f}%)") print("=" * 70) print(f"\n📋 ターゲットフォーマット分布(プロンプトから判定):") print(f"{'Format':<12} {'Count':>6} {'Percent':>8}") print("-" * 30) for fmt in ["JSON", "YAML", "TOML", "XML", "CSV", "UNKNOWN"]: count = format_from_prompt.get(fmt, 0) pct = f"{count/total*100:.1f}%" bar = "█" * int(count/total*50) print(f"{fmt:<12} {count:>6} {pct:>8} {bar}") print(f"\n📋 出力フォーマット分布(出力内容から判定):") print(f"{'Format':<12} {'Count':>6} {'Percent':>8}") print("-" * 30) for fmt, count in format_from_output.most_common(): pct = f"{count/total*100:.1f}%" bar = "█" * int(count/total*50) print(f"{fmt:<12} {count:>6} {pct:>8} {bar}") print(f"\n📋 タスクタイプ分布:") print(f"{'Task Type':<25} {'Count':>6} {'Percent':>8}") print("-" * 45) for task, count in task_types.most_common(20): pct = f"{count/total*100:.1f}%" print(f"{task:<25} {count:>6} {pct:>8}") # public_150との比較 print(f"\n📋 public_150.json との比較(参考):") print(f"{'Format':<8} {'public_150':>12} {'dataset':>12} {'充足度':>10}") print("-" * 45) public_counts = {"JSON": 50, "YAML": 35, "TOML": 25, "XML": 20, "CSV": 20} for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]: pub = public_counts[fmt] ds_count = format_from_prompt.get(fmt, 0) ratio = f"{ds_count/pub:.1f}x" if pub > 0 else "N/A" print(f"{fmt:<8} {pub:>12} {ds_count:>12} {ratio:>10}") print(f"\n📋 各フォーマットのサンプル:") for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]: samples = samples_by_format.get(fmt, []) print(f"\n--- {fmt} サンプル ({len(samples)}件) ---") for s in samples: print(f" [#{s['index']}] prompt: {s['prompt_preview'][:100]}") print(f" output: {s['output_preview'][:100]}") if __name__ == "__main__": if len(sys.argv) > 1: dataset_id = sys.argv[1] else: dataset_id = "u-10bei/structured_data_with_cot_dataset_512_v4" analyze_dataset(dataset_id)