|
|
""" |
|
|
Dataset validation and quality checking utilities. |
|
|
|
|
|
Features: |
|
|
- Data integrity checks |
|
|
- Quality metrics computation |
|
|
- Statistical analysis |
|
|
- Outlier detection |
|
|
- Dataset health reporting |
|
|
""" |
|
|
|
|
|
import json |
|
|
import logging |
|
|
from pathlib import Path |
|
|
from typing import Any, Dict, List |
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class DatasetValidator: |
|
|
""" |
|
|
Comprehensive dataset validation and quality checking. |
|
|
""" |
|
|
|
|
|
def __init__(self, strict: bool = False): |
|
|
""" |
|
|
Initialize dataset validator. |
|
|
|
|
|
Args: |
|
|
strict: If True, fail validation on any critical issues |
|
|
""" |
|
|
self.strict = strict |
|
|
self.issues: List[Dict[str, Any]] = [] |
|
|
self.stats: Dict[str, Any] = {} |
|
|
|
|
|
def validate_dataset( |
|
|
self, |
|
|
samples: List[Dict], |
|
|
check_images: bool = True, |
|
|
check_poses: bool = True, |
|
|
check_metadata: bool = True, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Validate entire dataset. |
|
|
|
|
|
Args: |
|
|
samples: List of training samples |
|
|
check_images: Validate image data |
|
|
check_poses: Validate pose data |
|
|
check_metadata: Validate metadata |
|
|
|
|
|
Returns: |
|
|
Validation report dictionary |
|
|
""" |
|
|
logger.info(f"Validating dataset with {len(samples)} samples...") |
|
|
|
|
|
self.issues = [] |
|
|
self.stats = { |
|
|
"total_samples": len(samples), |
|
|
"valid_samples": 0, |
|
|
"invalid_samples": 0, |
|
|
"warnings": 0, |
|
|
"errors": 0, |
|
|
} |
|
|
|
|
|
|
|
|
for i, sample in enumerate(samples): |
|
|
sample_issues = self._validate_sample( |
|
|
sample, |
|
|
idx=i, |
|
|
check_images=check_images, |
|
|
check_poses=check_poses, |
|
|
check_metadata=check_metadata, |
|
|
) |
|
|
|
|
|
if sample_issues: |
|
|
self.issues.extend(sample_issues) |
|
|
self.stats["invalid_samples"] += 1 |
|
|
self.stats["errors"] += sum( |
|
|
1 for issue in sample_issues if issue["severity"] == "error" |
|
|
) |
|
|
self.stats["warnings"] += sum( |
|
|
1 for issue in sample_issues if issue["severity"] == "warning" |
|
|
) |
|
|
else: |
|
|
self.stats["valid_samples"] += 1 |
|
|
|
|
|
|
|
|
self._compute_statistics(samples) |
|
|
|
|
|
|
|
|
report = self._generate_report() |
|
|
|
|
|
if self.strict and self.stats["errors"] > 0: |
|
|
raise ValueError(f"Dataset validation failed with {self.stats['errors']} errors") |
|
|
|
|
|
return report |
|
|
|
|
|
def _validate_sample( |
|
|
self, |
|
|
sample: Dict, |
|
|
idx: int, |
|
|
check_images: bool = True, |
|
|
check_poses: bool = True, |
|
|
check_metadata: bool = True, |
|
|
) -> List[Dict[str, Any]]: |
|
|
"""Validate a single sample.""" |
|
|
issues = [] |
|
|
|
|
|
|
|
|
required_fields = ["images", "poses_target"] |
|
|
for field in required_fields: |
|
|
if field not in sample: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": field, |
|
|
"severity": "error", |
|
|
"message": f"Missing required field: {field}", |
|
|
} |
|
|
) |
|
|
|
|
|
if issues: |
|
|
return issues |
|
|
|
|
|
|
|
|
if check_images: |
|
|
img_issues = self._validate_images(sample["images"], idx) |
|
|
issues.extend(img_issues) |
|
|
|
|
|
|
|
|
if check_poses: |
|
|
pose_issues = self._validate_poses(sample.get("poses_target"), idx) |
|
|
issues.extend(pose_issues) |
|
|
|
|
|
|
|
|
if check_metadata: |
|
|
meta_issues = self._validate_metadata(sample, idx) |
|
|
issues.extend(meta_issues) |
|
|
|
|
|
return issues |
|
|
|
|
|
def _validate_images(self, images: Any, idx: int) -> List[Dict[str, Any]]: |
|
|
"""Validate image data.""" |
|
|
issues = [] |
|
|
|
|
|
if images is None: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "images", |
|
|
"severity": "error", |
|
|
"message": "Images is None", |
|
|
} |
|
|
) |
|
|
return issues |
|
|
|
|
|
|
|
|
if isinstance(images, (list, tuple)): |
|
|
if len(images) == 0: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "images", |
|
|
"severity": "error", |
|
|
"message": "Empty image list", |
|
|
} |
|
|
) |
|
|
return issues |
|
|
|
|
|
|
|
|
img = images[0] |
|
|
if isinstance(img, (str, Path)): |
|
|
|
|
|
img_path = Path(img) |
|
|
if not img_path.exists(): |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "images", |
|
|
"severity": "error", |
|
|
"message": f"Image file not found: {img_path}", |
|
|
} |
|
|
) |
|
|
elif isinstance(img, np.ndarray): |
|
|
|
|
|
if img.ndim != 3 or img.shape[2] != 3: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "images", |
|
|
"severity": "error", |
|
|
"message": f"Invalid image shape: {img.shape}, expected (H, W, 3)", |
|
|
} |
|
|
) |
|
|
if img.dtype != np.uint8: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "images", |
|
|
"severity": "warning", |
|
|
"message": f"Image dtype is {img.dtype}, expected uint8", |
|
|
} |
|
|
) |
|
|
elif isinstance(images, torch.Tensor): |
|
|
|
|
|
if images.ndim != 4 or images.shape[1] != 3: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "images", |
|
|
"severity": "error", |
|
|
"message": f"Invalid tensor shape: {images.shape}, expected (N, 3, H, W)", |
|
|
} |
|
|
) |
|
|
else: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "images", |
|
|
"severity": "error", |
|
|
"message": f"Unknown image type: {type(images)}", |
|
|
} |
|
|
) |
|
|
|
|
|
return issues |
|
|
|
|
|
def _validate_poses(self, poses: Any, idx: int) -> List[Dict[str, Any]]: |
|
|
"""Validate pose data.""" |
|
|
issues = [] |
|
|
|
|
|
if poses is None: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "poses_target", |
|
|
"severity": "error", |
|
|
"message": "Poses is None", |
|
|
} |
|
|
) |
|
|
return issues |
|
|
|
|
|
|
|
|
if isinstance(poses, torch.Tensor): |
|
|
poses = poses.cpu().numpy() |
|
|
|
|
|
if not isinstance(poses, np.ndarray): |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "poses_target", |
|
|
"severity": "error", |
|
|
"message": f"Poses must be numpy array or tensor, got {type(poses)}", |
|
|
} |
|
|
) |
|
|
return issues |
|
|
|
|
|
|
|
|
if poses.ndim != 3 or poses.shape[1] not in [3, 4] or poses.shape[2] not in [3, 4]: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "poses_target", |
|
|
"severity": "error", |
|
|
"message": ( |
|
|
f"Invalid pose shape: {poses.shape}, " f"expected (N, 3, 4) or (N, 4, 4)" |
|
|
), |
|
|
} |
|
|
) |
|
|
return issues |
|
|
|
|
|
|
|
|
if np.any(np.isnan(poses)) or np.any(np.isinf(poses)): |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "poses_target", |
|
|
"severity": "error", |
|
|
"message": "Poses contain NaN or Inf values", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if poses.shape[1] == 4 and poses.shape[2] == 4: |
|
|
for i, pose in enumerate(poses): |
|
|
rot = pose[:3, :3] |
|
|
det = np.linalg.det(rot) |
|
|
if not np.isclose(det, 1.0, atol=1e-3): |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "poses_target", |
|
|
"severity": "warning", |
|
|
"message": ( |
|
|
f"Pose {i} rotation matrix determinant is {det:.6f}, " |
|
|
f"expected ~1.0" |
|
|
), |
|
|
} |
|
|
) |
|
|
|
|
|
return issues |
|
|
|
|
|
def _validate_metadata(self, sample: Dict, idx: int) -> List[Dict[str, Any]]: |
|
|
"""Validate metadata fields.""" |
|
|
issues = [] |
|
|
|
|
|
|
|
|
if "weight" in sample: |
|
|
weight = sample["weight"] |
|
|
if isinstance(weight, (torch.Tensor, np.ndarray)): |
|
|
weight = float(weight) |
|
|
if not isinstance(weight, (int, float)) or weight < 0: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "weight", |
|
|
"severity": "warning", |
|
|
"message": f"Invalid weight value: {weight}", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if "error" in sample: |
|
|
error = sample["error"] |
|
|
if isinstance(error, (torch.Tensor, np.ndarray)): |
|
|
error = float(error) |
|
|
if not isinstance(error, (int, float)) or error < 0: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "error", |
|
|
"severity": "warning", |
|
|
"message": f"Invalid error value: {error}", |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
if "sequence_id" in sample and sample["sequence_id"] is None: |
|
|
issues.append( |
|
|
{ |
|
|
"sample_idx": idx, |
|
|
"field": "sequence_id", |
|
|
"severity": "warning", |
|
|
"message": "sequence_id is None", |
|
|
} |
|
|
) |
|
|
|
|
|
return issues |
|
|
|
|
|
def _compute_statistics(self, samples: List[Dict]): |
|
|
"""Compute dataset statistics.""" |
|
|
if not samples: |
|
|
return |
|
|
|
|
|
|
|
|
num_images = [] |
|
|
image_shapes = [] |
|
|
for sample in samples: |
|
|
images = sample.get("images") |
|
|
if images is not None: |
|
|
if isinstance(images, (list, tuple)): |
|
|
num_images.append(len(images)) |
|
|
if images and isinstance(images[0], np.ndarray): |
|
|
image_shapes.append(images[0].shape[:2]) |
|
|
elif isinstance(images, torch.Tensor): |
|
|
num_images.append(images.shape[0]) |
|
|
image_shapes.append(images.shape[2:]) |
|
|
|
|
|
|
|
|
pose_errors = [] |
|
|
weights = [] |
|
|
for sample in samples: |
|
|
if "error" in sample: |
|
|
error = sample["error"] |
|
|
if isinstance(error, (torch.Tensor, np.ndarray)): |
|
|
error = float(error) |
|
|
pose_errors.append(error) |
|
|
if "weight" in sample: |
|
|
weight = sample["weight"] |
|
|
if isinstance(weight, (torch.Tensor, np.ndarray)): |
|
|
weight = float(weight) |
|
|
weights.append(weight) |
|
|
|
|
|
self.stats.update( |
|
|
{ |
|
|
"num_images": { |
|
|
"mean": float(np.mean(num_images)) if num_images else 0, |
|
|
"min": int(np.min(num_images)) if num_images else 0, |
|
|
"max": int(np.max(num_images)) if num_images else 0, |
|
|
"std": float(np.std(num_images)) if num_images else 0, |
|
|
}, |
|
|
"image_shapes": list(set(image_shapes)) if image_shapes else [], |
|
|
"pose_errors": ( |
|
|
{ |
|
|
"mean": float(np.mean(pose_errors)) if pose_errors else 0, |
|
|
"median": float(np.median(pose_errors)) if pose_errors else 0, |
|
|
"min": float(np.min(pose_errors)) if pose_errors else 0, |
|
|
"max": float(np.max(pose_errors)) if pose_errors else 0, |
|
|
"std": float(np.std(pose_errors)) if pose_errors else 0, |
|
|
"q25": float(np.percentile(pose_errors, 25)) if pose_errors else 0, |
|
|
"q75": float(np.percentile(pose_errors, 75)) if pose_errors else 0, |
|
|
} |
|
|
if pose_errors |
|
|
else {} |
|
|
), |
|
|
"weights": ( |
|
|
{ |
|
|
"mean": float(np.mean(weights)) if weights else 1.0, |
|
|
"min": float(np.min(weights)) if weights else 1.0, |
|
|
"max": float(np.max(weights)) if weights else 1.0, |
|
|
"std": float(np.std(weights)) if weights else 1.0, |
|
|
} |
|
|
if weights |
|
|
else {} |
|
|
), |
|
|
} |
|
|
) |
|
|
|
|
|
def _generate_report(self) -> Dict[str, Any]: |
|
|
"""Generate validation report.""" |
|
|
return { |
|
|
"validation_passed": self.stats["errors"] == 0, |
|
|
"statistics": self.stats, |
|
|
"issues": self.issues, |
|
|
"summary": { |
|
|
"total_samples": self.stats["total_samples"], |
|
|
"valid_samples": self.stats["valid_samples"], |
|
|
"invalid_samples": self.stats["invalid_samples"], |
|
|
"error_count": self.stats["errors"], |
|
|
"warning_count": self.stats["warnings"], |
|
|
"validity_rate": self.stats["valid_samples"] / max(self.stats["total_samples"], 1), |
|
|
}, |
|
|
} |
|
|
|
|
|
|
|
|
def validate_dataset_file(dataset_path: Path, strict: bool = False) -> Dict[str, Any]: |
|
|
""" |
|
|
Validate a saved dataset file. |
|
|
|
|
|
Args: |
|
|
dataset_path: Path to dataset file (pickle, json, or hdf5) |
|
|
strict: If True, raise exception on validation failure |
|
|
|
|
|
Returns: |
|
|
Validation report |
|
|
""" |
|
|
logger.info(f"Validating dataset file: {dataset_path}") |
|
|
|
|
|
if not dataset_path.exists(): |
|
|
raise FileNotFoundError(f"Dataset file not found: {dataset_path}") |
|
|
|
|
|
|
|
|
if dataset_path.suffix == ".pkl" or dataset_path.suffix == ".pickle": |
|
|
import pickle |
|
|
|
|
|
with open(dataset_path, "rb") as f: |
|
|
samples = pickle.load(f) |
|
|
elif dataset_path.suffix == ".json": |
|
|
with open(dataset_path) as f: |
|
|
data = json.load(f) |
|
|
samples = data.get("samples", data) |
|
|
elif dataset_path.suffix in [".h5", ".hdf5"]: |
|
|
import h5py |
|
|
|
|
|
with h5py.File(dataset_path, "r") as f: |
|
|
|
|
|
samples = [] |
|
|
num_samples = f["images"].shape[0] |
|
|
for i in range(num_samples): |
|
|
sample = {"images": f["images"][i]} |
|
|
if "poses" in f: |
|
|
sample["poses_target"] = f["poses"][i] |
|
|
if "weights" in f: |
|
|
sample["weight"] = float(f["weights"][i]) |
|
|
samples.append(sample) |
|
|
else: |
|
|
raise ValueError(f"Unsupported dataset format: {dataset_path.suffix}") |
|
|
|
|
|
|
|
|
validator = DatasetValidator(strict=strict) |
|
|
return validator.validate_dataset(samples) |
|
|
|
|
|
|
|
|
def check_dataset_integrity( |
|
|
dataset_dir: Path, |
|
|
check_files: bool = True, |
|
|
check_consistency: bool = True, |
|
|
) -> Dict[str, Any]: |
|
|
""" |
|
|
Check dataset directory integrity. |
|
|
|
|
|
Args: |
|
|
dataset_dir: Directory containing training samples |
|
|
check_files: Check if all required files exist |
|
|
check_consistency: Check consistency between samples |
|
|
|
|
|
Returns: |
|
|
Integrity check report |
|
|
""" |
|
|
logger.info(f"Checking dataset integrity: {dataset_dir}") |
|
|
|
|
|
issues = [] |
|
|
stats = { |
|
|
"total_samples": 0, |
|
|
"valid_samples": 0, |
|
|
"missing_files": 0, |
|
|
"inconsistent_samples": 0, |
|
|
} |
|
|
|
|
|
|
|
|
sample_dirs = [d for d in dataset_dir.iterdir() if d.is_dir()] |
|
|
|
|
|
for sample_dir in sample_dirs: |
|
|
stats["total_samples"] += 1 |
|
|
|
|
|
|
|
|
if check_files: |
|
|
required_files = ["ba_poses.npy"] |
|
|
image_files = list(sample_dir.glob("*.jpg")) + list(sample_dir.glob("*.png")) |
|
|
|
|
|
if not image_files: |
|
|
issues.append( |
|
|
{ |
|
|
"sample": str(sample_dir), |
|
|
"severity": "error", |
|
|
"message": "No images found", |
|
|
} |
|
|
) |
|
|
stats["missing_files"] += 1 |
|
|
continue |
|
|
|
|
|
for req_file in required_files: |
|
|
if not (sample_dir / req_file).exists(): |
|
|
issues.append( |
|
|
{ |
|
|
"sample": str(sample_dir), |
|
|
"severity": "error", |
|
|
"message": f"Missing required file: {req_file}", |
|
|
} |
|
|
) |
|
|
stats["missing_files"] += 1 |
|
|
|
|
|
|
|
|
if check_consistency: |
|
|
try: |
|
|
poses = np.load(sample_dir / "ba_poses.npy") |
|
|
num_poses = poses.shape[0] |
|
|
num_images = len(image_files) |
|
|
|
|
|
if num_poses != num_images: |
|
|
issues.append( |
|
|
{ |
|
|
"sample": str(sample_dir), |
|
|
"severity": "warning", |
|
|
"message": f"Pose count ({num_poses}) != image count ({num_images})", |
|
|
} |
|
|
) |
|
|
stats["inconsistent_samples"] += 1 |
|
|
except Exception as e: |
|
|
issues.append( |
|
|
{ |
|
|
"sample": str(sample_dir), |
|
|
"severity": "error", |
|
|
"message": f"Failed to check consistency: {e}", |
|
|
} |
|
|
) |
|
|
|
|
|
if not issues or all(issue["sample"] != str(sample_dir) for issue in issues): |
|
|
stats["valid_samples"] += 1 |
|
|
|
|
|
return { |
|
|
"integrity_passed": stats["missing_files"] == 0, |
|
|
"statistics": stats, |
|
|
"issues": issues, |
|
|
"summary": { |
|
|
"total_samples": stats["total_samples"], |
|
|
"valid_samples": stats["valid_samples"], |
|
|
"missing_files": stats["missing_files"], |
|
|
"inconsistent_samples": stats["inconsistent_samples"], |
|
|
}, |
|
|
} |
|
|
|