|
|
|
|
|
""" |
|
|
Dataset validation script for CodeLlama fine-tuning |
|
|
Validates format, content, and quality of JSONL datasets |
|
|
""" |
|
|
|
|
|
import json |
|
|
import sys |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Tuple |
|
|
from collections import Counter |
|
|
|
|
|
def validate_dataset(input_file: str, min_length: int = 3) -> Dict: |
|
|
"""Comprehensive dataset validation""" |
|
|
|
|
|
print(f"π Validating dataset: {input_file}") |
|
|
print("=" * 70) |
|
|
|
|
|
results = { |
|
|
"valid_samples": [], |
|
|
"invalid_samples": [], |
|
|
"errors": [], |
|
|
"warnings": [], |
|
|
"statistics": {} |
|
|
} |
|
|
|
|
|
total_lines = 0 |
|
|
valid_count = 0 |
|
|
invalid_count = 0 |
|
|
|
|
|
|
|
|
instruction_lengths = [] |
|
|
response_lengths = [] |
|
|
has_code_markers = 0 |
|
|
duplicates = [] |
|
|
seen_samples = set() |
|
|
|
|
|
print("\nπ Checking each sample...") |
|
|
|
|
|
with open(input_file, 'r', encoding='utf-8') as f: |
|
|
for line_num, line in enumerate(f, 1): |
|
|
total_lines += 1 |
|
|
line = line.strip() |
|
|
|
|
|
if not line: |
|
|
continue |
|
|
|
|
|
sample = None |
|
|
try: |
|
|
sample = json.loads(line) |
|
|
except json.JSONDecodeError as e: |
|
|
invalid_count += 1 |
|
|
error_msg = f"Line {line_num}: Invalid JSON - {str(e)}" |
|
|
results["errors"].append(error_msg) |
|
|
results["invalid_samples"].append({"line": line_num, "error": error_msg}) |
|
|
continue |
|
|
|
|
|
|
|
|
validation_errors = [] |
|
|
|
|
|
|
|
|
if "instruction" not in sample: |
|
|
validation_errors.append("Missing 'instruction' field") |
|
|
if "response" not in sample: |
|
|
validation_errors.append("Missing 'response' field") |
|
|
|
|
|
|
|
|
if "instruction" in sample and not isinstance(sample["instruction"], str): |
|
|
validation_errors.append("'instruction' must be a string") |
|
|
if "response" in sample and not isinstance(sample["response"], str): |
|
|
validation_errors.append("'response' must be a string") |
|
|
|
|
|
|
|
|
if "instruction" in sample: |
|
|
instruction = sample["instruction"].strip() |
|
|
if not instruction: |
|
|
validation_errors.append("Empty 'instruction' field") |
|
|
elif len(instruction) < min_length: |
|
|
validation_errors.append(f"'instruction' too short (< {min_length} chars)") |
|
|
else: |
|
|
instruction_lengths.append(len(instruction)) |
|
|
|
|
|
if "response" in sample: |
|
|
response = sample["response"].strip() |
|
|
if not response: |
|
|
validation_errors.append("Empty 'response' field") |
|
|
elif len(response) < min_length: |
|
|
validation_errors.append(f"'response' too short (< {min_length} chars)") |
|
|
else: |
|
|
response_lengths.append(len(response)) |
|
|
if '```verilog' in response or '```' in response: |
|
|
has_code_markers += 1 |
|
|
|
|
|
|
|
|
sample_hash = hash(json.dumps(sample, sort_keys=True)) |
|
|
if sample_hash in seen_samples: |
|
|
duplicates.append(line_num) |
|
|
results["warnings"].append(f"Line {line_num}: Duplicate sample") |
|
|
else: |
|
|
seen_samples.add(sample_hash) |
|
|
|
|
|
|
|
|
if validation_errors: |
|
|
invalid_count += 1 |
|
|
error_msg = f"Line {line_num}: {'; '.join(validation_errors)}" |
|
|
results["errors"].append(error_msg) |
|
|
results["invalid_samples"].append({"line": line_num, "errors": validation_errors}) |
|
|
else: |
|
|
valid_count += 1 |
|
|
results["valid_samples"].append(line_num) |
|
|
|
|
|
|
|
|
results["statistics"] = { |
|
|
"total_lines": total_lines, |
|
|
"valid_samples": valid_count, |
|
|
"invalid_samples": invalid_count, |
|
|
"duplicates": len(duplicates), |
|
|
"avg_instruction_length": sum(instruction_lengths) / len(instruction_lengths) if instruction_lengths else 0, |
|
|
"avg_response_length": sum(response_lengths) / len(response_lengths) if response_lengths else 0, |
|
|
"min_instruction_length": min(instruction_lengths) if instruction_lengths else 0, |
|
|
"max_instruction_length": max(instruction_lengths) if instruction_lengths else 0, |
|
|
"min_response_length": min(response_lengths) if response_lengths else 0, |
|
|
"max_response_length": max(response_lengths) if response_lengths else 0, |
|
|
"samples_with_code_markers": has_code_markers, |
|
|
"code_marker_percentage": (has_code_markers / valid_count * 100) if valid_count > 0 else 0 |
|
|
} |
|
|
|
|
|
|
|
|
print(f"\nπ Validation Results:") |
|
|
print("=" * 70) |
|
|
print(f" Total lines: {total_lines}") |
|
|
print(f" β
Valid samples: {valid_count}") |
|
|
print(f" β Invalid samples: {invalid_count}") |
|
|
print(f" β οΈ Duplicates: {len(duplicates)}") |
|
|
|
|
|
if instruction_lengths: |
|
|
print(f"\nπ Instruction Statistics:") |
|
|
print(f" Average length: {results['statistics']['avg_instruction_length']:.1f} chars") |
|
|
print(f" Min/Max: {results['statistics']['min_instruction_length']} / {results['statistics']['max_instruction_length']} chars") |
|
|
|
|
|
if response_lengths: |
|
|
print(f"\nπ Response Statistics:") |
|
|
print(f" Average length: {results['statistics']['avg_response_length']:.1f} chars") |
|
|
print(f" Min/Max: {results['statistics']['min_response_length']} / {results['statistics']['max_response_length']} chars") |
|
|
print(f" Samples with code markers: {has_code_markers} ({results['statistics']['code_marker_percentage']:.1f}%)") |
|
|
|
|
|
if results["errors"]: |
|
|
print(f"\nβ Errors ({len(results['errors'])}):") |
|
|
for error in results["errors"][:10]: |
|
|
print(f" {error}") |
|
|
if len(results["errors"]) > 10: |
|
|
print(f" ... and {len(results['errors']) - 10} more errors") |
|
|
|
|
|
if results["warnings"]: |
|
|
print(f"\nβ οΈ Warnings ({len(results['warnings'])}):") |
|
|
for warning in results["warnings"][:5]: |
|
|
print(f" {warning}") |
|
|
if len(results["warnings"]) > 5: |
|
|
print(f" ... and {len(results['warnings']) - 5} more warnings") |
|
|
|
|
|
|
|
|
print(f"\n" + "=" * 70) |
|
|
if invalid_count == 0 and len(duplicates) == 0: |
|
|
print("β
DATASET VALIDATION PASSED - Ready for training!") |
|
|
elif invalid_count == 0: |
|
|
print("β οΈ DATASET VALIDATION PASSED (with warnings about duplicates)") |
|
|
else: |
|
|
print("β DATASET VALIDATION FAILED - Fix errors before training") |
|
|
print("=" * 70) |
|
|
|
|
|
return results |
|
|
|
|
|
if __name__ == "__main__": |
|
|
import argparse |
|
|
|
|
|
parser = argparse.ArgumentParser(description="Validate dataset for training") |
|
|
parser.add_argument("--input", required=True, help="Input JSONL file to validate") |
|
|
parser.add_argument("--report", help="Optional: Save validation report to JSON file") |
|
|
parser.add_argument("--min-length", type=int, default=3, help="Minimum field length (default: 3)") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
if not Path(args.input).exists(): |
|
|
print(f"β Error: File not found: {args.input}") |
|
|
sys.exit(1) |
|
|
|
|
|
results = validate_dataset(args.input, args.min_length) |
|
|
|
|
|
|
|
|
if args.report: |
|
|
with open(args.report, 'w') as f: |
|
|
json.dump(results, f, indent=2) |
|
|
print(f"\nπ Validation report saved to: {args.report}") |
|
|
|
|
|
|
|
|
if results["statistics"]["invalid_samples"] > 0: |
|
|
sys.exit(1) |
|
|
else: |
|
|
sys.exit(0) |
|
|
|
|
|
|
|
|
|