""" Dataset validation and quality checking utilities. Features: - Data integrity checks - Quality metrics computation - Statistical analysis - Outlier detection - Dataset health reporting """ import json import logging from pathlib import Path from typing import Any, Dict, List import numpy as np import torch logger = logging.getLogger(__name__) class DatasetValidator: """ Comprehensive dataset validation and quality checking. """ def __init__(self, strict: bool = False): """ Initialize dataset validator. Args: strict: If True, fail validation on any critical issues """ self.strict = strict self.issues: List[Dict[str, Any]] = [] self.stats: Dict[str, Any] = {} def validate_dataset( self, samples: List[Dict], check_images: bool = True, check_poses: bool = True, check_metadata: bool = True, ) -> Dict[str, Any]: """ Validate entire dataset. Args: samples: List of training samples check_images: Validate image data check_poses: Validate pose data check_metadata: Validate metadata Returns: Validation report dictionary """ logger.info(f"Validating dataset with {len(samples)} samples...") self.issues = [] self.stats = { "total_samples": len(samples), "valid_samples": 0, "invalid_samples": 0, "warnings": 0, "errors": 0, } # Validate each sample for i, sample in enumerate(samples): sample_issues = self._validate_sample( sample, idx=i, check_images=check_images, check_poses=check_poses, check_metadata=check_metadata, ) if sample_issues: self.issues.extend(sample_issues) self.stats["invalid_samples"] += 1 self.stats["errors"] += sum( 1 for issue in sample_issues if issue["severity"] == "error" ) self.stats["warnings"] += sum( 1 for issue in sample_issues if issue["severity"] == "warning" ) else: self.stats["valid_samples"] += 1 # Compute overall statistics self._compute_statistics(samples) # Generate report report = self._generate_report() if self.strict and self.stats["errors"] > 0: raise ValueError(f"Dataset validation failed with {self.stats['errors']} errors") return report def _validate_sample( self, sample: Dict, idx: int, check_images: bool = True, check_poses: bool = True, check_metadata: bool = True, ) -> List[Dict[str, Any]]: """Validate a single sample.""" issues = [] # Check required fields required_fields = ["images", "poses_target"] for field in required_fields: if field not in sample: issues.append( { "sample_idx": idx, "field": field, "severity": "error", "message": f"Missing required field: {field}", } ) if issues: return issues # Skip further checks if missing required fields # Validate images if check_images: img_issues = self._validate_images(sample["images"], idx) issues.extend(img_issues) # Validate poses if check_poses: pose_issues = self._validate_poses(sample.get("poses_target"), idx) issues.extend(pose_issues) # Validate metadata if check_metadata: meta_issues = self._validate_metadata(sample, idx) issues.extend(meta_issues) return issues def _validate_images(self, images: Any, idx: int) -> List[Dict[str, Any]]: """Validate image data.""" issues = [] if images is None: issues.append( { "sample_idx": idx, "field": "images", "severity": "error", "message": "Images is None", } ) return issues # Handle different image formats if isinstance(images, (list, tuple)): if len(images) == 0: issues.append( { "sample_idx": idx, "field": "images", "severity": "error", "message": "Empty image list", } ) return issues # Check first image img = images[0] if isinstance(img, (str, Path)): # Path to image file img_path = Path(img) if not img_path.exists(): issues.append( { "sample_idx": idx, "field": "images", "severity": "error", "message": f"Image file not found: {img_path}", } ) elif isinstance(img, np.ndarray): # Numpy array if img.ndim != 3 or img.shape[2] != 3: issues.append( { "sample_idx": idx, "field": "images", "severity": "error", "message": f"Invalid image shape: {img.shape}, expected (H, W, 3)", } ) if img.dtype != np.uint8: issues.append( { "sample_idx": idx, "field": "images", "severity": "warning", "message": f"Image dtype is {img.dtype}, expected uint8", } ) elif isinstance(images, torch.Tensor): # Tensor format if images.ndim != 4 or images.shape[1] != 3: issues.append( { "sample_idx": idx, "field": "images", "severity": "error", "message": f"Invalid tensor shape: {images.shape}, expected (N, 3, H, W)", } ) else: issues.append( { "sample_idx": idx, "field": "images", "severity": "error", "message": f"Unknown image type: {type(images)}", } ) return issues def _validate_poses(self, poses: Any, idx: int) -> List[Dict[str, Any]]: """Validate pose data.""" issues = [] if poses is None: issues.append( { "sample_idx": idx, "field": "poses_target", "severity": "error", "message": "Poses is None", } ) return issues # Convert to numpy if tensor if isinstance(poses, torch.Tensor): poses = poses.cpu().numpy() if not isinstance(poses, np.ndarray): issues.append( { "sample_idx": idx, "field": "poses_target", "severity": "error", "message": f"Poses must be numpy array or tensor, got {type(poses)}", } ) return issues # Check shape if poses.ndim != 3 or poses.shape[1] not in [3, 4] or poses.shape[2] not in [3, 4]: issues.append( { "sample_idx": idx, "field": "poses_target", "severity": "error", "message": ( f"Invalid pose shape: {poses.shape}, " f"expected (N, 3, 4) or (N, 4, 4)" ), } ) return issues # Check for NaN or Inf if np.any(np.isnan(poses)) or np.any(np.isinf(poses)): issues.append( { "sample_idx": idx, "field": "poses_target", "severity": "error", "message": "Poses contain NaN or Inf values", } ) # Check rotation matrix validity (if 4x4) if poses.shape[1] == 4 and poses.shape[2] == 4: for i, pose in enumerate(poses): rot = pose[:3, :3] det = np.linalg.det(rot) if not np.isclose(det, 1.0, atol=1e-3): issues.append( { "sample_idx": idx, "field": "poses_target", "severity": "warning", "message": ( f"Pose {i} rotation matrix determinant is {det:.6f}, " f"expected ~1.0" ), } ) return issues def _validate_metadata(self, sample: Dict, idx: int) -> List[Dict[str, Any]]: """Validate metadata fields.""" issues = [] # Check weight if present if "weight" in sample: weight = sample["weight"] if isinstance(weight, (torch.Tensor, np.ndarray)): weight = float(weight) if not isinstance(weight, (int, float)) or weight < 0: issues.append( { "sample_idx": idx, "field": "weight", "severity": "warning", "message": f"Invalid weight value: {weight}", } ) # Check error if present if "error" in sample: error = sample["error"] if isinstance(error, (torch.Tensor, np.ndarray)): error = float(error) if not isinstance(error, (int, float)) or error < 0: issues.append( { "sample_idx": idx, "field": "error", "severity": "warning", "message": f"Invalid error value: {error}", } ) # Check sequence_id if present if "sequence_id" in sample and sample["sequence_id"] is None: issues.append( { "sample_idx": idx, "field": "sequence_id", "severity": "warning", "message": "sequence_id is None", } ) return issues def _compute_statistics(self, samples: List[Dict]): """Compute dataset statistics.""" if not samples: return # Image statistics num_images = [] image_shapes = [] for sample in samples: images = sample.get("images") if images is not None: if isinstance(images, (list, tuple)): num_images.append(len(images)) if images and isinstance(images[0], np.ndarray): image_shapes.append(images[0].shape[:2]) elif isinstance(images, torch.Tensor): num_images.append(images.shape[0]) image_shapes.append(images.shape[2:]) # Pose statistics pose_errors = [] weights = [] for sample in samples: if "error" in sample: error = sample["error"] if isinstance(error, (torch.Tensor, np.ndarray)): error = float(error) pose_errors.append(error) if "weight" in sample: weight = sample["weight"] if isinstance(weight, (torch.Tensor, np.ndarray)): weight = float(weight) weights.append(weight) self.stats.update( { "num_images": { "mean": float(np.mean(num_images)) if num_images else 0, "min": int(np.min(num_images)) if num_images else 0, "max": int(np.max(num_images)) if num_images else 0, "std": float(np.std(num_images)) if num_images else 0, }, "image_shapes": list(set(image_shapes)) if image_shapes else [], "pose_errors": ( { "mean": float(np.mean(pose_errors)) if pose_errors else 0, "median": float(np.median(pose_errors)) if pose_errors else 0, "min": float(np.min(pose_errors)) if pose_errors else 0, "max": float(np.max(pose_errors)) if pose_errors else 0, "std": float(np.std(pose_errors)) if pose_errors else 0, "q25": float(np.percentile(pose_errors, 25)) if pose_errors else 0, "q75": float(np.percentile(pose_errors, 75)) if pose_errors else 0, } if pose_errors else {} ), "weights": ( { "mean": float(np.mean(weights)) if weights else 1.0, "min": float(np.min(weights)) if weights else 1.0, "max": float(np.max(weights)) if weights else 1.0, "std": float(np.std(weights)) if weights else 1.0, } if weights else {} ), } ) def _generate_report(self) -> Dict[str, Any]: """Generate validation report.""" return { "validation_passed": self.stats["errors"] == 0, "statistics": self.stats, "issues": self.issues, "summary": { "total_samples": self.stats["total_samples"], "valid_samples": self.stats["valid_samples"], "invalid_samples": self.stats["invalid_samples"], "error_count": self.stats["errors"], "warning_count": self.stats["warnings"], "validity_rate": self.stats["valid_samples"] / max(self.stats["total_samples"], 1), }, } def validate_dataset_file(dataset_path: Path, strict: bool = False) -> Dict[str, Any]: """ Validate a saved dataset file. Args: dataset_path: Path to dataset file (pickle, json, or hdf5) strict: If True, raise exception on validation failure Returns: Validation report """ logger.info(f"Validating dataset file: {dataset_path}") if not dataset_path.exists(): raise FileNotFoundError(f"Dataset file not found: {dataset_path}") # Load dataset based on extension if dataset_path.suffix == ".pkl" or dataset_path.suffix == ".pickle": import pickle with open(dataset_path, "rb") as f: samples = pickle.load(f) elif dataset_path.suffix == ".json": with open(dataset_path) as f: data = json.load(f) samples = data.get("samples", data) # Handle both formats elif dataset_path.suffix in [".h5", ".hdf5"]: import h5py with h5py.File(dataset_path, "r") as f: # Load from HDF5 format samples = [] num_samples = f["images"].shape[0] for i in range(num_samples): sample = {"images": f["images"][i]} if "poses" in f: sample["poses_target"] = f["poses"][i] if "weights" in f: sample["weight"] = float(f["weights"][i]) samples.append(sample) else: raise ValueError(f"Unsupported dataset format: {dataset_path.suffix}") # Validate validator = DatasetValidator(strict=strict) return validator.validate_dataset(samples) def check_dataset_integrity( dataset_dir: Path, check_files: bool = True, check_consistency: bool = True, ) -> Dict[str, Any]: """ Check dataset directory integrity. Args: dataset_dir: Directory containing training samples check_files: Check if all required files exist check_consistency: Check consistency between samples Returns: Integrity check report """ logger.info(f"Checking dataset integrity: {dataset_dir}") issues = [] stats = { "total_samples": 0, "valid_samples": 0, "missing_files": 0, "inconsistent_samples": 0, } # Find all sample directories sample_dirs = [d for d in dataset_dir.iterdir() if d.is_dir()] for sample_dir in sample_dirs: stats["total_samples"] += 1 # Check required files if check_files: required_files = ["ba_poses.npy"] image_files = list(sample_dir.glob("*.jpg")) + list(sample_dir.glob("*.png")) if not image_files: issues.append( { "sample": str(sample_dir), "severity": "error", "message": "No images found", } ) stats["missing_files"] += 1 continue for req_file in required_files: if not (sample_dir / req_file).exists(): issues.append( { "sample": str(sample_dir), "severity": "error", "message": f"Missing required file: {req_file}", } ) stats["missing_files"] += 1 # Check consistency if check_consistency: try: poses = np.load(sample_dir / "ba_poses.npy") num_poses = poses.shape[0] num_images = len(image_files) if num_poses != num_images: issues.append( { "sample": str(sample_dir), "severity": "warning", "message": f"Pose count ({num_poses}) != image count ({num_images})", } ) stats["inconsistent_samples"] += 1 except Exception as e: issues.append( { "sample": str(sample_dir), "severity": "error", "message": f"Failed to check consistency: {e}", } ) if not issues or all(issue["sample"] != str(sample_dir) for issue in issues): stats["valid_samples"] += 1 return { "integrity_passed": stats["missing_files"] == 0, "statistics": stats, "issues": issues, "summary": { "total_samples": stats["total_samples"], "valid_samples": stats["valid_samples"], "missing_files": stats["missing_files"], "inconsistent_samples": stats["inconsistent_samples"], }, }