omini-model / datasets /validate_dataset_format.py
marcos
feat: Add full fine-tuning (no LoRA) and dataset generation tools
cbe0918
Raw
History Blame Contribute Delete
7.85 kB
#!/usr/bin/env python3
"""
Validate dataset format for Stage 1 and Stage 2 training.
Checks that all required fields are present and have correct shapes/types.
Usage:
python validate_dataset_format.py --input ./data/dataset.pt [--samples 100]
"""
import argparse
import sys
import torch
def validate_dataset(data_path: str, num_samples: int = 100):
"""Validate dataset format for training compatibility."""
print(f"\n{'='*60}")
print(f"Dataset Validation: {data_path}")
print(f"{'='*60}\n")
# Load dataset
print(f"[1/6] Loading dataset...")
try:
data = torch.load(data_path, map_location="cpu", weights_only=False)
print(f" βœ“ Loaded {len(data)} samples")
except Exception as e:
print(f" βœ— Failed to load: {e}")
return False
if len(data) == 0:
print(" βœ— Dataset is empty!")
return False
# Check required fields
print(f"\n[2/6] Checking required fields...")
required_fields = ["whisper_features", "snac_tokens"]
optional_fields = ["text_tokens", "answer", "text", "word_alignments"]
sample = data[0]
for field in required_fields:
if field not in sample:
print(f" βœ— Missing required field: {field}")
return False
print(f" βœ“ {field}: present")
for field in optional_fields:
if field in sample:
print(f" βœ“ {field}: present (optional)")
else:
print(f" β—‹ {field}: missing (optional)")
# Validate whisper_features
print(f"\n[3/6] Validating whisper_features...")
errors = 0
whisper_dims = []
for i, item in enumerate(data[:num_samples]):
wf = item.get("whisper_features")
if wf is None:
print(f" βœ— Sample {i}: whisper_features is None")
errors += 1
continue
if not isinstance(wf, torch.Tensor):
print(f" βœ— Sample {i}: whisper_features is not a Tensor (got {type(wf)})")
errors += 1
continue
if len(wf.shape) != 2:
print(f" βœ— Sample {i}: whisper_features has wrong shape {wf.shape} (expected 2D)")
errors += 1
continue
if wf.shape[1] != 1280:
print(f" βœ— Sample {i}: whisper_features dim {wf.shape[1]} != 1280")
errors += 1
continue
whisper_dims.append(wf.shape[0])
if errors == 0:
print(f" βœ“ All samples have valid whisper_features [seq_len, 1280]")
print(f" Sequence lengths: min={min(whisper_dims)}, max={max(whisper_dims)}, avg={sum(whisper_dims)/len(whisper_dims):.1f}")
else:
print(f" βœ— {errors}/{num_samples} samples have invalid whisper_features")
# Validate snac_tokens
print(f"\n[4/6] Validating snac_tokens...")
errors = 0
snac_lens = []
snac_ranges = []
for i, item in enumerate(data[:num_samples]):
st = item.get("snac_tokens")
if st is None:
print(f" βœ— Sample {i}: snac_tokens is None")
errors += 1
continue
if not isinstance(st, torch.Tensor):
print(f" βœ— Sample {i}: snac_tokens is not a Tensor (got {type(st)})")
errors += 1
continue
if len(st.shape) != 1:
print(f" βœ— Sample {i}: snac_tokens has wrong shape {st.shape} (expected 1D)")
errors += 1
continue
if len(st) % 7 != 0:
print(f" ⚠ Sample {i}: snac_tokens length {len(st)} not multiple of 7 (will be truncated)")
snac_lens.append(len(st))
snac_ranges.append((st.min().item(), st.max().item()))
if errors == 0:
print(f" βœ“ All samples have valid snac_tokens")
print(f" Lengths: min={min(snac_lens)}, max={max(snac_lens)}, avg={sum(snac_lens)/len(snac_lens):.1f}")
min_val = min(r[0] for r in snac_ranges)
max_val = max(r[1] for r in snac_ranges)
print(f" Token range: {min_val} - {max_val}")
# Check SNAC offsets (should be in range 128266 to ~155000)
if min_val >= 128266 and max_val < 160000:
print(f" βœ“ SNAC token offsets look correct (128266-based)")
else:
print(f" ⚠ SNAC token range unusual - check offsets")
else:
print(f" βœ— {errors}/{num_samples} samples have invalid snac_tokens")
# Validate text_tokens (optional but recommended)
print(f"\n[5/6] Validating text_tokens (optional)...")
if "text_tokens" in sample:
errors = 0
token_lens = []
for i, item in enumerate(data[:num_samples]):
tt = item.get("text_tokens")
if tt is None:
continue
if isinstance(tt, torch.Tensor):
token_lens.append(len(tt))
elif isinstance(tt, list):
token_lens.append(len(tt))
else:
print(f" ⚠ Sample {i}: text_tokens is {type(tt)}")
errors += 1
if token_lens:
print(f" βœ“ text_tokens present in {len(token_lens)}/{num_samples} samples")
print(f" Lengths: min={min(token_lens)}, max={max(token_lens)}, avg={sum(token_lens)/len(token_lens):.1f}")
else:
print(f" β—‹ text_tokens not present (will be tokenized on-the-fly)")
# Validate word_alignments (optional)
print(f"\n[6/6] Validating word_alignments (optional)...")
if "word_alignments" in sample:
errors = 0
alignment_counts = []
for i, item in enumerate(data[:num_samples]):
wa = item.get("word_alignments")
if wa is None:
continue
if not isinstance(wa, list):
print(f" ⚠ Sample {i}: word_alignments is not a list")
errors += 1
continue
alignment_counts.append(len(wa))
# Check structure of first alignment
if len(wa) > 0 and i == 0:
first = wa[0]
expected_keys = {"word", "start_frame", "end_frame"}
if isinstance(first, dict) and expected_keys.issubset(first.keys()):
print(f" βœ“ word_alignments structure: {list(first.keys())}")
else:
print(f" ⚠ Unexpected word_alignment structure: {first}")
if alignment_counts:
print(f" βœ“ word_alignments present in {len(alignment_counts)}/{num_samples} samples")
print(f" Words per sample: min={min(alignment_counts)}, max={max(alignment_counts)}, avg={sum(alignment_counts)/len(alignment_counts):.1f}")
else:
print(f" β—‹ word_alignments not present (will use proportional alignment)")
# Summary
print(f"\n{'='*60}")
print("SUMMARY")
print(f"{'='*60}")
print(f" Total samples: {len(data)}")
print(f" Required fields: whisper_features βœ“, snac_tokens βœ“")
optional_present = []
if "text_tokens" in sample:
optional_present.append("text_tokens")
if "answer" in sample or "text" in sample:
optional_present.append("answer/text")
if "word_alignments" in sample:
optional_present.append("word_alignments")
print(f" Optional fields: {', '.join(optional_present) if optional_present else 'none'}")
print(f"\n Stage 1 compatible: βœ“ YES")
print(f" Stage 2 compatible: βœ“ YES")
print(f"{'='*60}\n")
return True
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Validate dataset format")
parser.add_argument("--input", type=str, required=True, help="Path to dataset.pt")
parser.add_argument("--samples", type=int, default=100, help="Number of samples to check")
args = parser.parse_args()
success = validate_dataset(args.input, args.samples)
sys.exit(0 if success else 1)