File size: 8,061 Bytes
eada9ff |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 |
#!/usr/bin/env python3
"""
Dataset validation script for CodeLlama fine-tuning
Validates format, content, and quality of JSONL datasets
"""
import json
import sys
from pathlib import Path
from typing import Dict, List, Tuple
from collections import Counter
def validate_dataset(input_file: str, min_length: int = 3) -> Dict:
"""Comprehensive dataset validation"""
print(f"๐ Validating dataset: {input_file}")
print("=" * 70)
results = {
"valid_samples": [],
"invalid_samples": [],
"errors": [],
"warnings": [],
"statistics": {}
}
total_lines = 0
valid_count = 0
invalid_count = 0
# Statistics
instruction_lengths = []
response_lengths = []
has_code_markers = 0
duplicates = []
seen_samples = set()
print("\n๐ Checking each sample...")
with open(input_file, 'r', encoding='utf-8') as f:
for line_num, line in enumerate(f, 1):
total_lines += 1
line = line.strip()
if not line:
continue
sample = None
try:
sample = json.loads(line)
except json.JSONDecodeError as e:
invalid_count += 1
error_msg = f"Line {line_num}: Invalid JSON - {str(e)}"
results["errors"].append(error_msg)
results["invalid_samples"].append({"line": line_num, "error": error_msg})
continue
# Validate fields
validation_errors = []
# Check required fields
if "instruction" not in sample:
validation_errors.append("Missing 'instruction' field")
if "response" not in sample:
validation_errors.append("Missing 'response' field")
# Check data types
if "instruction" in sample and not isinstance(sample["instruction"], str):
validation_errors.append("'instruction' must be a string")
if "response" in sample and not isinstance(sample["response"], str):
validation_errors.append("'response' must be a string")
# Check content
if "instruction" in sample:
instruction = sample["instruction"].strip()
if not instruction:
validation_errors.append("Empty 'instruction' field")
elif len(instruction) < min_length:
validation_errors.append(f"'instruction' too short (< {min_length} chars)")
else:
instruction_lengths.append(len(instruction))
if "response" in sample:
response = sample["response"].strip()
if not response:
validation_errors.append("Empty 'response' field")
elif len(response) < min_length:
validation_errors.append(f"'response' too short (< {min_length} chars)")
else:
response_lengths.append(len(response))
if '```verilog' in response or '```' in response:
has_code_markers += 1
# Check for duplicates
sample_hash = hash(json.dumps(sample, sort_keys=True))
if sample_hash in seen_samples:
duplicates.append(line_num)
results["warnings"].append(f"Line {line_num}: Duplicate sample")
else:
seen_samples.add(sample_hash)
# Record result
if validation_errors:
invalid_count += 1
error_msg = f"Line {line_num}: {'; '.join(validation_errors)}"
results["errors"].append(error_msg)
results["invalid_samples"].append({"line": line_num, "errors": validation_errors})
else:
valid_count += 1
results["valid_samples"].append(line_num)
# Calculate statistics
results["statistics"] = {
"total_lines": total_lines,
"valid_samples": valid_count,
"invalid_samples": invalid_count,
"duplicates": len(duplicates),
"avg_instruction_length": sum(instruction_lengths) / len(instruction_lengths) if instruction_lengths else 0,
"avg_response_length": sum(response_lengths) / len(response_lengths) if response_lengths else 0,
"min_instruction_length": min(instruction_lengths) if instruction_lengths else 0,
"max_instruction_length": max(instruction_lengths) if instruction_lengths else 0,
"min_response_length": min(response_lengths) if response_lengths else 0,
"max_response_length": max(response_lengths) if response_lengths else 0,
"samples_with_code_markers": has_code_markers,
"code_marker_percentage": (has_code_markers / valid_count * 100) if valid_count > 0 else 0
}
# Print results
print(f"\n๐ Validation Results:")
print("=" * 70)
print(f" Total lines: {total_lines}")
print(f" โ
Valid samples: {valid_count}")
print(f" โ Invalid samples: {invalid_count}")
print(f" โ ๏ธ Duplicates: {len(duplicates)}")
if instruction_lengths:
print(f"\n๐ Instruction Statistics:")
print(f" Average length: {results['statistics']['avg_instruction_length']:.1f} chars")
print(f" Min/Max: {results['statistics']['min_instruction_length']} / {results['statistics']['max_instruction_length']} chars")
if response_lengths:
print(f"\n๐ Response Statistics:")
print(f" Average length: {results['statistics']['avg_response_length']:.1f} chars")
print(f" Min/Max: {results['statistics']['min_response_length']} / {results['statistics']['max_response_length']} chars")
print(f" Samples with code markers: {has_code_markers} ({results['statistics']['code_marker_percentage']:.1f}%)")
if results["errors"]:
print(f"\nโ Errors ({len(results['errors'])}):")
for error in results["errors"][:10]: # Show first 10
print(f" {error}")
if len(results["errors"]) > 10:
print(f" ... and {len(results['errors']) - 10} more errors")
if results["warnings"]:
print(f"\nโ ๏ธ Warnings ({len(results['warnings'])}):")
for warning in results["warnings"][:5]: # Show first 5
print(f" {warning}")
if len(results["warnings"]) > 5:
print(f" ... and {len(results['warnings']) - 5} more warnings")
# Validation summary
print(f"\n" + "=" * 70)
if invalid_count == 0 and len(duplicates) == 0:
print("โ
DATASET VALIDATION PASSED - Ready for training!")
elif invalid_count == 0:
print("โ ๏ธ DATASET VALIDATION PASSED (with warnings about duplicates)")
else:
print("โ DATASET VALIDATION FAILED - Fix errors before training")
print("=" * 70)
return results
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Validate dataset for training")
parser.add_argument("--input", required=True, help="Input JSONL file to validate")
parser.add_argument("--report", help="Optional: Save validation report to JSON file")
parser.add_argument("--min-length", type=int, default=3, help="Minimum field length (default: 3)")
args = parser.parse_args()
if not Path(args.input).exists():
print(f"โ Error: File not found: {args.input}")
sys.exit(1)
results = validate_dataset(args.input, args.min_length)
# Save report if requested
if args.report:
with open(args.report, 'w') as f:
json.dump(results, f, indent=2)
print(f"\n๐ Validation report saved to: {args.report}")
# Exit with appropriate code
if results["statistics"]["invalid_samples"] > 0:
sys.exit(1)
else:
sys.exit(0)
|