ustwo-api / scripts /prepare_lora_dataset.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
28.6 kB
#!/usr/bin/env python3
"""Unified data preprocessing pipeline for LoRA emotion2vec 7-class fine-tuning.
Extracts, preprocesses, and merges samples from three sources:
1. AI Hub 263 — anchor (acted Korean, 7-class)
2. AI Hub 71631 — booster (outdoor spontaneous Korean, mapped to 7-class)
3. RAVDESS — English (acted, 7-class)
Outputs a unified manifest (train/val JSONs) ready for LoRA training.
Usage:
python scripts/prepare_lora_dataset.py \
--anchor-dir "data/AI Hub 263" \
--booster-label-zip "data/AI Hub 71631/01-1.정식개방데이터/Training/02.라벨링데이터/TL_02.실외.zip" \
--booster-audio-zip "data/AI Hub 71631/01-1.정식개방데이터/Training/01.원천데이터/TS_02.실외.zip" \
--ravdess-dir data/ravdess \
--output-dir data/lora_7class
"""
from __future__ import annotations
import argparse
import csv
import io
import json
import logging
import random
import zipfile
from collections import Counter, defaultdict
from pathlib import Path
import torch
import torchaudio
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)-8s %(message)s",
datefmt="%H:%M:%S",
)
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Constants
# ---------------------------------------------------------------------------
LABEL2IDX: dict[str, int] = {
"happiness": 0,
"anger": 1,
"disgust": 2,
"fear": 3,
"neutral": 4,
"sadness": 5,
"surprise": 6,
}
VALID_LABELS = set(LABEL2IDX.keys())
TARGET_SR = 16_000
RMS_THRESHOLD = 0.001 # 0.005→0.001: disgust 등 저에너지 발화 보존 (진짜 무음만 제거)
# ---------------------------------------------------------------------------
# Task 1 — Label Mappers
# ---------------------------------------------------------------------------
_MAP_263: dict[str, str] = {
"angry": "anger",
"happiness": "happiness",
"neutral": "neutral",
"sadness": "sadness",
"surprise": "surprise",
"fear": "fear",
"disgust": "disgust",
}
def map_263_label(raw: str) -> str | None:
"""Map AI Hub 263 annotator label to 7-class. Case-insensitive."""
if not raw:
return None
return _MAP_263.get(raw.strip().lower())
_MAP_71631: dict[str, str] = {
"기쁨": "happiness",
"화남": "anger",
"놀라움": "surprise",
"슬픔": "sadness",
"두려움": "fear",
"없음": "neutral",
"중립": "neutral",
}
def map_71631_label(raw: str) -> str | None:
"""Map AI Hub 71631 Korean emotion label to 7-class."""
if not raw:
return None
return _MAP_71631.get(raw.strip())
def map_ravdess_label(raw: str) -> str | None:
"""Map RAVDESS label to 7-class. Passthrough except joy→happiness."""
if not raw:
return None
lbl = raw.strip().lower()
if lbl == "joy":
return "happiness"
if lbl in VALID_LABELS:
return lbl
return None
def majority_vote_263(emotions: list[str | None]) -> str | None:
"""Return majority label (3/5+) or None for no majority/tie."""
valid = [e for e in emotions if e is not None]
if not valid:
return None
counts = Counter(valid)
top_label, top_count = counts.most_common(1)[0]
if top_count < 3:
return None
# Check for tie at top count
tied = [lbl for lbl, c in counts.items() if c == top_count]
if len(tied) > 1:
return None
return top_label
# ---------------------------------------------------------------------------
# Task 1 — Audio Preprocessing
# ---------------------------------------------------------------------------
import re
def _clean_text(text: str) -> str:
"""Clean text for STT-friendly format.
Removes non-verbal tags like (웃음), (한숨), keeps ?, !, ...
"""
# Remove non-verbal tags: (웃음), (한숨), (침묵), [noise], etc.
text = re.sub(r"[(\[(][^)\])]*[)\])]", "", text)
# Remove trailing/leading whitespace, collapse multiple spaces
text = re.sub(r"\s+", " ", text).strip()
return text
def _compute_rms(waveform: torch.Tensor) -> float:
"""Compute RMS of a waveform tensor."""
return float(torch.sqrt(torch.mean(waveform.float() ** 2)))
def preprocess_audio(input_path: Path, output_path: Path) -> bool:
"""Resample to 16kHz mono, trim silence, reject if RMS < threshold.
Returns True if file was saved, False if rejected.
"""
input_path = Path(input_path)
output_path = Path(output_path)
waveform, sr = torchaudio.load(str(input_path))
# Mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
# Resample
if sr != TARGET_SR:
waveform = torchaudio.functional.resample(waveform, sr, TARGET_SR)
# Trim silence (leading/trailing)
waveform_trimmed = torchaudio.functional.vad(waveform, TARGET_SR)
if waveform_trimmed.numel() > 0:
waveform = waveform_trimmed
# RMS check
rms = _compute_rms(waveform)
if rms < RMS_THRESHOLD:
return False
output_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(str(output_path), waveform, TARGET_SR)
return True
def preprocess_audio_from_tensor(
waveform: torch.Tensor, sr: int, output_path: Path
) -> bool:
"""Preprocess an in-memory waveform tensor and save to output_path.
Used for 71631 where we slice in memory from the full conversation wav.
"""
output_path = Path(output_path)
# Mono
if waveform.dim() > 1 and waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
elif waveform.dim() == 1:
waveform = waveform.unsqueeze(0)
# Resample
if sr != TARGET_SR:
waveform = torchaudio.functional.resample(waveform, sr, TARGET_SR)
# Trim silence (leading/trailing) — same as preprocess_audio
waveform_trimmed = torchaudio.functional.vad(waveform, TARGET_SR)
if waveform_trimmed.numel() > 0:
waveform = waveform_trimmed
# RMS check
rms = _compute_rms(waveform)
if rms < RMS_THRESHOLD:
return False
output_path.parent.mkdir(parents=True, exist_ok=True)
torchaudio.save(str(output_path), waveform, TARGET_SR)
return True
# ---------------------------------------------------------------------------
# Task 2 — AI Hub 263 Extraction
# ---------------------------------------------------------------------------
def parse_263_row(row: list[str]) -> dict | None:
"""Parse a CSV row from AI Hub 263.
Columns: wav_id, 발화문, 상황, 1번감정, 1번감정세기, 2번감정, 2번감정세기,
3번감정, 3번감정세기, 4번감정, 4번감정세기, 5번감정, 5번감정세기, 나이, 성별
Returns dict with wav_id, label, agreement, max_intensity, or None if no majority.
"""
if len(row) < 15:
return None
wav_id = row[0].strip()
# Extract 5 annotator emotions and intensities
annotator_emotions: list[str | None] = []
intensities: list[int] = []
for i in range(5):
emo_col = 3 + i * 2 # 3, 5, 7, 9, 11
int_col = 3 + i * 2 + 1 # 4, 6, 8, 10, 12
raw_emo = row[emo_col].strip() if emo_col < len(row) else ""
raw_int = row[int_col].strip() if int_col < len(row) else "0"
mapped = map_263_label(raw_emo)
annotator_emotions.append(mapped)
try:
intensities.append(int(raw_int))
except ValueError:
intensities.append(0)
label = majority_vote_263(annotator_emotions)
if label is None:
return None
agreement = sum(1 for e in annotator_emotions if e == label)
# Max intensity among annotators who voted for the majority label
max_intensity = max(
(intensities[i] for i, e in enumerate(annotator_emotions) if e == label),
default=0,
)
# 발화문 (text) — column 1
text = row[1].strip() if len(row) > 1 else ""
text = _clean_text(text)
return {
"wav_id": wav_id,
"text": text,
"label": label,
"agreement": agreement,
"max_intensity": max_intensity,
}
def extract_anchor_263(
anchor_dir: Path,
output_dir: Path,
cap_per_class: int = 2100,
) -> list[dict]:
"""Extract and preprocess AI Hub 263 dataset.
Parses 3 CSVs (cp949), majority-votes annotator labels,
priority-sorts (agreement desc, intensity desc), caps per class,
extracts wavs from ZIPs, preprocesses audio.
"""
anchor_dir = Path(anchor_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
csv_files = sorted(anchor_dir.glob("*.csv"))
zip_files = sorted(anchor_dir.glob("*.zip"))
log.info("263: Found %d CSVs, %d ZIPs", len(csv_files), len(zip_files))
# Step 1: Parse all CSVs
all_parsed: list[dict] = []
for csv_path in csv_files:
with open(csv_path, encoding="cp949", newline="") as f:
reader = csv.reader(f)
header = next(reader) # skip header
for row in reader:
result = parse_263_row(row)
if result is not None:
result["csv_source"] = csv_path.stem
all_parsed.append(result)
log.info("263: Parsed %d rows with majority vote", len(all_parsed))
# Step 2: Group by label, priority sort, cap
by_label: dict[str, list[dict]] = defaultdict(list)
for item in all_parsed:
by_label[item["label"]].append(item)
selected: list[dict] = []
for label, items in by_label.items():
# Sort by agreement desc, then intensity desc
items.sort(key=lambda x: (x["agreement"], x["max_intensity"]), reverse=True)
capped = items[:cap_per_class]
selected.extend(capped)
log.info("263: %s — %d available, %d selected", label, len(items), len(capped))
# Step 3: Build wav_id → zip lookup
wav_to_zip: dict[str, tuple[zipfile.ZipFile, str]] = {}
zip_handles = [zipfile.ZipFile(zp) for zp in zip_files]
for zf in zip_handles:
for name in zf.namelist():
if name.endswith(".wav"):
basename = Path(name).stem
wav_to_zip[basename] = (zf, name)
# Step 4: Extract and preprocess
samples: list[dict] = []
skipped = 0
for item in selected:
wav_id = item["wav_id"]
if wav_id not in wav_to_zip:
skipped += 1
continue
zf, zip_entry = wav_to_zip[wav_id]
out_path = output_dir / item["label"] / f"{wav_id}.wav"
try:
with zf.open(zip_entry) as src:
audio_bytes = src.read()
# Write to temp, then preprocess
tmp_path = output_dir / f"_tmp_{wav_id}.wav"
tmp_path.write_bytes(audio_bytes)
ok = preprocess_audio(tmp_path, out_path)
tmp_path.unlink(missing_ok=True)
if ok:
samples.append({
"path": str(out_path),
"label": item["label"],
"label_idx": LABEL2IDX[item["label"]],
"source": "263",
"speaker_id": f"263_{wav_id[:8]}",
"text": item.get("text", ""),
"agreement": item["agreement"],
"intensity": item["max_intensity"],
})
else:
skipped += 1
except Exception as e:
log.warning("263: Failed to process %s: %s", wav_id, e)
skipped += 1
# Close zip handles
for zf in zip_handles:
zf.close()
log.info("263: Extracted %d samples, skipped %d", len(samples), skipped)
return samples
# ---------------------------------------------------------------------------
# Task 3 — AI Hub 71631 Outdoor Extraction
# ---------------------------------------------------------------------------
def parse_71631_utterance(conv_entry: dict) -> dict | None:
"""Parse a conversation entry from 71631 JSON.
Filters by VerifyEmotionLevel (보통/강함 only, rejects 약함).
Returns dict with label, intensity, start_time, end_time, speaker_no or None.
"""
level = conv_entry.get("VerifyEmotionLevel", "")
if level not in ("보통", "강함"):
return None
emotion = conv_entry.get("VerifyEmotionTarget", "")
label = map_71631_label(emotion)
if label is None:
return None
try:
start_time = float(conv_entry["StartTime"])
end_time = float(conv_entry["EndTime"])
except (KeyError, ValueError):
return None
if end_time <= start_time:
return None
# Text from conversation entry
text = _clean_text(conv_entry.get("Text", ""))
return {
"label": label,
"intensity": level,
"start_time": start_time,
"end_time": end_time,
"speaker_no": conv_entry.get("SpeakerNo", ""),
"text": text,
}
def extract_booster_71631(
label_zip: Path,
audio_zip: Path,
output_dir: Path,
cap: int = 3500,
max_per_speaker: int = 20,
) -> list[dict]:
"""Extract and preprocess AI Hub 71631 outdoor dataset.
Parses label ZIP JSONs, filters intensity, applies speaker cap,
slices wav segments, resamples to 16kHz, RMS-filters neutral.
"""
label_zip = Path(label_zip)
audio_zip = Path(audio_zip)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
# Step 1: Parse all label JSONs
all_utterances: list[dict] = []
with zipfile.ZipFile(label_zip) as lzf:
json_files = [n for n in lzf.namelist() if n.endswith(".json")]
log.info("71631: Found %d label JSONs", len(json_files))
for jf in json_files:
try:
with lzf.open(jf) as f:
data = json.load(f)
except Exception as e:
log.warning("71631: Failed to parse %s: %s", jf, e)
continue
filename = data.get("File", {}).get("FileName", "")
conv_id = filename # Use filename as conversation_id
spk1_id = data.get("Speaker1", {}).get("ID", "")
spk2_id = data.get("Speaker2", {}).get("ID", "")
for entry in data.get("Conversation", []):
parsed = parse_71631_utterance(entry)
if parsed is None:
continue
# Determine speaker ID
spk_no = parsed["speaker_no"]
if spk_no == "Speaker1":
spk_id = spk1_id
elif spk_no == "Speaker2":
spk_id = spk2_id
else:
spk_id = spk_no
parsed["conversation_id"] = conv_id
parsed["speaker_id"] = f"71631_{spk_id}"
parsed["wav_filename"] = filename
parsed["text_no"] = entry.get("TextNo", "")
all_utterances.append(parsed)
log.info("71631: Parsed %d utterances (보통/강함)", len(all_utterances))
# Step 2: Speaker cap
speaker_counts: Counter = Counter()
speaker_capped: list[dict] = []
# Priority: 강함 first, then 보통
all_utterances.sort(key=lambda x: (0 if x["intensity"] == "강함" else 1))
for utt in all_utterances:
spk = utt["speaker_id"]
if speaker_counts[spk] < max_per_speaker:
speaker_capped.append(utt)
speaker_counts[spk] += 1
log.info("71631: After speaker cap (%d/spk): %d utterances", max_per_speaker, len(speaker_capped))
# Step 3: Group by label, cap per class
by_label: dict[str, list[dict]] = defaultdict(list)
for utt in speaker_capped:
by_label[utt["label"]].append(utt)
selected: list[dict] = []
for label, items in by_label.items():
# Priority: 강함 first (already sorted)
capped = items[:cap]
selected.extend(capped)
log.info("71631: %s — %d available, %d selected", label, len(items), len(capped))
# Step 4: Group by wav filename for efficient audio loading
by_wav: dict[str, list[dict]] = defaultdict(list)
for utt in selected:
by_wav[utt["wav_filename"]].append(utt)
# Step 5: Extract audio segments
samples: list[dict] = []
skipped = 0
with zipfile.ZipFile(audio_zip) as azf:
wav_lookup: dict[str, str] = {}
for name in azf.namelist():
if name.endswith(".wav"):
stem = Path(name).stem
wav_lookup[stem] = name
for wav_filename, utterances in by_wav.items():
if wav_filename not in wav_lookup:
log.warning("71631: WAV not found in zip: %s", wav_filename)
skipped += len(utterances)
continue
zip_entry = wav_lookup[wav_filename]
try:
with azf.open(zip_entry) as src:
audio_bytes = src.read()
buf = io.BytesIO(audio_bytes)
waveform, sr = torchaudio.load(buf)
except Exception as e:
log.warning("71631: Failed to load %s: %s", wav_filename, e)
skipped += len(utterances)
continue
# Mono
if waveform.shape[0] > 1:
waveform = waveform.mean(dim=0, keepdim=True)
for utt in utterances:
start_sample = int(utt["start_time"] * sr)
end_sample = int(utt["end_time"] * sr)
if end_sample > waveform.shape[1]:
end_sample = waveform.shape[1]
if start_sample >= end_sample:
skipped += 1
continue
segment = waveform[:, start_sample:end_sample]
out_name = f"{wav_filename}_{utt['text_no']}.wav"
out_path = output_dir / utt["label"] / out_name
ok = preprocess_audio_from_tensor(segment, sr, out_path)
if ok:
samples.append({
"path": str(out_path),
"label": utt["label"],
"label_idx": LABEL2IDX[utt["label"]],
"source": "71631",
"speaker_id": utt["speaker_id"],
"text": utt.get("text", ""),
"conversation_id": utt["conversation_id"],
"intensity": utt["intensity"],
})
else:
skipped += 1
log.info("71631: Extracted %d samples, skipped %d", len(samples), skipped)
return samples
# ---------------------------------------------------------------------------
# Task 4 — RAVDESS Extraction
# ---------------------------------------------------------------------------
def extract_ravdess(ravdess_dir: Path, output_dir: Path) -> list[dict]:
"""Extract and preprocess RAVDESS dataset from manifest.csv."""
ravdess_dir = Path(ravdess_dir)
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
manifest_path = ravdess_dir / "manifest.csv"
if not manifest_path.exists():
log.error("RAVDESS manifest not found: %s", manifest_path)
return []
samples: list[dict] = []
skipped = 0
with open(manifest_path) as f:
reader = csv.DictReader(f)
for row in reader:
clean_path = Path(row["clean_path"])
emotion_raw = row.get("emotion", "")
label = map_ravdess_label(emotion_raw)
if label is None:
skipped += 1
continue
actor_id = int(row["actor_id"])
out_name = clean_path.name
out_path = output_dir / label / out_name
if not clean_path.exists():
skipped += 1
continue
ok = preprocess_audio(clean_path, out_path)
if ok:
samples.append({
"path": str(out_path),
"label": label,
"label_idx": LABEL2IDX[label],
"source": "ravdess",
"actor_id": actor_id,
"speaker_id": f"ravdess_{actor_id}",
"text": "", # RAVDESS uses fixed sentences, not useful for text emotion
})
else:
skipped += 1
log.info("RAVDESS: Extracted %d samples, skipped %d", len(samples), skipped)
return samples
# ---------------------------------------------------------------------------
# Task 4 — Train/Val Splits
# ---------------------------------------------------------------------------
def speaker_isolated_split(
samples: list[dict], val_ratio: float = 0.1
) -> tuple[list[dict], list[dict]]:
"""Split by speaker_id — no leakage between train/val."""
if not samples:
return [], []
# Group by speaker
by_speaker: dict[str, list[dict]] = defaultdict(list)
for s in samples:
by_speaker[s["speaker_id"]].append(s)
speakers = list(by_speaker.keys())
random.shuffle(speakers)
total = len(samples)
target_val = int(total * val_ratio)
val_samples: list[dict] = []
val_speakers: set[str] = set()
for spk in speakers:
if len(val_samples) >= target_val:
break
val_samples.extend(by_speaker[spk])
val_speakers.add(spk)
train_samples = [s for s in samples if s["speaker_id"] not in val_speakers]
return train_samples, val_samples
def conversation_isolated_split(
samples: list[dict], val_ratio: float = 0.1
) -> tuple[list[dict], list[dict]]:
"""Split by conversation_id — no leakage between train/val."""
if not samples:
return [], []
by_conv: dict[str, list[dict]] = defaultdict(list)
for s in samples:
by_conv[s["conversation_id"]].append(s)
convs = list(by_conv.keys())
random.shuffle(convs)
total = len(samples)
target_val = int(total * val_ratio)
val_samples: list[dict] = []
val_convs: set[str] = set()
for conv in convs:
if len(val_samples) >= target_val:
break
val_samples.extend(by_conv[conv])
val_convs.add(conv)
train_samples = [s for s in samples if s["conversation_id"] not in val_convs]
return train_samples, val_samples
def actor_split_ravdess(
samples: list[dict], val_actors: list[int]
) -> tuple[list[dict], list[dict]]:
"""Split RAVDESS by actor — specified actors go to val."""
val_set = set(val_actors)
train = [s for s in samples if s["actor_id"] not in val_set]
val = [s for s in samples if s["actor_id"] in val_set]
return train, val
# ---------------------------------------------------------------------------
# Task 4 — Manifest & Stats
# ---------------------------------------------------------------------------
def save_manifest(samples: list[dict], path: Path) -> None:
"""Save manifest as JSON Lines."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
with open(path, "w") as f:
json.dump(samples, f, indent=2, ensure_ascii=False)
log.info("Saved manifest: %s (%d samples)", path, len(samples))
def save_stats(train: list[dict], val: list[dict], path: Path) -> None:
"""Save dataset statistics."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
def _count_stats(samples: list[dict]) -> dict:
by_label: Counter = Counter()
by_source: Counter = Counter()
for s in samples:
by_label[s["label"]] += 1
by_source[s["source"]] += 1
return {
"total": len(samples),
"by_label": dict(sorted(by_label.items())),
"by_source": dict(sorted(by_source.items())),
}
stats = {
"train": _count_stats(train),
"val": _count_stats(val),
"label2idx": LABEL2IDX,
}
with open(path, "w") as f:
json.dump(stats, f, indent=2, ensure_ascii=False)
log.info("Saved stats: %s", path)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
parser = argparse.ArgumentParser(
description="Prepare unified LoRA 7-class dataset"
)
parser.add_argument(
"--anchor-dir",
type=Path,
default=Path("data/AI Hub 263"),
help="Path to AI Hub 263 directory with CSVs + ZIPs",
)
parser.add_argument(
"--booster-label-zip",
type=Path,
default=Path(
"data/AI Hub 71631/01-1.정식개방데이터/Training/"
"02.라벨링데이터/TL_02.실외.zip"
),
help="Path to 71631 label ZIP",
)
parser.add_argument(
"--booster-audio-zip",
type=Path,
default=Path(
"data/AI Hub 71631/01-1.정식개방데이터/Training/"
"01.원천데이터/TS_02.실외.zip"
),
help="Path to 71631 audio ZIP",
)
parser.add_argument(
"--ravdess-dir",
type=Path,
default=Path("data/ravdess"),
help="Path to RAVDESS directory with manifest.csv",
)
parser.add_argument(
"--output-dir",
type=Path,
default=Path("data/lora_7class"),
help="Output directory for processed dataset",
)
parser.add_argument("--cap-263", type=int, default=2100, help="Cap per class for 263")
parser.add_argument("--cap-71631", type=int, default=3500, help="Cap per class for 71631")
parser.add_argument("--max-per-speaker-71631", type=int, default=20, help="Max utterances per speaker for 71631")
parser.add_argument("--val-ratio", type=float, default=0.1, help="Val ratio for splits")
parser.add_argument("--seed", type=int, default=42, help="Random seed")
parser.add_argument(
"--skip-263", action="store_true", help="Skip AI Hub 263 extraction"
)
parser.add_argument(
"--skip-71631", action="store_true", help="Skip AI Hub 71631 extraction"
)
parser.add_argument(
"--skip-ravdess", action="store_true", help="Skip RAVDESS extraction"
)
args = parser.parse_args()
random.seed(args.seed)
output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)
all_train: list[dict] = []
all_val: list[dict] = []
# ---- 263 Anchor ----
if not args.skip_263:
log.info("=" * 60)
log.info("Extracting AI Hub 263 (anchor)")
samples_263 = extract_anchor_263(
args.anchor_dir,
output_dir / "263",
cap_per_class=args.cap_263,
)
train_263, val_263 = speaker_isolated_split(samples_263, args.val_ratio)
log.info("263: train=%d, val=%d", len(train_263), len(val_263))
all_train.extend(train_263)
all_val.extend(val_263)
# ---- 71631 Booster ----
if not args.skip_71631:
log.info("=" * 60)
log.info("Extracting AI Hub 71631 (booster)")
samples_71631 = extract_booster_71631(
args.booster_label_zip,
args.booster_audio_zip,
output_dir / "71631",
cap=args.cap_71631,
max_per_speaker=args.max_per_speaker_71631,
)
train_71631, val_71631 = conversation_isolated_split(
samples_71631, args.val_ratio
)
log.info("71631: train=%d, val=%d", len(train_71631), len(val_71631))
all_train.extend(train_71631)
all_val.extend(val_71631)
# ---- RAVDESS ----
if not args.skip_ravdess:
log.info("=" * 60)
log.info("Extracting RAVDESS")
samples_ravdess = extract_ravdess(
args.ravdess_dir,
output_dir / "ravdess",
)
val_actors = [21, 22, 23, 24]
train_ravdess, val_ravdess = actor_split_ravdess(
samples_ravdess, val_actors
)
log.info("RAVDESS: train=%d, val=%d", len(train_ravdess), len(val_ravdess))
all_train.extend(train_ravdess)
all_val.extend(val_ravdess)
# ---- Save ----
log.info("=" * 60)
log.info("Total: train=%d, val=%d", len(all_train), len(all_val))
save_manifest(all_train, output_dir / "train_manifest.json")
save_manifest(all_val, output_dir / "val_manifest.json")
save_stats(all_train, all_val, output_dir / "stats.json")
log.info("Done! Output: %s", output_dir)
if __name__ == "__main__":
main()