3d_model / ylff /utils /dataset_curation.py
Azan
Clean deployment build (Squashed)
7a87926
"""
Dataset curation utilities.
Features:
- Quality-based filtering
- Dataset balancing
- Smart sampling strategies
- Outlier removal
- Dataset splitting
"""
import logging
from typing import Any, Dict, List, Optional, Tuple
import numpy as np
logger = logging.getLogger(__name__)
class DatasetCurator:
"""
Advanced dataset curation and filtering.
"""
def __init__(self):
"""Initialize dataset curator."""
self.stats = {}
def filter_by_quality(
self,
samples: List[Dict],
min_error: Optional[float] = None,
max_error: Optional[float] = None,
min_weight: Optional[float] = None,
max_weight: Optional[float] = None,
min_images: Optional[int] = None,
max_images: Optional[int] = None,
) -> Tuple[List[Dict], Dict[str, int]]:
"""
Filter samples by quality metrics.
Args:
samples: List of training samples
min_error: Minimum error threshold
max_error: Maximum error threshold
min_weight: Minimum weight threshold
max_weight: Maximum weight threshold
min_images: Minimum number of images
max_images: Maximum number of images
Returns:
Tuple of (filtered_samples, filter_stats)
"""
logger.info(f"Filtering {len(samples)} samples by quality...")
filtered = []
stats = {
"original_count": len(samples),
"filtered_count": 0,
"removed_by_error": 0,
"removed_by_weight": 0,
"removed_by_images": 0,
}
for sample in samples:
removed = False
# Filter by error
if "error" in sample:
error = float(sample["error"])
if min_error is not None and error < min_error:
stats["removed_by_error"] += 1
removed = True
if max_error is not None and error > max_error:
stats["removed_by_error"] += 1
removed = True
# Filter by weight
if not removed and "weight" in sample:
weight = float(sample["weight"])
if min_weight is not None and weight < min_weight:
stats["removed_by_weight"] += 1
removed = True
if max_weight is not None and weight > max_weight:
stats["removed_by_weight"] += 1
removed = True
# Filter by image count
if not removed:
images = sample.get("images")
if images is not None:
if isinstance(images, (list, tuple)):
num_images = len(images)
elif isinstance(images, np.ndarray):
num_images = images.shape[0]
else:
num_images = 0
if min_images is not None and num_images < min_images:
stats["removed_by_images"] += 1
removed = True
if max_images is not None and num_images > max_images:
stats["removed_by_images"] += 1
removed = True
if not removed:
filtered.append(sample)
stats["filtered_count"] = len(filtered)
logger.info(f"Filtered {len(filtered)}/{len(samples)} samples")
return filtered, stats
def remove_outliers(
self,
samples: List[Dict],
error_percentile: float = 95.0,
method: str = "error",
) -> Tuple[List[Dict], Dict[str, int]]:
"""
Remove outlier samples.
Args:
samples: List of training samples
error_percentile: Percentile threshold for outliers
method: Outlier detection method ("error", "weight", "statistical")
Returns:
Tuple of (filtered_samples, stats)
"""
logger.info(f"Removing outliers from {len(samples)} samples...")
if method == "error" and all("error" in s for s in samples):
errors = [float(s["error"]) for s in samples]
threshold = np.percentile(errors, error_percentile)
filtered = [s for s in samples if float(s["error"]) <= threshold]
elif method == "weight" and all("weight" in s for s in samples):
weights = [float(s["weight"]) for s in samples]
threshold = np.percentile(weights, error_percentile)
filtered = [s for s in samples if float(s["weight"]) <= threshold]
elif method == "statistical":
# Use IQR method
if all("error" in s for s in samples):
errors = np.array([float(s["error"]) for s in samples])
q1, q3 = np.percentile(errors, [25, 75])
iqr = q3 - q1
lower_bound = q1 - 1.5 * iqr
upper_bound = q3 + 1.5 * iqr
filtered = [s for s in samples if lower_bound <= float(s["error"]) <= upper_bound]
else:
filtered = samples
else:
filtered = samples
stats = {
"original_count": len(samples),
"filtered_count": len(filtered),
"removed": len(samples) - len(filtered),
}
logger.info(f"Removed {stats['removed']} outliers")
return filtered, stats
def balance_dataset(
self,
samples: List[Dict],
strategy: str = "error_bins",
num_bins: int = 10,
max_samples_per_bin: Optional[int] = None,
) -> Tuple[List[Dict], Dict[str, Any]]:
"""
Balance dataset by error distribution.
Args:
samples: List of training samples
strategy: Balancing strategy ("error_bins", "uniform", "weighted")
num_bins: Number of error bins for binning strategy
max_samples_per_bin: Maximum samples per bin (None = no limit)
Returns:
Tuple of (balanced_samples, stats)
"""
logger.info(f"Balancing {len(samples)} samples using {strategy} strategy...")
if not samples or "error" not in samples[0]:
logger.warning("Cannot balance: no error field in samples")
return samples, {"original_count": len(samples), "balanced_count": len(samples)}
errors = [float(s["error"]) for s in samples]
min_error = min(errors)
max_error = max(errors)
if strategy == "error_bins":
# Bin samples by error
bins = np.linspace(min_error, max_error, num_bins + 1)
binned_samples = [[] for _ in range(num_bins)]
for sample in samples:
error = float(sample["error"])
bin_idx = np.digitize(error, bins) - 1
bin_idx = max(0, min(bin_idx, num_bins - 1))
binned_samples[bin_idx].append(sample)
# Sample from each bin
balanced = []
bin_counts = []
for bin_samples in binned_samples:
if max_samples_per_bin and len(bin_samples) > max_samples_per_bin:
# Randomly sample
indices = np.random.choice(
len(bin_samples), max_samples_per_bin, replace=False
)
bin_samples = [bin_samples[i] for i in indices]
balanced.extend(bin_samples)
bin_counts.append(len(bin_samples))
stats = {
"original_count": len(samples),
"balanced_count": len(balanced),
"bin_counts": bin_counts,
"strategy": strategy,
}
elif strategy == "uniform":
# Uniform sampling across error range
target_count = len(samples) // num_bins
bins = np.linspace(min_error, max_error, num_bins + 1)
balanced = []
for i in range(num_bins):
bin_samples = [s for s in samples if bins[i] <= float(s["error"]) < bins[i + 1]]
if len(bin_samples) > target_count:
indices = np.random.choice(len(bin_samples), target_count, replace=False)
balanced.extend([bin_samples[j] for j in indices])
else:
balanced.extend(bin_samples)
stats = {
"original_count": len(samples),
"balanced_count": len(balanced),
"strategy": strategy,
}
elif strategy == "weighted":
# Weighted sampling based on error
weights = [1.0 / (float(s["error"]) + 1e-6) for s in samples]
weights = np.array(weights)
weights = weights / weights.sum()
# Sample with replacement
indices = np.random.choice(len(samples), len(samples), p=weights, replace=True)
balanced = [samples[i] for i in indices]
stats = {
"original_count": len(samples),
"balanced_count": len(balanced),
"strategy": strategy,
}
else:
logger.warning(f"Unknown strategy: {strategy}, returning original samples")
balanced = samples
stats = {
"original_count": len(samples),
"balanced_count": len(balanced),
"strategy": "none",
}
logger.info(f"Balanced dataset: {len(balanced)} samples")
return balanced, stats
def split_dataset(
self,
samples: List[Dict],
train_ratio: float = 0.8,
val_ratio: float = 0.1,
test_ratio: float = 0.1,
stratify_by: Optional[str] = "error",
random_seed: Optional[int] = None,
) -> Tuple[List[Dict], List[Dict], List[Dict], Dict[str, Any]]:
"""
Split dataset into train/val/test sets.
Args:
samples: List of training samples
train_ratio: Training set ratio
val_ratio: Validation set ratio
test_ratio: Test set ratio
stratify_by: Field to stratify by ("error", "weight", None)
random_seed: Random seed for reproducibility
Returns:
Tuple of (train_samples, val_samples, test_samples, stats)
"""
if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6:
raise ValueError("Ratios must sum to 1.0")
if random_seed is not None:
np.random.seed(random_seed)
logger.info(
f"Splitting {len(samples)} samples: "
f"{train_ratio:.1%} train, {val_ratio:.1%} val, {test_ratio:.1%} test"
)
if stratify_by and stratify_by in samples[0]:
# Stratified split
# Bin samples by stratify field
values = [float(s[stratify_by]) for s in samples]
num_bins = min(10, len(samples) // 10)
bins = np.linspace(min(values), max(values), num_bins + 1)
train_samples = []
val_samples = []
test_samples = []
for i in range(num_bins):
bin_samples = [
s for s in samples if bins[i] <= float(s[stratify_by]) < bins[i + 1]
]
if i == num_bins - 1: # Include upper bound
bin_samples = [s for s in samples if float(s[stratify_by]) >= bins[i]]
# Shuffle
np.random.shuffle(bin_samples)
# Split
n = len(bin_samples)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
train_samples.extend(bin_samples[:n_train])
val_samples.extend(bin_samples[n_train : n_train + n_val])
test_samples.extend(bin_samples[n_train + n_val :])
else:
# Random split
np.random.shuffle(samples)
n = len(samples)
n_train = int(n * train_ratio)
n_val = int(n * val_ratio)
train_samples = samples[:n_train]
val_samples = samples[n_train : n_train + n_val]
test_samples = samples[n_train + n_val :]
stats = {
"total": len(samples),
"train": len(train_samples),
"val": len(val_samples),
"test": len(test_samples),
"train_ratio": len(train_samples) / len(samples),
"val_ratio": len(val_samples) / len(samples),
"test_ratio": len(test_samples) / len(samples),
}
logger.info(
f"Split: {len(train_samples)} train, "
f"{len(val_samples)} val, {len(test_samples)} test"
)
return train_samples, val_samples, test_samples, stats
def sample_dataset(
self,
samples: List[Dict],
num_samples: int,
strategy: str = "random",
weights: Optional[List[float]] = None,
) -> List[Dict]:
"""
Sample subset of dataset.
Args:
samples: List of training samples
num_samples: Number of samples to select
strategy: Sampling strategy ("random", "weighted", "error_based")
weights: Custom weights for weighted sampling
Returns:
List of sampled samples
"""
if num_samples >= len(samples):
return samples
logger.info(f"Sampling {num_samples} from {len(samples)} samples using {strategy}")
if strategy == "random":
indices = np.random.choice(len(samples), num_samples, replace=False)
return [samples[i] for i in indices]
elif strategy == "weighted":
if weights is None:
# Use error-based weights (inverse error)
if "error" in samples[0]:
weights = [1.0 / (float(s["error"]) + 1e-6) for s in samples]
else:
weights = [1.0] * len(samples)
weights = np.array(weights)
weights = weights / weights.sum()
indices = np.random.choice(len(samples), num_samples, p=weights, replace=False)
return [samples[i] for i in indices]
elif strategy == "error_based":
# Sample uniformly across error range
if "error" in samples[0]:
errors = [float(s["error"]) for s in samples]
min_error = min(errors)
max_error = max(errors)
bins = np.linspace(min_error, max_error, num_samples + 1)
sampled = []
for i in range(num_samples):
bin_samples = [
s for s in samples if bins[i] <= float(s["error"]) < bins[i + 1]
]
if bin_samples:
sampled.append(np.random.choice(bin_samples))
return sampled
else:
return np.random.choice(samples, num_samples, replace=False).tolist()
else:
raise ValueError(f"Unknown sampling strategy: {strategy}")