#!/usr/bin/env python3 """ Validate dataset format for Stage 1 and Stage 2 training. Checks that all required fields are present and have correct shapes/types. Usage: python validate_dataset_format.py --input ./data/dataset.pt [--samples 100] """ import argparse import sys import torch def validate_dataset(data_path: str, num_samples: int = 100): """Validate dataset format for training compatibility.""" print(f"\n{'='*60}") print(f"Dataset Validation: {data_path}") print(f"{'='*60}\n") # Load dataset print(f"[1/6] Loading dataset...") try: data = torch.load(data_path, map_location="cpu", weights_only=False) print(f" ✓ Loaded {len(data)} samples") except Exception as e: print(f" ✗ Failed to load: {e}") return False if len(data) == 0: print(" ✗ Dataset is empty!") return False # Check required fields print(f"\n[2/6] Checking required fields...") required_fields = ["whisper_features", "snac_tokens"] optional_fields = ["text_tokens", "answer", "text", "word_alignments"] sample = data[0] for field in required_fields: if field not in sample: print(f" ✗ Missing required field: {field}") return False print(f" ✓ {field}: present") for field in optional_fields: if field in sample: print(f" ✓ {field}: present (optional)") else: print(f" ○ {field}: missing (optional)") # Validate whisper_features print(f"\n[3/6] Validating whisper_features...") errors = 0 whisper_dims = [] for i, item in enumerate(data[:num_samples]): wf = item.get("whisper_features") if wf is None: print(f" ✗ Sample {i}: whisper_features is None") errors += 1 continue if not isinstance(wf, torch.Tensor): print(f" ✗ Sample {i}: whisper_features is not a Tensor (got {type(wf)})") errors += 1 continue if len(wf.shape) != 2: print(f" ✗ Sample {i}: whisper_features has wrong shape {wf.shape} (expected 2D)") errors += 1 continue if wf.shape[1] != 1280: print(f" ✗ Sample {i}: whisper_features dim {wf.shape[1]} != 1280") errors += 1 continue whisper_dims.append(wf.shape[0]) if errors == 0: print(f" ✓ All samples have valid whisper_features [seq_len, 1280]") print(f" Sequence lengths: min={min(whisper_dims)}, max={max(whisper_dims)}, avg={sum(whisper_dims)/len(whisper_dims):.1f}") else: print(f" ✗ {errors}/{num_samples} samples have invalid whisper_features") # Validate snac_tokens print(f"\n[4/6] Validating snac_tokens...") errors = 0 snac_lens = [] snac_ranges = [] for i, item in enumerate(data[:num_samples]): st = item.get("snac_tokens") if st is None: print(f" ✗ Sample {i}: snac_tokens is None") errors += 1 continue if not isinstance(st, torch.Tensor): print(f" ✗ Sample {i}: snac_tokens is not a Tensor (got {type(st)})") errors += 1 continue if len(st.shape) != 1: print(f" ✗ Sample {i}: snac_tokens has wrong shape {st.shape} (expected 1D)") errors += 1 continue if len(st) % 7 != 0: print(f" ⚠ Sample {i}: snac_tokens length {len(st)} not multiple of 7 (will be truncated)") snac_lens.append(len(st)) snac_ranges.append((st.min().item(), st.max().item())) if errors == 0: print(f" ✓ All samples have valid snac_tokens") print(f" Lengths: min={min(snac_lens)}, max={max(snac_lens)}, avg={sum(snac_lens)/len(snac_lens):.1f}") min_val = min(r[0] for r in snac_ranges) max_val = max(r[1] for r in snac_ranges) print(f" Token range: {min_val} - {max_val}") # Check SNAC offsets (should be in range 128266 to ~155000) if min_val >= 128266 and max_val < 160000: print(f" ✓ SNAC token offsets look correct (128266-based)") else: print(f" ⚠ SNAC token range unusual - check offsets") else: print(f" ✗ {errors}/{num_samples} samples have invalid snac_tokens") # Validate text_tokens (optional but recommended) print(f"\n[5/6] Validating text_tokens (optional)...") if "text_tokens" in sample: errors = 0 token_lens = [] for i, item in enumerate(data[:num_samples]): tt = item.get("text_tokens") if tt is None: continue if isinstance(tt, torch.Tensor): token_lens.append(len(tt)) elif isinstance(tt, list): token_lens.append(len(tt)) else: print(f" ⚠ Sample {i}: text_tokens is {type(tt)}") errors += 1 if token_lens: print(f" ✓ text_tokens present in {len(token_lens)}/{num_samples} samples") print(f" Lengths: min={min(token_lens)}, max={max(token_lens)}, avg={sum(token_lens)/len(token_lens):.1f}") else: print(f" ○ text_tokens not present (will be tokenized on-the-fly)") # Validate word_alignments (optional) print(f"\n[6/6] Validating word_alignments (optional)...") if "word_alignments" in sample: errors = 0 alignment_counts = [] for i, item in enumerate(data[:num_samples]): wa = item.get("word_alignments") if wa is None: continue if not isinstance(wa, list): print(f" ⚠ Sample {i}: word_alignments is not a list") errors += 1 continue alignment_counts.append(len(wa)) # Check structure of first alignment if len(wa) > 0 and i == 0: first = wa[0] expected_keys = {"word", "start_frame", "end_frame"} if isinstance(first, dict) and expected_keys.issubset(first.keys()): print(f" ✓ word_alignments structure: {list(first.keys())}") else: print(f" ⚠ Unexpected word_alignment structure: {first}") if alignment_counts: print(f" ✓ word_alignments present in {len(alignment_counts)}/{num_samples} samples") print(f" Words per sample: min={min(alignment_counts)}, max={max(alignment_counts)}, avg={sum(alignment_counts)/len(alignment_counts):.1f}") else: print(f" ○ word_alignments not present (will use proportional alignment)") # Summary print(f"\n{'='*60}") print("SUMMARY") print(f"{'='*60}") print(f" Total samples: {len(data)}") print(f" Required fields: whisper_features ✓, snac_tokens ✓") optional_present = [] if "text_tokens" in sample: optional_present.append("text_tokens") if "answer" in sample or "text" in sample: optional_present.append("answer/text") if "word_alignments" in sample: optional_present.append("word_alignments") print(f" Optional fields: {', '.join(optional_present) if optional_present else 'none'}") print(f"\n Stage 1 compatible: ✓ YES") print(f" Stage 2 compatible: ✓ YES") print(f"{'='*60}\n") return True if __name__ == "__main__": parser = argparse.ArgumentParser(description="Validate dataset format") parser.add_argument("--input", type=str, required=True, help="Path to dataset.pt") parser.add_argument("--samples", type=int, default=100, help="Number of samples to check") args = parser.parse_args() success = validate_dataset(args.input, args.samples) sys.exit(0 if success else 1)