Spaces:
Running
Running
| """ | |
| Dataset Format Analyzer | |
| SFTデータセットのフォーマット分布を分析するスクリプト | |
| 指定されたHuggingFaceデータセットをダウンロードし、 | |
| 各サンプルのターゲット出力がどのフォーマット(JSON/YAML/TOML/XML/CSV) | |
| であるかを判定・集計します。 | |
| """ | |
| import json | |
| import re | |
| import sys | |
| from collections import Counter, defaultdict | |
| def detect_format(text): | |
| """テキストのフォーマットを推定する""" | |
| text = text.strip() | |
| # マークダウンブロック除去 | |
| cleaned = re.sub(r"```\w*\n?", "", text).strip() | |
| if not cleaned: | |
| return "EMPTY" | |
| # JSON: { or [ で始まる | |
| if cleaned.startswith("{") or cleaned.startswith("["): | |
| try: | |
| json.loads(cleaned) | |
| return "JSON" | |
| except: | |
| return "JSON" # JSONっぽいが壊れている | |
| # XML: < で始まる(<?xml or <tag) | |
| if cleaned.startswith("<"): | |
| return "XML" | |
| # CSV: カンマ区切りの複数行 | |
| lines = cleaned.split("\n") | |
| if len(lines) >= 2: | |
| comma_counts = [line.count(",") for line in lines[:5] if line.strip()] | |
| if comma_counts and all(c == comma_counts[0] and c > 0 for c in comma_counts): | |
| return "CSV" | |
| # TOML: [section] パターンまたは key = value パターン | |
| if re.match(r"^\[[\w\.\-]+\]", cleaned) or re.match(r'^[\w\.\-]+\s*=\s*', cleaned): | |
| return "TOML" | |
| # YAML: key: value パターン(インデント構造) | |
| if re.match(r'^[\w\-]+:\s', cleaned) or cleaned.startswith("---") or cleaned.startswith("- "): | |
| return "YAML" | |
| return "OTHER" | |
| def detect_format_from_prompt(prompt_text): | |
| """プロンプト(query)からターゲットフォーマットを推定""" | |
| prompt_lower = prompt_text.lower() | |
| # 明示的な指示を検索 | |
| patterns = { | |
| "JSON": [r"output\s+json", r"to\s+json", r"in\s+json", r"json\s+code", r"json\s+format"], | |
| "YAML": [r"output\s+yaml", r"to\s+yaml", r"in\s+yaml", r"yaml\s+code", r"yaml\s+format"], | |
| "TOML": [r"output\s+toml", r"to\s+toml", r"in\s+toml", r"toml\s+code", r"toml\s+format"], | |
| "XML": [r"output\s+xml", r"to\s+xml", r"in\s+xml", r"xml\s+code", r"xml\s+format"], | |
| "CSV": [r"output\s+csv", r"to\s+csv", r"in\s+csv", r"csv\s+code", r"csv\s+format"], | |
| } | |
| for fmt, pats in patterns.items(): | |
| for pat in pats: | |
| if re.search(pat, prompt_lower): | |
| return fmt | |
| # タスク名パターン (e.g., "Text to JSON", "CSV to YAML") | |
| task_pattern = r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)" | |
| match = re.search(task_pattern, prompt_lower) | |
| if match: | |
| return match.group(2).upper() | |
| return None | |
| def analyze_dataset(dataset_id): | |
| """HuggingFaceデータセットを分析""" | |
| from datasets import load_dataset | |
| print(f"📥 データセットをダウンロード中: {dataset_id}") | |
| ds = load_dataset(dataset_id, split="train") | |
| print(f"✅ ダウンロード完了: {len(ds)} 件\n") | |
| # messages構造を解析 | |
| format_from_output = Counter() | |
| format_from_prompt = Counter() | |
| task_types = Counter() | |
| cot_count = 0 | |
| samples_by_format = defaultdict(list) | |
| for i, row in enumerate(ds): | |
| messages = row.get("messages", []) | |
| # messagesからuser/assistantを抽出 | |
| user_msg = "" | |
| assistant_msg = "" | |
| has_cot = False | |
| for msg in messages: | |
| role = msg.get("role", "") | |
| content = msg.get("content", "") | |
| if role == "user": | |
| user_msg = content | |
| elif role == "assistant": | |
| assistant_msg = content | |
| if "<think>" in content or "</think>" in content: | |
| has_cot = True | |
| if has_cot: | |
| cot_count += 1 | |
| # CoT部分を除去してアシスタントの最終出力を取得 | |
| final_output = assistant_msg | |
| think_match = re.search(r"</think>\s*(.*)", assistant_msg, re.DOTALL) | |
| if think_match: | |
| final_output = think_match.group(1).strip() | |
| # 出力フォーマットを判定(2つの方法) | |
| fmt_output = detect_format(final_output) | |
| fmt_prompt = detect_format_from_prompt(user_msg) | |
| format_from_output[fmt_output] += 1 | |
| if fmt_prompt: | |
| format_from_prompt[fmt_prompt] += 1 | |
| else: | |
| format_from_prompt["UNKNOWN"] += 1 | |
| # タスクタイプ推定 | |
| task_match = re.search(r"(text|json|yaml|toml|xml|csv)\s+to\s+(json|yaml|toml|xml|csv)", user_msg.lower()) | |
| if task_match: | |
| task_type = f"{task_match.group(1).upper()} to {task_match.group(2).upper()}" | |
| elif "please output" in user_msg.lower(): | |
| task_type = f"Text to {fmt_prompt or fmt_output}" | |
| else: | |
| task_type = "OTHER" | |
| task_types[task_type] += 1 | |
| # サンプル保存(各フォーマット最大2件) | |
| fmt_key = fmt_prompt or fmt_output | |
| if len(samples_by_format[fmt_key]) < 2: | |
| samples_by_format[fmt_key].append({ | |
| "index": i, | |
| "prompt_preview": user_msg[:150], | |
| "output_preview": final_output[:150], | |
| }) | |
| # --- 結果出力 --- | |
| total = len(ds) | |
| print("=" * 70) | |
| print(f"📊 データセット分析結果: {dataset_id}") | |
| print(f" 総サンプル数: {total}") | |
| print(f" CoTあり: {cot_count} ({cot_count/total*100:.1f}%)") | |
| print("=" * 70) | |
| print(f"\n📋 ターゲットフォーマット分布(プロンプトから判定):") | |
| print(f"{'Format':<12} {'Count':>6} {'Percent':>8}") | |
| print("-" * 30) | |
| for fmt in ["JSON", "YAML", "TOML", "XML", "CSV", "UNKNOWN"]: | |
| count = format_from_prompt.get(fmt, 0) | |
| pct = f"{count/total*100:.1f}%" | |
| bar = "█" * int(count/total*50) | |
| print(f"{fmt:<12} {count:>6} {pct:>8} {bar}") | |
| print(f"\n📋 出力フォーマット分布(出力内容から判定):") | |
| print(f"{'Format':<12} {'Count':>6} {'Percent':>8}") | |
| print("-" * 30) | |
| for fmt, count in format_from_output.most_common(): | |
| pct = f"{count/total*100:.1f}%" | |
| bar = "█" * int(count/total*50) | |
| print(f"{fmt:<12} {count:>6} {pct:>8} {bar}") | |
| print(f"\n📋 タスクタイプ分布:") | |
| print(f"{'Task Type':<25} {'Count':>6} {'Percent':>8}") | |
| print("-" * 45) | |
| for task, count in task_types.most_common(20): | |
| pct = f"{count/total*100:.1f}%" | |
| print(f"{task:<25} {count:>6} {pct:>8}") | |
| # public_150との比較 | |
| print(f"\n📋 public_150.json との比較(参考):") | |
| print(f"{'Format':<8} {'public_150':>12} {'dataset':>12} {'充足度':>10}") | |
| print("-" * 45) | |
| public_counts = {"JSON": 50, "YAML": 35, "TOML": 25, "XML": 20, "CSV": 20} | |
| for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]: | |
| pub = public_counts[fmt] | |
| ds_count = format_from_prompt.get(fmt, 0) | |
| ratio = f"{ds_count/pub:.1f}x" if pub > 0 else "N/A" | |
| print(f"{fmt:<8} {pub:>12} {ds_count:>12} {ratio:>10}") | |
| print(f"\n📋 各フォーマットのサンプル:") | |
| for fmt in ["JSON", "YAML", "TOML", "XML", "CSV"]: | |
| samples = samples_by_format.get(fmt, []) | |
| print(f"\n--- {fmt} サンプル ({len(samples)}件) ---") | |
| for s in samples: | |
| print(f" [#{s['index']}] prompt: {s['prompt_preview'][:100]}") | |
| print(f" output: {s['output_preview'][:100]}") | |
| if __name__ == "__main__": | |
| if len(sys.argv) > 1: | |
| dataset_id = sys.argv[1] | |
| else: | |
| dataset_id = "u-10bei/structured_data_with_cot_dataset_512_v4" | |
| analyze_dataset(dataset_id) | |