MYai / scripts /dataset_stats.py
montignyp's picture
Initial commit: Flux Identity LoRA Training Environment
1a3a976
#!/usr/bin/env python3
"""
Dataset Statistics Script for Flux Identity LoRA Training
Features:
- Count total images and caption files
- List missing captions
- Resolution distribution histogram
- Average/min/max dimensions
- File format breakdown
- Caption length statistics
"""
import os
import sys
import argparse
from pathlib import Path
from collections import defaultdict, Counter
from PIL import Image
from rich.console import Console
from rich.table import Table
from rich.panel import Panel
from rich.progress import Progress, SpinnerColumn, TextColumn
console = Console()
# Supported image formats
IMAGE_EXTENSIONS = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff', '.tif'}
def scan_files(directory: Path) -> tuple:
"""Scan directory for image and caption files."""
images = []
captions = []
for f in directory.iterdir():
if f.is_file():
if f.suffix.lower() in IMAGE_EXTENSIONS:
images.append(f)
elif f.suffix.lower() == '.txt':
captions.append(f)
return sorted(images), sorted(captions)
def get_image_info(image_path: Path) -> dict:
"""Get information about an image."""
try:
with Image.open(image_path) as img:
return {
'path': image_path,
'width': img.width,
'height': img.height,
'format': img.format,
'mode': img.mode,
'size_kb': image_path.stat().st_size / 1024
}
except Exception as e:
return {
'path': image_path,
'error': str(e)
}
def get_caption_info(caption_path: Path) -> dict:
"""Get information about a caption file."""
try:
content = caption_path.read_text(encoding='utf-8').strip()
words = content.split()
return {
'path': caption_path,
'length': len(content),
'word_count': len(words),
'content': content[:200] + '...' if len(content) > 200 else content
}
except Exception as e:
return {
'path': caption_path,
'error': str(e)
}
def create_histogram(values: list, bins: int = 10, width: int = 40) -> str:
"""Create a simple text histogram."""
if not values:
return "No data"
min_val = min(values)
max_val = max(values)
if min_val == max_val:
return f"All values: {min_val}"
bin_size = (max_val - min_val) / bins
histogram = defaultdict(int)
for v in values:
bin_idx = min(int((v - min_val) / bin_size), bins - 1)
histogram[bin_idx] += 1
max_count = max(histogram.values()) if histogram else 1
lines = []
for i in range(bins):
bin_start = min_val + i * bin_size
bin_end = bin_start + bin_size
count = histogram[i]
bar_len = int((count / max_count) * width)
bar = 'β–ˆ' * bar_len
lines.append(f"{bin_start:6.0f}-{bin_end:6.0f} | {bar} ({count})")
return '\n'.join(lines)
def print_stats(images: list, captions: list, image_infos: list, caption_infos: list, directory: Path):
"""Print comprehensive dataset statistics."""
console.print(Panel.fit(
f"[bold blue]Dataset Statistics[/bold blue]\n[dim]{directory}[/dim]",
border_style="blue"
))
# Basic counts
console.print("\n[bold cyan]═══ File Counts ═══[/bold cyan]")
counts_table = Table(show_header=False)
counts_table.add_column("Metric", style="cyan")
counts_table.add_column("Value", style="green")
counts_table.add_row("Total Images", str(len(images)))
counts_table.add_row("Total Caption Files", str(len(captions)))
# Check for matching pairs
image_stems = {img.stem for img in images}
caption_stems = {cap.stem for cap in captions}
matched = image_stems & caption_stems
images_without_captions = image_stems - caption_stems
captions_without_images = caption_stems - image_stems
counts_table.add_row("Matched Image-Caption Pairs", str(len(matched)))
counts_table.add_row("Images Missing Captions", str(len(images_without_captions)))
counts_table.add_row("Orphan Caption Files", str(len(captions_without_images)))
console.print(counts_table)
# Missing captions
if images_without_captions:
console.print("\n[bold yellow]Images Missing Captions:[/bold yellow]")
for stem in sorted(images_without_captions)[:15]:
console.print(f" β€’ {stem}")
if len(images_without_captions) > 15:
console.print(f" ... and {len(images_without_captions) - 15} more")
# Image statistics
valid_infos = [i for i in image_infos if 'error' not in i]
if valid_infos:
console.print("\n[bold cyan]═══ Image Dimensions ═══[/bold cyan]")
widths = [i['width'] for i in valid_infos]
heights = [i['height'] for i in valid_infos]
sizes = [i['size_kb'] for i in valid_infos]
dim_table = Table()
dim_table.add_column("Metric", style="cyan")
dim_table.add_column("Width", style="green")
dim_table.add_column("Height", style="green")
dim_table.add_row("Minimum", str(min(widths)), str(min(heights)))
dim_table.add_row("Maximum", str(max(widths)), str(max(heights)))
dim_table.add_row("Average", f"{sum(widths)/len(widths):.0f}", f"{sum(heights)/len(heights):.0f}")
console.print(dim_table)
# Resolution distribution
console.print("\n[bold]Resolution Distribution:[/bold]")
resolutions = Counter(f"{i['width']}x{i['height']}" for i in valid_infos)
res_table = Table()
res_table.add_column("Resolution", style="cyan")
res_table.add_column("Count", style="green")
res_table.add_column("Percentage", style="yellow")
res_table.add_column("", style="dim")
for res, count in resolutions.most_common(10):
pct = (count / len(valid_infos)) * 100
bar = 'β–ˆ' * int(pct / 2)
res_table.add_row(res, str(count), f"{pct:.1f}%", bar)
if len(resolutions) > 10:
res_table.add_row("...", f"+{len(resolutions) - 10} more", "", "")
console.print(res_table)
# File format breakdown
console.print("\n[bold]File Format Breakdown:[/bold]")
formats = Counter(i['format'] for i in valid_infos)
fmt_table = Table()
fmt_table.add_column("Format", style="cyan")
fmt_table.add_column("Count", style="green")
fmt_table.add_column("Percentage", style="yellow")
for fmt, count in formats.most_common():
pct = (count / len(valid_infos)) * 100
fmt_table.add_row(fmt or "Unknown", str(count), f"{pct:.1f}%")
console.print(fmt_table)
# File size statistics
console.print("\n[bold]File Size Statistics:[/bold]")
size_table = Table(show_header=False)
size_table.add_column("Metric", style="cyan")
size_table.add_column("Value", style="green")
size_table.add_row("Minimum Size", f"{min(sizes):.1f} KB")
size_table.add_row("Maximum Size", f"{max(sizes):.1f} KB")
size_table.add_row("Average Size", f"{sum(sizes)/len(sizes):.1f} KB")
size_table.add_row("Total Size", f"{sum(sizes)/1024:.1f} MB")
console.print(size_table)
# Caption statistics
valid_captions = [c for c in caption_infos if 'error' not in c]
if valid_captions:
console.print("\n[bold cyan]═══ Caption Statistics ═══[/bold cyan]")
lengths = [c['length'] for c in valid_captions]
word_counts = [c['word_count'] for c in valid_captions]
cap_table = Table()
cap_table.add_column("Metric", style="cyan")
cap_table.add_column("Characters", style="green")
cap_table.add_column("Words", style="green")
cap_table.add_row("Minimum", str(min(lengths)), str(min(word_counts)))
cap_table.add_row("Maximum", str(max(lengths)), str(max(word_counts)))
cap_table.add_row("Average", f"{sum(lengths)/len(lengths):.0f}", f"{sum(word_counts)/len(word_counts):.1f}")
console.print(cap_table)
# Caption length histogram
console.print("\n[bold]Caption Length Distribution (characters):[/bold]")
console.print(create_histogram(lengths, bins=8, width=30))
# Sample captions
console.print("\n[bold]Sample Captions:[/bold]")
for cap in valid_captions[:3]:
console.print(f"\n [dim]{cap['path'].stem}:[/dim]")
console.print(f" {cap['content']}")
# Recommendations
console.print("\n[bold cyan]═══ Recommendations ═══[/bold cyan]")
recommendations = []
if images_without_captions:
recommendations.append(f"⚠ Add captions for {len(images_without_captions)} image(s)")
if valid_infos:
small = sum(1 for i in valid_infos if i['width'] < 512 or i['height'] < 512)
if small > 0:
recommendations.append(f"⚠ {small} image(s) are smaller than 512px")
large = sum(1 for i in valid_infos if i['width'] > 2048 or i['height'] > 2048)
if large > 0:
recommendations.append(f"β„Ή {large} image(s) are larger than 2048px (will be resized)")
if valid_captions:
short = sum(1 for c in valid_captions if c['word_count'] < 5)
if short > 0:
recommendations.append(f"⚠ {short} caption(s) are very short (<5 words)")
if not recommendations:
recommendations.append("βœ“ Dataset looks good!")
for rec in recommendations:
console.print(f" {rec}")
def main():
parser = argparse.ArgumentParser(description="Generate dataset statistics for Flux LoRA training")
parser.add_argument(
"directory",
nargs="?",
default="/workspace/flux-project/datasets/identity/images",
help="Directory containing training images and captions"
)
parser.add_argument(
"-o", "--output",
help="Save statistics to file"
)
args = parser.parse_args()
dataset_dir = Path(args.directory)
if not dataset_dir.exists():
console.print(f"[red]Error: Directory not found: {dataset_dir}[/red]")
sys.exit(1)
console.print(f"[bold]Scanning: {dataset_dir}[/bold]\n")
# Scan files
images, captions = scan_files(dataset_dir)
if not images and not captions:
console.print("[yellow]No images or captions found in directory.[/yellow]")
sys.exit(0)
# Gather image info
image_infos = []
if images:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
task = progress.add_task(f"Analyzing {len(images)} images...", total=len(images))
for img in images:
image_infos.append(get_image_info(img))
progress.advance(task)
# Gather caption info
caption_infos = []
if captions:
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console
) as progress:
task = progress.add_task(f"Analyzing {len(captions)} captions...", total=len(captions))
for cap in captions:
caption_infos.append(get_caption_info(cap))
progress.advance(task)
# Print statistics
print_stats(images, captions, image_infos, caption_infos, dataset_dir)
# Save to file if requested
if args.output:
with open(args.output, 'w') as f:
f.write(f"Dataset Statistics: {dataset_dir}\n")
f.write("=" * 50 + "\n\n")
f.write(f"Images: {len(images)}\n")
f.write(f"Captions: {len(captions)}\n")
# Add more details as needed
console.print(f"\n[dim]Statistics saved to: {args.output}[/dim]")
if __name__ == "__main__":
main()