| |
| """ |
| Dataset Validation Script for Flux Identity LoRA Training |
| |
| Features: |
| - Scan all images in dataset directory |
| - Detect corrupt/unreadable files using PIL |
| - Report image resolutions |
| - Find duplicate files using perceptual hashing (imagehash) |
| - Verify matching caption files (.txt) for each image |
| - Generate validation report |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
| from collections import defaultdict |
| from PIL import Image |
| import imagehash |
| from rich.console import Console |
| from rich.table import Table |
| from rich.progress import Progress, SpinnerColumn, TextColumn |
|
|
| console = Console() |
|
|
| |
| IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff', '.tif'} |
|
|
|
|
| def scan_images(directory: Path) -> list: |
| """Scan directory for image files.""" |
| images = [] |
| for ext in IMAGE_EXTENSIONS: |
| images.extend(directory.glob(f'*{ext}')) |
| images.extend(directory.glob(f'*{ext.upper()}')) |
| return sorted(set(images)) |
|
|
|
|
| def validate_image(image_path: Path) -> dict: |
| """Validate a single image file.""" |
| result = { |
| 'path': image_path, |
| 'valid': False, |
| 'width': 0, |
| 'height': 0, |
| 'format': None, |
| 'mode': None, |
| 'error': None, |
| 'phash': None, |
| 'has_caption': False, |
| 'caption_path': None |
| } |
|
|
| try: |
| with Image.open(image_path) as img: |
| img.verify() |
|
|
| |
| with Image.open(image_path) as img: |
| result['valid'] = True |
| result['width'] = img.width |
| result['height'] = img.height |
| result['format'] = img.format |
| result['mode'] = img.mode |
|
|
| |
| result['phash'] = str(imagehash.phash(img)) |
|
|
| except Exception as e: |
| result['error'] = str(e) |
|
|
| |
| caption_path = image_path.with_suffix('.txt') |
| if caption_path.exists(): |
| result['has_caption'] = True |
| result['caption_path'] = caption_path |
|
|
| return result |
|
|
|
|
| def find_duplicates(results: list) -> dict: |
| """Find duplicate images using perceptual hashing.""" |
| hash_groups = defaultdict(list) |
|
|
| for r in results: |
| if r['phash']: |
| hash_groups[r['phash']].append(r['path']) |
|
|
| |
| return {h: paths for h, paths in hash_groups.items() if len(paths) > 1} |
|
|
|
|
| def generate_report(results: list, duplicates: dict, output_file: Path = None): |
| """Generate validation report.""" |
|
|
| valid_count = sum(1 for r in results if r['valid']) |
| invalid_count = sum(1 for r in results if not r['valid']) |
| with_caption = sum(1 for r in results if r['has_caption']) |
| without_caption = sum(1 for r in results if r['valid'] and not r['has_caption']) |
|
|
| console.print("\n[bold blue]═══ Dataset Validation Report ═══[/bold blue]\n") |
|
|
| |
| summary = Table(title="Summary", show_header=False) |
| summary.add_column("Metric", style="cyan") |
| summary.add_column("Value", style="green") |
|
|
| summary.add_row("Total Images Scanned", str(len(results))) |
| summary.add_row("Valid Images", str(valid_count)) |
| summary.add_row("Invalid/Corrupt Images", str(invalid_count)) |
| summary.add_row("Images with Captions", str(with_caption)) |
| summary.add_row("Images Missing Captions", str(without_caption)) |
| summary.add_row("Duplicate Groups Found", str(len(duplicates))) |
|
|
| console.print(summary) |
|
|
| |
| resolutions = defaultdict(int) |
| for r in results: |
| if r['valid']: |
| res = f"{r['width']}x{r['height']}" |
| resolutions[res] += 1 |
|
|
| if resolutions: |
| console.print("\n[bold]Resolution Distribution:[/bold]") |
| res_table = Table() |
| res_table.add_column("Resolution", style="cyan") |
| res_table.add_column("Count", style="green") |
| res_table.add_column("Percentage", style="yellow") |
|
|
| for res, count in sorted(resolutions.items(), key=lambda x: -x[1])[:10]: |
| pct = (count / valid_count) * 100 if valid_count > 0 else 0 |
| res_table.add_row(res, str(count), f"{pct:.1f}%") |
|
|
| console.print(res_table) |
|
|
| |
| invalid_files = [r for r in results if not r['valid']] |
| if invalid_files: |
| console.print("\n[bold red]Invalid/Corrupt Files:[/bold red]") |
| for r in invalid_files: |
| console.print(f" • {r['path'].name}: {r['error']}") |
|
|
| |
| missing_captions = [r for r in results if r['valid'] and not r['has_caption']] |
| if missing_captions: |
| console.print("\n[bold yellow]Images Missing Captions:[/bold yellow]") |
| for r in missing_captions[:20]: |
| console.print(f" • {r['path'].name}") |
| if len(missing_captions) > 20: |
| console.print(f" ... and {len(missing_captions) - 20} more") |
|
|
| |
| if duplicates: |
| console.print("\n[bold yellow]Potential Duplicates (by perceptual hash):[/bold yellow]") |
| for hash_val, paths in list(duplicates.items())[:10]: |
| console.print(f"\n Hash: {hash_val}") |
| for p in paths: |
| console.print(f" • {p.name}") |
| if len(duplicates) > 10: |
| console.print(f"\n ... and {len(duplicates) - 10} more duplicate groups") |
|
|
| |
| console.print("\n[bold blue]Recommendations:[/bold blue]") |
|
|
| if invalid_count > 0: |
| console.print(f" ⚠ Remove or fix {invalid_count} corrupt image(s)") |
|
|
| if without_caption > 0: |
| console.print(f" ⚠ Add captions for {without_caption} image(s)") |
|
|
| if duplicates: |
| total_dups = sum(len(paths) - 1 for paths in duplicates.values()) |
| console.print(f" ⚠ Review {total_dups} potential duplicate image(s)") |
|
|
| |
| small_images = [r for r in results if r['valid'] and (r['width'] < 512 or r['height'] < 512)] |
| if small_images: |
| console.print(f" ⚠ {len(small_images)} image(s) are smaller than 512px (may affect quality)") |
|
|
| if valid_count > 0 and without_caption == 0 and invalid_count == 0 and not duplicates: |
| console.print(" ✓ Dataset looks good! Ready for training.") |
|
|
| |
| if output_file: |
| with open(output_file, 'w') as f: |
| f.write("Dataset Validation Report\n") |
| f.write("=" * 50 + "\n\n") |
| f.write(f"Total Images: {len(results)}\n") |
| f.write(f"Valid: {valid_count}\n") |
| f.write(f"Invalid: {invalid_count}\n") |
| f.write(f"With Captions: {with_caption}\n") |
| f.write(f"Missing Captions: {without_caption}\n") |
| f.write(f"Duplicate Groups: {len(duplicates)}\n") |
|
|
| if invalid_files: |
| f.write("\nInvalid Files:\n") |
| for r in invalid_files: |
| f.write(f" - {r['path']}: {r['error']}\n") |
|
|
| if missing_captions: |
| f.write("\nMissing Captions:\n") |
| for r in missing_captions: |
| f.write(f" - {r['path']}\n") |
|
|
| console.print(f"\n[dim]Report saved to: {output_file}[/dim]") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Validate dataset for Flux LoRA training") |
| parser.add_argument( |
| "directory", |
| nargs="?", |
| default="/workspace/flux-project/datasets/identity/images", |
| help="Directory containing training images" |
| ) |
| parser.add_argument( |
| "-o", "--output", |
| help="Save report to file" |
| ) |
| parser.add_argument( |
| "--no-duplicates", |
| action="store_true", |
| help="Skip duplicate detection (faster)" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| dataset_dir = Path(args.directory) |
|
|
| if not dataset_dir.exists(): |
| console.print(f"[red]Error: Directory not found: {dataset_dir}[/red]") |
| sys.exit(1) |
|
|
| console.print(f"[bold]Scanning directory: {dataset_dir}[/bold]") |
|
|
| |
| images = scan_images(dataset_dir) |
|
|
| if not images: |
| console.print("[yellow]No images found in directory.[/yellow]") |
| sys.exit(0) |
|
|
| console.print(f"Found {len(images)} image file(s)\n") |
|
|
| |
| results = [] |
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[progress.description]{task.description}"), |
| console=console |
| ) as progress: |
| task = progress.add_task("Validating images...", total=len(images)) |
|
|
| for img_path in images: |
| result = validate_image(img_path) |
| results.append(result) |
| progress.advance(task) |
|
|
| |
| duplicates = {} |
| if not args.no_duplicates: |
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[progress.description]{task.description}"), |
| console=console |
| ) as progress: |
| task = progress.add_task("Checking for duplicates...", total=1) |
| duplicates = find_duplicates(results) |
| progress.advance(task) |
|
|
| |
| output_file = Path(args.output) if args.output else None |
| generate_report(results, duplicates, output_file) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|