| |
| """Prepare MELD test split for English fusion grid search. |
| |
| Extracts mp4 → wav (16kHz mono) and builds a manifest with text + emotion labels. |
| |
| Usage: |
| python scripts/prepare_meld_fusion_data.py |
| """ |
| from __future__ import annotations |
|
|
| import csv |
| import io |
| import json |
| import logging |
| import subprocess |
| import tempfile |
| import zipfile |
| from collections import Counter |
| from pathlib import Path |
|
|
| logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s") |
| logger = logging.getLogger(__name__) |
|
|
| PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"] |
|
|
| |
| MELD_LABEL_MAP = { |
| "neutral": "neutral", |
| "joy": "joy", |
| "sadness": "sadness", |
| "anger": "anger", |
| "surprise": "surprise", |
| "fear": "fear", |
| "disgust": "disgust", |
| } |
|
|
|
|
| def main(): |
| zip_path = Path("data/english_test.zip") |
| output_dir = Path("data/meld_fusion") |
| audio_dir = output_dir / "audio" |
| audio_dir.mkdir(parents=True, exist_ok=True) |
|
|
| zf = zipfile.ZipFile(zip_path) |
|
|
| |
| logger.info("Parsing MELD test CSV...") |
| with zf.open("MELD.Raw/MELD.Raw/test_sent_emo.csv") as f: |
| reader = csv.DictReader(io.TextIOWrapper(f, encoding="utf-8")) |
| rows = list(reader) |
| logger.info("MELD test: %d utterances", len(rows)) |
|
|
| |
| csv_lookup = {} |
| for r in rows: |
| key = (int(r["Dialogue_ID"]), int(r["Utterance_ID"])) |
| csv_lookup[key] = r |
|
|
| |
| test_mp4s = {} |
| for name in zf.namelist(): |
| if "output_repeated_splits_test" in name and name.endswith(".mp4"): |
| fname = Path(name).name |
| if fname.startswith("._"): |
| continue |
| |
| try: |
| parts = fname.replace(".mp4", "").split("_") |
| dia_id = int(parts[0].replace("dia", "")) |
| utt_id = int(parts[1].replace("utt", "")) |
| test_mp4s[(dia_id, utt_id)] = name |
| except (ValueError, IndexError): |
| continue |
|
|
| logger.info("Found %d test mp4 files (excluding macOS metadata)", len(test_mp4s)) |
|
|
| |
| manifest = [] |
| skipped = 0 |
|
|
| matched_keys = set(csv_lookup.keys()) & set(test_mp4s.keys()) |
| logger.info("Matched CSV↔mp4: %d", len(matched_keys)) |
|
|
| for i, key in enumerate(sorted(matched_keys)): |
| row = csv_lookup[key] |
| mp4_name = test_mp4s[key] |
| dia_id, utt_id = key |
|
|
| label = MELD_LABEL_MAP.get(row["Emotion"]) |
| if label is None: |
| skipped += 1 |
| continue |
|
|
| text = row["Utterance"].strip() |
| if not text: |
| skipped += 1 |
| continue |
|
|
| wav_path = audio_dir / f"dia{dia_id}_utt{utt_id}.wav" |
|
|
| if not wav_path.exists(): |
| |
| try: |
| mp4_bytes = zf.read(mp4_name) |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp: |
| tmp.write(mp4_bytes) |
| tmp_path = tmp.name |
|
|
| result = subprocess.run( |
| ["ffmpeg", "-y", "-i", tmp_path, |
| "-ar", "16000", "-ac", "1", "-f", "wav", |
| str(wav_path)], |
| capture_output=True, timeout=30, |
| ) |
| Path(tmp_path).unlink(missing_ok=True) |
|
|
| if result.returncode != 0: |
| skipped += 1 |
| continue |
| except Exception as e: |
| logger.warning("Failed dia%d_utt%d: %s", dia_id, utt_id, e) |
| skipped += 1 |
| continue |
|
|
| manifest.append({ |
| "path": str(wav_path), |
| "text": text, |
| "label": label, |
| "source": "meld_test", |
| "dialogue_id": dia_id, |
| "utterance_id": utt_id, |
| }) |
|
|
| if (i + 1) % 200 == 0: |
| logger.info("Processed %d / %d", i + 1, len(matched_keys)) |
|
|
| zf.close() |
|
|
| |
| manifest_path = output_dir / "manifest.json" |
| with open(manifest_path, "w", encoding="utf-8") as f: |
| json.dump(manifest, f, indent=2, ensure_ascii=False) |
|
|
| logger.info("Saved %d samples to %s (skipped %d)", len(manifest), manifest_path, skipped) |
|
|
| |
| emotions = Counter(s["label"] for s in manifest) |
| for e, c in sorted(emotions.items(), key=lambda x: -x[1]): |
| print(f" {e}: {c}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|