| |
| """ |
| Dataset Statistics Script for Flux Identity LoRA Training |
| |
| Features: |
| - Count total images and caption files |
| - List missing captions |
| - Resolution distribution histogram |
| - Average/min/max dimensions |
| - File format breakdown |
| - Caption length statistics |
| """ |
|
|
| import os |
| import sys |
| import argparse |
| from pathlib import Path |
| from collections import defaultdict, Counter |
| from PIL import Image |
| from rich.console import Console |
| from rich.table import Table |
| from rich.panel import Panel |
| from rich.progress import Progress, SpinnerColumn, TextColumn |
|
|
| console = Console() |
|
|
| |
| IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff', '.tif'} |
|
|
|
|
| def scan_files(directory: Path) -> tuple: |
| """Scan directory for image and caption files.""" |
| images = [] |
| captions = [] |
|
|
| for f in directory.iterdir(): |
| if f.is_file(): |
| if f.suffix.lower() in IMAGE_EXTENSIONS: |
| images.append(f) |
| elif f.suffix.lower() == '.txt': |
| captions.append(f) |
|
|
| return sorted(images), sorted(captions) |
|
|
|
|
| def get_image_info(image_path: Path) -> dict: |
| """Get information about an image.""" |
| try: |
| with Image.open(image_path) as img: |
| return { |
| 'path': image_path, |
| 'width': img.width, |
| 'height': img.height, |
| 'format': img.format, |
| 'mode': img.mode, |
| 'size_kb': image_path.stat().st_size / 1024 |
| } |
| except Exception as e: |
| return { |
| 'path': image_path, |
| 'error': str(e) |
| } |
|
|
|
|
| def get_caption_info(caption_path: Path) -> dict: |
| """Get information about a caption file.""" |
| try: |
| content = caption_path.read_text(encoding='utf-8').strip() |
| words = content.split() |
| return { |
| 'path': caption_path, |
| 'length': len(content), |
| 'word_count': len(words), |
| 'content': content[:200] + '...' if len(content) > 200 else content |
| } |
| except Exception as e: |
| return { |
| 'path': caption_path, |
| 'error': str(e) |
| } |
|
|
|
|
| def create_histogram(values: list, bins: int = 10, width: int = 40) -> str: |
| """Create a simple text histogram.""" |
| if not values: |
| return "No data" |
|
|
| min_val = min(values) |
| max_val = max(values) |
|
|
| if min_val == max_val: |
| return f"All values: {min_val}" |
|
|
| bin_size = (max_val - min_val) / bins |
| histogram = defaultdict(int) |
|
|
| for v in values: |
| bin_idx = min(int((v - min_val) / bin_size), bins - 1) |
| histogram[bin_idx] += 1 |
|
|
| max_count = max(histogram.values()) if histogram else 1 |
|
|
| lines = [] |
| for i in range(bins): |
| bin_start = min_val + i * bin_size |
| bin_end = bin_start + bin_size |
| count = histogram[i] |
| bar_len = int((count / max_count) * width) |
| bar = 'β' * bar_len |
| lines.append(f"{bin_start:6.0f}-{bin_end:6.0f} | {bar} ({count})") |
|
|
| return '\n'.join(lines) |
|
|
|
|
| def print_stats(images: list, captions: list, image_infos: list, caption_infos: list, directory: Path): |
| """Print comprehensive dataset statistics.""" |
|
|
| console.print(Panel.fit( |
| f"[bold blue]Dataset Statistics[/bold blue]\n[dim]{directory}[/dim]", |
| border_style="blue" |
| )) |
|
|
| |
| console.print("\n[bold cyan]βββ File Counts βββ[/bold cyan]") |
|
|
| counts_table = Table(show_header=False) |
| counts_table.add_column("Metric", style="cyan") |
| counts_table.add_column("Value", style="green") |
|
|
| counts_table.add_row("Total Images", str(len(images))) |
| counts_table.add_row("Total Caption Files", str(len(captions))) |
|
|
| |
| image_stems = {img.stem for img in images} |
| caption_stems = {cap.stem for cap in captions} |
|
|
| matched = image_stems & caption_stems |
| images_without_captions = image_stems - caption_stems |
| captions_without_images = caption_stems - image_stems |
|
|
| counts_table.add_row("Matched Image-Caption Pairs", str(len(matched))) |
| counts_table.add_row("Images Missing Captions", str(len(images_without_captions))) |
| counts_table.add_row("Orphan Caption Files", str(len(captions_without_images))) |
|
|
| console.print(counts_table) |
|
|
| |
| if images_without_captions: |
| console.print("\n[bold yellow]Images Missing Captions:[/bold yellow]") |
| for stem in sorted(images_without_captions)[:15]: |
| console.print(f" β’ {stem}") |
| if len(images_without_captions) > 15: |
| console.print(f" ... and {len(images_without_captions) - 15} more") |
|
|
| |
| valid_infos = [i for i in image_infos if 'error' not in i] |
|
|
| if valid_infos: |
| console.print("\n[bold cyan]βββ Image Dimensions βββ[/bold cyan]") |
|
|
| widths = [i['width'] for i in valid_infos] |
| heights = [i['height'] for i in valid_infos] |
| sizes = [i['size_kb'] for i in valid_infos] |
|
|
| dim_table = Table() |
| dim_table.add_column("Metric", style="cyan") |
| dim_table.add_column("Width", style="green") |
| dim_table.add_column("Height", style="green") |
|
|
| dim_table.add_row("Minimum", str(min(widths)), str(min(heights))) |
| dim_table.add_row("Maximum", str(max(widths)), str(max(heights))) |
| dim_table.add_row("Average", f"{sum(widths)/len(widths):.0f}", f"{sum(heights)/len(heights):.0f}") |
|
|
| console.print(dim_table) |
|
|
| |
| console.print("\n[bold]Resolution Distribution:[/bold]") |
| resolutions = Counter(f"{i['width']}x{i['height']}" for i in valid_infos) |
|
|
| res_table = Table() |
| res_table.add_column("Resolution", style="cyan") |
| res_table.add_column("Count", style="green") |
| res_table.add_column("Percentage", style="yellow") |
| res_table.add_column("", style="dim") |
|
|
| for res, count in resolutions.most_common(10): |
| pct = (count / len(valid_infos)) * 100 |
| bar = 'β' * int(pct / 2) |
| res_table.add_row(res, str(count), f"{pct:.1f}%", bar) |
|
|
| if len(resolutions) > 10: |
| res_table.add_row("...", f"+{len(resolutions) - 10} more", "", "") |
|
|
| console.print(res_table) |
|
|
| |
| console.print("\n[bold]File Format Breakdown:[/bold]") |
| formats = Counter(i['format'] for i in valid_infos) |
|
|
| fmt_table = Table() |
| fmt_table.add_column("Format", style="cyan") |
| fmt_table.add_column("Count", style="green") |
| fmt_table.add_column("Percentage", style="yellow") |
|
|
| for fmt, count in formats.most_common(): |
| pct = (count / len(valid_infos)) * 100 |
| fmt_table.add_row(fmt or "Unknown", str(count), f"{pct:.1f}%") |
|
|
| console.print(fmt_table) |
|
|
| |
| console.print("\n[bold]File Size Statistics:[/bold]") |
| size_table = Table(show_header=False) |
| size_table.add_column("Metric", style="cyan") |
| size_table.add_column("Value", style="green") |
|
|
| size_table.add_row("Minimum Size", f"{min(sizes):.1f} KB") |
| size_table.add_row("Maximum Size", f"{max(sizes):.1f} KB") |
| size_table.add_row("Average Size", f"{sum(sizes)/len(sizes):.1f} KB") |
| size_table.add_row("Total Size", f"{sum(sizes)/1024:.1f} MB") |
|
|
| console.print(size_table) |
|
|
| |
| valid_captions = [c for c in caption_infos if 'error' not in c] |
|
|
| if valid_captions: |
| console.print("\n[bold cyan]βββ Caption Statistics βββ[/bold cyan]") |
|
|
| lengths = [c['length'] for c in valid_captions] |
| word_counts = [c['word_count'] for c in valid_captions] |
|
|
| cap_table = Table() |
| cap_table.add_column("Metric", style="cyan") |
| cap_table.add_column("Characters", style="green") |
| cap_table.add_column("Words", style="green") |
|
|
| cap_table.add_row("Minimum", str(min(lengths)), str(min(word_counts))) |
| cap_table.add_row("Maximum", str(max(lengths)), str(max(word_counts))) |
| cap_table.add_row("Average", f"{sum(lengths)/len(lengths):.0f}", f"{sum(word_counts)/len(word_counts):.1f}") |
|
|
| console.print(cap_table) |
|
|
| |
| console.print("\n[bold]Caption Length Distribution (characters):[/bold]") |
| console.print(create_histogram(lengths, bins=8, width=30)) |
|
|
| |
| console.print("\n[bold]Sample Captions:[/bold]") |
| for cap in valid_captions[:3]: |
| console.print(f"\n [dim]{cap['path'].stem}:[/dim]") |
| console.print(f" {cap['content']}") |
|
|
| |
| console.print("\n[bold cyan]βββ Recommendations βββ[/bold cyan]") |
|
|
| recommendations = [] |
|
|
| if images_without_captions: |
| recommendations.append(f"β Add captions for {len(images_without_captions)} image(s)") |
|
|
| if valid_infos: |
| small = sum(1 for i in valid_infos if i['width'] < 512 or i['height'] < 512) |
| if small > 0: |
| recommendations.append(f"β {small} image(s) are smaller than 512px") |
|
|
| large = sum(1 for i in valid_infos if i['width'] > 2048 or i['height'] > 2048) |
| if large > 0: |
| recommendations.append(f"βΉ {large} image(s) are larger than 2048px (will be resized)") |
|
|
| if valid_captions: |
| short = sum(1 for c in valid_captions if c['word_count'] < 5) |
| if short > 0: |
| recommendations.append(f"β {short} caption(s) are very short (<5 words)") |
|
|
| if not recommendations: |
| recommendations.append("β Dataset looks good!") |
|
|
| for rec in recommendations: |
| console.print(f" {rec}") |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate dataset statistics for Flux LoRA training") |
| parser.add_argument( |
| "directory", |
| nargs="?", |
| default="/workspace/flux-project/datasets/identity/images", |
| help="Directory containing training images and captions" |
| ) |
| parser.add_argument( |
| "-o", "--output", |
| help="Save statistics to file" |
| ) |
|
|
| args = parser.parse_args() |
|
|
| dataset_dir = Path(args.directory) |
|
|
| if not dataset_dir.exists(): |
| console.print(f"[red]Error: Directory not found: {dataset_dir}[/red]") |
| sys.exit(1) |
|
|
| console.print(f"[bold]Scanning: {dataset_dir}[/bold]\n") |
|
|
| |
| images, captions = scan_files(dataset_dir) |
|
|
| if not images and not captions: |
| console.print("[yellow]No images or captions found in directory.[/yellow]") |
| sys.exit(0) |
|
|
| |
| image_infos = [] |
| if images: |
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[progress.description]{task.description}"), |
| console=console |
| ) as progress: |
| task = progress.add_task(f"Analyzing {len(images)} images...", total=len(images)) |
|
|
| for img in images: |
| image_infos.append(get_image_info(img)) |
| progress.advance(task) |
|
|
| |
| caption_infos = [] |
| if captions: |
| with Progress( |
| SpinnerColumn(), |
| TextColumn("[progress.description]{task.description}"), |
| console=console |
| ) as progress: |
| task = progress.add_task(f"Analyzing {len(captions)} captions...", total=len(captions)) |
|
|
| for cap in captions: |
| caption_infos.append(get_caption_info(cap)) |
| progress.advance(task) |
|
|
| |
| print_stats(images, captions, image_infos, caption_infos, dataset_dir) |
|
|
| |
| if args.output: |
| with open(args.output, 'w') as f: |
| f.write(f"Dataset Statistics: {dataset_dir}\n") |
| f.write("=" * 50 + "\n\n") |
| f.write(f"Images: {len(images)}\n") |
| f.write(f"Captions: {len(captions)}\n") |
| |
| console.print(f"\n[dim]Statistics saved to: {args.output}[/dim]") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|