3d_model / ylff /utils /dataset_validation.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Dataset validation and quality checking utilities.
Features:
- Data integrity checks
- Quality metrics computation
- Statistical analysis
- Outlier detection
- Dataset health reporting
"""
import json
import logging
from pathlib import Path
from typing import Any, Dict, List
import numpy as np
import torch
logger = logging.getLogger(__name__)
class DatasetValidator:
"""
Comprehensive dataset validation and quality checking.
"""
def __init__(self, strict: bool = False):
"""
Initialize dataset validator.
Args:
strict: If True, fail validation on any critical issues
"""
self.strict = strict
self.issues: List[Dict[str, Any]] = []
self.stats: Dict[str, Any] = {}
def validate_dataset(
self,
samples: List[Dict],
check_images: bool = True,
check_poses: bool = True,
check_metadata: bool = True,
) -> Dict[str, Any]:
"""
Validate entire dataset.
Args:
samples: List of training samples
check_images: Validate image data
check_poses: Validate pose data
check_metadata: Validate metadata
Returns:
Validation report dictionary
"""
logger.info(f"Validating dataset with {len(samples)} samples...")
self.issues = []
self.stats = {
"total_samples": len(samples),
"valid_samples": 0,
"invalid_samples": 0,
"warnings": 0,
"errors": 0,
}
# Validate each sample
for i, sample in enumerate(samples):
sample_issues = self._validate_sample(
sample,
idx=i,
check_images=check_images,
check_poses=check_poses,
check_metadata=check_metadata,
)
if sample_issues:
self.issues.extend(sample_issues)
self.stats["invalid_samples"] += 1
self.stats["errors"] += sum(
1 for issue in sample_issues if issue["severity"] == "error"
)
self.stats["warnings"] += sum(
1 for issue in sample_issues if issue["severity"] == "warning"
)
else:
self.stats["valid_samples"] += 1
# Compute overall statistics
self._compute_statistics(samples)
# Generate report
report = self._generate_report()
if self.strict and self.stats["errors"] > 0:
raise ValueError(f"Dataset validation failed with {self.stats['errors']} errors")
return report
def _validate_sample(
self,
sample: Dict,
idx: int,
check_images: bool = True,
check_poses: bool = True,
check_metadata: bool = True,
) -> List[Dict[str, Any]]:
"""Validate a single sample."""
issues = []
# Check required fields
required_fields = ["images", "poses_target"]
for field in required_fields:
if field not in sample:
issues.append(
{
"sample_idx": idx,
"field": field,
"severity": "error",
"message": f"Missing required field: {field}",
}
)
if issues:
return issues # Skip further checks if missing required fields
# Validate images
if check_images:
img_issues = self._validate_images(sample["images"], idx)
issues.extend(img_issues)
# Validate poses
if check_poses:
pose_issues = self._validate_poses(sample.get("poses_target"), idx)
issues.extend(pose_issues)
# Validate metadata
if check_metadata:
meta_issues = self._validate_metadata(sample, idx)
issues.extend(meta_issues)
return issues
def _validate_images(self, images: Any, idx: int) -> List[Dict[str, Any]]:
"""Validate image data."""
issues = []
if images is None:
issues.append(
{
"sample_idx": idx,
"field": "images",
"severity": "error",
"message": "Images is None",
}
)
return issues
# Handle different image formats
if isinstance(images, (list, tuple)):
if len(images) == 0:
issues.append(
{
"sample_idx": idx,
"field": "images",
"severity": "error",
"message": "Empty image list",
}
)
return issues
# Check first image
img = images[0]
if isinstance(img, (str, Path)):
# Path to image file
img_path = Path(img)
if not img_path.exists():
issues.append(
{
"sample_idx": idx,
"field": "images",
"severity": "error",
"message": f"Image file not found: {img_path}",
}
)
elif isinstance(img, np.ndarray):
# Numpy array
if img.ndim != 3 or img.shape[2] != 3:
issues.append(
{
"sample_idx": idx,
"field": "images",
"severity": "error",
"message": f"Invalid image shape: {img.shape}, expected (H, W, 3)",
}
)
if img.dtype != np.uint8:
issues.append(
{
"sample_idx": idx,
"field": "images",
"severity": "warning",
"message": f"Image dtype is {img.dtype}, expected uint8",
}
)
elif isinstance(images, torch.Tensor):
# Tensor format
if images.ndim != 4 or images.shape[1] != 3:
issues.append(
{
"sample_idx": idx,
"field": "images",
"severity": "error",
"message": f"Invalid tensor shape: {images.shape}, expected (N, 3, H, W)",
}
)
else:
issues.append(
{
"sample_idx": idx,
"field": "images",
"severity": "error",
"message": f"Unknown image type: {type(images)}",
}
)
return issues
def _validate_poses(self, poses: Any, idx: int) -> List[Dict[str, Any]]:
"""Validate pose data."""
issues = []
if poses is None:
issues.append(
{
"sample_idx": idx,
"field": "poses_target",
"severity": "error",
"message": "Poses is None",
}
)
return issues
# Convert to numpy if tensor
if isinstance(poses, torch.Tensor):
poses = poses.cpu().numpy()
if not isinstance(poses, np.ndarray):
issues.append(
{
"sample_idx": idx,
"field": "poses_target",
"severity": "error",
"message": f"Poses must be numpy array or tensor, got {type(poses)}",
}
)
return issues
# Check shape
if poses.ndim != 3 or poses.shape[1] not in [3, 4] or poses.shape[2] not in [3, 4]:
issues.append(
{
"sample_idx": idx,
"field": "poses_target",
"severity": "error",
"message": (
f"Invalid pose shape: {poses.shape}, " f"expected (N, 3, 4) or (N, 4, 4)"
),
}
)
return issues
# Check for NaN or Inf
if np.any(np.isnan(poses)) or np.any(np.isinf(poses)):
issues.append(
{
"sample_idx": idx,
"field": "poses_target",
"severity": "error",
"message": "Poses contain NaN or Inf values",
}
)
# Check rotation matrix validity (if 4x4)
if poses.shape[1] == 4 and poses.shape[2] == 4:
for i, pose in enumerate(poses):
rot = pose[:3, :3]
det = np.linalg.det(rot)
if not np.isclose(det, 1.0, atol=1e-3):
issues.append(
{
"sample_idx": idx,
"field": "poses_target",
"severity": "warning",
"message": (
f"Pose {i} rotation matrix determinant is {det:.6f}, "
f"expected ~1.0"
),
}
)
return issues
def _validate_metadata(self, sample: Dict, idx: int) -> List[Dict[str, Any]]:
"""Validate metadata fields."""
issues = []
# Check weight if present
if "weight" in sample:
weight = sample["weight"]
if isinstance(weight, (torch.Tensor, np.ndarray)):
weight = float(weight)
if not isinstance(weight, (int, float)) or weight < 0:
issues.append(
{
"sample_idx": idx,
"field": "weight",
"severity": "warning",
"message": f"Invalid weight value: {weight}",
}
)
# Check error if present
if "error" in sample:
error = sample["error"]
if isinstance(error, (torch.Tensor, np.ndarray)):
error = float(error)
if not isinstance(error, (int, float)) or error < 0:
issues.append(
{
"sample_idx": idx,
"field": "error",
"severity": "warning",
"message": f"Invalid error value: {error}",
}
)
# Check sequence_id if present
if "sequence_id" in sample and sample["sequence_id"] is None:
issues.append(
{
"sample_idx": idx,
"field": "sequence_id",
"severity": "warning",
"message": "sequence_id is None",
}
)
return issues
def _compute_statistics(self, samples: List[Dict]):
"""Compute dataset statistics."""
if not samples:
return
# Image statistics
num_images = []
image_shapes = []
for sample in samples:
images = sample.get("images")
if images is not None:
if isinstance(images, (list, tuple)):
num_images.append(len(images))
if images and isinstance(images[0], np.ndarray):
image_shapes.append(images[0].shape[:2])
elif isinstance(images, torch.Tensor):
num_images.append(images.shape[0])
image_shapes.append(images.shape[2:])
# Pose statistics
pose_errors = []
weights = []
for sample in samples:
if "error" in sample:
error = sample["error"]
if isinstance(error, (torch.Tensor, np.ndarray)):
error = float(error)
pose_errors.append(error)
if "weight" in sample:
weight = sample["weight"]
if isinstance(weight, (torch.Tensor, np.ndarray)):
weight = float(weight)
weights.append(weight)
self.stats.update(
{
"num_images": {
"mean": float(np.mean(num_images)) if num_images else 0,
"min": int(np.min(num_images)) if num_images else 0,
"max": int(np.max(num_images)) if num_images else 0,
"std": float(np.std(num_images)) if num_images else 0,
},
"image_shapes": list(set(image_shapes)) if image_shapes else [],
"pose_errors": (
{
"mean": float(np.mean(pose_errors)) if pose_errors else 0,
"median": float(np.median(pose_errors)) if pose_errors else 0,
"min": float(np.min(pose_errors)) if pose_errors else 0,
"max": float(np.max(pose_errors)) if pose_errors else 0,
"std": float(np.std(pose_errors)) if pose_errors else 0,
"q25": float(np.percentile(pose_errors, 25)) if pose_errors else 0,
"q75": float(np.percentile(pose_errors, 75)) if pose_errors else 0,
}
if pose_errors
else {}
),
"weights": (
{
"mean": float(np.mean(weights)) if weights else 1.0,
"min": float(np.min(weights)) if weights else 1.0,
"max": float(np.max(weights)) if weights else 1.0,
"std": float(np.std(weights)) if weights else 1.0,
}
if weights
else {}
),
}
)
def _generate_report(self) -> Dict[str, Any]:
"""Generate validation report."""
return {
"validation_passed": self.stats["errors"] == 0,
"statistics": self.stats,
"issues": self.issues,
"summary": {
"total_samples": self.stats["total_samples"],
"valid_samples": self.stats["valid_samples"],
"invalid_samples": self.stats["invalid_samples"],
"error_count": self.stats["errors"],
"warning_count": self.stats["warnings"],
"validity_rate": self.stats["valid_samples"] / max(self.stats["total_samples"], 1),
},
}
def validate_dataset_file(dataset_path: Path, strict: bool = False) -> Dict[str, Any]:
"""
Validate a saved dataset file.
Args:
dataset_path: Path to dataset file (pickle, json, or hdf5)
strict: If True, raise exception on validation failure
Returns:
Validation report
"""
logger.info(f"Validating dataset file: {dataset_path}")
if not dataset_path.exists():
raise FileNotFoundError(f"Dataset file not found: {dataset_path}")
# Load dataset based on extension
if dataset_path.suffix == ".pkl" or dataset_path.suffix == ".pickle":
import pickle
with open(dataset_path, "rb") as f:
samples = pickle.load(f)
elif dataset_path.suffix == ".json":
with open(dataset_path) as f:
data = json.load(f)
samples = data.get("samples", data) # Handle both formats
elif dataset_path.suffix in [".h5", ".hdf5"]:
import h5py
with h5py.File(dataset_path, "r") as f:
# Load from HDF5 format
samples = []
num_samples = f["images"].shape[0]
for i in range(num_samples):
sample = {"images": f["images"][i]}
if "poses" in f:
sample["poses_target"] = f["poses"][i]
if "weights" in f:
sample["weight"] = float(f["weights"][i])
samples.append(sample)
else:
raise ValueError(f"Unsupported dataset format: {dataset_path.suffix}")
# Validate
validator = DatasetValidator(strict=strict)
return validator.validate_dataset(samples)
def check_dataset_integrity(
dataset_dir: Path,
check_files: bool = True,
check_consistency: bool = True,
) -> Dict[str, Any]:
"""
Check dataset directory integrity.
Args:
dataset_dir: Directory containing training samples
check_files: Check if all required files exist
check_consistency: Check consistency between samples
Returns:
Integrity check report
"""
logger.info(f"Checking dataset integrity: {dataset_dir}")
issues = []
stats = {
"total_samples": 0,
"valid_samples": 0,
"missing_files": 0,
"inconsistent_samples": 0,
}
# Find all sample directories
sample_dirs = [d for d in dataset_dir.iterdir() if d.is_dir()]
for sample_dir in sample_dirs:
stats["total_samples"] += 1
# Check required files
if check_files:
required_files = ["ba_poses.npy"]
image_files = list(sample_dir.glob("*.jpg")) + list(sample_dir.glob("*.png"))
if not image_files:
issues.append(
{
"sample": str(sample_dir),
"severity": "error",
"message": "No images found",
}
)
stats["missing_files"] += 1
continue
for req_file in required_files:
if not (sample_dir / req_file).exists():
issues.append(
{
"sample": str(sample_dir),
"severity": "error",
"message": f"Missing required file: {req_file}",
}
)
stats["missing_files"] += 1
# Check consistency
if check_consistency:
try:
poses = np.load(sample_dir / "ba_poses.npy")
num_poses = poses.shape[0]
num_images = len(image_files)
if num_poses != num_images:
issues.append(
{
"sample": str(sample_dir),
"severity": "warning",
"message": f"Pose count ({num_poses}) != image count ({num_images})",
}
)
stats["inconsistent_samples"] += 1
except Exception as e:
issues.append(
{
"sample": str(sample_dir),
"severity": "error",
"message": f"Failed to check consistency: {e}",
}
)
if not issues or all(issue["sample"] != str(sample_dir) for issue in issues):
stats["valid_samples"] += 1
return {
"integrity_passed": stats["missing_files"] == 0,
"statistics": stats,
"issues": issues,
"summary": {
"total_samples": stats["total_samples"],
"valid_samples": stats["valid_samples"],
"missing_files": stats["missing_files"],
"inconsistent_samples": stats["inconsistent_samples"],
},
}