#!/usr/bin/env python3 """Unified data preprocessing pipeline for LoRA emotion2vec 7-class fine-tuning. Extracts, preprocesses, and merges samples from three sources: 1. AI Hub 263 — anchor (acted Korean, 7-class) 2. AI Hub 71631 — booster (outdoor spontaneous Korean, mapped to 7-class) 3. RAVDESS — English (acted, 7-class) Outputs a unified manifest (train/val JSONs) ready for LoRA training. Usage: python scripts/prepare_lora_dataset.py \ --anchor-dir "data/AI Hub 263" \ --booster-label-zip "data/AI Hub 71631/01-1.정식개방데이터/Training/02.라벨링데이터/TL_02.실외.zip" \ --booster-audio-zip "data/AI Hub 71631/01-1.정식개방데이터/Training/01.원천데이터/TS_02.실외.zip" \ --ravdess-dir data/ravdess \ --output-dir data/lora_7class """ from __future__ import annotations import argparse import csv import io import json import logging import random import zipfile from collections import Counter, defaultdict from pathlib import Path import torch import torchaudio logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s", datefmt="%H:%M:%S", ) log = logging.getLogger(__name__) # --------------------------------------------------------------------------- # Constants # --------------------------------------------------------------------------- LABEL2IDX: dict[str, int] = { "happiness": 0, "anger": 1, "disgust": 2, "fear": 3, "neutral": 4, "sadness": 5, "surprise": 6, } VALID_LABELS = set(LABEL2IDX.keys()) TARGET_SR = 16_000 RMS_THRESHOLD = 0.001 # 0.005→0.001: disgust 등 저에너지 발화 보존 (진짜 무음만 제거) # --------------------------------------------------------------------------- # Task 1 — Label Mappers # --------------------------------------------------------------------------- _MAP_263: dict[str, str] = { "angry": "anger", "happiness": "happiness", "neutral": "neutral", "sadness": "sadness", "surprise": "surprise", "fear": "fear", "disgust": "disgust", } def map_263_label(raw: str) -> str | None: """Map AI Hub 263 annotator label to 7-class. Case-insensitive.""" if not raw: return None return _MAP_263.get(raw.strip().lower()) _MAP_71631: dict[str, str] = { "기쁨": "happiness", "화남": "anger", "놀라움": "surprise", "슬픔": "sadness", "두려움": "fear", "없음": "neutral", "중립": "neutral", } def map_71631_label(raw: str) -> str | None: """Map AI Hub 71631 Korean emotion label to 7-class.""" if not raw: return None return _MAP_71631.get(raw.strip()) def map_ravdess_label(raw: str) -> str | None: """Map RAVDESS label to 7-class. Passthrough except joy→happiness.""" if not raw: return None lbl = raw.strip().lower() if lbl == "joy": return "happiness" if lbl in VALID_LABELS: return lbl return None def majority_vote_263(emotions: list[str | None]) -> str | None: """Return majority label (3/5+) or None for no majority/tie.""" valid = [e for e in emotions if e is not None] if not valid: return None counts = Counter(valid) top_label, top_count = counts.most_common(1)[0] if top_count < 3: return None # Check for tie at top count tied = [lbl for lbl, c in counts.items() if c == top_count] if len(tied) > 1: return None return top_label # --------------------------------------------------------------------------- # Task 1 — Audio Preprocessing # --------------------------------------------------------------------------- import re def _clean_text(text: str) -> str: """Clean text for STT-friendly format. Removes non-verbal tags like (웃음), (한숨), keeps ?, !, ... """ # Remove non-verbal tags: (웃음), (한숨), (침묵), [noise], etc. text = re.sub(r"[(\[(][^)\])]*[)\])]", "", text) # Remove trailing/leading whitespace, collapse multiple spaces text = re.sub(r"\s+", " ", text).strip() return text def _compute_rms(waveform: torch.Tensor) -> float: """Compute RMS of a waveform tensor.""" return float(torch.sqrt(torch.mean(waveform.float() ** 2))) def preprocess_audio(input_path: Path, output_path: Path) -> bool: """Resample to 16kHz mono, trim silence, reject if RMS < threshold. Returns True if file was saved, False if rejected. """ input_path = Path(input_path) output_path = Path(output_path) waveform, sr = torchaudio.load(str(input_path)) # Mono if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) # Resample if sr != TARGET_SR: waveform = torchaudio.functional.resample(waveform, sr, TARGET_SR) # Trim silence (leading/trailing) waveform_trimmed = torchaudio.functional.vad(waveform, TARGET_SR) if waveform_trimmed.numel() > 0: waveform = waveform_trimmed # RMS check rms = _compute_rms(waveform) if rms < RMS_THRESHOLD: return False output_path.parent.mkdir(parents=True, exist_ok=True) torchaudio.save(str(output_path), waveform, TARGET_SR) return True def preprocess_audio_from_tensor( waveform: torch.Tensor, sr: int, output_path: Path ) -> bool: """Preprocess an in-memory waveform tensor and save to output_path. Used for 71631 where we slice in memory from the full conversation wav. """ output_path = Path(output_path) # Mono if waveform.dim() > 1 and waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) elif waveform.dim() == 1: waveform = waveform.unsqueeze(0) # Resample if sr != TARGET_SR: waveform = torchaudio.functional.resample(waveform, sr, TARGET_SR) # Trim silence (leading/trailing) — same as preprocess_audio waveform_trimmed = torchaudio.functional.vad(waveform, TARGET_SR) if waveform_trimmed.numel() > 0: waveform = waveform_trimmed # RMS check rms = _compute_rms(waveform) if rms < RMS_THRESHOLD: return False output_path.parent.mkdir(parents=True, exist_ok=True) torchaudio.save(str(output_path), waveform, TARGET_SR) return True # --------------------------------------------------------------------------- # Task 2 — AI Hub 263 Extraction # --------------------------------------------------------------------------- def parse_263_row(row: list[str]) -> dict | None: """Parse a CSV row from AI Hub 263. Columns: wav_id, 발화문, 상황, 1번감정, 1번감정세기, 2번감정, 2번감정세기, 3번감정, 3번감정세기, 4번감정, 4번감정세기, 5번감정, 5번감정세기, 나이, 성별 Returns dict with wav_id, label, agreement, max_intensity, or None if no majority. """ if len(row) < 15: return None wav_id = row[0].strip() # Extract 5 annotator emotions and intensities annotator_emotions: list[str | None] = [] intensities: list[int] = [] for i in range(5): emo_col = 3 + i * 2 # 3, 5, 7, 9, 11 int_col = 3 + i * 2 + 1 # 4, 6, 8, 10, 12 raw_emo = row[emo_col].strip() if emo_col < len(row) else "" raw_int = row[int_col].strip() if int_col < len(row) else "0" mapped = map_263_label(raw_emo) annotator_emotions.append(mapped) try: intensities.append(int(raw_int)) except ValueError: intensities.append(0) label = majority_vote_263(annotator_emotions) if label is None: return None agreement = sum(1 for e in annotator_emotions if e == label) # Max intensity among annotators who voted for the majority label max_intensity = max( (intensities[i] for i, e in enumerate(annotator_emotions) if e == label), default=0, ) # 발화문 (text) — column 1 text = row[1].strip() if len(row) > 1 else "" text = _clean_text(text) return { "wav_id": wav_id, "text": text, "label": label, "agreement": agreement, "max_intensity": max_intensity, } def extract_anchor_263( anchor_dir: Path, output_dir: Path, cap_per_class: int = 2100, ) -> list[dict]: """Extract and preprocess AI Hub 263 dataset. Parses 3 CSVs (cp949), majority-votes annotator labels, priority-sorts (agreement desc, intensity desc), caps per class, extracts wavs from ZIPs, preprocesses audio. """ anchor_dir = Path(anchor_dir) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) csv_files = sorted(anchor_dir.glob("*.csv")) zip_files = sorted(anchor_dir.glob("*.zip")) log.info("263: Found %d CSVs, %d ZIPs", len(csv_files), len(zip_files)) # Step 1: Parse all CSVs all_parsed: list[dict] = [] for csv_path in csv_files: with open(csv_path, encoding="cp949", newline="") as f: reader = csv.reader(f) header = next(reader) # skip header for row in reader: result = parse_263_row(row) if result is not None: result["csv_source"] = csv_path.stem all_parsed.append(result) log.info("263: Parsed %d rows with majority vote", len(all_parsed)) # Step 2: Group by label, priority sort, cap by_label: dict[str, list[dict]] = defaultdict(list) for item in all_parsed: by_label[item["label"]].append(item) selected: list[dict] = [] for label, items in by_label.items(): # Sort by agreement desc, then intensity desc items.sort(key=lambda x: (x["agreement"], x["max_intensity"]), reverse=True) capped = items[:cap_per_class] selected.extend(capped) log.info("263: %s — %d available, %d selected", label, len(items), len(capped)) # Step 3: Build wav_id → zip lookup wav_to_zip: dict[str, tuple[zipfile.ZipFile, str]] = {} zip_handles = [zipfile.ZipFile(zp) for zp in zip_files] for zf in zip_handles: for name in zf.namelist(): if name.endswith(".wav"): basename = Path(name).stem wav_to_zip[basename] = (zf, name) # Step 4: Extract and preprocess samples: list[dict] = [] skipped = 0 for item in selected: wav_id = item["wav_id"] if wav_id not in wav_to_zip: skipped += 1 continue zf, zip_entry = wav_to_zip[wav_id] out_path = output_dir / item["label"] / f"{wav_id}.wav" try: with zf.open(zip_entry) as src: audio_bytes = src.read() # Write to temp, then preprocess tmp_path = output_dir / f"_tmp_{wav_id}.wav" tmp_path.write_bytes(audio_bytes) ok = preprocess_audio(tmp_path, out_path) tmp_path.unlink(missing_ok=True) if ok: samples.append({ "path": str(out_path), "label": item["label"], "label_idx": LABEL2IDX[item["label"]], "source": "263", "speaker_id": f"263_{wav_id[:8]}", "text": item.get("text", ""), "agreement": item["agreement"], "intensity": item["max_intensity"], }) else: skipped += 1 except Exception as e: log.warning("263: Failed to process %s: %s", wav_id, e) skipped += 1 # Close zip handles for zf in zip_handles: zf.close() log.info("263: Extracted %d samples, skipped %d", len(samples), skipped) return samples # --------------------------------------------------------------------------- # Task 3 — AI Hub 71631 Outdoor Extraction # --------------------------------------------------------------------------- def parse_71631_utterance(conv_entry: dict) -> dict | None: """Parse a conversation entry from 71631 JSON. Filters by VerifyEmotionLevel (보통/강함 only, rejects 약함). Returns dict with label, intensity, start_time, end_time, speaker_no or None. """ level = conv_entry.get("VerifyEmotionLevel", "") if level not in ("보통", "강함"): return None emotion = conv_entry.get("VerifyEmotionTarget", "") label = map_71631_label(emotion) if label is None: return None try: start_time = float(conv_entry["StartTime"]) end_time = float(conv_entry["EndTime"]) except (KeyError, ValueError): return None if end_time <= start_time: return None # Text from conversation entry text = _clean_text(conv_entry.get("Text", "")) return { "label": label, "intensity": level, "start_time": start_time, "end_time": end_time, "speaker_no": conv_entry.get("SpeakerNo", ""), "text": text, } def extract_booster_71631( label_zip: Path, audio_zip: Path, output_dir: Path, cap: int = 3500, max_per_speaker: int = 20, ) -> list[dict]: """Extract and preprocess AI Hub 71631 outdoor dataset. Parses label ZIP JSONs, filters intensity, applies speaker cap, slices wav segments, resamples to 16kHz, RMS-filters neutral. """ label_zip = Path(label_zip) audio_zip = Path(audio_zip) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) # Step 1: Parse all label JSONs all_utterances: list[dict] = [] with zipfile.ZipFile(label_zip) as lzf: json_files = [n for n in lzf.namelist() if n.endswith(".json")] log.info("71631: Found %d label JSONs", len(json_files)) for jf in json_files: try: with lzf.open(jf) as f: data = json.load(f) except Exception as e: log.warning("71631: Failed to parse %s: %s", jf, e) continue filename = data.get("File", {}).get("FileName", "") conv_id = filename # Use filename as conversation_id spk1_id = data.get("Speaker1", {}).get("ID", "") spk2_id = data.get("Speaker2", {}).get("ID", "") for entry in data.get("Conversation", []): parsed = parse_71631_utterance(entry) if parsed is None: continue # Determine speaker ID spk_no = parsed["speaker_no"] if spk_no == "Speaker1": spk_id = spk1_id elif spk_no == "Speaker2": spk_id = spk2_id else: spk_id = spk_no parsed["conversation_id"] = conv_id parsed["speaker_id"] = f"71631_{spk_id}" parsed["wav_filename"] = filename parsed["text_no"] = entry.get("TextNo", "") all_utterances.append(parsed) log.info("71631: Parsed %d utterances (보통/강함)", len(all_utterances)) # Step 2: Speaker cap speaker_counts: Counter = Counter() speaker_capped: list[dict] = [] # Priority: 강함 first, then 보통 all_utterances.sort(key=lambda x: (0 if x["intensity"] == "강함" else 1)) for utt in all_utterances: spk = utt["speaker_id"] if speaker_counts[spk] < max_per_speaker: speaker_capped.append(utt) speaker_counts[spk] += 1 log.info("71631: After speaker cap (%d/spk): %d utterances", max_per_speaker, len(speaker_capped)) # Step 3: Group by label, cap per class by_label: dict[str, list[dict]] = defaultdict(list) for utt in speaker_capped: by_label[utt["label"]].append(utt) selected: list[dict] = [] for label, items in by_label.items(): # Priority: 강함 first (already sorted) capped = items[:cap] selected.extend(capped) log.info("71631: %s — %d available, %d selected", label, len(items), len(capped)) # Step 4: Group by wav filename for efficient audio loading by_wav: dict[str, list[dict]] = defaultdict(list) for utt in selected: by_wav[utt["wav_filename"]].append(utt) # Step 5: Extract audio segments samples: list[dict] = [] skipped = 0 with zipfile.ZipFile(audio_zip) as azf: wav_lookup: dict[str, str] = {} for name in azf.namelist(): if name.endswith(".wav"): stem = Path(name).stem wav_lookup[stem] = name for wav_filename, utterances in by_wav.items(): if wav_filename not in wav_lookup: log.warning("71631: WAV not found in zip: %s", wav_filename) skipped += len(utterances) continue zip_entry = wav_lookup[wav_filename] try: with azf.open(zip_entry) as src: audio_bytes = src.read() buf = io.BytesIO(audio_bytes) waveform, sr = torchaudio.load(buf) except Exception as e: log.warning("71631: Failed to load %s: %s", wav_filename, e) skipped += len(utterances) continue # Mono if waveform.shape[0] > 1: waveform = waveform.mean(dim=0, keepdim=True) for utt in utterances: start_sample = int(utt["start_time"] * sr) end_sample = int(utt["end_time"] * sr) if end_sample > waveform.shape[1]: end_sample = waveform.shape[1] if start_sample >= end_sample: skipped += 1 continue segment = waveform[:, start_sample:end_sample] out_name = f"{wav_filename}_{utt['text_no']}.wav" out_path = output_dir / utt["label"] / out_name ok = preprocess_audio_from_tensor(segment, sr, out_path) if ok: samples.append({ "path": str(out_path), "label": utt["label"], "label_idx": LABEL2IDX[utt["label"]], "source": "71631", "speaker_id": utt["speaker_id"], "text": utt.get("text", ""), "conversation_id": utt["conversation_id"], "intensity": utt["intensity"], }) else: skipped += 1 log.info("71631: Extracted %d samples, skipped %d", len(samples), skipped) return samples # --------------------------------------------------------------------------- # Task 4 — RAVDESS Extraction # --------------------------------------------------------------------------- def extract_ravdess(ravdess_dir: Path, output_dir: Path) -> list[dict]: """Extract and preprocess RAVDESS dataset from manifest.csv.""" ravdess_dir = Path(ravdess_dir) output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) manifest_path = ravdess_dir / "manifest.csv" if not manifest_path.exists(): log.error("RAVDESS manifest not found: %s", manifest_path) return [] samples: list[dict] = [] skipped = 0 with open(manifest_path) as f: reader = csv.DictReader(f) for row in reader: clean_path = Path(row["clean_path"]) emotion_raw = row.get("emotion", "") label = map_ravdess_label(emotion_raw) if label is None: skipped += 1 continue actor_id = int(row["actor_id"]) out_name = clean_path.name out_path = output_dir / label / out_name if not clean_path.exists(): skipped += 1 continue ok = preprocess_audio(clean_path, out_path) if ok: samples.append({ "path": str(out_path), "label": label, "label_idx": LABEL2IDX[label], "source": "ravdess", "actor_id": actor_id, "speaker_id": f"ravdess_{actor_id}", "text": "", # RAVDESS uses fixed sentences, not useful for text emotion }) else: skipped += 1 log.info("RAVDESS: Extracted %d samples, skipped %d", len(samples), skipped) return samples # --------------------------------------------------------------------------- # Task 4 — Train/Val Splits # --------------------------------------------------------------------------- def speaker_isolated_split( samples: list[dict], val_ratio: float = 0.1 ) -> tuple[list[dict], list[dict]]: """Split by speaker_id — no leakage between train/val.""" if not samples: return [], [] # Group by speaker by_speaker: dict[str, list[dict]] = defaultdict(list) for s in samples: by_speaker[s["speaker_id"]].append(s) speakers = list(by_speaker.keys()) random.shuffle(speakers) total = len(samples) target_val = int(total * val_ratio) val_samples: list[dict] = [] val_speakers: set[str] = set() for spk in speakers: if len(val_samples) >= target_val: break val_samples.extend(by_speaker[spk]) val_speakers.add(spk) train_samples = [s for s in samples if s["speaker_id"] not in val_speakers] return train_samples, val_samples def conversation_isolated_split( samples: list[dict], val_ratio: float = 0.1 ) -> tuple[list[dict], list[dict]]: """Split by conversation_id — no leakage between train/val.""" if not samples: return [], [] by_conv: dict[str, list[dict]] = defaultdict(list) for s in samples: by_conv[s["conversation_id"]].append(s) convs = list(by_conv.keys()) random.shuffle(convs) total = len(samples) target_val = int(total * val_ratio) val_samples: list[dict] = [] val_convs: set[str] = set() for conv in convs: if len(val_samples) >= target_val: break val_samples.extend(by_conv[conv]) val_convs.add(conv) train_samples = [s for s in samples if s["conversation_id"] not in val_convs] return train_samples, val_samples def actor_split_ravdess( samples: list[dict], val_actors: list[int] ) -> tuple[list[dict], list[dict]]: """Split RAVDESS by actor — specified actors go to val.""" val_set = set(val_actors) train = [s for s in samples if s["actor_id"] not in val_set] val = [s for s in samples if s["actor_id"] in val_set] return train, val # --------------------------------------------------------------------------- # Task 4 — Manifest & Stats # --------------------------------------------------------------------------- def save_manifest(samples: list[dict], path: Path) -> None: """Save manifest as JSON Lines.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) with open(path, "w") as f: json.dump(samples, f, indent=2, ensure_ascii=False) log.info("Saved manifest: %s (%d samples)", path, len(samples)) def save_stats(train: list[dict], val: list[dict], path: Path) -> None: """Save dataset statistics.""" path = Path(path) path.parent.mkdir(parents=True, exist_ok=True) def _count_stats(samples: list[dict]) -> dict: by_label: Counter = Counter() by_source: Counter = Counter() for s in samples: by_label[s["label"]] += 1 by_source[s["source"]] += 1 return { "total": len(samples), "by_label": dict(sorted(by_label.items())), "by_source": dict(sorted(by_source.items())), } stats = { "train": _count_stats(train), "val": _count_stats(val), "label2idx": LABEL2IDX, } with open(path, "w") as f: json.dump(stats, f, indent=2, ensure_ascii=False) log.info("Saved stats: %s", path) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): parser = argparse.ArgumentParser( description="Prepare unified LoRA 7-class dataset" ) parser.add_argument( "--anchor-dir", type=Path, default=Path("data/AI Hub 263"), help="Path to AI Hub 263 directory with CSVs + ZIPs", ) parser.add_argument( "--booster-label-zip", type=Path, default=Path( "data/AI Hub 71631/01-1.정식개방데이터/Training/" "02.라벨링데이터/TL_02.실외.zip" ), help="Path to 71631 label ZIP", ) parser.add_argument( "--booster-audio-zip", type=Path, default=Path( "data/AI Hub 71631/01-1.정식개방데이터/Training/" "01.원천데이터/TS_02.실외.zip" ), help="Path to 71631 audio ZIP", ) parser.add_argument( "--ravdess-dir", type=Path, default=Path("data/ravdess"), help="Path to RAVDESS directory with manifest.csv", ) parser.add_argument( "--output-dir", type=Path, default=Path("data/lora_7class"), help="Output directory for processed dataset", ) parser.add_argument("--cap-263", type=int, default=2100, help="Cap per class for 263") parser.add_argument("--cap-71631", type=int, default=3500, help="Cap per class for 71631") parser.add_argument("--max-per-speaker-71631", type=int, default=20, help="Max utterances per speaker for 71631") parser.add_argument("--val-ratio", type=float, default=0.1, help="Val ratio for splits") parser.add_argument("--seed", type=int, default=42, help="Random seed") parser.add_argument( "--skip-263", action="store_true", help="Skip AI Hub 263 extraction" ) parser.add_argument( "--skip-71631", action="store_true", help="Skip AI Hub 71631 extraction" ) parser.add_argument( "--skip-ravdess", action="store_true", help="Skip RAVDESS extraction" ) args = parser.parse_args() random.seed(args.seed) output_dir = args.output_dir output_dir.mkdir(parents=True, exist_ok=True) all_train: list[dict] = [] all_val: list[dict] = [] # ---- 263 Anchor ---- if not args.skip_263: log.info("=" * 60) log.info("Extracting AI Hub 263 (anchor)") samples_263 = extract_anchor_263( args.anchor_dir, output_dir / "263", cap_per_class=args.cap_263, ) train_263, val_263 = speaker_isolated_split(samples_263, args.val_ratio) log.info("263: train=%d, val=%d", len(train_263), len(val_263)) all_train.extend(train_263) all_val.extend(val_263) # ---- 71631 Booster ---- if not args.skip_71631: log.info("=" * 60) log.info("Extracting AI Hub 71631 (booster)") samples_71631 = extract_booster_71631( args.booster_label_zip, args.booster_audio_zip, output_dir / "71631", cap=args.cap_71631, max_per_speaker=args.max_per_speaker_71631, ) train_71631, val_71631 = conversation_isolated_split( samples_71631, args.val_ratio ) log.info("71631: train=%d, val=%d", len(train_71631), len(val_71631)) all_train.extend(train_71631) all_val.extend(val_71631) # ---- RAVDESS ---- if not args.skip_ravdess: log.info("=" * 60) log.info("Extracting RAVDESS") samples_ravdess = extract_ravdess( args.ravdess_dir, output_dir / "ravdess", ) val_actors = [21, 22, 23, 24] train_ravdess, val_ravdess = actor_split_ravdess( samples_ravdess, val_actors ) log.info("RAVDESS: train=%d, val=%d", len(train_ravdess), len(val_ravdess)) all_train.extend(train_ravdess) all_val.extend(val_ravdess) # ---- Save ---- log.info("=" * 60) log.info("Total: train=%d, val=%d", len(all_train), len(all_val)) save_manifest(all_train, output_dir / "train_manifest.json") save_manifest(all_val, output_dir / "val_manifest.json") save_stats(all_train, all_val, output_dir / "stats.json") log.info("Done! Output: %s", output_dir) if __name__ == "__main__": main()