#!/usr/bin/env python3 """AI Hub 감정 태깅 자유대화(성인) 데이터셋 → 벤치마크 테스트 서브셋 준비. AI Hub #71631 데이터셋의 JSON 라벨 + 스테레오 WAV에서 발화 단위를 추출하여 균형 잡힌 6-class 테스트셋을 생성한다. Usage: python scripts/prepare_aihub_test_subset.py --aihub-dir data/samples python scripts/prepare_aihub_test_subset.py --aihub-dir /path/to/full/dataset --samples-per-class 83 """ from __future__ import annotations import argparse import csv import json import logging import os import random import sys from collections import defaultdict from pathlib import Path import numpy as np import soundfile as sf logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") logger = logging.getLogger(__name__) # ────────────────────────────────────────────── # Label Mapping: AI Hub 한국어 → Project 6-class # ────────────────────────────────────────────── AIHUB_LABEL_MAP = { "기쁨": "joy", "놀라움": "surprise", "두려움": "fear", "사랑스러움": "joy", # Affection → joy (user confirmed) "슬픔": "sadness", "화남": "anger", "없음": "neutral", "중립": "neutral", # appears in SpeakerEmotionCategory } EVAL_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear"] # Minimum utterance duration (seconds) — too short = unreliable emotion MIN_DURATION_SEC = 0.5 # Maximum utterance duration — cap very long utterances MAX_DURATION_SEC = 30.0 # ────────────────────────────────────────────── # Step 1: Parse AI Hub JSON + discover WAV pairs # ────────────────────────────────────────────── def discover_pairs(aihub_dir: str) -> list[tuple[Path, Path]]: """Find matched WAV-JSON file pairs in AI Hub directory structure. Expected structure: aihub_dir/01.원천데이터/{01.실내,02.실외}/xxx.wav aihub_dir/02.라벨링데이터/{01.실내,02.실외}/xxx.json """ source_dir = Path(aihub_dir) / "01.원천데이터" label_dir = Path(aihub_dir) / "02.라벨링데이터" if not source_dir.exists() or not label_dir.exists(): logger.error("Expected 01.원천데이터 and 02.라벨링데이터 under %s", aihub_dir) sys.exit(1) # Build WAV lookup: stem → path wav_lookup = {} for wav_path in source_dir.rglob("*.wav"): if str(wav_path).endswith(":Zone.Identifier"): continue wav_lookup[wav_path.stem] = wav_path # Match JSON → WAV pairs = [] for json_path in label_dir.rglob("*.json"): if str(json_path).endswith(":Zone.Identifier"): continue stem = json_path.stem wav_path = wav_lookup.get(stem) if wav_path: pairs.append((wav_path, json_path)) else: logger.warning("No WAV match for %s", json_path.name) logger.info("Discovered %d WAV-JSON pairs", len(pairs)) return pairs def parse_utterances(pairs: list[tuple[Path, Path]]) -> list[dict]: """Parse all utterances from JSON label files. Uses VerifyEmotionTarget as ground truth (annotator-verified label). """ utterances = [] for wav_path, json_path in pairs: with open(json_path, encoding="utf-8") as f: data = json.load(f) wav_info = data.get("Wav", {}) file_info = data.get("File", {}) sr = int(wav_info.get("SamplingRate", 16000)) n_channels = int(wav_info.get("NumberOfChannel", 2)) # Speaker info speakers = {} for key in ("Speaker1", "Speaker2"): spk = data.get(key, {}) speakers[key] = { "id": spk.get("ID", ""), "gender": spk.get("Gender", ""), "age": spk.get("Age", ""), } for utt in data.get("Conversation", []): emotion_kr = utt.get("VerifyEmotionTarget", "").strip() emotion_en = AIHUB_LABEL_MAP.get(emotion_kr) if emotion_en is None: continue # Unknown label, skip if emotion_en not in EVAL_LABELS: continue try: start = float(str(utt["StartTime"]).replace(",", "")) end = float(str(utt["EndTime"]).replace(",", "")) except (KeyError, ValueError): continue duration = end - start if duration < MIN_DURATION_SEC or duration > MAX_DURATION_SEC: continue speaker_no = utt.get("SpeakerNo", "Speaker1") speaker_info = speakers.get(speaker_no, {}) # Determine which channel to extract (0-indexed) # Speaker1 = left channel (0), Speaker2 = right channel (1) channel = 0 if speaker_no == "Speaker1" else 1 if n_channels == 1: channel = 0 utterances.append({ "wav_path": str(wav_path), "json_path": str(json_path), "file_stem": wav_path.stem, "text_no": utt.get("TextNo", ""), "text": utt.get("Text", ""), "start": start, "end": end, "duration": duration, "emotion": emotion_en, "emotion_kr": emotion_kr, "intensity": utt.get("VerifyEmotionLevel", ""), "speaker_no": speaker_no, "speaker_id": speaker_info.get("id", ""), "speaker_gender": speaker_info.get("gender", ""), "speaker_age": speaker_info.get("age", ""), "channel": channel, "sample_rate": sr, "n_channels": n_channels, }) logger.info("Parsed %d valid utterances across %d files", len(utterances), len(pairs)) return utterances # ────────────────────────────────────────────── # Step 2: Balanced sampling # ────────────────────────────────────────────── def balanced_sample( utterances: list[dict], samples_per_class: int, seed: int = 42, ) -> list[dict]: """Stratified balanced sampling: target samples_per_class per emotion. Ensures: - Duration diversity (short/medium/long mix) - Speaker diversity (spread across speakers) - For rare classes (e.g., fear), takes all available if < target """ rng = random.Random(seed) # Group by emotion by_emotion: dict[str, list[dict]] = defaultdict(list) for utt in utterances: by_emotion[utt["emotion"]].append(utt) selected = [] stats = {} for emotion in EVAL_LABELS: pool = by_emotion.get(emotion, []) if not pool: logger.warning("No samples for emotion '%s'", emotion) stats[emotion] = 0 continue if len(pool) <= samples_per_class: # Take all for rare classes chosen = pool else: # Duration-stratified sampling short = [u for u in pool if u["duration"] < 3.0] medium = [u for u in pool if 3.0 <= u["duration"] < 10.0] long = [u for u in pool if u["duration"] >= 10.0] # Target ratio: 30% short, 50% medium, 20% long n_short = max(1, int(samples_per_class * 0.3)) n_long = max(1, int(samples_per_class * 0.2)) n_medium = samples_per_class - n_short - n_long chosen = [] for bucket, n in [(short, n_short), (medium, n_medium), (long, n_long)]: rng.shuffle(bucket) chosen.extend(bucket[:n]) # Fill remaining if any bucket was short if len(chosen) < samples_per_class: remaining = [u for u in pool if u not in chosen] rng.shuffle(remaining) chosen.extend(remaining[: samples_per_class - len(chosen)]) chosen = chosen[:samples_per_class] selected.extend(chosen) stats[emotion] = len(chosen) logger.info("Sampling result: %s (total: %d)", stats, len(selected)) return selected # ────────────────────────────────────────────── # Step 3: Extract utterance WAVs # ────────────────────────────────────────────── def extract_utterances( selected: list[dict], output_dir: str, ) -> list[dict]: """Extract individual utterance WAV segments from conversation files. Reads the stereo WAV, extracts the correct speaker channel, and saves as mono 16kHz WAV. """ out_path = Path(output_dir) records = [] # Cache loaded audio files (avoid re-reading same WAV) audio_cache: dict[str, tuple[np.ndarray, int]] = {} for i, utt in enumerate(selected): emotion = utt["emotion"] emotion_dir = out_path / "test_audio" / emotion emotion_dir.mkdir(parents=True, exist_ok=True) # Load audio (cached) wav_path = utt["wav_path"] if wav_path not in audio_cache: try: audio, sr = sf.read(wav_path, dtype="float32") audio_cache[wav_path] = (audio, sr) except Exception as e: logger.warning("Failed to read %s: %s", wav_path, e) continue audio, sr = audio_cache[wav_path] # Extract channel if audio.ndim == 2: channel = min(utt["channel"], audio.shape[1] - 1) mono = audio[:, channel] else: mono = audio # Extract time range start_sample = int(utt["start"] * sr) end_sample = int(utt["end"] * sr) start_sample = max(0, start_sample) end_sample = min(len(mono), end_sample) segment = mono[start_sample:end_sample] if len(segment) < int(MIN_DURATION_SEC * sr): logger.warning("Segment too short after extraction: %s_%s", utt["file_stem"], utt["text_no"]) continue # Resample to 16kHz if needed if sr != 16000: import librosa segment = librosa.resample(segment, orig_sr=sr, target_sr=16000) sr = 16000 # Save filename = f"kr_{emotion}_{i:04d}.wav" filepath = emotion_dir / filename sf.write(str(filepath), segment, 16000, subtype="PCM_16") records.append({ "file_path": str(filepath.relative_to(out_path)), "emotion": emotion, "duration": round(len(segment) / 16000, 3), "speaker_id": utt["speaker_id"], "speaker_gender": utt["speaker_gender"], "intensity": utt["intensity"], "text": utt["text"], "source_file": utt["file_stem"], }) if (i + 1) % 100 == 0: logger.info("Extracted %d/%d utterances", i + 1, len(selected)) logger.info("Extracted %d utterance WAVs to %s", len(records), output_dir) return records # ────────────────────────────────────────────── # Step 4: Write labels CSV + metadata JSON # ────────────────────────────────────────────── def write_outputs(records: list[dict], output_dir: str, utterances: list[dict]): """Write test_labels.csv and metadata.json.""" out_path = Path(output_dir) # CSV csv_path = out_path / "test_labels.csv" with open(csv_path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=[ "file_path", "emotion", "duration", "speaker_id", "speaker_gender", "intensity", "text", "source_file", ]) writer.writeheader() writer.writerows(records) logger.info("Wrote %s (%d records)", csv_path, len(records)) # Metadata from collections import Counter emotion_dist = Counter(r["emotion"] for r in records) duration_stats = [r["duration"] for r in records] intensity_dist = Counter(r["intensity"] for r in records) metadata = { "dataset": "AI Hub #71631 - 감정이 태깅된 자유대화 (성인)", "subset": "test", "total_samples": len(records), "eval_classes": EVAL_LABELS, "label_mapping": AIHUB_LABEL_MAP, "emotion_distribution": dict(emotion_dist), "intensity_distribution": dict(intensity_dist), "duration_stats": { "mean": round(sum(duration_stats) / max(len(duration_stats), 1), 2), "min": round(min(duration_stats, default=0), 2), "max": round(max(duration_stats, default=0), 2), }, "total_source_utterances": len(utterances), "note": "disgust class absent from AI Hub dataset — 6-class evaluation", } meta_path = out_path / "metadata.json" with open(meta_path, "w", encoding="utf-8") as f: json.dump(metadata, f, indent=2, ensure_ascii=False) logger.info("Wrote %s", meta_path) # ────────────────────────────────────────────── # Main # ────────────────────────────────────────────── def main(): parser = argparse.ArgumentParser( description="AI Hub 감정 데이터셋 → 벤치마크 테스트 서브셋 준비", ) parser.add_argument("--aihub-dir", required=True, help="AI Hub 데이터 루트 (01.원천데이터, 02.라벨링데이터 포함)") parser.add_argument("--output-dir", default="data/evaluation/korean", help="출력 디렉토리") parser.add_argument("--samples-per-class", type=int, default=83, help="클래스당 목표 샘플 수 (default: 83)") parser.add_argument("--seed", type=int, default=42, help="랜덤 시드") parser.add_argument("--ground-truth", default="verify", choices=["verify", "speaker"], help="Ground truth 소스: verify=검증자 라벨, speaker=화자 자기보고") args = parser.parse_args() # 1. Discover pairs pairs = discover_pairs(args.aihub_dir) if not pairs: logger.error("No WAV-JSON pairs found") sys.exit(1) # 2. Parse utterances utterances = parse_utterances(pairs) if not utterances: logger.error("No valid utterances parsed") sys.exit(1) # Log distribution before sampling from collections import Counter raw_dist = Counter(u["emotion"] for u in utterances) logger.info("Raw distribution: %s", dict(raw_dist)) # 3. Balanced sampling selected = balanced_sample(utterances, args.samples_per_class, seed=args.seed) # 4. Extract WAVs records = extract_utterances(selected, args.output_dir) # 5. Write outputs write_outputs(records, args.output_dir, utterances) print(f"\nDone! Test subset ready at {args.output_dir}/") print(f" - {len(records)} utterance WAVs in test_audio/") print(f" - test_labels.csv") print(f" - metadata.json") if __name__ == "__main__": main()