3d_model / ylff /utils /dataset_analysis.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Dataset analysis and reporting utilities.
Features:
- Statistical analysis
- Quality metrics
- Visualization generation
- Report generation
"""
import json
import logging
from pathlib import Path
from typing import Any, Dict, List, Optional
import numpy as np
logger = logging.getLogger(__name__)
class DatasetAnalyzer:
"""
Comprehensive dataset analysis and reporting.
"""
def __init__(self):
"""Initialize dataset analyzer."""
self.stats: Dict[str, Any] = {}
def analyze_dataset(
self,
samples: List[Dict],
compute_distributions: bool = True,
compute_correlations: bool = True,
) -> Dict[str, Any]:
"""
Perform comprehensive dataset analysis.
Args:
samples: List of training samples
compute_distributions: Compute error/weight distributions
compute_correlations: Compute correlations between metrics
Returns:
Analysis report dictionary
"""
logger.info(f"Analyzing dataset with {len(samples)} samples...")
if not samples:
return {"error": "Empty dataset"}
# Basic statistics
self.stats = {
"total_samples": len(samples),
"sample_fields": list(samples[0].keys()) if samples else [],
}
# Extract metrics
errors = []
weights = []
num_images = []
sequence_ids = []
for sample in samples:
if "error" in sample:
error = sample["error"]
if isinstance(error, (np.ndarray, list)):
error = float(error[0]) if len(error) > 0 else 0.0
errors.append(float(error))
if "weight" in sample:
weight = sample["weight"]
if isinstance(weight, (np.ndarray, list)):
weight = float(weight[0]) if len(weight) > 0 else 1.0
weights.append(float(weight))
if "images" in sample:
images = sample["images"]
if isinstance(images, (list, tuple)):
num_images.append(len(images))
elif isinstance(images, np.ndarray):
num_images.append(images.shape[0])
if "sequence_id" in sample:
sequence_ids.append(str(sample["sequence_id"]))
# Error statistics
if errors:
self.stats["error_statistics"] = self._compute_statistics(errors, "error")
if compute_distributions:
self.stats["error_distribution"] = self._compute_distribution(errors, bins=50)
# Weight statistics
if weights:
self.stats["weight_statistics"] = self._compute_statistics(weights, "weight")
if compute_distributions:
self.stats["weight_distribution"] = self._compute_distribution(weights, bins=50)
# Image count statistics
if num_images:
self.stats["image_count_statistics"] = self._compute_statistics(num_images, "images")
# Sequence statistics
if sequence_ids:
unique_sequences = set(sequence_ids)
samples_per_sequence = {}
for seq_id in unique_sequences:
samples_per_sequence[seq_id] = sequence_ids.count(seq_id)
self.stats["sequence_statistics"] = {
"unique_sequences": len(unique_sequences),
"samples_per_sequence": {
"mean": float(np.mean(list(samples_per_sequence.values()))),
"min": int(np.min(list(samples_per_sequence.values()))),
"max": int(np.max(list(samples_per_sequence.values()))),
"std": float(np.std(list(samples_per_sequence.values()))),
},
}
# Correlations
if compute_correlations and errors and weights:
correlation = np.corrcoef(errors, weights)[0, 1]
self.stats["correlations"] = {"error_weight": float(correlation)}
# Quality metrics
self.stats["quality_metrics"] = self._compute_quality_metrics(samples, errors, weights)
return self.stats
def _compute_statistics(self, values: List[float], name: str) -> Dict[str, float]:
"""Compute statistical measures."""
arr = np.array(values)
return {
"mean": float(np.mean(arr)),
"median": float(np.median(arr)),
"std": float(np.std(arr)),
"min": float(np.min(arr)),
"max": float(np.max(arr)),
"q25": float(np.percentile(arr, 25)),
"q75": float(np.percentile(arr, 75)),
"q90": float(np.percentile(arr, 90)),
"q95": float(np.percentile(arr, 95)),
"q99": float(np.percentile(arr, 99)),
}
def _compute_distribution(self, values: List[float], bins: int = 50) -> Dict[str, Any]:
"""Compute value distribution."""
arr = np.array(values)
hist, bin_edges = np.histogram(arr, bins=bins)
return {
"histogram": hist.tolist(),
"bin_edges": bin_edges.tolist(),
"bin_centers": ((bin_edges[:-1] + bin_edges[1:]) / 2).tolist(),
}
def _compute_quality_metrics(
self,
samples: List[Dict],
errors: List[float],
weights: List[float],
) -> Dict[str, Any]:
"""Compute dataset quality metrics."""
metrics = {}
if errors:
# Error-based metrics
metrics["low_error_ratio"] = sum(1 for e in errors if e < 2.0) / len(errors)
metrics["medium_error_ratio"] = sum(1 for e in errors if 2.0 <= e < 30.0) / len(errors)
metrics["high_error_ratio"] = sum(1 for e in errors if e >= 30.0) / len(errors)
if weights:
# Weight-based metrics
metrics["weight_diversity"] = float(np.std(weights))
metrics["uniform_weight_ratio"] = sum(1 for w in weights if abs(w - 1.0) < 0.1) / len(
weights
)
# Completeness
required_fields = ["images", "poses_target"]
completeness = {}
for field in required_fields:
completeness[field] = sum(
1 for s in samples if field in s and s[field] is not None
) / len(samples)
metrics["completeness"] = completeness
return metrics
def generate_report(
self,
output_path: Optional[Path] = None,
format: str = "json",
) -> str:
"""
Generate human-readable report.
Args:
output_path: Path to save report (optional)
format: Report format ("json", "text", "markdown")
Returns:
Report string
"""
if format == "json":
report = json.dumps(self.stats, indent=2, default=str)
elif format == "text":
report = self._generate_text_report()
elif format == "markdown":
report = self._generate_markdown_report()
else:
raise ValueError(f"Unknown format: {format}")
if output_path:
output_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_path, "w") as f:
f.write(report)
logger.info(f"Report saved to: {output_path}")
return report
def _generate_text_report(self) -> str:
"""Generate text report."""
lines = []
lines.append("=" * 80)
lines.append("DATASET ANALYSIS REPORT")
lines.append("=" * 80)
lines.append("")
lines.append(f"Total Samples: {self.stats.get('total_samples', 0)}")
lines.append("")
# Error statistics
if "error_statistics" in self.stats:
lines.append("Error Statistics:")
err_stats = self.stats["error_statistics"]
lines.append(f" Mean: {err_stats['mean']:.4f}")
lines.append(f" Median: {err_stats['median']:.4f}")
lines.append(f" Std: {err_stats['std']:.4f}")
lines.append(f" Range: [{err_stats['min']:.4f}, {err_stats['max']:.4f}]")
lines.append(f" Q25: {err_stats['q25']:.4f}")
lines.append(f" Q75: {err_stats['q75']:.4f}")
lines.append(f" Q95: {err_stats['q95']:.4f}")
lines.append("")
# Quality metrics
if "quality_metrics" in self.stats:
lines.append("Quality Metrics:")
qm = self.stats["quality_metrics"]
if "low_error_ratio" in qm:
lines.append(f" Low error (< 2°): {qm['low_error_ratio'] * 100:.1f}%")
lines.append(f" Medium error (2-30°): {qm['medium_error_ratio'] * 100:.1f}%")
lines.append(f" High error (> 30°): {qm['high_error_ratio'] * 100:.1f}%")
lines.append("")
# Sequence statistics
if "sequence_statistics" in self.stats:
lines.append("Sequence Statistics:")
seq_stats = self.stats["sequence_statistics"]
lines.append(f" Unique sequences: {seq_stats['unique_sequences']}")
sps = seq_stats["samples_per_sequence"]
lines.append(f" Samples per sequence: {sps['mean']:.1f} ± {sps['std']:.1f}")
lines.append("")
return "\n".join(lines)
def _generate_markdown_report(self) -> str:
"""Generate markdown report."""
lines = []
lines.append("# Dataset Analysis Report")
lines.append("")
lines.append(f"**Total Samples:** {self.stats.get('total_samples', 0)}")
lines.append("")
# Error statistics table
if "error_statistics" in self.stats:
lines.append("## Error Statistics")
lines.append("")
lines.append("| Metric | Value |")
lines.append("|--------|-------|")
err_stats = self.stats["error_statistics"]
for key, value in err_stats.items():
lines.append(f"| {key} | {value:.4f} |")
lines.append("")
# Quality metrics
if "quality_metrics" in self.stats:
lines.append("## Quality Metrics")
lines.append("")
qm = self.stats["quality_metrics"]
if "low_error_ratio" in qm:
lines.append(f"- **Low error (< 2°):** {qm['low_error_ratio']*100:.1f}%")
lines.append(f"- **Medium error (2-30°):** {qm['medium_error_ratio']*100:.1f}%")
lines.append(f"- **High error (> 30°):** {qm['high_error_ratio']*100:.1f}%")
lines.append("")
return "\n".join(lines)
def analyze_dataset_file(
dataset_path: Path,
output_path: Optional[Path] = None,
format: str = "json",
) -> Dict[str, Any]:
"""
Analyze a saved dataset file.
Args:
dataset_path: Path to dataset file
output_path: Path to save analysis report
format: Report format
Returns:
Analysis results
"""
logger.info(f"Analyzing dataset file: {dataset_path}")
# Load dataset
if dataset_path.suffix == ".pkl" or dataset_path.suffix == ".pickle":
import pickle
with open(dataset_path, "rb") as f:
samples = pickle.load(f)
elif dataset_path.suffix == ".json":
with open(dataset_path) as f:
data = json.load(f)
samples = data.get("samples", data)
elif dataset_path.suffix in [".h5", ".hdf5"]:
import h5py
with h5py.File(dataset_path, "r") as f:
samples = []
num_samples = f["images"].shape[0]
for i in range(num_samples):
sample = {"images": f["images"][i]}
if "poses" in f:
sample["poses_target"] = f["poses"][i]
if "weights" in f:
sample["weight"] = float(f["weights"][i])
samples.append(sample)
else:
raise ValueError(f"Unsupported dataset format: {dataset_path.suffix}")
# Analyze
analyzer = DatasetAnalyzer()
results = analyzer.analyze_dataset(samples)
# Generate report
if output_path:
analyzer.generate_report(output_path, format=format)
return results