| """ |
| Dataset curation utilities. |
| |
| Features: |
| - Quality-based filtering |
| - Dataset balancing |
| - Smart sampling strategies |
| - Outlier removal |
| - Dataset splitting |
| """ |
|
|
| import logging |
| from typing import Any, Dict, List, Optional, Tuple |
| import numpy as np |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class DatasetCurator: |
| """ |
| Advanced dataset curation and filtering. |
| """ |
|
|
| def __init__(self): |
| """Initialize dataset curator.""" |
| self.stats = {} |
|
|
| def filter_by_quality( |
| self, |
| samples: List[Dict], |
| min_error: Optional[float] = None, |
| max_error: Optional[float] = None, |
| min_weight: Optional[float] = None, |
| max_weight: Optional[float] = None, |
| min_images: Optional[int] = None, |
| max_images: Optional[int] = None, |
| ) -> Tuple[List[Dict], Dict[str, int]]: |
| """ |
| Filter samples by quality metrics. |
| |
| Args: |
| samples: List of training samples |
| min_error: Minimum error threshold |
| max_error: Maximum error threshold |
| min_weight: Minimum weight threshold |
| max_weight: Maximum weight threshold |
| min_images: Minimum number of images |
| max_images: Maximum number of images |
| |
| Returns: |
| Tuple of (filtered_samples, filter_stats) |
| """ |
| logger.info(f"Filtering {len(samples)} samples by quality...") |
|
|
| filtered = [] |
| stats = { |
| "original_count": len(samples), |
| "filtered_count": 0, |
| "removed_by_error": 0, |
| "removed_by_weight": 0, |
| "removed_by_images": 0, |
| } |
|
|
| for sample in samples: |
| removed = False |
|
|
| |
| if "error" in sample: |
| error = float(sample["error"]) |
| if min_error is not None and error < min_error: |
| stats["removed_by_error"] += 1 |
| removed = True |
| if max_error is not None and error > max_error: |
| stats["removed_by_error"] += 1 |
| removed = True |
|
|
| |
| if not removed and "weight" in sample: |
| weight = float(sample["weight"]) |
| if min_weight is not None and weight < min_weight: |
| stats["removed_by_weight"] += 1 |
| removed = True |
| if max_weight is not None and weight > max_weight: |
| stats["removed_by_weight"] += 1 |
| removed = True |
|
|
| |
| if not removed: |
| images = sample.get("images") |
| if images is not None: |
| if isinstance(images, (list, tuple)): |
| num_images = len(images) |
| elif isinstance(images, np.ndarray): |
| num_images = images.shape[0] |
| else: |
| num_images = 0 |
|
|
| if min_images is not None and num_images < min_images: |
| stats["removed_by_images"] += 1 |
| removed = True |
| if max_images is not None and num_images > max_images: |
| stats["removed_by_images"] += 1 |
| removed = True |
|
|
| if not removed: |
| filtered.append(sample) |
|
|
| stats["filtered_count"] = len(filtered) |
| logger.info(f"Filtered {len(filtered)}/{len(samples)} samples") |
|
|
| return filtered, stats |
|
|
| def remove_outliers( |
| self, |
| samples: List[Dict], |
| error_percentile: float = 95.0, |
| method: str = "error", |
| ) -> Tuple[List[Dict], Dict[str, int]]: |
| """ |
| Remove outlier samples. |
| |
| Args: |
| samples: List of training samples |
| error_percentile: Percentile threshold for outliers |
| method: Outlier detection method ("error", "weight", "statistical") |
| |
| Returns: |
| Tuple of (filtered_samples, stats) |
| """ |
| logger.info(f"Removing outliers from {len(samples)} samples...") |
|
|
| if method == "error" and all("error" in s for s in samples): |
| errors = [float(s["error"]) for s in samples] |
| threshold = np.percentile(errors, error_percentile) |
| filtered = [s for s in samples if float(s["error"]) <= threshold] |
| elif method == "weight" and all("weight" in s for s in samples): |
| weights = [float(s["weight"]) for s in samples] |
| threshold = np.percentile(weights, error_percentile) |
| filtered = [s for s in samples if float(s["weight"]) <= threshold] |
| elif method == "statistical": |
| |
| if all("error" in s for s in samples): |
| errors = np.array([float(s["error"]) for s in samples]) |
| q1, q3 = np.percentile(errors, [25, 75]) |
| iqr = q3 - q1 |
| lower_bound = q1 - 1.5 * iqr |
| upper_bound = q3 + 1.5 * iqr |
| filtered = [s for s in samples if lower_bound <= float(s["error"]) <= upper_bound] |
| else: |
| filtered = samples |
| else: |
| filtered = samples |
|
|
| stats = { |
| "original_count": len(samples), |
| "filtered_count": len(filtered), |
| "removed": len(samples) - len(filtered), |
| } |
|
|
| logger.info(f"Removed {stats['removed']} outliers") |
|
|
| return filtered, stats |
|
|
| def balance_dataset( |
| self, |
| samples: List[Dict], |
| strategy: str = "error_bins", |
| num_bins: int = 10, |
| max_samples_per_bin: Optional[int] = None, |
| ) -> Tuple[List[Dict], Dict[str, Any]]: |
| """ |
| Balance dataset by error distribution. |
| |
| Args: |
| samples: List of training samples |
| strategy: Balancing strategy ("error_bins", "uniform", "weighted") |
| num_bins: Number of error bins for binning strategy |
| max_samples_per_bin: Maximum samples per bin (None = no limit) |
| |
| Returns: |
| Tuple of (balanced_samples, stats) |
| """ |
| logger.info(f"Balancing {len(samples)} samples using {strategy} strategy...") |
|
|
| if not samples or "error" not in samples[0]: |
| logger.warning("Cannot balance: no error field in samples") |
| return samples, {"original_count": len(samples), "balanced_count": len(samples)} |
|
|
| errors = [float(s["error"]) for s in samples] |
| min_error = min(errors) |
| max_error = max(errors) |
|
|
| if strategy == "error_bins": |
| |
| bins = np.linspace(min_error, max_error, num_bins + 1) |
| binned_samples = [[] for _ in range(num_bins)] |
|
|
| for sample in samples: |
| error = float(sample["error"]) |
| bin_idx = np.digitize(error, bins) - 1 |
| bin_idx = max(0, min(bin_idx, num_bins - 1)) |
| binned_samples[bin_idx].append(sample) |
|
|
| |
| balanced = [] |
| bin_counts = [] |
| for bin_samples in binned_samples: |
| if max_samples_per_bin and len(bin_samples) > max_samples_per_bin: |
| |
| indices = np.random.choice( |
| len(bin_samples), max_samples_per_bin, replace=False |
| ) |
| bin_samples = [bin_samples[i] for i in indices] |
| balanced.extend(bin_samples) |
| bin_counts.append(len(bin_samples)) |
|
|
| stats = { |
| "original_count": len(samples), |
| "balanced_count": len(balanced), |
| "bin_counts": bin_counts, |
| "strategy": strategy, |
| } |
|
|
| elif strategy == "uniform": |
| |
| target_count = len(samples) // num_bins |
| bins = np.linspace(min_error, max_error, num_bins + 1) |
| balanced = [] |
|
|
| for i in range(num_bins): |
| bin_samples = [s for s in samples if bins[i] <= float(s["error"]) < bins[i + 1]] |
| if len(bin_samples) > target_count: |
| indices = np.random.choice(len(bin_samples), target_count, replace=False) |
| balanced.extend([bin_samples[j] for j in indices]) |
| else: |
| balanced.extend(bin_samples) |
|
|
| stats = { |
| "original_count": len(samples), |
| "balanced_count": len(balanced), |
| "strategy": strategy, |
| } |
|
|
| elif strategy == "weighted": |
| |
| weights = [1.0 / (float(s["error"]) + 1e-6) for s in samples] |
| weights = np.array(weights) |
| weights = weights / weights.sum() |
|
|
| |
| indices = np.random.choice(len(samples), len(samples), p=weights, replace=True) |
| balanced = [samples[i] for i in indices] |
|
|
| stats = { |
| "original_count": len(samples), |
| "balanced_count": len(balanced), |
| "strategy": strategy, |
| } |
|
|
| else: |
| logger.warning(f"Unknown strategy: {strategy}, returning original samples") |
| balanced = samples |
| stats = { |
| "original_count": len(samples), |
| "balanced_count": len(balanced), |
| "strategy": "none", |
| } |
|
|
| logger.info(f"Balanced dataset: {len(balanced)} samples") |
|
|
| return balanced, stats |
|
|
| def split_dataset( |
| self, |
| samples: List[Dict], |
| train_ratio: float = 0.8, |
| val_ratio: float = 0.1, |
| test_ratio: float = 0.1, |
| stratify_by: Optional[str] = "error", |
| random_seed: Optional[int] = None, |
| ) -> Tuple[List[Dict], List[Dict], List[Dict], Dict[str, Any]]: |
| """ |
| Split dataset into train/val/test sets. |
| |
| Args: |
| samples: List of training samples |
| train_ratio: Training set ratio |
| val_ratio: Validation set ratio |
| test_ratio: Test set ratio |
| stratify_by: Field to stratify by ("error", "weight", None) |
| random_seed: Random seed for reproducibility |
| |
| Returns: |
| Tuple of (train_samples, val_samples, test_samples, stats) |
| """ |
| if abs(train_ratio + val_ratio + test_ratio - 1.0) > 1e-6: |
| raise ValueError("Ratios must sum to 1.0") |
|
|
| if random_seed is not None: |
| np.random.seed(random_seed) |
|
|
| logger.info( |
| f"Splitting {len(samples)} samples: " |
| f"{train_ratio:.1%} train, {val_ratio:.1%} val, {test_ratio:.1%} test" |
| ) |
|
|
| if stratify_by and stratify_by in samples[0]: |
| |
| |
| values = [float(s[stratify_by]) for s in samples] |
| num_bins = min(10, len(samples) // 10) |
| bins = np.linspace(min(values), max(values), num_bins + 1) |
|
|
| train_samples = [] |
| val_samples = [] |
| test_samples = [] |
|
|
| for i in range(num_bins): |
| bin_samples = [ |
| s for s in samples if bins[i] <= float(s[stratify_by]) < bins[i + 1] |
| ] |
| if i == num_bins - 1: |
| bin_samples = [s for s in samples if float(s[stratify_by]) >= bins[i]] |
|
|
| |
| np.random.shuffle(bin_samples) |
|
|
| |
| n = len(bin_samples) |
| n_train = int(n * train_ratio) |
| n_val = int(n * val_ratio) |
|
|
| train_samples.extend(bin_samples[:n_train]) |
| val_samples.extend(bin_samples[n_train : n_train + n_val]) |
| test_samples.extend(bin_samples[n_train + n_val :]) |
|
|
| else: |
| |
| np.random.shuffle(samples) |
|
|
| n = len(samples) |
| n_train = int(n * train_ratio) |
| n_val = int(n * val_ratio) |
|
|
| train_samples = samples[:n_train] |
| val_samples = samples[n_train : n_train + n_val] |
| test_samples = samples[n_train + n_val :] |
|
|
| stats = { |
| "total": len(samples), |
| "train": len(train_samples), |
| "val": len(val_samples), |
| "test": len(test_samples), |
| "train_ratio": len(train_samples) / len(samples), |
| "val_ratio": len(val_samples) / len(samples), |
| "test_ratio": len(test_samples) / len(samples), |
| } |
|
|
| logger.info( |
| f"Split: {len(train_samples)} train, " |
| f"{len(val_samples)} val, {len(test_samples)} test" |
| ) |
|
|
| return train_samples, val_samples, test_samples, stats |
|
|
| def sample_dataset( |
| self, |
| samples: List[Dict], |
| num_samples: int, |
| strategy: str = "random", |
| weights: Optional[List[float]] = None, |
| ) -> List[Dict]: |
| """ |
| Sample subset of dataset. |
| |
| Args: |
| samples: List of training samples |
| num_samples: Number of samples to select |
| strategy: Sampling strategy ("random", "weighted", "error_based") |
| weights: Custom weights for weighted sampling |
| |
| Returns: |
| List of sampled samples |
| """ |
| if num_samples >= len(samples): |
| return samples |
|
|
| logger.info(f"Sampling {num_samples} from {len(samples)} samples using {strategy}") |
|
|
| if strategy == "random": |
| indices = np.random.choice(len(samples), num_samples, replace=False) |
| return [samples[i] for i in indices] |
|
|
| elif strategy == "weighted": |
| if weights is None: |
| |
| if "error" in samples[0]: |
| weights = [1.0 / (float(s["error"]) + 1e-6) for s in samples] |
| else: |
| weights = [1.0] * len(samples) |
|
|
| weights = np.array(weights) |
| weights = weights / weights.sum() |
| indices = np.random.choice(len(samples), num_samples, p=weights, replace=False) |
| return [samples[i] for i in indices] |
|
|
| elif strategy == "error_based": |
| |
| if "error" in samples[0]: |
| errors = [float(s["error"]) for s in samples] |
| min_error = min(errors) |
| max_error = max(errors) |
| bins = np.linspace(min_error, max_error, num_samples + 1) |
|
|
| sampled = [] |
| for i in range(num_samples): |
| bin_samples = [ |
| s for s in samples if bins[i] <= float(s["error"]) < bins[i + 1] |
| ] |
| if bin_samples: |
| sampled.append(np.random.choice(bin_samples)) |
| return sampled |
| else: |
| return np.random.choice(samples, num_samples, replace=False).tolist() |
|
|
| else: |
| raise ValueError(f"Unknown sampling strategy: {strategy}") |
|
|