| |
| """Pre-process images for SPAI: normalize dimensions and quality assessment.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import csv |
| import json |
| from pathlib import Path |
| from typing import Any |
|
|
| from PIL import Image |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser( |
| description="Pre-process images for SPAI model input" |
| ) |
| parser.add_argument( |
| "--input-dir", |
| type=Path, |
| required=True, |
| help="Input directory with images (organized by category/subcategory)", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=Path, |
| required=True, |
| help="Output directory for normalized images", |
| ) |
| parser.add_argument( |
| "--metadata-csv", |
| type=Path, |
| required=True, |
| help="Input metadata CSV (from quality_filtered)", |
| ) |
| parser.add_argument( |
| "--target-size", |
| type=int, |
| default=224, |
| help="Target dimension for square resize (SPAI default: 224)", |
| ) |
| parser.add_argument( |
| "--min-upscale-ratio", |
| type=float, |
| default=0.5, |
| help="Minimum ratio before upscaling (0.5 = don't upscale if < 50% of target)", |
| ) |
| parser.add_argument( |
| "--filter-small", |
| action="store_true", |
| help="Exclude images smaller than target_size after analysis", |
| ) |
| parser.add_argument( |
| "--quality-loss-threshold", |
| type=float, |
| default=0.3, |
| help="Warn if quality loss > N (0-1)", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def assess_upscale_quality(original_size: tuple[int, int], target_size: int) -> dict[str, Any]: |
| """Assess how much quality will be lost by upscaling.""" |
| orig_w, orig_h = original_size |
| orig_mp = (orig_w * orig_h) / 1e6 |
| target_mp = (target_size * target_size) / 1e6 |
| |
| |
| smaller_dim = min(orig_w, orig_h) |
| upscale_ratio = target_size / smaller_dim if smaller_dim > 0 else 1.0 |
| |
| |
| |
| quality_score = 1.0 / upscale_ratio |
| |
| |
| if upscale_ratio <= 1.0: |
| method = "LANCZOS" |
| quality_concern = False |
| elif upscale_ratio <= 1.5: |
| method = "LANCZOS" |
| quality_concern = False |
| elif upscale_ratio <= 2.0: |
| method = "BICUBIC" |
| quality_concern = False |
| elif upscale_ratio <= 3.0: |
| method = "BICUBIC" |
| quality_concern = True |
| else: |
| method = "BILINEAR" |
| quality_concern = True |
| |
| return { |
| "original_size": f"{orig_w}x{orig_h}", |
| "original_mp": round(orig_mp, 3), |
| "upscale_ratio": round(upscale_ratio, 2), |
| "quality_score": round(quality_score, 2), |
| "interpolation": method, |
| "quality_concern": quality_concern, |
| } |
|
|
|
|
| def normalize_image( |
| input_path: Path, |
| output_path: Path, |
| target_size: int = 224, |
| ) -> dict[str, Any]: |
| """Normalize image to target size. Returns assessment + status.""" |
| |
| try: |
| with Image.open(input_path) as img: |
| orig_size = img.size |
| orig_format = img.format |
| |
| |
| assessment = assess_upscale_quality(orig_size, target_size) |
| |
| |
| |
| img_resized = img.copy() |
| img_resized.thumbnail((target_size, target_size), Image.Resampling.LANCZOS) |
| |
| |
| img_padded = Image.new("RGB", (target_size, target_size), (255, 255, 255)) |
| |
| |
| paste_x = (target_size - img_resized.width) // 2 |
| paste_y = (target_size - img_resized.height) // 2 |
| |
| if img_resized.mode == "RGBA": |
| img_padded.paste(img_resized, (paste_x, paste_y), img_resized) |
| else: |
| img_padded.paste(img_resized, (paste_x, paste_y)) |
| |
| |
| output_path.parent.mkdir(parents=True, exist_ok=True) |
| img_padded.save(output_path, "JPEG", quality=95) |
| |
| assessment["status"] = "success" |
| assessment["final_size"] = "224x224" |
| assessment["output_path"] = str(output_path) |
| |
| return assessment |
| |
| except Exception as e: |
| return { |
| "status": "error", |
| "error": str(type(e).__name__), |
| "message": str(e), |
| } |
|
|
|
|
| def process_images( |
| input_dir: Path, |
| output_dir: Path, |
| metadata_csv: Path, |
| target_size: int = 224, |
| min_upscale_ratio: float = 0.5, |
| filter_small: bool = False, |
| quality_loss_threshold: float = 0.3, |
| ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]: |
| """Process all images. Returns (success, failed/filtered).""" |
| |
| |
| with metadata_csv.open("r", encoding="utf-8", newline="") as f: |
| rows = list(csv.DictReader(f)) |
| |
| success_rows = [] |
| failed_rows = [] |
| |
| for row in rows: |
| |
| image_path = Path(row.get("quality_filtered_path", "")) |
| if not image_path.is_absolute(): |
| image_path = input_dir.parent / image_path |
| |
| if not image_path.exists(): |
| |
| alt_path = input_dir / Path(row.get("quality_filtered_path", "")).name |
| if alt_path.exists(): |
| image_path = alt_path |
| else: |
| failed_rows.append({**row, "reason": f"File not found: {image_path}"}) |
| continue |
| |
| |
| assessment = normalize_image(image_path, output_dir / image_path.name, target_size) |
| |
| if assessment["status"] != "success": |
| failed_rows.append({**row, "reason": assessment.get("message", "Unknown error")}) |
| continue |
| |
| |
| upscale_ratio = assessment.get("upscale_ratio", 1.0) |
| quality_score = assessment.get("quality_score", 1.0) |
| |
| if upscale_ratio > 1.0 / min_upscale_ratio: |
| if filter_small: |
| failed_rows.append({ |
| **row, |
| "reason": f"Too much upscaling required ({upscale_ratio}x) - filtered", |
| **assessment, |
| }) |
| continue |
| |
| |
| row["normalized_path"] = assessment["output_path"] |
| row["upscale_ratio"] = assessment["upscale_ratio"] |
| row["quality_score"] = assessment["quality_score"] |
| row["interpolation"] = assessment["interpolation"] |
| row["quality_concern"] = assessment["quality_concern"] |
| |
| success_rows.append(row) |
| |
| return success_rows, failed_rows |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| |
| print(f"Pre-processing images for SPAI model") |
| print(f"Target size: {args.target_size}x{args.target_size}") |
| print(f"Min upscale ratio: {args.min_upscale_ratio}") |
| print(f"Filter small images: {args.filter_small}") |
| print() |
| |
| success_rows, failed_rows = process_images( |
| input_dir=args.input_dir, |
| output_dir=args.output_dir, |
| metadata_csv=args.metadata_csv, |
| target_size=args.target_size, |
| min_upscale_ratio=args.min_upscale_ratio, |
| filter_small=args.filter_small, |
| quality_loss_threshold=args.quality_loss_threshold, |
| ) |
| |
| |
| output_metadata = args.output_dir / "normalized_metadata.csv" |
| output_metadata.parent.mkdir(parents=True, exist_ok=True) |
| |
| if success_rows: |
| fieldnames = sorted({key for row in success_rows for key in row.keys()}) |
| with output_metadata.open("w", encoding="utf-8", newline="") as f: |
| writer = csv.DictWriter(f, fieldnames=fieldnames) |
| writer.writeheader() |
| writer.writerows(success_rows) |
| |
| |
| print("=" * 70) |
| print("RESULTS") |
| print("=" * 70) |
| print(f"✓ Processed: {len(success_rows)}") |
| print(f"✗ Failed/Filtered: {len(failed_rows)}") |
| |
| |
| if success_rows: |
| upscales = [float(r.get("upscale_ratio", 1)) for r in success_rows if "upscale_ratio" in r] |
| quality_scores = [float(r.get("quality_score", 1)) for r in success_rows] |
| have_concerns = sum(1 for r in success_rows if r.get("quality_concern")) |
| |
| print(f"\nQuality Metrics:") |
| print(f" Avg upscale ratio: {sum(upscales)/len(upscales):.2f}x") |
| print(f" Avg quality score: {sum(quality_scores)/len(quality_scores):.2f}/1.0") |
| print(f" Images with concerns: {have_concerns} ({100*have_concerns/len(success_rows):.1f}%)") |
| |
| print(f"\nNormalized images: {args.output_dir}") |
| print(f"Updated metadata: {output_metadata}") |
| |
| if failed_rows: |
| print(f"\nFailed/Filtered reasons:") |
| reasons = {} |
| for row in failed_rows: |
| reason = row.get("reason", "unknown") |
| reasons[reason] = reasons.get(reason, 0) + 1 |
| |
| for reason, count in sorted(reasons.items(), key=lambda x: x[1], reverse=True): |
| print(f" {reason}: {count}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|