| |
| """Unified data preprocessing pipeline for LoRA emotion2vec 7-class fine-tuning. |
| |
| Extracts, preprocesses, and merges samples from three sources: |
| 1. AI Hub 263 — anchor (acted Korean, 7-class) |
| 2. AI Hub 71631 — booster (outdoor spontaneous Korean, mapped to 7-class) |
| 3. RAVDESS — English (acted, 7-class) |
| |
| Outputs a unified manifest (train/val JSONs) ready for LoRA training. |
| |
| Usage: |
| python scripts/prepare_lora_dataset.py \ |
| --anchor-dir "data/AI Hub 263" \ |
| --booster-label-zip "data/AI Hub 71631/01-1.정식개방데이터/Training/02.라벨링데이터/TL_02.실외.zip" \ |
| --booster-audio-zip "data/AI Hub 71631/01-1.정식개방데이터/Training/01.원천데이터/TS_02.실외.zip" \ |
| --ravdess-dir data/ravdess \ |
| --output-dir data/lora_7class |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import io |
| import json |
| import logging |
| import random |
| import zipfile |
| from collections import Counter, defaultdict |
| from pathlib import Path |
|
|
| import torch |
| import torchaudio |
|
|
| logging.basicConfig( |
| level=logging.INFO, |
| format="%(asctime)s %(levelname)-8s %(message)s", |
| datefmt="%H:%M:%S", |
| ) |
| log = logging.getLogger(__name__) |
|
|
| |
| |
| |
| LABEL2IDX: dict[str, int] = { |
| "happiness": 0, |
| "anger": 1, |
| "disgust": 2, |
| "fear": 3, |
| "neutral": 4, |
| "sadness": 5, |
| "surprise": 6, |
| } |
|
|
| VALID_LABELS = set(LABEL2IDX.keys()) |
|
|
| TARGET_SR = 16_000 |
| RMS_THRESHOLD = 0.001 |
|
|
| |
| |
| |
|
|
| _MAP_263: dict[str, str] = { |
| "angry": "anger", |
| "happiness": "happiness", |
| "neutral": "neutral", |
| "sadness": "sadness", |
| "surprise": "surprise", |
| "fear": "fear", |
| "disgust": "disgust", |
| } |
|
|
|
|
| def map_263_label(raw: str) -> str | None: |
| """Map AI Hub 263 annotator label to 7-class. Case-insensitive.""" |
| if not raw: |
| return None |
| return _MAP_263.get(raw.strip().lower()) |
|
|
|
|
| _MAP_71631: dict[str, str] = { |
| "기쁨": "happiness", |
| "화남": "anger", |
| "놀라움": "surprise", |
| "슬픔": "sadness", |
| "두려움": "fear", |
| "없음": "neutral", |
| "중립": "neutral", |
| } |
|
|
|
|
| def map_71631_label(raw: str) -> str | None: |
| """Map AI Hub 71631 Korean emotion label to 7-class.""" |
| if not raw: |
| return None |
| return _MAP_71631.get(raw.strip()) |
|
|
|
|
| def map_ravdess_label(raw: str) -> str | None: |
| """Map RAVDESS label to 7-class. Passthrough except joy→happiness.""" |
| if not raw: |
| return None |
| lbl = raw.strip().lower() |
| if lbl == "joy": |
| return "happiness" |
| if lbl in VALID_LABELS: |
| return lbl |
| return None |
|
|
|
|
| def majority_vote_263(emotions: list[str | None]) -> str | None: |
| """Return majority label (3/5+) or None for no majority/tie.""" |
| valid = [e for e in emotions if e is not None] |
| if not valid: |
| return None |
| counts = Counter(valid) |
| top_label, top_count = counts.most_common(1)[0] |
| if top_count < 3: |
| return None |
| |
| tied = [lbl for lbl, c in counts.items() if c == top_count] |
| if len(tied) > 1: |
| return None |
| return top_label |
|
|
|
|
| |
| |
| |
|
|
|
|
| import re |
|
|
| def _clean_text(text: str) -> str: |
| """Clean text for STT-friendly format. |
| |
| Removes non-verbal tags like (웃음), (한숨), keeps ?, !, ... |
| """ |
| |
| text = re.sub(r"[(\[(][^)\])]*[)\])]", "", text) |
| |
| text = re.sub(r"\s+", " ", text).strip() |
| return text |
|
|
|
|
| def _compute_rms(waveform: torch.Tensor) -> float: |
| """Compute RMS of a waveform tensor.""" |
| return float(torch.sqrt(torch.mean(waveform.float() ** 2))) |
|
|
|
|
| def preprocess_audio(input_path: Path, output_path: Path) -> bool: |
| """Resample to 16kHz mono, trim silence, reject if RMS < threshold. |
| |
| Returns True if file was saved, False if rejected. |
| """ |
| input_path = Path(input_path) |
| output_path = Path(output_path) |
|
|
| waveform, sr = torchaudio.load(str(input_path)) |
|
|
| |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| |
| if sr != TARGET_SR: |
| waveform = torchaudio.functional.resample(waveform, sr, TARGET_SR) |
|
|
| |
| waveform_trimmed = torchaudio.functional.vad(waveform, TARGET_SR) |
| if waveform_trimmed.numel() > 0: |
| waveform = waveform_trimmed |
|
|
| |
| rms = _compute_rms(waveform) |
| if rms < RMS_THRESHOLD: |
| return False |
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| torchaudio.save(str(output_path), waveform, TARGET_SR) |
| return True |
|
|
|
|
| def preprocess_audio_from_tensor( |
| waveform: torch.Tensor, sr: int, output_path: Path |
| ) -> bool: |
| """Preprocess an in-memory waveform tensor and save to output_path. |
| |
| Used for 71631 where we slice in memory from the full conversation wav. |
| """ |
| output_path = Path(output_path) |
|
|
| |
| if waveform.dim() > 1 and waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
| elif waveform.dim() == 1: |
| waveform = waveform.unsqueeze(0) |
|
|
| |
| if sr != TARGET_SR: |
| waveform = torchaudio.functional.resample(waveform, sr, TARGET_SR) |
|
|
| |
| waveform_trimmed = torchaudio.functional.vad(waveform, TARGET_SR) |
| if waveform_trimmed.numel() > 0: |
| waveform = waveform_trimmed |
|
|
| |
| rms = _compute_rms(waveform) |
| if rms < RMS_THRESHOLD: |
| return False |
|
|
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| torchaudio.save(str(output_path), waveform, TARGET_SR) |
| return True |
|
|
|
|
| |
| |
| |
|
|
|
|
| def parse_263_row(row: list[str]) -> dict | None: |
| """Parse a CSV row from AI Hub 263. |
| |
| Columns: wav_id, 발화문, 상황, 1번감정, 1번감정세기, 2번감정, 2번감정세기, |
| 3번감정, 3번감정세기, 4번감정, 4번감정세기, 5번감정, 5번감정세기, 나이, 성별 |
| |
| Returns dict with wav_id, label, agreement, max_intensity, or None if no majority. |
| """ |
| if len(row) < 15: |
| return None |
|
|
| wav_id = row[0].strip() |
| |
| annotator_emotions: list[str | None] = [] |
| intensities: list[int] = [] |
| for i in range(5): |
| emo_col = 3 + i * 2 |
| int_col = 3 + i * 2 + 1 |
| raw_emo = row[emo_col].strip() if emo_col < len(row) else "" |
| raw_int = row[int_col].strip() if int_col < len(row) else "0" |
| mapped = map_263_label(raw_emo) |
| annotator_emotions.append(mapped) |
| try: |
| intensities.append(int(raw_int)) |
| except ValueError: |
| intensities.append(0) |
|
|
| label = majority_vote_263(annotator_emotions) |
| if label is None: |
| return None |
|
|
| agreement = sum(1 for e in annotator_emotions if e == label) |
| |
| max_intensity = max( |
| (intensities[i] for i, e in enumerate(annotator_emotions) if e == label), |
| default=0, |
| ) |
|
|
| |
| text = row[1].strip() if len(row) > 1 else "" |
| text = _clean_text(text) |
|
|
| return { |
| "wav_id": wav_id, |
| "text": text, |
| "label": label, |
| "agreement": agreement, |
| "max_intensity": max_intensity, |
| } |
|
|
|
|
| def extract_anchor_263( |
| anchor_dir: Path, |
| output_dir: Path, |
| cap_per_class: int = 2100, |
| ) -> list[dict]: |
| """Extract and preprocess AI Hub 263 dataset. |
| |
| Parses 3 CSVs (cp949), majority-votes annotator labels, |
| priority-sorts (agreement desc, intensity desc), caps per class, |
| extracts wavs from ZIPs, preprocesses audio. |
| """ |
| anchor_dir = Path(anchor_dir) |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| csv_files = sorted(anchor_dir.glob("*.csv")) |
| zip_files = sorted(anchor_dir.glob("*.zip")) |
|
|
| log.info("263: Found %d CSVs, %d ZIPs", len(csv_files), len(zip_files)) |
|
|
| |
| all_parsed: list[dict] = [] |
| for csv_path in csv_files: |
| with open(csv_path, encoding="cp949", newline="") as f: |
| reader = csv.reader(f) |
| header = next(reader) |
| for row in reader: |
| result = parse_263_row(row) |
| if result is not None: |
| result["csv_source"] = csv_path.stem |
| all_parsed.append(result) |
|
|
| log.info("263: Parsed %d rows with majority vote", len(all_parsed)) |
|
|
| |
| by_label: dict[str, list[dict]] = defaultdict(list) |
| for item in all_parsed: |
| by_label[item["label"]].append(item) |
|
|
| selected: list[dict] = [] |
| for label, items in by_label.items(): |
| |
| items.sort(key=lambda x: (x["agreement"], x["max_intensity"]), reverse=True) |
| capped = items[:cap_per_class] |
| selected.extend(capped) |
| log.info("263: %s — %d available, %d selected", label, len(items), len(capped)) |
|
|
| |
| wav_to_zip: dict[str, tuple[zipfile.ZipFile, str]] = {} |
| zip_handles = [zipfile.ZipFile(zp) for zp in zip_files] |
| for zf in zip_handles: |
| for name in zf.namelist(): |
| if name.endswith(".wav"): |
| basename = Path(name).stem |
| wav_to_zip[basename] = (zf, name) |
|
|
| |
| samples: list[dict] = [] |
| skipped = 0 |
| for item in selected: |
| wav_id = item["wav_id"] |
| if wav_id not in wav_to_zip: |
| skipped += 1 |
| continue |
|
|
| zf, zip_entry = wav_to_zip[wav_id] |
| out_path = output_dir / item["label"] / f"{wav_id}.wav" |
|
|
| try: |
| with zf.open(zip_entry) as src: |
| audio_bytes = src.read() |
|
|
| |
| tmp_path = output_dir / f"_tmp_{wav_id}.wav" |
| tmp_path.write_bytes(audio_bytes) |
|
|
| ok = preprocess_audio(tmp_path, out_path) |
| tmp_path.unlink(missing_ok=True) |
|
|
| if ok: |
| samples.append({ |
| "path": str(out_path), |
| "label": item["label"], |
| "label_idx": LABEL2IDX[item["label"]], |
| "source": "263", |
| "speaker_id": f"263_{wav_id[:8]}", |
| "text": item.get("text", ""), |
| "agreement": item["agreement"], |
| "intensity": item["max_intensity"], |
| }) |
| else: |
| skipped += 1 |
| except Exception as e: |
| log.warning("263: Failed to process %s: %s", wav_id, e) |
| skipped += 1 |
|
|
| |
| for zf in zip_handles: |
| zf.close() |
|
|
| log.info("263: Extracted %d samples, skipped %d", len(samples), skipped) |
| return samples |
|
|
|
|
| |
| |
| |
|
|
|
|
| def parse_71631_utterance(conv_entry: dict) -> dict | None: |
| """Parse a conversation entry from 71631 JSON. |
| |
| Filters by VerifyEmotionLevel (보통/강함 only, rejects 약함). |
| Returns dict with label, intensity, start_time, end_time, speaker_no or None. |
| """ |
| level = conv_entry.get("VerifyEmotionLevel", "") |
| if level not in ("보통", "강함"): |
| return None |
|
|
| emotion = conv_entry.get("VerifyEmotionTarget", "") |
| label = map_71631_label(emotion) |
| if label is None: |
| return None |
|
|
| try: |
| start_time = float(conv_entry["StartTime"]) |
| end_time = float(conv_entry["EndTime"]) |
| except (KeyError, ValueError): |
| return None |
|
|
| if end_time <= start_time: |
| return None |
|
|
| |
| text = _clean_text(conv_entry.get("Text", "")) |
|
|
| return { |
| "label": label, |
| "intensity": level, |
| "start_time": start_time, |
| "end_time": end_time, |
| "speaker_no": conv_entry.get("SpeakerNo", ""), |
| "text": text, |
| } |
|
|
|
|
| def extract_booster_71631( |
| label_zip: Path, |
| audio_zip: Path, |
| output_dir: Path, |
| cap: int = 3500, |
| max_per_speaker: int = 20, |
| ) -> list[dict]: |
| """Extract and preprocess AI Hub 71631 outdoor dataset. |
| |
| Parses label ZIP JSONs, filters intensity, applies speaker cap, |
| slices wav segments, resamples to 16kHz, RMS-filters neutral. |
| """ |
| label_zip = Path(label_zip) |
| audio_zip = Path(audio_zip) |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| all_utterances: list[dict] = [] |
| with zipfile.ZipFile(label_zip) as lzf: |
| json_files = [n for n in lzf.namelist() if n.endswith(".json")] |
| log.info("71631: Found %d label JSONs", len(json_files)) |
|
|
| for jf in json_files: |
| try: |
| with lzf.open(jf) as f: |
| data = json.load(f) |
| except Exception as e: |
| log.warning("71631: Failed to parse %s: %s", jf, e) |
| continue |
|
|
| filename = data.get("File", {}).get("FileName", "") |
| conv_id = filename |
| spk1_id = data.get("Speaker1", {}).get("ID", "") |
| spk2_id = data.get("Speaker2", {}).get("ID", "") |
|
|
| for entry in data.get("Conversation", []): |
| parsed = parse_71631_utterance(entry) |
| if parsed is None: |
| continue |
| |
| spk_no = parsed["speaker_no"] |
| if spk_no == "Speaker1": |
| spk_id = spk1_id |
| elif spk_no == "Speaker2": |
| spk_id = spk2_id |
| else: |
| spk_id = spk_no |
|
|
| parsed["conversation_id"] = conv_id |
| parsed["speaker_id"] = f"71631_{spk_id}" |
| parsed["wav_filename"] = filename |
| parsed["text_no"] = entry.get("TextNo", "") |
| all_utterances.append(parsed) |
|
|
| log.info("71631: Parsed %d utterances (보통/강함)", len(all_utterances)) |
|
|
| |
| speaker_counts: Counter = Counter() |
| speaker_capped: list[dict] = [] |
| |
| all_utterances.sort(key=lambda x: (0 if x["intensity"] == "강함" else 1)) |
| for utt in all_utterances: |
| spk = utt["speaker_id"] |
| if speaker_counts[spk] < max_per_speaker: |
| speaker_capped.append(utt) |
| speaker_counts[spk] += 1 |
|
|
| log.info("71631: After speaker cap (%d/spk): %d utterances", max_per_speaker, len(speaker_capped)) |
|
|
| |
| by_label: dict[str, list[dict]] = defaultdict(list) |
| for utt in speaker_capped: |
| by_label[utt["label"]].append(utt) |
|
|
| selected: list[dict] = [] |
| for label, items in by_label.items(): |
| |
| capped = items[:cap] |
| selected.extend(capped) |
| log.info("71631: %s — %d available, %d selected", label, len(items), len(capped)) |
|
|
| |
| by_wav: dict[str, list[dict]] = defaultdict(list) |
| for utt in selected: |
| by_wav[utt["wav_filename"]].append(utt) |
|
|
| |
| samples: list[dict] = [] |
| skipped = 0 |
|
|
| with zipfile.ZipFile(audio_zip) as azf: |
| wav_lookup: dict[str, str] = {} |
| for name in azf.namelist(): |
| if name.endswith(".wav"): |
| stem = Path(name).stem |
| wav_lookup[stem] = name |
|
|
| for wav_filename, utterances in by_wav.items(): |
| if wav_filename not in wav_lookup: |
| log.warning("71631: WAV not found in zip: %s", wav_filename) |
| skipped += len(utterances) |
| continue |
|
|
| zip_entry = wav_lookup[wav_filename] |
| try: |
| with azf.open(zip_entry) as src: |
| audio_bytes = src.read() |
|
|
| buf = io.BytesIO(audio_bytes) |
| waveform, sr = torchaudio.load(buf) |
| except Exception as e: |
| log.warning("71631: Failed to load %s: %s", wav_filename, e) |
| skipped += len(utterances) |
| continue |
|
|
| |
| if waveform.shape[0] > 1: |
| waveform = waveform.mean(dim=0, keepdim=True) |
|
|
| for utt in utterances: |
| start_sample = int(utt["start_time"] * sr) |
| end_sample = int(utt["end_time"] * sr) |
|
|
| if end_sample > waveform.shape[1]: |
| end_sample = waveform.shape[1] |
| if start_sample >= end_sample: |
| skipped += 1 |
| continue |
|
|
| segment = waveform[:, start_sample:end_sample] |
|
|
| out_name = f"{wav_filename}_{utt['text_no']}.wav" |
| out_path = output_dir / utt["label"] / out_name |
|
|
| ok = preprocess_audio_from_tensor(segment, sr, out_path) |
| if ok: |
| samples.append({ |
| "path": str(out_path), |
| "label": utt["label"], |
| "label_idx": LABEL2IDX[utt["label"]], |
| "source": "71631", |
| "speaker_id": utt["speaker_id"], |
| "text": utt.get("text", ""), |
| "conversation_id": utt["conversation_id"], |
| "intensity": utt["intensity"], |
| }) |
| else: |
| skipped += 1 |
|
|
| log.info("71631: Extracted %d samples, skipped %d", len(samples), skipped) |
| return samples |
|
|
|
|
| |
| |
| |
|
|
|
|
| def extract_ravdess(ravdess_dir: Path, output_dir: Path) -> list[dict]: |
| """Extract and preprocess RAVDESS dataset from manifest.csv.""" |
| ravdess_dir = Path(ravdess_dir) |
| output_dir = Path(output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| manifest_path = ravdess_dir / "manifest.csv" |
| if not manifest_path.exists(): |
| log.error("RAVDESS manifest not found: %s", manifest_path) |
| return [] |
|
|
| samples: list[dict] = [] |
| skipped = 0 |
|
|
| with open(manifest_path) as f: |
| reader = csv.DictReader(f) |
| for row in reader: |
| clean_path = Path(row["clean_path"]) |
| emotion_raw = row.get("emotion", "") |
| label = map_ravdess_label(emotion_raw) |
| if label is None: |
| skipped += 1 |
| continue |
|
|
| actor_id = int(row["actor_id"]) |
| out_name = clean_path.name |
| out_path = output_dir / label / out_name |
|
|
| if not clean_path.exists(): |
| skipped += 1 |
| continue |
|
|
| ok = preprocess_audio(clean_path, out_path) |
| if ok: |
| samples.append({ |
| "path": str(out_path), |
| "label": label, |
| "label_idx": LABEL2IDX[label], |
| "source": "ravdess", |
| "actor_id": actor_id, |
| "speaker_id": f"ravdess_{actor_id}", |
| "text": "", |
| }) |
| else: |
| skipped += 1 |
|
|
| log.info("RAVDESS: Extracted %d samples, skipped %d", len(samples), skipped) |
| return samples |
|
|
|
|
| |
| |
| |
|
|
|
|
| def speaker_isolated_split( |
| samples: list[dict], val_ratio: float = 0.1 |
| ) -> tuple[list[dict], list[dict]]: |
| """Split by speaker_id — no leakage between train/val.""" |
| if not samples: |
| return [], [] |
|
|
| |
| by_speaker: dict[str, list[dict]] = defaultdict(list) |
| for s in samples: |
| by_speaker[s["speaker_id"]].append(s) |
|
|
| speakers = list(by_speaker.keys()) |
| random.shuffle(speakers) |
|
|
| total = len(samples) |
| target_val = int(total * val_ratio) |
|
|
| val_samples: list[dict] = [] |
| val_speakers: set[str] = set() |
| for spk in speakers: |
| if len(val_samples) >= target_val: |
| break |
| val_samples.extend(by_speaker[spk]) |
| val_speakers.add(spk) |
|
|
| train_samples = [s for s in samples if s["speaker_id"] not in val_speakers] |
| return train_samples, val_samples |
|
|
|
|
| def conversation_isolated_split( |
| samples: list[dict], val_ratio: float = 0.1 |
| ) -> tuple[list[dict], list[dict]]: |
| """Split by conversation_id — no leakage between train/val.""" |
| if not samples: |
| return [], [] |
|
|
| by_conv: dict[str, list[dict]] = defaultdict(list) |
| for s in samples: |
| by_conv[s["conversation_id"]].append(s) |
|
|
| convs = list(by_conv.keys()) |
| random.shuffle(convs) |
|
|
| total = len(samples) |
| target_val = int(total * val_ratio) |
|
|
| val_samples: list[dict] = [] |
| val_convs: set[str] = set() |
| for conv in convs: |
| if len(val_samples) >= target_val: |
| break |
| val_samples.extend(by_conv[conv]) |
| val_convs.add(conv) |
|
|
| train_samples = [s for s in samples if s["conversation_id"] not in val_convs] |
| return train_samples, val_samples |
|
|
|
|
| def actor_split_ravdess( |
| samples: list[dict], val_actors: list[int] |
| ) -> tuple[list[dict], list[dict]]: |
| """Split RAVDESS by actor — specified actors go to val.""" |
| val_set = set(val_actors) |
| train = [s for s in samples if s["actor_id"] not in val_set] |
| val = [s for s in samples if s["actor_id"] in val_set] |
| return train, val |
|
|
|
|
| |
| |
| |
|
|
|
|
| def save_manifest(samples: list[dict], path: Path) -> None: |
| """Save manifest as JSON Lines.""" |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
| with open(path, "w") as f: |
| json.dump(samples, f, indent=2, ensure_ascii=False) |
| log.info("Saved manifest: %s (%d samples)", path, len(samples)) |
|
|
|
|
| def save_stats(train: list[dict], val: list[dict], path: Path) -> None: |
| """Save dataset statistics.""" |
| path = Path(path) |
| path.parent.mkdir(parents=True, exist_ok=True) |
|
|
| def _count_stats(samples: list[dict]) -> dict: |
| by_label: Counter = Counter() |
| by_source: Counter = Counter() |
| for s in samples: |
| by_label[s["label"]] += 1 |
| by_source[s["source"]] += 1 |
| return { |
| "total": len(samples), |
| "by_label": dict(sorted(by_label.items())), |
| "by_source": dict(sorted(by_source.items())), |
| } |
|
|
| stats = { |
| "train": _count_stats(train), |
| "val": _count_stats(val), |
| "label2idx": LABEL2IDX, |
| } |
|
|
| with open(path, "w") as f: |
| json.dump(stats, f, indent=2, ensure_ascii=False) |
| log.info("Saved stats: %s", path) |
|
|
|
|
| |
| |
| |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser( |
| description="Prepare unified LoRA 7-class dataset" |
| ) |
| parser.add_argument( |
| "--anchor-dir", |
| type=Path, |
| default=Path("data/AI Hub 263"), |
| help="Path to AI Hub 263 directory with CSVs + ZIPs", |
| ) |
| parser.add_argument( |
| "--booster-label-zip", |
| type=Path, |
| default=Path( |
| "data/AI Hub 71631/01-1.정식개방데이터/Training/" |
| "02.라벨링데이터/TL_02.실외.zip" |
| ), |
| help="Path to 71631 label ZIP", |
| ) |
| parser.add_argument( |
| "--booster-audio-zip", |
| type=Path, |
| default=Path( |
| "data/AI Hub 71631/01-1.정식개방데이터/Training/" |
| "01.원천데이터/TS_02.실외.zip" |
| ), |
| help="Path to 71631 audio ZIP", |
| ) |
| parser.add_argument( |
| "--ravdess-dir", |
| type=Path, |
| default=Path("data/ravdess"), |
| help="Path to RAVDESS directory with manifest.csv", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| default=Path("data/lora_7class"), |
| help="Output directory for processed dataset", |
| ) |
| parser.add_argument("--cap-263", type=int, default=2100, help="Cap per class for 263") |
| parser.add_argument("--cap-71631", type=int, default=3500, help="Cap per class for 71631") |
| parser.add_argument("--max-per-speaker-71631", type=int, default=20, help="Max utterances per speaker for 71631") |
| parser.add_argument("--val-ratio", type=float, default=0.1, help="Val ratio for splits") |
| parser.add_argument("--seed", type=int, default=42, help="Random seed") |
| parser.add_argument( |
| "--skip-263", action="store_true", help="Skip AI Hub 263 extraction" |
| ) |
| parser.add_argument( |
| "--skip-71631", action="store_true", help="Skip AI Hub 71631 extraction" |
| ) |
| parser.add_argument( |
| "--skip-ravdess", action="store_true", help="Skip RAVDESS extraction" |
| ) |
| args = parser.parse_args() |
|
|
| random.seed(args.seed) |
|
|
| output_dir = args.output_dir |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| all_train: list[dict] = [] |
| all_val: list[dict] = [] |
|
|
| |
| if not args.skip_263: |
| log.info("=" * 60) |
| log.info("Extracting AI Hub 263 (anchor)") |
| samples_263 = extract_anchor_263( |
| args.anchor_dir, |
| output_dir / "263", |
| cap_per_class=args.cap_263, |
| ) |
| train_263, val_263 = speaker_isolated_split(samples_263, args.val_ratio) |
| log.info("263: train=%d, val=%d", len(train_263), len(val_263)) |
| all_train.extend(train_263) |
| all_val.extend(val_263) |
|
|
| |
| if not args.skip_71631: |
| log.info("=" * 60) |
| log.info("Extracting AI Hub 71631 (booster)") |
| samples_71631 = extract_booster_71631( |
| args.booster_label_zip, |
| args.booster_audio_zip, |
| output_dir / "71631", |
| cap=args.cap_71631, |
| max_per_speaker=args.max_per_speaker_71631, |
| ) |
| train_71631, val_71631 = conversation_isolated_split( |
| samples_71631, args.val_ratio |
| ) |
| log.info("71631: train=%d, val=%d", len(train_71631), len(val_71631)) |
| all_train.extend(train_71631) |
| all_val.extend(val_71631) |
|
|
| |
| if not args.skip_ravdess: |
| log.info("=" * 60) |
| log.info("Extracting RAVDESS") |
| samples_ravdess = extract_ravdess( |
| args.ravdess_dir, |
| output_dir / "ravdess", |
| ) |
| val_actors = [21, 22, 23, 24] |
| train_ravdess, val_ravdess = actor_split_ravdess( |
| samples_ravdess, val_actors |
| ) |
| log.info("RAVDESS: train=%d, val=%d", len(train_ravdess), len(val_ravdess)) |
| all_train.extend(train_ravdess) |
| all_val.extend(val_ravdess) |
|
|
| |
| log.info("=" * 60) |
| log.info("Total: train=%d, val=%d", len(all_train), len(all_val)) |
|
|
| save_manifest(all_train, output_dir / "train_manifest.json") |
| save_manifest(all_val, output_dir / "val_manifest.json") |
| save_stats(all_train, all_val, output_dir / "stats.json") |
|
|
| log.info("Done! Output: %s", output_dir) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|