ustwo-api / scripts /prepare_aihub_test_subset.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
15.8 kB
#!/usr/bin/env python3
"""AI Hub 감정 νƒœκΉ… μžμœ λŒ€ν™”(성인) 데이터셋 β†’ 벀치마크 ν…ŒμŠ€νŠΈ μ„œλΈŒμ…‹ μ€€λΉ„.
AI Hub #71631 λ°μ΄ν„°μ…‹μ˜ JSON 라벨 + μŠ€ν…Œλ ˆμ˜€ WAVμ—μ„œ λ°œν™” λ‹¨μœ„λ₯Ό μΆ”μΆœν•˜μ—¬
κ· ν˜• 작힌 6-class ν…ŒμŠ€νŠΈμ…‹μ„ μƒμ„±ν•œλ‹€.
Usage:
python scripts/prepare_aihub_test_subset.py --aihub-dir data/samples
python scripts/prepare_aihub_test_subset.py --aihub-dir /path/to/full/dataset --samples-per-class 83
"""
from __future__ import annotations
import argparse
import csv
import json
import logging
import os
import random
import sys
from collections import defaultdict
from pathlib import Path
import numpy as np
import soundfile as sf
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
# ──────────────────────────────────────────────
# Label Mapping: AI Hub ν•œκ΅­μ–΄ β†’ Project 6-class
# ──────────────────────────────────────────────
AIHUB_LABEL_MAP = {
"기쁨": "joy",
"놀라움": "surprise",
"두렀움": "fear",
"μ‚¬λž‘μŠ€λŸ¬μ›€": "joy", # Affection β†’ joy (user confirmed)
"μŠ¬ν””": "sadness",
"화남": "anger",
"μ—†μŒ": "neutral",
"쀑립": "neutral", # appears in SpeakerEmotionCategory
}
EVAL_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear"]
# Minimum utterance duration (seconds) β€” too short = unreliable emotion
MIN_DURATION_SEC = 0.5
# Maximum utterance duration β€” cap very long utterances
MAX_DURATION_SEC = 30.0
# ──────────────────────────────────────────────
# Step 1: Parse AI Hub JSON + discover WAV pairs
# ──────────────────────────────────────────────
def discover_pairs(aihub_dir: str) -> list[tuple[Path, Path]]:
"""Find matched WAV-JSON file pairs in AI Hub directory structure.
Expected structure:
aihub_dir/01.μ›μ²œλ°μ΄ν„°/{01.μ‹€λ‚΄,02.μ‹€μ™Έ}/xxx.wav
aihub_dir/02.라벨링데이터/{01.μ‹€λ‚΄,02.μ‹€μ™Έ}/xxx.json
"""
source_dir = Path(aihub_dir) / "01.μ›μ²œλ°μ΄ν„°"
label_dir = Path(aihub_dir) / "02.라벨링데이터"
if not source_dir.exists() or not label_dir.exists():
logger.error("Expected 01.μ›μ²œλ°μ΄ν„° and 02.라벨링데이터 under %s", aihub_dir)
sys.exit(1)
# Build WAV lookup: stem β†’ path
wav_lookup = {}
for wav_path in source_dir.rglob("*.wav"):
if str(wav_path).endswith(":Zone.Identifier"):
continue
wav_lookup[wav_path.stem] = wav_path
# Match JSON β†’ WAV
pairs = []
for json_path in label_dir.rglob("*.json"):
if str(json_path).endswith(":Zone.Identifier"):
continue
stem = json_path.stem
wav_path = wav_lookup.get(stem)
if wav_path:
pairs.append((wav_path, json_path))
else:
logger.warning("No WAV match for %s", json_path.name)
logger.info("Discovered %d WAV-JSON pairs", len(pairs))
return pairs
def parse_utterances(pairs: list[tuple[Path, Path]]) -> list[dict]:
"""Parse all utterances from JSON label files.
Uses VerifyEmotionTarget as ground truth (annotator-verified label).
"""
utterances = []
for wav_path, json_path in pairs:
with open(json_path, encoding="utf-8") as f:
data = json.load(f)
wav_info = data.get("Wav", {})
file_info = data.get("File", {})
sr = int(wav_info.get("SamplingRate", 16000))
n_channels = int(wav_info.get("NumberOfChannel", 2))
# Speaker info
speakers = {}
for key in ("Speaker1", "Speaker2"):
spk = data.get(key, {})
speakers[key] = {
"id": spk.get("ID", ""),
"gender": spk.get("Gender", ""),
"age": spk.get("Age", ""),
}
for utt in data.get("Conversation", []):
emotion_kr = utt.get("VerifyEmotionTarget", "").strip()
emotion_en = AIHUB_LABEL_MAP.get(emotion_kr)
if emotion_en is None:
continue # Unknown label, skip
if emotion_en not in EVAL_LABELS:
continue
try:
start = float(str(utt["StartTime"]).replace(",", ""))
end = float(str(utt["EndTime"]).replace(",", ""))
except (KeyError, ValueError):
continue
duration = end - start
if duration < MIN_DURATION_SEC or duration > MAX_DURATION_SEC:
continue
speaker_no = utt.get("SpeakerNo", "Speaker1")
speaker_info = speakers.get(speaker_no, {})
# Determine which channel to extract (0-indexed)
# Speaker1 = left channel (0), Speaker2 = right channel (1)
channel = 0 if speaker_no == "Speaker1" else 1
if n_channels == 1:
channel = 0
utterances.append({
"wav_path": str(wav_path),
"json_path": str(json_path),
"file_stem": wav_path.stem,
"text_no": utt.get("TextNo", ""),
"text": utt.get("Text", ""),
"start": start,
"end": end,
"duration": duration,
"emotion": emotion_en,
"emotion_kr": emotion_kr,
"intensity": utt.get("VerifyEmotionLevel", ""),
"speaker_no": speaker_no,
"speaker_id": speaker_info.get("id", ""),
"speaker_gender": speaker_info.get("gender", ""),
"speaker_age": speaker_info.get("age", ""),
"channel": channel,
"sample_rate": sr,
"n_channels": n_channels,
})
logger.info("Parsed %d valid utterances across %d files", len(utterances), len(pairs))
return utterances
# ──────────────────────────────────────────────
# Step 2: Balanced sampling
# ──────────────────────────────────────────────
def balanced_sample(
utterances: list[dict],
samples_per_class: int,
seed: int = 42,
) -> list[dict]:
"""Stratified balanced sampling: target samples_per_class per emotion.
Ensures:
- Duration diversity (short/medium/long mix)
- Speaker diversity (spread across speakers)
- For rare classes (e.g., fear), takes all available if < target
"""
rng = random.Random(seed)
# Group by emotion
by_emotion: dict[str, list[dict]] = defaultdict(list)
for utt in utterances:
by_emotion[utt["emotion"]].append(utt)
selected = []
stats = {}
for emotion in EVAL_LABELS:
pool = by_emotion.get(emotion, [])
if not pool:
logger.warning("No samples for emotion '%s'", emotion)
stats[emotion] = 0
continue
if len(pool) <= samples_per_class:
# Take all for rare classes
chosen = pool
else:
# Duration-stratified sampling
short = [u for u in pool if u["duration"] < 3.0]
medium = [u for u in pool if 3.0 <= u["duration"] < 10.0]
long = [u for u in pool if u["duration"] >= 10.0]
# Target ratio: 30% short, 50% medium, 20% long
n_short = max(1, int(samples_per_class * 0.3))
n_long = max(1, int(samples_per_class * 0.2))
n_medium = samples_per_class - n_short - n_long
chosen = []
for bucket, n in [(short, n_short), (medium, n_medium), (long, n_long)]:
rng.shuffle(bucket)
chosen.extend(bucket[:n])
# Fill remaining if any bucket was short
if len(chosen) < samples_per_class:
remaining = [u for u in pool if u not in chosen]
rng.shuffle(remaining)
chosen.extend(remaining[: samples_per_class - len(chosen)])
chosen = chosen[:samples_per_class]
selected.extend(chosen)
stats[emotion] = len(chosen)
logger.info("Sampling result: %s (total: %d)", stats, len(selected))
return selected
# ──────────────────────────────────────────────
# Step 3: Extract utterance WAVs
# ──────────────────────────────────────────────
def extract_utterances(
selected: list[dict],
output_dir: str,
) -> list[dict]:
"""Extract individual utterance WAV segments from conversation files.
Reads the stereo WAV, extracts the correct speaker channel,
and saves as mono 16kHz WAV.
"""
out_path = Path(output_dir)
records = []
# Cache loaded audio files (avoid re-reading same WAV)
audio_cache: dict[str, tuple[np.ndarray, int]] = {}
for i, utt in enumerate(selected):
emotion = utt["emotion"]
emotion_dir = out_path / "test_audio" / emotion
emotion_dir.mkdir(parents=True, exist_ok=True)
# Load audio (cached)
wav_path = utt["wav_path"]
if wav_path not in audio_cache:
try:
audio, sr = sf.read(wav_path, dtype="float32")
audio_cache[wav_path] = (audio, sr)
except Exception as e:
logger.warning("Failed to read %s: %s", wav_path, e)
continue
audio, sr = audio_cache[wav_path]
# Extract channel
if audio.ndim == 2:
channel = min(utt["channel"], audio.shape[1] - 1)
mono = audio[:, channel]
else:
mono = audio
# Extract time range
start_sample = int(utt["start"] * sr)
end_sample = int(utt["end"] * sr)
start_sample = max(0, start_sample)
end_sample = min(len(mono), end_sample)
segment = mono[start_sample:end_sample]
if len(segment) < int(MIN_DURATION_SEC * sr):
logger.warning("Segment too short after extraction: %s_%s", utt["file_stem"], utt["text_no"])
continue
# Resample to 16kHz if needed
if sr != 16000:
import librosa
segment = librosa.resample(segment, orig_sr=sr, target_sr=16000)
sr = 16000
# Save
filename = f"kr_{emotion}_{i:04d}.wav"
filepath = emotion_dir / filename
sf.write(str(filepath), segment, 16000, subtype="PCM_16")
records.append({
"file_path": str(filepath.relative_to(out_path)),
"emotion": emotion,
"duration": round(len(segment) / 16000, 3),
"speaker_id": utt["speaker_id"],
"speaker_gender": utt["speaker_gender"],
"intensity": utt["intensity"],
"text": utt["text"],
"source_file": utt["file_stem"],
})
if (i + 1) % 100 == 0:
logger.info("Extracted %d/%d utterances", i + 1, len(selected))
logger.info("Extracted %d utterance WAVs to %s", len(records), output_dir)
return records
# ──────────────────────────────────────────────
# Step 4: Write labels CSV + metadata JSON
# ──────────────────────────────────────────────
def write_outputs(records: list[dict], output_dir: str, utterances: list[dict]):
"""Write test_labels.csv and metadata.json."""
out_path = Path(output_dir)
# CSV
csv_path = out_path / "test_labels.csv"
with open(csv_path, "w", newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=[
"file_path", "emotion", "duration", "speaker_id",
"speaker_gender", "intensity", "text", "source_file",
])
writer.writeheader()
writer.writerows(records)
logger.info("Wrote %s (%d records)", csv_path, len(records))
# Metadata
from collections import Counter
emotion_dist = Counter(r["emotion"] for r in records)
duration_stats = [r["duration"] for r in records]
intensity_dist = Counter(r["intensity"] for r in records)
metadata = {
"dataset": "AI Hub #71631 - 감정이 νƒœκΉ…λœ μžμœ λŒ€ν™” (성인)",
"subset": "test",
"total_samples": len(records),
"eval_classes": EVAL_LABELS,
"label_mapping": AIHUB_LABEL_MAP,
"emotion_distribution": dict(emotion_dist),
"intensity_distribution": dict(intensity_dist),
"duration_stats": {
"mean": round(sum(duration_stats) / max(len(duration_stats), 1), 2),
"min": round(min(duration_stats, default=0), 2),
"max": round(max(duration_stats, default=0), 2),
},
"total_source_utterances": len(utterances),
"note": "disgust class absent from AI Hub dataset β€” 6-class evaluation",
}
meta_path = out_path / "metadata.json"
with open(meta_path, "w", encoding="utf-8") as f:
json.dump(metadata, f, indent=2, ensure_ascii=False)
logger.info("Wrote %s", meta_path)
# ──────────────────────────────────────────────
# Main
# ──────────────────────────────────────────────
def main():
parser = argparse.ArgumentParser(
description="AI Hub 감정 데이터셋 β†’ 벀치마크 ν…ŒμŠ€νŠΈ μ„œλΈŒμ…‹ μ€€λΉ„",
)
parser.add_argument("--aihub-dir", required=True, help="AI Hub 데이터 루트 (01.μ›μ²œλ°μ΄ν„°, 02.라벨링데이터 포함)")
parser.add_argument("--output-dir", default="data/evaluation/korean", help="좜λ ₯ 디렉토리")
parser.add_argument("--samples-per-class", type=int, default=83, help="ν΄λž˜μŠ€λ‹Ή λͺ©ν‘œ μƒ˜ν”Œ 수 (default: 83)")
parser.add_argument("--seed", type=int, default=42, help="랜덀 μ‹œλ“œ")
parser.add_argument("--ground-truth", default="verify", choices=["verify", "speaker"],
help="Ground truth μ†ŒμŠ€: verify=κ²€μ¦μž 라벨, speaker=ν™”μž 자기보고")
args = parser.parse_args()
# 1. Discover pairs
pairs = discover_pairs(args.aihub_dir)
if not pairs:
logger.error("No WAV-JSON pairs found")
sys.exit(1)
# 2. Parse utterances
utterances = parse_utterances(pairs)
if not utterances:
logger.error("No valid utterances parsed")
sys.exit(1)
# Log distribution before sampling
from collections import Counter
raw_dist = Counter(u["emotion"] for u in utterances)
logger.info("Raw distribution: %s", dict(raw_dist))
# 3. Balanced sampling
selected = balanced_sample(utterances, args.samples_per_class, seed=args.seed)
# 4. Extract WAVs
records = extract_utterances(selected, args.output_dir)
# 5. Write outputs
write_outputs(records, args.output_dir, utterances)
print(f"\nDone! Test subset ready at {args.output_dir}/")
print(f" - {len(records)} utterance WAVs in test_audio/")
print(f" - test_labels.csv")
print(f" - metadata.json")
if __name__ == "__main__":
main()