#!/usr/bin/env python3 """ Dataset Statistics Script for Flux Identity LoRA Training Features: - Count total images and caption files - List missing captions - Resolution distribution histogram - Average/min/max dimensions - File format breakdown - Caption length statistics """ import os import sys import argparse from pathlib import Path from collections import defaultdict, Counter from PIL import Image from rich.console import Console from rich.table import Table from rich.panel import Panel from rich.progress import Progress, SpinnerColumn, TextColumn console = Console() # Supported image formats IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff', '.tif'} def scan_files(directory: Path) -> tuple: """Scan directory for image and caption files.""" images = [] captions = [] for f in directory.iterdir(): if f.is_file(): if f.suffix.lower() in IMAGE_EXTENSIONS: images.append(f) elif f.suffix.lower() == '.txt': captions.append(f) return sorted(images), sorted(captions) def get_image_info(image_path: Path) -> dict: """Get information about an image.""" try: with Image.open(image_path) as img: return { 'path': image_path, 'width': img.width, 'height': img.height, 'format': img.format, 'mode': img.mode, 'size_kb': image_path.stat().st_size / 1024 } except Exception as e: return { 'path': image_path, 'error': str(e) } def get_caption_info(caption_path: Path) -> dict: """Get information about a caption file.""" try: content = caption_path.read_text(encoding='utf-8').strip() words = content.split() return { 'path': caption_path, 'length': len(content), 'word_count': len(words), 'content': content[:200] + '...' if len(content) > 200 else content } except Exception as e: return { 'path': caption_path, 'error': str(e) } def create_histogram(values: list, bins: int = 10, width: int = 40) -> str: """Create a simple text histogram.""" if not values: return "No data" min_val = min(values) max_val = max(values) if min_val == max_val: return f"All values: {min_val}" bin_size = (max_val - min_val) / bins histogram = defaultdict(int) for v in values: bin_idx = min(int((v - min_val) / bin_size), bins - 1) histogram[bin_idx] += 1 max_count = max(histogram.values()) if histogram else 1 lines = [] for i in range(bins): bin_start = min_val + i * bin_size bin_end = bin_start + bin_size count = histogram[i] bar_len = int((count / max_count) * width) bar = '█' * bar_len lines.append(f"{bin_start:6.0f}-{bin_end:6.0f} | {bar} ({count})") return '\n'.join(lines) def print_stats(images: list, captions: list, image_infos: list, caption_infos: list, directory: Path): """Print comprehensive dataset statistics.""" console.print(Panel.fit( f"[bold blue]Dataset Statistics[/bold blue]\n[dim]{directory}[/dim]", border_style="blue" )) # Basic counts console.print("\n[bold cyan]═══ File Counts ═══[/bold cyan]") counts_table = Table(show_header=False) counts_table.add_column("Metric", style="cyan") counts_table.add_column("Value", style="green") counts_table.add_row("Total Images", str(len(images))) counts_table.add_row("Total Caption Files", str(len(captions))) # Check for matching pairs image_stems = {img.stem for img in images} caption_stems = {cap.stem for cap in captions} matched = image_stems & caption_stems images_without_captions = image_stems - caption_stems captions_without_images = caption_stems - image_stems counts_table.add_row("Matched Image-Caption Pairs", str(len(matched))) counts_table.add_row("Images Missing Captions", str(len(images_without_captions))) counts_table.add_row("Orphan Caption Files", str(len(captions_without_images))) console.print(counts_table) # Missing captions if images_without_captions: console.print("\n[bold yellow]Images Missing Captions:[/bold yellow]") for stem in sorted(images_without_captions)[:15]: console.print(f" • {stem}") if len(images_without_captions) > 15: console.print(f" ... and {len(images_without_captions) - 15} more") # Image statistics valid_infos = [i for i in image_infos if 'error' not in i] if valid_infos: console.print("\n[bold cyan]═══ Image Dimensions ═══[/bold cyan]") widths = [i['width'] for i in valid_infos] heights = [i['height'] for i in valid_infos] sizes = [i['size_kb'] for i in valid_infos] dim_table = Table() dim_table.add_column("Metric", style="cyan") dim_table.add_column("Width", style="green") dim_table.add_column("Height", style="green") dim_table.add_row("Minimum", str(min(widths)), str(min(heights))) dim_table.add_row("Maximum", str(max(widths)), str(max(heights))) dim_table.add_row("Average", f"{sum(widths)/len(widths):.0f}", f"{sum(heights)/len(heights):.0f}") console.print(dim_table) # Resolution distribution console.print("\n[bold]Resolution Distribution:[/bold]") resolutions = Counter(f"{i['width']}x{i['height']}" for i in valid_infos) res_table = Table() res_table.add_column("Resolution", style="cyan") res_table.add_column("Count", style="green") res_table.add_column("Percentage", style="yellow") res_table.add_column("", style="dim") for res, count in resolutions.most_common(10): pct = (count / len(valid_infos)) * 100 bar = '█' * int(pct / 2) res_table.add_row(res, str(count), f"{pct:.1f}%", bar) if len(resolutions) > 10: res_table.add_row("...", f"+{len(resolutions) - 10} more", "", "") console.print(res_table) # File format breakdown console.print("\n[bold]File Format Breakdown:[/bold]") formats = Counter(i['format'] for i in valid_infos) fmt_table = Table() fmt_table.add_column("Format", style="cyan") fmt_table.add_column("Count", style="green") fmt_table.add_column("Percentage", style="yellow") for fmt, count in formats.most_common(): pct = (count / len(valid_infos)) * 100 fmt_table.add_row(fmt or "Unknown", str(count), f"{pct:.1f}%") console.print(fmt_table) # File size statistics console.print("\n[bold]File Size Statistics:[/bold]") size_table = Table(show_header=False) size_table.add_column("Metric", style="cyan") size_table.add_column("Value", style="green") size_table.add_row("Minimum Size", f"{min(sizes):.1f} KB") size_table.add_row("Maximum Size", f"{max(sizes):.1f} KB") size_table.add_row("Average Size", f"{sum(sizes)/len(sizes):.1f} KB") size_table.add_row("Total Size", f"{sum(sizes)/1024:.1f} MB") console.print(size_table) # Caption statistics valid_captions = [c for c in caption_infos if 'error' not in c] if valid_captions: console.print("\n[bold cyan]═══ Caption Statistics ═══[/bold cyan]") lengths = [c['length'] for c in valid_captions] word_counts = [c['word_count'] for c in valid_captions] cap_table = Table() cap_table.add_column("Metric", style="cyan") cap_table.add_column("Characters", style="green") cap_table.add_column("Words", style="green") cap_table.add_row("Minimum", str(min(lengths)), str(min(word_counts))) cap_table.add_row("Maximum", str(max(lengths)), str(max(word_counts))) cap_table.add_row("Average", f"{sum(lengths)/len(lengths):.0f}", f"{sum(word_counts)/len(word_counts):.1f}") console.print(cap_table) # Caption length histogram console.print("\n[bold]Caption Length Distribution (characters):[/bold]") console.print(create_histogram(lengths, bins=8, width=30)) # Sample captions console.print("\n[bold]Sample Captions:[/bold]") for cap in valid_captions[:3]: console.print(f"\n [dim]{cap['path'].stem}:[/dim]") console.print(f" {cap['content']}") # Recommendations console.print("\n[bold cyan]═══ Recommendations ═══[/bold cyan]") recommendations = [] if images_without_captions: recommendations.append(f"⚠ Add captions for {len(images_without_captions)} image(s)") if valid_infos: small = sum(1 for i in valid_infos if i['width'] < 512 or i['height'] < 512) if small > 0: recommendations.append(f"⚠ {small} image(s) are smaller than 512px") large = sum(1 for i in valid_infos if i['width'] > 2048 or i['height'] > 2048) if large > 0: recommendations.append(f"ℹ {large} image(s) are larger than 2048px (will be resized)") if valid_captions: short = sum(1 for c in valid_captions if c['word_count'] < 5) if short > 0: recommendations.append(f"⚠ {short} caption(s) are very short (<5 words)") if not recommendations: recommendations.append("✓ Dataset looks good!") for rec in recommendations: console.print(f" {rec}") def main(): parser = argparse.ArgumentParser(description="Generate dataset statistics for Flux LoRA training") parser.add_argument( "directory", nargs="?", default="/workspace/flux-project/datasets/identity/images", help="Directory containing training images and captions" ) parser.add_argument( "-o", "--output", help="Save statistics to file" ) args = parser.parse_args() dataset_dir = Path(args.directory) if not dataset_dir.exists(): console.print(f"[red]Error: Directory not found: {dataset_dir}[/red]") sys.exit(1) console.print(f"[bold]Scanning: {dataset_dir}[/bold]\n") # Scan files images, captions = scan_files(dataset_dir) if not images and not captions: console.print("[yellow]No images or captions found in directory.[/yellow]") sys.exit(0) # Gather image info image_infos = [] if images: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console ) as progress: task = progress.add_task(f"Analyzing {len(images)} images...", total=len(images)) for img in images: image_infos.append(get_image_info(img)) progress.advance(task) # Gather caption info caption_infos = [] if captions: with Progress( SpinnerColumn(), TextColumn("[progress.description]{task.description}"), console=console ) as progress: task = progress.add_task(f"Analyzing {len(captions)} captions...", total=len(captions)) for cap in captions: caption_infos.append(get_caption_info(cap)) progress.advance(task) # Print statistics print_stats(images, captions, image_infos, caption_infos, dataset_dir) # Save to file if requested if args.output: with open(args.output, 'w') as f: f.write(f"Dataset Statistics: {dataset_dir}\n") f.write("=" * 50 + "\n\n") f.write(f"Images: {len(images)}\n") f.write(f"Captions: {len(captions)}\n") # Add more details as needed console.print(f"\n[dim]Statistics saved to: {args.output}[/dim]") if __name__ == "__main__": main()