| |
| """ |
| Validate dataset format for Stage 1 and Stage 2 training. |
| |
| Checks that all required fields are present and have correct shapes/types. |
| |
| Usage: |
| python validate_dataset_format.py --input ./data/dataset.pt [--samples 100] |
| """ |
| import argparse |
| import sys |
| import torch |
|
|
|
|
| def validate_dataset(data_path: str, num_samples: int = 100): |
| """Validate dataset format for training compatibility.""" |
| print(f"\n{'='*60}") |
| print(f"Dataset Validation: {data_path}") |
| print(f"{'='*60}\n") |
|
|
| |
| print(f"[1/6] Loading dataset...") |
| try: |
| data = torch.load(data_path, map_location="cpu", weights_only=False) |
| print(f" β Loaded {len(data)} samples") |
| except Exception as e: |
| print(f" β Failed to load: {e}") |
| return False |
|
|
| if len(data) == 0: |
| print(" β Dataset is empty!") |
| return False |
|
|
| |
| print(f"\n[2/6] Checking required fields...") |
| required_fields = ["whisper_features", "snac_tokens"] |
| optional_fields = ["text_tokens", "answer", "text", "word_alignments"] |
|
|
| sample = data[0] |
| for field in required_fields: |
| if field not in sample: |
| print(f" β Missing required field: {field}") |
| return False |
| print(f" β {field}: present") |
|
|
| for field in optional_fields: |
| if field in sample: |
| print(f" β {field}: present (optional)") |
| else: |
| print(f" β {field}: missing (optional)") |
|
|
| |
| print(f"\n[3/6] Validating whisper_features...") |
| errors = 0 |
| whisper_dims = [] |
| for i, item in enumerate(data[:num_samples]): |
| wf = item.get("whisper_features") |
| if wf is None: |
| print(f" β Sample {i}: whisper_features is None") |
| errors += 1 |
| continue |
| if not isinstance(wf, torch.Tensor): |
| print(f" β Sample {i}: whisper_features is not a Tensor (got {type(wf)})") |
| errors += 1 |
| continue |
| if len(wf.shape) != 2: |
| print(f" β Sample {i}: whisper_features has wrong shape {wf.shape} (expected 2D)") |
| errors += 1 |
| continue |
| if wf.shape[1] != 1280: |
| print(f" β Sample {i}: whisper_features dim {wf.shape[1]} != 1280") |
| errors += 1 |
| continue |
| whisper_dims.append(wf.shape[0]) |
|
|
| if errors == 0: |
| print(f" β All samples have valid whisper_features [seq_len, 1280]") |
| print(f" Sequence lengths: min={min(whisper_dims)}, max={max(whisper_dims)}, avg={sum(whisper_dims)/len(whisper_dims):.1f}") |
| else: |
| print(f" β {errors}/{num_samples} samples have invalid whisper_features") |
|
|
| |
| print(f"\n[4/6] Validating snac_tokens...") |
| errors = 0 |
| snac_lens = [] |
| snac_ranges = [] |
| for i, item in enumerate(data[:num_samples]): |
| st = item.get("snac_tokens") |
| if st is None: |
| print(f" β Sample {i}: snac_tokens is None") |
| errors += 1 |
| continue |
| if not isinstance(st, torch.Tensor): |
| print(f" β Sample {i}: snac_tokens is not a Tensor (got {type(st)})") |
| errors += 1 |
| continue |
| if len(st.shape) != 1: |
| print(f" β Sample {i}: snac_tokens has wrong shape {st.shape} (expected 1D)") |
| errors += 1 |
| continue |
| if len(st) % 7 != 0: |
| print(f" β Sample {i}: snac_tokens length {len(st)} not multiple of 7 (will be truncated)") |
| snac_lens.append(len(st)) |
| snac_ranges.append((st.min().item(), st.max().item())) |
|
|
| if errors == 0: |
| print(f" β All samples have valid snac_tokens") |
| print(f" Lengths: min={min(snac_lens)}, max={max(snac_lens)}, avg={sum(snac_lens)/len(snac_lens):.1f}") |
| min_val = min(r[0] for r in snac_ranges) |
| max_val = max(r[1] for r in snac_ranges) |
| print(f" Token range: {min_val} - {max_val}") |
|
|
| |
| if min_val >= 128266 and max_val < 160000: |
| print(f" β SNAC token offsets look correct (128266-based)") |
| else: |
| print(f" β SNAC token range unusual - check offsets") |
| else: |
| print(f" β {errors}/{num_samples} samples have invalid snac_tokens") |
|
|
| |
| print(f"\n[5/6] Validating text_tokens (optional)...") |
| if "text_tokens" in sample: |
| errors = 0 |
| token_lens = [] |
| for i, item in enumerate(data[:num_samples]): |
| tt = item.get("text_tokens") |
| if tt is None: |
| continue |
| if isinstance(tt, torch.Tensor): |
| token_lens.append(len(tt)) |
| elif isinstance(tt, list): |
| token_lens.append(len(tt)) |
| else: |
| print(f" β Sample {i}: text_tokens is {type(tt)}") |
| errors += 1 |
|
|
| if token_lens: |
| print(f" β text_tokens present in {len(token_lens)}/{num_samples} samples") |
| print(f" Lengths: min={min(token_lens)}, max={max(token_lens)}, avg={sum(token_lens)/len(token_lens):.1f}") |
| else: |
| print(f" β text_tokens not present (will be tokenized on-the-fly)") |
|
|
| |
| print(f"\n[6/6] Validating word_alignments (optional)...") |
| if "word_alignments" in sample: |
| errors = 0 |
| alignment_counts = [] |
| for i, item in enumerate(data[:num_samples]): |
| wa = item.get("word_alignments") |
| if wa is None: |
| continue |
| if not isinstance(wa, list): |
| print(f" β Sample {i}: word_alignments is not a list") |
| errors += 1 |
| continue |
| alignment_counts.append(len(wa)) |
|
|
| |
| if len(wa) > 0 and i == 0: |
| first = wa[0] |
| expected_keys = {"word", "start_frame", "end_frame"} |
| if isinstance(first, dict) and expected_keys.issubset(first.keys()): |
| print(f" β word_alignments structure: {list(first.keys())}") |
| else: |
| print(f" β Unexpected word_alignment structure: {first}") |
|
|
| if alignment_counts: |
| print(f" β word_alignments present in {len(alignment_counts)}/{num_samples} samples") |
| print(f" Words per sample: min={min(alignment_counts)}, max={max(alignment_counts)}, avg={sum(alignment_counts)/len(alignment_counts):.1f}") |
| else: |
| print(f" β word_alignments not present (will use proportional alignment)") |
|
|
| |
| print(f"\n{'='*60}") |
| print("SUMMARY") |
| print(f"{'='*60}") |
| print(f" Total samples: {len(data)}") |
| print(f" Required fields: whisper_features β, snac_tokens β") |
|
|
| optional_present = [] |
| if "text_tokens" in sample: |
| optional_present.append("text_tokens") |
| if "answer" in sample or "text" in sample: |
| optional_present.append("answer/text") |
| if "word_alignments" in sample: |
| optional_present.append("word_alignments") |
|
|
| print(f" Optional fields: {', '.join(optional_present) if optional_present else 'none'}") |
| print(f"\n Stage 1 compatible: β YES") |
| print(f" Stage 2 compatible: β YES") |
| print(f"{'='*60}\n") |
|
|
| return True |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Validate dataset format") |
| parser.add_argument("--input", type=str, required=True, help="Path to dataset.pt") |
| parser.add_argument("--samples", type=int, default=100, help="Number of samples to check") |
| args = parser.parse_args() |
|
|
| success = validate_dataset(args.input, args.samples) |
| sys.exit(0 if success else 1) |
|
|