""" Dataset analysis and reporting utilities. Features: - Statistical analysis - Quality metrics - Visualization generation - Report generation """ import json import logging from pathlib import Path from typing import Any, Dict, List, Optional import numpy as np logger = logging.getLogger(__name__) class DatasetAnalyzer: """ Comprehensive dataset analysis and reporting. """ def __init__(self): """Initialize dataset analyzer.""" self.stats: Dict[str, Any] = {} def analyze_dataset( self, samples: List[Dict], compute_distributions: bool = True, compute_correlations: bool = True, ) -> Dict[str, Any]: """ Perform comprehensive dataset analysis. Args: samples: List of training samples compute_distributions: Compute error/weight distributions compute_correlations: Compute correlations between metrics Returns: Analysis report dictionary """ logger.info(f"Analyzing dataset with {len(samples)} samples...") if not samples: return {"error": "Empty dataset"} # Basic statistics self.stats = { "total_samples": len(samples), "sample_fields": list(samples[0].keys()) if samples else [], } # Extract metrics errors = [] weights = [] num_images = [] sequence_ids = [] for sample in samples: if "error" in sample: error = sample["error"] if isinstance(error, (np.ndarray, list)): error = float(error[0]) if len(error) > 0 else 0.0 errors.append(float(error)) if "weight" in sample: weight = sample["weight"] if isinstance(weight, (np.ndarray, list)): weight = float(weight[0]) if len(weight) > 0 else 1.0 weights.append(float(weight)) if "images" in sample: images = sample["images"] if isinstance(images, (list, tuple)): num_images.append(len(images)) elif isinstance(images, np.ndarray): num_images.append(images.shape[0]) if "sequence_id" in sample: sequence_ids.append(str(sample["sequence_id"])) # Error statistics if errors: self.stats["error_statistics"] = self._compute_statistics(errors, "error") if compute_distributions: self.stats["error_distribution"] = self._compute_distribution(errors, bins=50) # Weight statistics if weights: self.stats["weight_statistics"] = self._compute_statistics(weights, "weight") if compute_distributions: self.stats["weight_distribution"] = self._compute_distribution(weights, bins=50) # Image count statistics if num_images: self.stats["image_count_statistics"] = self._compute_statistics(num_images, "images") # Sequence statistics if sequence_ids: unique_sequences = set(sequence_ids) samples_per_sequence = {} for seq_id in unique_sequences: samples_per_sequence[seq_id] = sequence_ids.count(seq_id) self.stats["sequence_statistics"] = { "unique_sequences": len(unique_sequences), "samples_per_sequence": { "mean": float(np.mean(list(samples_per_sequence.values()))), "min": int(np.min(list(samples_per_sequence.values()))), "max": int(np.max(list(samples_per_sequence.values()))), "std": float(np.std(list(samples_per_sequence.values()))), }, } # Correlations if compute_correlations and errors and weights: correlation = np.corrcoef(errors, weights)[0, 1] self.stats["correlations"] = {"error_weight": float(correlation)} # Quality metrics self.stats["quality_metrics"] = self._compute_quality_metrics(samples, errors, weights) return self.stats def _compute_statistics(self, values: List[float], name: str) -> Dict[str, float]: """Compute statistical measures.""" arr = np.array(values) return { "mean": float(np.mean(arr)), "median": float(np.median(arr)), "std": float(np.std(arr)), "min": float(np.min(arr)), "max": float(np.max(arr)), "q25": float(np.percentile(arr, 25)), "q75": float(np.percentile(arr, 75)), "q90": float(np.percentile(arr, 90)), "q95": float(np.percentile(arr, 95)), "q99": float(np.percentile(arr, 99)), } def _compute_distribution(self, values: List[float], bins: int = 50) -> Dict[str, Any]: """Compute value distribution.""" arr = np.array(values) hist, bin_edges = np.histogram(arr, bins=bins) return { "histogram": hist.tolist(), "bin_edges": bin_edges.tolist(), "bin_centers": ((bin_edges[:-1] + bin_edges[1:]) / 2).tolist(), } def _compute_quality_metrics( self, samples: List[Dict], errors: List[float], weights: List[float], ) -> Dict[str, Any]: """Compute dataset quality metrics.""" metrics = {} if errors: # Error-based metrics metrics["low_error_ratio"] = sum(1 for e in errors if e < 2.0) / len(errors) metrics["medium_error_ratio"] = sum(1 for e in errors if 2.0 <= e < 30.0) / len(errors) metrics["high_error_ratio"] = sum(1 for e in errors if e >= 30.0) / len(errors) if weights: # Weight-based metrics metrics["weight_diversity"] = float(np.std(weights)) metrics["uniform_weight_ratio"] = sum(1 for w in weights if abs(w - 1.0) < 0.1) / len( weights ) # Completeness required_fields = ["images", "poses_target"] completeness = {} for field in required_fields: completeness[field] = sum( 1 for s in samples if field in s and s[field] is not None ) / len(samples) metrics["completeness"] = completeness return metrics def generate_report( self, output_path: Optional[Path] = None, format: str = "json", ) -> str: """ Generate human-readable report. Args: output_path: Path to save report (optional) format: Report format ("json", "text", "markdown") Returns: Report string """ if format == "json": report = json.dumps(self.stats, indent=2, default=str) elif format == "text": report = self._generate_text_report() elif format == "markdown": report = self._generate_markdown_report() else: raise ValueError(f"Unknown format: {format}") if output_path: output_path.parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: f.write(report) logger.info(f"Report saved to: {output_path}") return report def _generate_text_report(self) -> str: """Generate text report.""" lines = [] lines.append("=" * 80) lines.append("DATASET ANALYSIS REPORT") lines.append("=" * 80) lines.append("") lines.append(f"Total Samples: {self.stats.get('total_samples', 0)}") lines.append("") # Error statistics if "error_statistics" in self.stats: lines.append("Error Statistics:") err_stats = self.stats["error_statistics"] lines.append(f" Mean: {err_stats['mean']:.4f}") lines.append(f" Median: {err_stats['median']:.4f}") lines.append(f" Std: {err_stats['std']:.4f}") lines.append(f" Range: [{err_stats['min']:.4f}, {err_stats['max']:.4f}]") lines.append(f" Q25: {err_stats['q25']:.4f}") lines.append(f" Q75: {err_stats['q75']:.4f}") lines.append(f" Q95: {err_stats['q95']:.4f}") lines.append("") # Quality metrics if "quality_metrics" in self.stats: lines.append("Quality Metrics:") qm = self.stats["quality_metrics"] if "low_error_ratio" in qm: lines.append(f" Low error (< 2°): {qm['low_error_ratio'] * 100:.1f}%") lines.append(f" Medium error (2-30°): {qm['medium_error_ratio'] * 100:.1f}%") lines.append(f" High error (> 30°): {qm['high_error_ratio'] * 100:.1f}%") lines.append("") # Sequence statistics if "sequence_statistics" in self.stats: lines.append("Sequence Statistics:") seq_stats = self.stats["sequence_statistics"] lines.append(f" Unique sequences: {seq_stats['unique_sequences']}") sps = seq_stats["samples_per_sequence"] lines.append(f" Samples per sequence: {sps['mean']:.1f} ± {sps['std']:.1f}") lines.append("") return "\n".join(lines) def _generate_markdown_report(self) -> str: """Generate markdown report.""" lines = [] lines.append("# Dataset Analysis Report") lines.append("") lines.append(f"**Total Samples:** {self.stats.get('total_samples', 0)}") lines.append("") # Error statistics table if "error_statistics" in self.stats: lines.append("## Error Statistics") lines.append("") lines.append("| Metric | Value |") lines.append("|--------|-------|") err_stats = self.stats["error_statistics"] for key, value in err_stats.items(): lines.append(f"| {key} | {value:.4f} |") lines.append("") # Quality metrics if "quality_metrics" in self.stats: lines.append("## Quality Metrics") lines.append("") qm = self.stats["quality_metrics"] if "low_error_ratio" in qm: lines.append(f"- **Low error (< 2°):** {qm['low_error_ratio']*100:.1f}%") lines.append(f"- **Medium error (2-30°):** {qm['medium_error_ratio']*100:.1f}%") lines.append(f"- **High error (> 30°):** {qm['high_error_ratio']*100:.1f}%") lines.append("") return "\n".join(lines) def analyze_dataset_file( dataset_path: Path, output_path: Optional[Path] = None, format: str = "json", ) -> Dict[str, Any]: """ Analyze a saved dataset file. Args: dataset_path: Path to dataset file output_path: Path to save analysis report format: Report format Returns: Analysis results """ logger.info(f"Analyzing dataset file: {dataset_path}") # Load dataset if dataset_path.suffix == ".pkl" or dataset_path.suffix == ".pickle": import pickle with open(dataset_path, "rb") as f: samples = pickle.load(f) elif dataset_path.suffix == ".json": with open(dataset_path) as f: data = json.load(f) samples = data.get("samples", data) elif dataset_path.suffix in [".h5", ".hdf5"]: import h5py with h5py.File(dataset_path, "r") as f: samples = [] num_samples = f["images"].shape[0] for i in range(num_samples): sample = {"images": f["images"][i]} if "poses" in f: sample["poses_target"] = f["poses"][i] if "weights" in f: sample["weight"] = float(f["weights"][i]) samples.append(sample) else: raise ValueError(f"Unsupported dataset format: {dataset_path.suffix}") # Analyze analyzer = DatasetAnalyzer() results = analyzer.analyze_dataset(samples) # Generate report if output_path: analyzer.generate_report(output_path, format=format) return results