| |
| """AI Hub κ°μ νκΉ
μμ λν(μ±μΈ) λ°μ΄ν°μ
β λ²€μΉλ§ν¬ ν
μ€νΈ μλΈμ
μ€λΉ. |
| |
| AI Hub #71631 λ°μ΄ν°μ
μ JSON λΌλ²¨ + μ€ν
λ μ€ WAVμμ λ°ν λ¨μλ₯Ό μΆμΆνμ¬ |
| κ· ν μ‘ν 6-class ν
μ€νΈμ
μ μμ±νλ€. |
| |
| Usage: |
| python scripts/prepare_aihub_test_subset.py --aihub-dir data/samples |
| python scripts/prepare_aihub_test_subset.py --aihub-dir /path/to/full/dataset --samples-per-class 83 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| import logging |
| import os |
| import random |
| import sys |
| from collections import defaultdict |
| from pathlib import Path |
|
|
| import numpy as np |
| import soundfile as sf |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| |
| |
| |
|
|
| AIHUB_LABEL_MAP = { |
| "κΈ°μ¨": "joy", |
| "λλΌμ": "surprise", |
| "λλ €μ": "fear", |
| "μ¬λμ€λ¬μ": "joy", |
| "μ¬ν": "sadness", |
| "νλ¨": "anger", |
| "μμ": "neutral", |
| "μ€λ¦½": "neutral", |
| } |
|
|
| EVAL_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear"] |
|
|
| |
| MIN_DURATION_SEC = 0.5 |
| |
| MAX_DURATION_SEC = 30.0 |
|
|
|
|
| |
| |
| |
|
|
| def discover_pairs(aihub_dir: str) -> list[tuple[Path, Path]]: |
| """Find matched WAV-JSON file pairs in AI Hub directory structure. |
| |
| Expected structure: |
| aihub_dir/01.μμ²λ°μ΄ν°/{01.μ€λ΄,02.μ€μΈ}/xxx.wav |
| aihub_dir/02.λΌλ²¨λ§λ°μ΄ν°/{01.μ€λ΄,02.μ€μΈ}/xxx.json |
| """ |
| source_dir = Path(aihub_dir) / "01.μμ²λ°μ΄ν°" |
| label_dir = Path(aihub_dir) / "02.λΌλ²¨λ§λ°μ΄ν°" |
|
|
| if not source_dir.exists() or not label_dir.exists(): |
| logger.error("Expected 01.μμ²λ°μ΄ν° and 02.λΌλ²¨λ§λ°μ΄ν° under %s", aihub_dir) |
| sys.exit(1) |
|
|
| |
| wav_lookup = {} |
| for wav_path in source_dir.rglob("*.wav"): |
| if str(wav_path).endswith(":Zone.Identifier"): |
| continue |
| wav_lookup[wav_path.stem] = wav_path |
|
|
| |
| pairs = [] |
| for json_path in label_dir.rglob("*.json"): |
| if str(json_path).endswith(":Zone.Identifier"): |
| continue |
| stem = json_path.stem |
| wav_path = wav_lookup.get(stem) |
| if wav_path: |
| pairs.append((wav_path, json_path)) |
| else: |
| logger.warning("No WAV match for %s", json_path.name) |
|
|
| logger.info("Discovered %d WAV-JSON pairs", len(pairs)) |
| return pairs |
|
|
|
|
| def parse_utterances(pairs: list[tuple[Path, Path]]) -> list[dict]: |
| """Parse all utterances from JSON label files. |
| |
| Uses VerifyEmotionTarget as ground truth (annotator-verified label). |
| """ |
| utterances = [] |
|
|
| for wav_path, json_path in pairs: |
| with open(json_path, encoding="utf-8") as f: |
| data = json.load(f) |
|
|
| wav_info = data.get("Wav", {}) |
| file_info = data.get("File", {}) |
| sr = int(wav_info.get("SamplingRate", 16000)) |
| n_channels = int(wav_info.get("NumberOfChannel", 2)) |
|
|
| |
| speakers = {} |
| for key in ("Speaker1", "Speaker2"): |
| spk = data.get(key, {}) |
| speakers[key] = { |
| "id": spk.get("ID", ""), |
| "gender": spk.get("Gender", ""), |
| "age": spk.get("Age", ""), |
| } |
|
|
| for utt in data.get("Conversation", []): |
| emotion_kr = utt.get("VerifyEmotionTarget", "").strip() |
| emotion_en = AIHUB_LABEL_MAP.get(emotion_kr) |
| if emotion_en is None: |
| continue |
|
|
| if emotion_en not in EVAL_LABELS: |
| continue |
|
|
| try: |
| start = float(str(utt["StartTime"]).replace(",", "")) |
| end = float(str(utt["EndTime"]).replace(",", "")) |
| except (KeyError, ValueError): |
| continue |
|
|
| duration = end - start |
| if duration < MIN_DURATION_SEC or duration > MAX_DURATION_SEC: |
| continue |
|
|
| speaker_no = utt.get("SpeakerNo", "Speaker1") |
| speaker_info = speakers.get(speaker_no, {}) |
|
|
| |
| |
| channel = 0 if speaker_no == "Speaker1" else 1 |
| if n_channels == 1: |
| channel = 0 |
|
|
| utterances.append({ |
| "wav_path": str(wav_path), |
| "json_path": str(json_path), |
| "file_stem": wav_path.stem, |
| "text_no": utt.get("TextNo", ""), |
| "text": utt.get("Text", ""), |
| "start": start, |
| "end": end, |
| "duration": duration, |
| "emotion": emotion_en, |
| "emotion_kr": emotion_kr, |
| "intensity": utt.get("VerifyEmotionLevel", ""), |
| "speaker_no": speaker_no, |
| "speaker_id": speaker_info.get("id", ""), |
| "speaker_gender": speaker_info.get("gender", ""), |
| "speaker_age": speaker_info.get("age", ""), |
| "channel": channel, |
| "sample_rate": sr, |
| "n_channels": n_channels, |
| }) |
|
|
| logger.info("Parsed %d valid utterances across %d files", len(utterances), len(pairs)) |
| return utterances |
|
|
|
|
| |
| |
| |
|
|
| def balanced_sample( |
| utterances: list[dict], |
| samples_per_class: int, |
| seed: int = 42, |
| ) -> list[dict]: |
| """Stratified balanced sampling: target samples_per_class per emotion. |
| |
| Ensures: |
| - Duration diversity (short/medium/long mix) |
| - Speaker diversity (spread across speakers) |
| - For rare classes (e.g., fear), takes all available if < target |
| """ |
| rng = random.Random(seed) |
|
|
| |
| by_emotion: dict[str, list[dict]] = defaultdict(list) |
| for utt in utterances: |
| by_emotion[utt["emotion"]].append(utt) |
|
|
| selected = [] |
| stats = {} |
|
|
| for emotion in EVAL_LABELS: |
| pool = by_emotion.get(emotion, []) |
| if not pool: |
| logger.warning("No samples for emotion '%s'", emotion) |
| stats[emotion] = 0 |
| continue |
|
|
| if len(pool) <= samples_per_class: |
| |
| chosen = pool |
| else: |
| |
| short = [u for u in pool if u["duration"] < 3.0] |
| medium = [u for u in pool if 3.0 <= u["duration"] < 10.0] |
| long = [u for u in pool if u["duration"] >= 10.0] |
|
|
| |
| n_short = max(1, int(samples_per_class * 0.3)) |
| n_long = max(1, int(samples_per_class * 0.2)) |
| n_medium = samples_per_class - n_short - n_long |
|
|
| chosen = [] |
| for bucket, n in [(short, n_short), (medium, n_medium), (long, n_long)]: |
| rng.shuffle(bucket) |
| chosen.extend(bucket[:n]) |
|
|
| |
| if len(chosen) < samples_per_class: |
| remaining = [u for u in pool if u not in chosen] |
| rng.shuffle(remaining) |
| chosen.extend(remaining[: samples_per_class - len(chosen)]) |
|
|
| chosen = chosen[:samples_per_class] |
|
|
| selected.extend(chosen) |
| stats[emotion] = len(chosen) |
|
|
| logger.info("Sampling result: %s (total: %d)", stats, len(selected)) |
| return selected |
|
|
|
|
| |
| |
| |
|
|
| def extract_utterances( |
| selected: list[dict], |
| output_dir: str, |
| ) -> list[dict]: |
| """Extract individual utterance WAV segments from conversation files. |
| |
| Reads the stereo WAV, extracts the correct speaker channel, |
| and saves as mono 16kHz WAV. |
| """ |
| out_path = Path(output_dir) |
| records = [] |
|
|
| |
| audio_cache: dict[str, tuple[np.ndarray, int]] = {} |
|
|
| for i, utt in enumerate(selected): |
| emotion = utt["emotion"] |
| emotion_dir = out_path / "test_audio" / emotion |
| emotion_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| wav_path = utt["wav_path"] |
| if wav_path not in audio_cache: |
| try: |
| audio, sr = sf.read(wav_path, dtype="float32") |
| audio_cache[wav_path] = (audio, sr) |
| except Exception as e: |
| logger.warning("Failed to read %s: %s", wav_path, e) |
| continue |
|
|
| audio, sr = audio_cache[wav_path] |
|
|
| |
| if audio.ndim == 2: |
| channel = min(utt["channel"], audio.shape[1] - 1) |
| mono = audio[:, channel] |
| else: |
| mono = audio |
|
|
| |
| start_sample = int(utt["start"] * sr) |
| end_sample = int(utt["end"] * sr) |
| start_sample = max(0, start_sample) |
| end_sample = min(len(mono), end_sample) |
|
|
| segment = mono[start_sample:end_sample] |
|
|
| if len(segment) < int(MIN_DURATION_SEC * sr): |
| logger.warning("Segment too short after extraction: %s_%s", utt["file_stem"], utt["text_no"]) |
| continue |
|
|
| |
| if sr != 16000: |
| import librosa |
| segment = librosa.resample(segment, orig_sr=sr, target_sr=16000) |
| sr = 16000 |
|
|
| |
| filename = f"kr_{emotion}_{i:04d}.wav" |
| filepath = emotion_dir / filename |
| sf.write(str(filepath), segment, 16000, subtype="PCM_16") |
|
|
| records.append({ |
| "file_path": str(filepath.relative_to(out_path)), |
| "emotion": emotion, |
| "duration": round(len(segment) / 16000, 3), |
| "speaker_id": utt["speaker_id"], |
| "speaker_gender": utt["speaker_gender"], |
| "intensity": utt["intensity"], |
| "text": utt["text"], |
| "source_file": utt["file_stem"], |
| }) |
|
|
| if (i + 1) % 100 == 0: |
| logger.info("Extracted %d/%d utterances", i + 1, len(selected)) |
|
|
| logger.info("Extracted %d utterance WAVs to %s", len(records), output_dir) |
| return records |
|
|
|
|
| |
| |
| |
|
|
| def write_outputs(records: list[dict], output_dir: str, utterances: list[dict]): |
| """Write test_labels.csv and metadata.json.""" |
| out_path = Path(output_dir) |
|
|
| |
| csv_path = out_path / "test_labels.csv" |
| with open(csv_path, "w", newline="", encoding="utf-8") as f: |
| writer = csv.DictWriter(f, fieldnames=[ |
| "file_path", "emotion", "duration", "speaker_id", |
| "speaker_gender", "intensity", "text", "source_file", |
| ]) |
| writer.writeheader() |
| writer.writerows(records) |
| logger.info("Wrote %s (%d records)", csv_path, len(records)) |
|
|
| |
| from collections import Counter |
| emotion_dist = Counter(r["emotion"] for r in records) |
| duration_stats = [r["duration"] for r in records] |
| intensity_dist = Counter(r["intensity"] for r in records) |
|
|
| metadata = { |
| "dataset": "AI Hub #71631 - κ°μ μ΄ νκΉ
λ μμ λν (μ±μΈ)", |
| "subset": "test", |
| "total_samples": len(records), |
| "eval_classes": EVAL_LABELS, |
| "label_mapping": AIHUB_LABEL_MAP, |
| "emotion_distribution": dict(emotion_dist), |
| "intensity_distribution": dict(intensity_dist), |
| "duration_stats": { |
| "mean": round(sum(duration_stats) / max(len(duration_stats), 1), 2), |
| "min": round(min(duration_stats, default=0), 2), |
| "max": round(max(duration_stats, default=0), 2), |
| }, |
| "total_source_utterances": len(utterances), |
| "note": "disgust class absent from AI Hub dataset β 6-class evaluation", |
| } |
|
|
| meta_path = out_path / "metadata.json" |
| with open(meta_path, "w", encoding="utf-8") as f: |
| json.dump(metadata, f, indent=2, ensure_ascii=False) |
| logger.info("Wrote %s", meta_path) |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="AI Hub κ°μ λ°μ΄ν°μ
β λ²€μΉλ§ν¬ ν
μ€νΈ μλΈμ
μ€λΉ", |
| ) |
| parser.add_argument("--aihub-dir", required=True, help="AI Hub λ°μ΄ν° λ£¨νΈ (01.μμ²λ°μ΄ν°, 02.λΌλ²¨λ§λ°μ΄ν° ν¬ν¨)") |
| parser.add_argument("--output-dir", default="data/evaluation/korean", help="μΆλ ₯ λλ ν 리") |
| parser.add_argument("--samples-per-class", type=int, default=83, help="ν΄λμ€λΉ λͺ©ν μν μ (default: 83)") |
| parser.add_argument("--seed", type=int, default=42, help="λλ€ μλ") |
| parser.add_argument("--ground-truth", default="verify", choices=["verify", "speaker"], |
| help="Ground truth μμ€: verify=κ²μ¦μ λΌλ²¨, speaker=νμ μκΈ°λ³΄κ³ ") |
| args = parser.parse_args() |
|
|
| |
| pairs = discover_pairs(args.aihub_dir) |
| if not pairs: |
| logger.error("No WAV-JSON pairs found") |
| sys.exit(1) |
|
|
| |
| utterances = parse_utterances(pairs) |
| if not utterances: |
| logger.error("No valid utterances parsed") |
| sys.exit(1) |
|
|
| |
| from collections import Counter |
| raw_dist = Counter(u["emotion"] for u in utterances) |
| logger.info("Raw distribution: %s", dict(raw_dist)) |
|
|
| |
| selected = balanced_sample(utterances, args.samples_per_class, seed=args.seed) |
|
|
| |
| records = extract_utterances(selected, args.output_dir) |
|
|
| |
| write_outputs(records, args.output_dir, utterances) |
|
|
| print(f"\nDone! Test subset ready at {args.output_dir}/") |
| print(f" - {len(records)} utterance WAVs in test_audio/") |
| print(f" - test_labels.csv") |
| print(f" - metadata.json") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|