codellama-fine-tuning / scripts /validate_dataset.py
Prithvik-1's picture
Upload scripts/validate_dataset.py with huggingface_hub
eada9ff verified
#!/usr/bin/env python3
"""
Dataset validation script for CodeLlama fine-tuning
Validates format, content, and quality of JSONL datasets
"""
import json
import sys
from pathlib import Path
from typing import Dict, List, Tuple
from collections import Counter
def validate_dataset(input_file: str, min_length: int = 3) -> Dict:
"""Comprehensive dataset validation"""
print(f"πŸ” Validating dataset: {input_file}")
print("=" * 70)
results = {
"valid_samples": [],
"invalid_samples": [],
"errors": [],
"warnings": [],
"statistics": {}
}
total_lines = 0
valid_count = 0
invalid_count = 0
# Statistics
instruction_lengths = []
response_lengths = []
has_code_markers = 0
duplicates = []
seen_samples = set()
print("\nπŸ“‹ Checking each sample...")
with open(input_file, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
total_lines += 1
line = line.strip()
if not line:
continue
sample = None
try:
sample = json.loads(line)
except json.JSONDecodeError as e:
invalid_count += 1
error_msg = f"Line {line_num}: Invalid JSON - {str(e)}"
results["errors"].append(error_msg)
results["invalid_samples"].append({"line": line_num, "error": error_msg})
continue
# Validate fields
validation_errors = []
# Check required fields
if "instruction" not in sample:
validation_errors.append("Missing 'instruction' field")
if "response" not in sample:
validation_errors.append("Missing 'response' field")
# Check data types
if "instruction" in sample and not isinstance(sample["instruction"], str):
validation_errors.append("'instruction' must be a string")
if "response" in sample and not isinstance(sample["response"], str):
validation_errors.append("'response' must be a string")
# Check content
if "instruction" in sample:
instruction = sample["instruction"].strip()
if not instruction:
validation_errors.append("Empty 'instruction' field")
elif len(instruction) < min_length:
validation_errors.append(f"'instruction' too short (< {min_length} chars)")
else:
instruction_lengths.append(len(instruction))
if "response" in sample:
response = sample["response"].strip()
if not response:
validation_errors.append("Empty 'response' field")
elif len(response) < min_length:
validation_errors.append(f"'response' too short (< {min_length} chars)")
else:
response_lengths.append(len(response))
if '```verilog' in response or '```' in response:
has_code_markers += 1
# Check for duplicates
sample_hash = hash(json.dumps(sample, sort_keys=True))
if sample_hash in seen_samples:
duplicates.append(line_num)
results["warnings"].append(f"Line {line_num}: Duplicate sample")
else:
seen_samples.add(sample_hash)
# Record result
if validation_errors:
invalid_count += 1
error_msg = f"Line {line_num}: {'; '.join(validation_errors)}"
results["errors"].append(error_msg)
results["invalid_samples"].append({"line": line_num, "errors": validation_errors})
else:
valid_count += 1
results["valid_samples"].append(line_num)
# Calculate statistics
results["statistics"] = {
"total_lines": total_lines,
"valid_samples": valid_count,
"invalid_samples": invalid_count,
"duplicates": len(duplicates),
"avg_instruction_length": sum(instruction_lengths) / len(instruction_lengths) if instruction_lengths else 0,
"avg_response_length": sum(response_lengths) / len(response_lengths) if response_lengths else 0,
"min_instruction_length": min(instruction_lengths) if instruction_lengths else 0,
"max_instruction_length": max(instruction_lengths) if instruction_lengths else 0,
"min_response_length": min(response_lengths) if response_lengths else 0,
"max_response_length": max(response_lengths) if response_lengths else 0,
"samples_with_code_markers": has_code_markers,
"code_marker_percentage": (has_code_markers / valid_count * 100) if valid_count > 0 else 0
}
# Print results
print(f"\nπŸ“Š Validation Results:")
print("=" * 70)
print(f" Total lines: {total_lines}")
print(f" βœ… Valid samples: {valid_count}")
print(f" ❌ Invalid samples: {invalid_count}")
print(f" ⚠️ Duplicates: {len(duplicates)}")
if instruction_lengths:
print(f"\nπŸ“ Instruction Statistics:")
print(f" Average length: {results['statistics']['avg_instruction_length']:.1f} chars")
print(f" Min/Max: {results['statistics']['min_instruction_length']} / {results['statistics']['max_instruction_length']} chars")
if response_lengths:
print(f"\nπŸ“ Response Statistics:")
print(f" Average length: {results['statistics']['avg_response_length']:.1f} chars")
print(f" Min/Max: {results['statistics']['min_response_length']} / {results['statistics']['max_response_length']} chars")
print(f" Samples with code markers: {has_code_markers} ({results['statistics']['code_marker_percentage']:.1f}%)")
if results["errors"]:
print(f"\n❌ Errors ({len(results['errors'])}):")
for error in results["errors"][:10]: # Show first 10
print(f" {error}")
if len(results["errors"]) > 10:
print(f" ... and {len(results['errors']) - 10} more errors")
if results["warnings"]:
print(f"\n⚠️ Warnings ({len(results['warnings'])}):")
for warning in results["warnings"][:5]: # Show first 5
print(f" {warning}")
if len(results["warnings"]) > 5:
print(f" ... and {len(results['warnings']) - 5} more warnings")
# Validation summary
print(f"\n" + "=" * 70)
if invalid_count == 0 and len(duplicates) == 0:
print("βœ… DATASET VALIDATION PASSED - Ready for training!")
elif invalid_count == 0:
print("⚠️ DATASET VALIDATION PASSED (with warnings about duplicates)")
else:
print("❌ DATASET VALIDATION FAILED - Fix errors before training")
print("=" * 70)
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Validate dataset for training")
parser.add_argument("--input", required=True, help="Input JSONL file to validate")
parser.add_argument("--report", help="Optional: Save validation report to JSON file")
parser.add_argument("--min-length", type=int, default=3, help="Minimum field length (default: 3)")
args = parser.parse_args()
if not Path(args.input).exists():
print(f"❌ Error: File not found: {args.input}")
sys.exit(1)
results = validate_dataset(args.input, args.min_length)
# Save report if requested
if args.report:
with open(args.report, 'w') as f:
json.dump(results, f, indent=2)
print(f"\nπŸ“„ Validation report saved to: {args.report}")
# Exit with appropriate code
if results["statistics"]["invalid_samples"] > 0:
sys.exit(1)
else:
sys.exit(0)