ustwo-api / scripts /prepare_meld_fusion_data.py
asdfasdfqrqwer's picture
Deploy from GitHub 2026-04-23T03:56:31Z
c857b85
Raw
History Blame Contribute Delete
4.68 kB
#!/usr/bin/env python3
"""Prepare MELD test split for English fusion grid search.
Extracts mp4 → wav (16kHz mono) and builds a manifest with text + emotion labels.
Usage:
python scripts/prepare_meld_fusion_data.py
"""
from __future__ import annotations
import csv
import io
import json
import logging
import subprocess
import tempfile
import zipfile
from collections import Counter
from pathlib import Path
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
logger = logging.getLogger(__name__)
PROJECT_LABELS = ["neutral", "joy", "sadness", "anger", "surprise", "fear", "disgust"]
# MELD emotions map 1:1 to project labels
MELD_LABEL_MAP = {
"neutral": "neutral",
"joy": "joy",
"sadness": "sadness",
"anger": "anger",
"surprise": "surprise",
"fear": "fear",
"disgust": "disgust",
}
def main():
zip_path = Path("data/english_test.zip")
output_dir = Path("data/meld_fusion")
audio_dir = output_dir / "audio"
audio_dir.mkdir(parents=True, exist_ok=True)
zf = zipfile.ZipFile(zip_path)
# Step 1: Parse test CSV
logger.info("Parsing MELD test CSV...")
with zf.open("MELD.Raw/MELD.Raw/test_sent_emo.csv") as f:
reader = csv.DictReader(io.TextIOWrapper(f, encoding="utf-8"))
rows = list(reader)
logger.info("MELD test: %d utterances", len(rows))
# Build lookup: (dia_id, utt_id) → row
csv_lookup = {}
for r in rows:
key = (int(r["Dialogue_ID"]), int(r["Utterance_ID"]))
csv_lookup[key] = r
# Step 2: Find mp4 files in zip
test_mp4s = {}
for name in zf.namelist():
if "output_repeated_splits_test" in name and name.endswith(".mp4"):
fname = Path(name).name
if fname.startswith("._"):
continue # skip macOS metadata
# Parse dia{D}_utt{U}.mp4
try:
parts = fname.replace(".mp4", "").split("_")
dia_id = int(parts[0].replace("dia", ""))
utt_id = int(parts[1].replace("utt", ""))
test_mp4s[(dia_id, utt_id)] = name
except (ValueError, IndexError):
continue
logger.info("Found %d test mp4 files (excluding macOS metadata)", len(test_mp4s))
# Step 3: Match CSV ↔ mp4, extract wav
manifest = []
skipped = 0
matched_keys = set(csv_lookup.keys()) & set(test_mp4s.keys())
logger.info("Matched CSV↔mp4: %d", len(matched_keys))
for i, key in enumerate(sorted(matched_keys)):
row = csv_lookup[key]
mp4_name = test_mp4s[key]
dia_id, utt_id = key
label = MELD_LABEL_MAP.get(row["Emotion"])
if label is None:
skipped += 1
continue
text = row["Utterance"].strip()
if not text:
skipped += 1
continue
wav_path = audio_dir / f"dia{dia_id}_utt{utt_id}.wav"
if not wav_path.exists():
# Extract mp4 from zip → convert to 16kHz mono wav
try:
mp4_bytes = zf.read(mp4_name)
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp:
tmp.write(mp4_bytes)
tmp_path = tmp.name
result = subprocess.run(
["ffmpeg", "-y", "-i", tmp_path,
"-ar", "16000", "-ac", "1", "-f", "wav",
str(wav_path)],
capture_output=True, timeout=30,
)
Path(tmp_path).unlink(missing_ok=True)
if result.returncode != 0:
skipped += 1
continue
except Exception as e:
logger.warning("Failed dia%d_utt%d: %s", dia_id, utt_id, e)
skipped += 1
continue
manifest.append({
"path": str(wav_path),
"text": text,
"label": label,
"source": "meld_test",
"dialogue_id": dia_id,
"utterance_id": utt_id,
})
if (i + 1) % 200 == 0:
logger.info("Processed %d / %d", i + 1, len(matched_keys))
zf.close()
# Step 4: Save manifest
manifest_path = output_dir / "manifest.json"
with open(manifest_path, "w", encoding="utf-8") as f:
json.dump(manifest, f, indent=2, ensure_ascii=False)
logger.info("Saved %d samples to %s (skipped %d)", len(manifest), manifest_path, skipped)
# Stats
emotions = Counter(s["label"] for s in manifest)
for e, c in sorted(emotions.items(), key=lambda x: -x[1]):
print(f" {e}: {c}")
if __name__ == "__main__":
main()