|
|
""" |
|
|
ESC-50 dataset utilities for loading and sampling audio data. |
|
|
""" |
|
|
|
|
|
import csv |
|
|
import json |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
|
|
|
import pandas as pd |
|
|
|
|
|
from .logger import setup_logger |
|
|
|
|
|
logger = setup_logger(__name__) |
|
|
|
|
|
|
|
|
def load_or_create_class_subset(config: dict, all_categories: List[str]) -> List[str]: |
|
|
""" |
|
|
Load persisted class subset or create a new one. |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary with dataset.use_class_subset, etc. |
|
|
all_categories: List of all available categories |
|
|
|
|
|
Returns: |
|
|
List of category names to use (either subset or all) |
|
|
""" |
|
|
dataset_config = config.get('dataset', {}) |
|
|
use_subset = dataset_config.get('use_class_subset', False) |
|
|
|
|
|
if not use_subset: |
|
|
logger.info(f"Using all {len(all_categories)} classes") |
|
|
return all_categories |
|
|
|
|
|
num_classes = dataset_config.get('num_classes_subset', len(all_categories)) |
|
|
persist_path = Path(dataset_config.get('subset_persist_path', 'class_subset.json')) |
|
|
subset_seed = dataset_config.get('subset_seed', 42) |
|
|
|
|
|
|
|
|
if persist_path.exists(): |
|
|
try: |
|
|
with open(persist_path, 'r') as f: |
|
|
data = json.load(f) |
|
|
subset = data.get('classes', []) |
|
|
|
|
|
|
|
|
if len(subset) == num_classes and all(c in all_categories for c in subset): |
|
|
logger.info(f"Loaded persisted class subset from {persist_path}: {len(subset)} classes") |
|
|
return subset |
|
|
else: |
|
|
logger.warning(f"Invalid persisted subset, regenerating...") |
|
|
except Exception as e: |
|
|
logger.warning(f"Failed to load persisted subset: {e}, regenerating...") |
|
|
|
|
|
|
|
|
random.seed(subset_seed) |
|
|
subset = random.sample(all_categories, min(num_classes, len(all_categories))) |
|
|
subset.sort() |
|
|
|
|
|
|
|
|
persist_path.parent.mkdir(parents=True, exist_ok=True) |
|
|
with open(persist_path, 'w') as f: |
|
|
json.dump({ |
|
|
'classes': subset, |
|
|
'num_classes': len(subset), |
|
|
'seed': subset_seed, |
|
|
'total_available': len(all_categories) |
|
|
}, f, indent=2) |
|
|
|
|
|
logger.info(f"Created and persisted new class subset: {len(subset)} classes to {persist_path}") |
|
|
return subset |
|
|
|
|
|
|
|
|
class ESC50Dataset: |
|
|
"""Handler for ESC-50 dataset.""" |
|
|
|
|
|
|
|
|
ALL_CATEGORIES = [ |
|
|
'dog', 'chirping_birds', 'vacuum_cleaner', 'thunderstorm', 'door_wood_knock', |
|
|
'can_opening', 'crow', 'clapping', 'fireworks', 'chainsaw', 'airplane', |
|
|
'mouse_click', 'pouring_water', 'train', 'sheep', 'water_drops', 'church_bells', |
|
|
'clock_alarm', 'keyboard_typing', 'wind', 'footsteps', 'frog', 'cow', |
|
|
'brushing_teeth', 'car_horn', 'crackling_fire', 'helicopter', 'drinking_sipping', |
|
|
'rain', 'insects', 'laughing', 'hen', 'engine', 'breathing', 'crying_baby', |
|
|
'hand_saw', 'coughing', 'glass_breaking', 'snoring', 'toilet_flush', 'pig', |
|
|
'washing_machine', 'clock_tick', 'sneezing', 'rooster', 'sea_waves', 'siren', |
|
|
'cat', 'door_wood_creaks', 'crickets' |
|
|
] |
|
|
|
|
|
def __init__(self, metadata_path: str, audio_path: str, config: Optional[dict] = None): |
|
|
""" |
|
|
Initialize ESC-50 dataset handler. |
|
|
|
|
|
Args: |
|
|
metadata_path: Path to esc50.csv metadata file |
|
|
audio_path: Path to audio directory |
|
|
config: Optional configuration dict with dataset.use_class_subset settings |
|
|
""" |
|
|
self.metadata_path = Path(metadata_path) |
|
|
self.audio_path = Path(audio_path) |
|
|
self.config = config or {} |
|
|
self.df = None |
|
|
self.category_to_target = {} |
|
|
self.target_to_category = {} |
|
|
|
|
|
|
|
|
self.CATEGORIES = load_or_create_class_subset(self.config, self.ALL_CATEGORIES) |
|
|
self.category_usage_counts = {cat: 0 for cat in self.CATEGORIES} |
|
|
|
|
|
self.load_metadata() |
|
|
|
|
|
def load_metadata(self): |
|
|
"""Load ESC-50 metadata CSV.""" |
|
|
try: |
|
|
self.df = pd.read_csv(self.metadata_path) |
|
|
logger.info(f"Loaded ESC-50 metadata: {len(self.df)} files") |
|
|
|
|
|
|
|
|
for target, category in zip(self.df['target'], self.df['category']): |
|
|
self.category_to_target[category] = target |
|
|
self.target_to_category[target] = category |
|
|
|
|
|
logger.info(f"Found {len(self.category_to_target)} unique categories") |
|
|
except Exception as e: |
|
|
logger.error(f"Error loading metadata: {e}") |
|
|
raise |
|
|
|
|
|
def get_files_by_category(self, category: str) -> List[str]: |
|
|
""" |
|
|
Get all audio files for a specific category. |
|
|
|
|
|
Args: |
|
|
category: Sound category name |
|
|
|
|
|
Returns: |
|
|
List of filenames for the category |
|
|
""" |
|
|
if category not in self.category_to_target: |
|
|
raise ValueError(f"Unknown category: {category}") |
|
|
|
|
|
target = self.category_to_target[category] |
|
|
files = self.df[self.df['target'] == target]['filename'].tolist() |
|
|
return files |
|
|
|
|
|
def get_files_by_target(self, target: int) -> List[str]: |
|
|
""" |
|
|
Get all audio files for a specific target ID. |
|
|
|
|
|
Args: |
|
|
target: Target class ID (0-49) |
|
|
|
|
|
Returns: |
|
|
List of filenames for the target |
|
|
""" |
|
|
files = self.df[self.df['target'] == target]['filename'].tolist() |
|
|
return files |
|
|
|
|
|
def sample_categories(self, n: int, exclude: Optional[List[str]] = None) -> List[str]: |
|
|
""" |
|
|
Sample n unique random categories from the active subset. |
|
|
|
|
|
Args: |
|
|
n: Number of categories to sample |
|
|
exclude: Optional list of categories to exclude |
|
|
|
|
|
Returns: |
|
|
List of sampled category names |
|
|
""" |
|
|
available = [c for c in self.CATEGORIES if c not in (exclude or [])] |
|
|
if n > len(available): |
|
|
raise ValueError(f"Cannot sample {n} categories from subset, only {len(available)} available (subset size: {len(self.CATEGORIES)})") |
|
|
return random.sample(available, n) |
|
|
|
|
|
def sample_targets(self, n: int, exclude: Optional[List[int]] = None) -> List[int]: |
|
|
""" |
|
|
Sample n unique random targets from the active subset. |
|
|
|
|
|
Args: |
|
|
n: Number of targets to sample |
|
|
exclude: Optional list of targets to exclude |
|
|
|
|
|
Returns: |
|
|
List of sampled target IDs corresponding to categories in the subset |
|
|
""" |
|
|
|
|
|
available_targets = [self.category_to_target[cat] for cat in self.CATEGORIES] |
|
|
available = [t for t in available_targets if t not in (exclude or [])] |
|
|
if n > len(available): |
|
|
raise ValueError(f"Cannot sample {n} targets from subset, only {len(available)} available (subset size: {len(self.CATEGORIES)})") |
|
|
return random.sample(available, n) |
|
|
|
|
|
def sample_file_from_category(self, category: str) -> Tuple[str, str]: |
|
|
""" |
|
|
Sample a random audio file from a category. |
|
|
|
|
|
Args: |
|
|
category: Sound category name |
|
|
|
|
|
Returns: |
|
|
Tuple of (filename, full_path) |
|
|
""" |
|
|
files = self.get_files_by_category(category) |
|
|
filename = random.choice(files) |
|
|
full_path = str(self.audio_path / filename) |
|
|
return filename, full_path |
|
|
|
|
|
def sample_file_from_target(self, target: int) -> Tuple[str, str, str]: |
|
|
""" |
|
|
Sample a random audio file from a target. |
|
|
|
|
|
Args: |
|
|
target: Target class ID |
|
|
|
|
|
Returns: |
|
|
Tuple of (filename, category, full_path) |
|
|
""" |
|
|
files = self.get_files_by_target(target) |
|
|
filename = random.choice(files) |
|
|
category = self.target_to_category[target] |
|
|
full_path = str(self.audio_path / filename) |
|
|
return filename, category, full_path |
|
|
|
|
|
def get_category_from_filename(self, filename: str) -> str: |
|
|
"""Get category name from filename.""" |
|
|
row = self.df[self.df['filename'] == filename] |
|
|
if len(row) == 0: |
|
|
raise ValueError(f"Unknown filename: {filename}") |
|
|
return row.iloc[0]['category'] |
|
|
|
|
|
def get_file_path(self, filename: str) -> str: |
|
|
"""Get full path for a filename.""" |
|
|
return str(self.audio_path / filename) |
|
|
|
|
|
def sample_categories_balanced(self, n: int, exclude: Optional[List[str]] = None, |
|
|
answer_category: Optional[str] = None) -> List[str]: |
|
|
""" |
|
|
Sample n unique categories with balanced usage tracking. |
|
|
|
|
|
This method ensures that over many samples, all categories appear |
|
|
roughly equally as answers by preferentially sampling underused categories. |
|
|
|
|
|
Args: |
|
|
n: Number of categories to sample |
|
|
exclude: Optional list of categories to exclude |
|
|
answer_category: If provided, ensures this category is included and tracks it |
|
|
|
|
|
Returns: |
|
|
List of sampled category names with answer_category first if provided |
|
|
""" |
|
|
available = [c for c in self.CATEGORIES if c not in (exclude or [])] |
|
|
if n > len(available): |
|
|
raise ValueError(f"Cannot sample {n} categories, only {len(available)} available") |
|
|
|
|
|
if answer_category: |
|
|
|
|
|
self.category_usage_counts[answer_category] += 1 |
|
|
|
|
|
|
|
|
available = [c for c in available if c != answer_category] |
|
|
other_categories = random.sample(available, n - 1) |
|
|
return [answer_category] + other_categories |
|
|
else: |
|
|
|
|
|
return random.sample(available, n) |
|
|
|
|
|
def get_least_used_categories(self, n: int, exclude: Optional[List[str]] = None) -> List[str]: |
|
|
""" |
|
|
Get n categories that have been used least as answers. |
|
|
|
|
|
Args: |
|
|
n: Number of categories to get |
|
|
exclude: Optional list of categories to exclude |
|
|
|
|
|
Returns: |
|
|
List of least-used category names |
|
|
""" |
|
|
available = [c for c in self.CATEGORIES if c not in (exclude or [])] |
|
|
if n > len(available): |
|
|
raise ValueError(f"Cannot get {n} categories, only {len(available)} available") |
|
|
|
|
|
|
|
|
sorted_categories = sorted(available, key=lambda c: self.category_usage_counts[c]) |
|
|
|
|
|
|
|
|
min_count = self.category_usage_counts[sorted_categories[0]] |
|
|
candidates = [c for c in sorted_categories if self.category_usage_counts[c] == min_count] |
|
|
|
|
|
if len(candidates) >= n: |
|
|
|
|
|
return random.sample(candidates, n) |
|
|
else: |
|
|
|
|
|
result = candidates.copy() |
|
|
remaining = n - len(result) |
|
|
next_tier = [c for c in sorted_categories if c not in candidates][:remaining] |
|
|
result.extend(next_tier) |
|
|
return result |
|
|
|
|
|
def get_category_usage_stats(self) -> Dict[str, int]: |
|
|
"""Get current category usage statistics.""" |
|
|
return self.category_usage_counts.copy() |
|
|
|
|
|
def reset_category_usage(self): |
|
|
"""Reset category usage tracking.""" |
|
|
self.category_usage_counts = {cat: 0 for cat in self.CATEGORIES} |
|
|
logger.info("Reset category usage tracking") |
|
|
|
|
|
|
|
|
class PreprocessedESC50Dataset(ESC50Dataset): |
|
|
""" |
|
|
Handler for preprocessed ESC-50 dataset with effective durations. |
|
|
|
|
|
Extends ESC50Dataset to use trimmed audio files and effective duration |
|
|
metadata from amplitude-based preprocessing. |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
metadata_path: str, |
|
|
audio_path: str, |
|
|
preprocessed_path: str, |
|
|
config: Optional[dict] = None |
|
|
): |
|
|
""" |
|
|
Initialize preprocessed ESC-50 dataset handler. |
|
|
|
|
|
Args: |
|
|
metadata_path: Path to original esc50.csv metadata file |
|
|
audio_path: Path to original audio directory (fallback) |
|
|
preprocessed_path: Path to preprocessed data directory |
|
|
config: Optional configuration dict with dataset.use_class_subset settings |
|
|
""" |
|
|
super().__init__(metadata_path, audio_path, config) |
|
|
|
|
|
self.preprocessed_path = Path(preprocessed_path) |
|
|
self.trimmed_audio_path = self.preprocessed_path / "trimmed_audio" |
|
|
self.effective_durations_path = self.preprocessed_path / "effective_durations.csv" |
|
|
|
|
|
|
|
|
self.effective_df = None |
|
|
self.load_effective_durations() |
|
|
|
|
|
def load_effective_durations(self): |
|
|
"""Load effective durations from preprocessed CSV.""" |
|
|
try: |
|
|
self.effective_df = pd.read_csv(self.effective_durations_path) |
|
|
logger.info(f"Loaded effective durations for {len(self.effective_df)} clips") |
|
|
|
|
|
|
|
|
self.filename_to_effective = dict( |
|
|
zip(self.effective_df['filename'], self.effective_df['effective_duration_s']) |
|
|
) |
|
|
self.filename_to_category = dict( |
|
|
zip(self.effective_df['filename'], self.effective_df['category']) |
|
|
) |
|
|
|
|
|
|
|
|
self.category_effective_stats = self.effective_df.groupby('category').agg({ |
|
|
'effective_duration_s': ['mean', 'std', 'min', 'max', 'count'] |
|
|
}).round(4) |
|
|
self.category_effective_stats.columns = ['mean', 'std', 'min', 'max', 'count'] |
|
|
|
|
|
logger.info("Created effective duration lookup tables") |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error loading effective durations: {e}") |
|
|
raise |
|
|
|
|
|
def get_effective_duration(self, filename: str) -> float: |
|
|
""" |
|
|
Get effective duration for a specific file. |
|
|
|
|
|
Args: |
|
|
filename: Audio filename |
|
|
|
|
|
Returns: |
|
|
Effective duration in seconds |
|
|
""" |
|
|
if filename not in self.filename_to_effective: |
|
|
logger.warning(f"No effective duration for {filename}, using default 5.0s") |
|
|
return 5.0 |
|
|
return self.filename_to_effective[filename] |
|
|
|
|
|
def get_category_effective_stats(self, category: str) -> Dict: |
|
|
""" |
|
|
Get effective duration statistics for a category. |
|
|
|
|
|
Args: |
|
|
category: Category name |
|
|
|
|
|
Returns: |
|
|
Dict with mean, std, min, max, count |
|
|
""" |
|
|
if category not in self.category_effective_stats.index: |
|
|
return {'mean': 5.0, 'std': 0.0, 'min': 5.0, 'max': 5.0, 'count': 0} |
|
|
|
|
|
stats = self.category_effective_stats.loc[category] |
|
|
return { |
|
|
'mean': stats['mean'], |
|
|
'std': stats['std'], |
|
|
'min': stats['min'], |
|
|
'max': stats['max'], |
|
|
'count': int(stats['count']) |
|
|
} |
|
|
|
|
|
def get_files_by_category_with_durations(self, category: str) -> List[Dict]: |
|
|
""" |
|
|
Get all files for a category with their effective durations. |
|
|
|
|
|
Args: |
|
|
category: Category name |
|
|
|
|
|
Returns: |
|
|
List of dicts with filename, effective_duration_s, filepath |
|
|
""" |
|
|
cat_df = self.effective_df[self.effective_df['category'] == category] |
|
|
|
|
|
results = [] |
|
|
for _, row in cat_df.iterrows(): |
|
|
results.append({ |
|
|
'filename': row['filename'], |
|
|
'effective_duration_s': row['effective_duration_s'], |
|
|
'filepath': str(self.trimmed_audio_path / row['filename']), |
|
|
'raw_duration_s': row['raw_duration_s'], |
|
|
'peak_amplitude_db': row['peak_amplitude_db'] |
|
|
}) |
|
|
|
|
|
return results |
|
|
|
|
|
def sample_file_from_category_with_duration( |
|
|
self, |
|
|
category: str, |
|
|
min_effective_duration: float = None, |
|
|
max_effective_duration: float = None |
|
|
) -> Tuple[str, str, float]: |
|
|
""" |
|
|
Sample a file from category with optional duration constraints. |
|
|
|
|
|
Args: |
|
|
category: Category name |
|
|
min_effective_duration: Minimum effective duration (optional) |
|
|
max_effective_duration: Maximum effective duration (optional) |
|
|
|
|
|
Returns: |
|
|
Tuple of (filename, filepath, effective_duration_s) |
|
|
""" |
|
|
files = self.get_files_by_category_with_durations(category) |
|
|
|
|
|
|
|
|
if min_effective_duration is not None: |
|
|
files = [f for f in files if f['effective_duration_s'] >= min_effective_duration] |
|
|
if max_effective_duration is not None: |
|
|
files = [f for f in files if f['effective_duration_s'] <= max_effective_duration] |
|
|
|
|
|
if not files: |
|
|
|
|
|
logger.warning(f"No files match duration constraints for {category}, using any file") |
|
|
files = self.get_files_by_category_with_durations(category) |
|
|
|
|
|
selected = random.choice(files) |
|
|
return selected['filename'], selected['filepath'], selected['effective_duration_s'] |
|
|
|
|
|
def sample_files_from_category_to_reach_duration( |
|
|
self, |
|
|
category: str, |
|
|
target_duration_s: float, |
|
|
prefer_same_file: bool = True |
|
|
) -> Tuple[List[str], List[str], float]: |
|
|
""" |
|
|
Sample files from a category to reach a target total effective duration. |
|
|
|
|
|
Args: |
|
|
category: Category name |
|
|
target_duration_s: Target total effective duration |
|
|
prefer_same_file: If True, try repeating same file first |
|
|
|
|
|
Returns: |
|
|
Tuple of (filenames_list, filepaths_list, actual_total_duration_s) |
|
|
""" |
|
|
files = self.get_files_by_category_with_durations(category) |
|
|
|
|
|
if not files: |
|
|
raise ValueError(f"No files found for category: {category}") |
|
|
|
|
|
selected_filenames = [] |
|
|
selected_filepaths = [] |
|
|
total_duration = 0.0 |
|
|
|
|
|
if prefer_same_file: |
|
|
|
|
|
files_sorted = sorted(files, key=lambda x: x['effective_duration_s'], reverse=True) |
|
|
selected_file = files_sorted[0] |
|
|
|
|
|
|
|
|
reps_needed = max(1, int(target_duration_s / selected_file['effective_duration_s']) + 1) |
|
|
|
|
|
for _ in range(reps_needed): |
|
|
selected_filenames.append(selected_file['filename']) |
|
|
selected_filepaths.append(selected_file['filepath']) |
|
|
total_duration += selected_file['effective_duration_s'] |
|
|
|
|
|
if total_duration >= target_duration_s: |
|
|
break |
|
|
else: |
|
|
|
|
|
random.shuffle(files) |
|
|
file_idx = 0 |
|
|
|
|
|
while total_duration < target_duration_s: |
|
|
selected_file = files[file_idx % len(files)] |
|
|
selected_filenames.append(selected_file['filename']) |
|
|
selected_filepaths.append(selected_file['filepath']) |
|
|
total_duration += selected_file['effective_duration_s'] |
|
|
file_idx += 1 |
|
|
|
|
|
|
|
|
if file_idx > 100: |
|
|
logger.warning(f"Hit safety limit when sampling files for {category}") |
|
|
break |
|
|
|
|
|
return selected_filenames, selected_filepaths, total_duration |
|
|
|
|
|
def get_categories_sorted_by_effective_duration(self, ascending: bool = True) -> List[str]: |
|
|
""" |
|
|
Get categories sorted by their mean effective duration. |
|
|
|
|
|
Args: |
|
|
ascending: If True, shortest first; if False, longest first |
|
|
|
|
|
Returns: |
|
|
List of category names sorted by mean effective duration |
|
|
""" |
|
|
sorted_stats = self.category_effective_stats.sort_values('mean', ascending=ascending) |
|
|
return sorted_stats.index.tolist() |
|
|
|
|
|
|