Spatial-BEATs / analyze_label_mapping.py
dieKarotte's picture
Add files using upload-large-folder tool
29615e9 verified
Raw
History Blame Contribute Delete
3.84 kB
#!/usr/bin/env python3
"""Analyze mono_primary_label -> mono_target_label mapping in ov1_foa.jsonl (train split)."""
import json
from collections import defaultdict, Counter
JSONL = "/apdcephfs_cq10/share_1603164/user/schmittzhu/data/metadata/ov1_foa.jsonl"
# Collect data
string_primary = Counter()
string_all_labels = defaultdict(list) # primary_label -> list of mono_audio_labels combos
percussion_primary = Counter()
# Full mapping: mono_primary_label -> mono_target_label
full_mapping = defaultdict(set) # primary -> set of targets
full_mapping_counts = defaultdict(Counter) # target -> Counter of primary labels
with open(JSONL) as f:
for line in f:
rec = json.loads(line)
if rec["split"] != "train":
continue
primary = rec["mono_primary_label"]
target = rec["mono_target_label"]
audio_labels = rec["mono_audio_labels"]
full_mapping[primary].add(target)
full_mapping_counts[target][primary] += 1
if target == "string_instrument":
string_primary[primary] += 1
string_all_labels[primary].append(tuple(audio_labels))
if target == "percussion":
percussion_primary[primary] += 1
# ============================================================
print("=" * 80)
print("1) mono_target_label == 'string_instrument' : mono_primary_label counts")
print("=" * 80)
for label, cnt in string_primary.most_common():
print(f" {label:45s} {cnt:6d}")
print(f" {'TOTAL':45s} {sum(string_primary.values()):6d}")
# Suspicious non-string labels
SUSPECT_STRING = {
"Hi-hat", "Cymbal", "Crash_cymbal", "Drum", "Snare_drum", "Bass_drum",
"Drum_kit", "Tabla", "Gong", "Tambourine", "Marimba_and_xylophone",
"Mallet_percussion", "Vibraphone", "Steelpan",
}
suspect_found = {k for k in string_primary if k in SUSPECT_STRING}
print()
print("-" * 80)
print("Non-string suspects in string_instrument (with full audio_labels combos):")
print("-" * 80)
# Also show ANY primary that looks percussive
for label in sorted(string_primary):
# Show all labels for inspection
combos = Counter(string_all_labels[label])
# Check if any combo contains percussion-like terms
is_suspect = any(
any(t in tag for tag in combo for t in ["Drum", "Cymbal", "Hi-hat", "Percussion", "Gong", "Tambourine", "Tabla", "Mallet", "Marimba", "Vibraphone", "Steelpan"])
for combo in combos
)
if is_suspect or label in SUSPECT_STRING:
print(f"\n ** {label} (count={string_primary[label]}) **")
for combo, n in combos.most_common():
print(f" x{n:4d} {list(combo)}")
# ============================================================
print()
print("=" * 80)
print("2) mono_target_label == 'percussion' : mono_primary_label counts")
print("=" * 80)
for label, cnt in percussion_primary.most_common():
print(f" {label:45s} {cnt:6d}")
print(f" {'TOTAL':45s} {sum(percussion_primary.values()):6d}")
# ============================================================
print()
print("=" * 80)
print("3) Complete mapping: mono_primary_label -> mono_target_label (train split)")
print("=" * 80)
# Sort by target, then primary
all_targets = sorted(full_mapping_counts.keys())
print(f"\nTotal unique mono_target_label classes: {len(all_targets)}")
print(f"Total unique mono_primary_label values: {len(full_mapping)}")
print()
print(f"{'mono_target_label':30s} {'mono_primary_label':45s} {'count':>8s}")
print("-" * 90)
for target in all_targets:
primaries = full_mapping_counts[target]
for i, (prim, cnt) in enumerate(primaries.most_common()):
t_display = target if i == 0 else ""
print(f" {t_display:28s} {prim:45s} {cnt:8d}")
# subtotal
total = sum(primaries.values())
print(f" {'':28s} {'--- subtotal ---':45s} {total:8d}")
print()