|
|
""" |
|
|
Main pipeline runner for temporal reasoning audio dataset generation. |
|
|
|
|
|
This script orchestrates the generation of all task datasets. |
|
|
""" |
|
|
|
|
|
import argparse |
|
|
import sys |
|
|
import yaml |
|
|
from pathlib import Path |
|
|
from typing import List, Optional |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent)) |
|
|
|
|
|
from utils import setup_logger, set_random_seed |
|
|
from tasks.task_count import CountTaskGenerator |
|
|
from tasks.task_duration import DurationTaskGenerator |
|
|
from tasks.task_order import OrderTaskGenerator |
|
|
from tasks.task_volume import VolumeTaskGenerator |
|
|
|
|
|
|
|
|
def load_config(config_path: str) -> dict: |
|
|
"""Load configuration from YAML file.""" |
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
return config |
|
|
|
|
|
|
|
|
def run_count_task(config: dict, logger): |
|
|
"""Run the count task generation.""" |
|
|
if not config['tasks']['count']['enabled']: |
|
|
logger.info("Count task is disabled, skipping...") |
|
|
return |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("STARTING COUNT TASK GENERATION") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
generator = CountTaskGenerator(config, logger) |
|
|
generator.dataset.reset_category_usage() |
|
|
generator.generate_dataset() |
|
|
|
|
|
|
|
|
usage_stats = generator.dataset.get_category_usage_stats() |
|
|
sorted_stats = sorted(usage_stats.items(), key=lambda x: x[1], reverse=True) |
|
|
logger.info("Category usage statistics (as answers):") |
|
|
logger.info(f" Min usage: {sorted_stats[-1][1]} (category: {sorted_stats[-1][0]})") |
|
|
logger.info(f" Max usage: {sorted_stats[0][1]} (category: {sorted_stats[0][0]})") |
|
|
logger.info(f" Mean usage: {sum(usage_stats.values()) / len(usage_stats):.2f}") |
|
|
|
|
|
logger.info("Count task completed successfully!") |
|
|
|
|
|
|
|
|
def run_duration_task(config: dict, logger): |
|
|
"""Run the duration task generation.""" |
|
|
if not config['tasks']['duration']['enabled']: |
|
|
logger.info("Duration task is disabled, skipping...") |
|
|
return |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("STARTING DURATION TASK GENERATION") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
generator = DurationTaskGenerator(config, logger) |
|
|
generator.dataset.reset_category_usage() |
|
|
generator.generate_dataset() |
|
|
|
|
|
|
|
|
usage_stats = generator.dataset.get_category_usage_stats() |
|
|
sorted_stats = sorted(usage_stats.items(), key=lambda x: x[1], reverse=True) |
|
|
logger.info("Category usage statistics (as longest/shortest answers):") |
|
|
logger.info(f" Min usage: {sorted_stats[-1][1]} (category: {sorted_stats[-1][0]})") |
|
|
logger.info(f" Max usage: {sorted_stats[0][1]} (category: {sorted_stats[0][0]})") |
|
|
logger.info(f" Mean usage: {sum(usage_stats.values()) / len(usage_stats):.2f}") |
|
|
|
|
|
logger.info("Duration task completed successfully!") |
|
|
|
|
|
|
|
|
def run_order_task(config: dict, logger): |
|
|
"""Run the order task generation.""" |
|
|
if not config['tasks']['order']['enabled']: |
|
|
logger.info("Order task is disabled, skipping...") |
|
|
return |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("STARTING ORDER TASK GENERATION") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
generator = OrderTaskGenerator(config, logger) |
|
|
generator.dataset.reset_category_usage() |
|
|
generator.generate_dataset() |
|
|
|
|
|
|
|
|
usage_stats = generator.dataset.get_category_usage_stats() |
|
|
sorted_stats = sorted(usage_stats.items(), key=lambda x: x[1], reverse=True) |
|
|
logger.info("Category usage statistics (as first/last/after/before answers):") |
|
|
logger.info(f" Min usage: {sorted_stats[-1][1]} (category: {sorted_stats[-1][0]})") |
|
|
logger.info(f" Max usage: {sorted_stats[0][1]} (category: {sorted_stats[0][0]})") |
|
|
logger.info(f" Mean usage: {sum(usage_stats.values()) / len(usage_stats):.2f}") |
|
|
|
|
|
logger.info("Order task completed successfully!") |
|
|
|
|
|
|
|
|
def run_volume_task(config: dict, logger): |
|
|
"""Run the volume task generation.""" |
|
|
if not config['tasks']['volume']['enabled']: |
|
|
logger.info("Volume task is disabled, skipping...") |
|
|
return |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("STARTING VOLUME TASK GENERATION") |
|
|
logger.info("=" * 80) |
|
|
|
|
|
generator = VolumeTaskGenerator(config, logger) |
|
|
generator.dataset.reset_category_usage() |
|
|
generator.generate_dataset() |
|
|
|
|
|
|
|
|
usage_stats = generator.dataset.get_category_usage_stats() |
|
|
sorted_stats = sorted(usage_stats.items(), key=lambda x: x[1], reverse=True) |
|
|
logger.info("Category usage statistics (as loudest/softest answers):") |
|
|
logger.info(f" Min usage: {sorted_stats[-1][1]} (category: {sorted_stats[-1][0]})") |
|
|
logger.info(f" Max usage: {sorted_stats[0][1]} (category: {sorted_stats[0][0]})") |
|
|
logger.info(f" Mean usage: {sum(usage_stats.values()) / len(usage_stats):.2f}") |
|
|
|
|
|
logger.info("Volume task completed successfully!") |
|
|
|
|
|
|
|
|
def run_pipeline( |
|
|
config_path: str, |
|
|
tasks: Optional[List[str]] = None, |
|
|
output_path: Optional[str] = None |
|
|
): |
|
|
""" |
|
|
Run the complete dataset generation pipeline. |
|
|
|
|
|
Args: |
|
|
config_path: Path to configuration YAML file |
|
|
tasks: Optional list of specific tasks to run (default: all enabled tasks) |
|
|
output_path: Optional custom output path (overrides config) |
|
|
""" |
|
|
|
|
|
config = load_config(config_path) |
|
|
|
|
|
|
|
|
if output_path: |
|
|
config['output']['base_path'] = output_path |
|
|
|
|
|
|
|
|
output_base = Path(config['output']['base_path']) |
|
|
output_base.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
|
|
|
set_random_seed(config['random_seed']) |
|
|
|
|
|
|
|
|
logger = setup_logger( |
|
|
'pipeline', |
|
|
log_file=str(output_base / config['logging']['log_file']), |
|
|
level=config['logging']['level'], |
|
|
console_output=config['logging']['console_output'] |
|
|
) |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("TEMPORAL REASONING AUDIO DATASET GENERATION PIPELINE") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"Configuration: {config_path}") |
|
|
logger.info(f"Output directory: {output_base}") |
|
|
logger.info(f"Random seed: {config['random_seed']}") |
|
|
logger.info(f"ESC-50 audio path: {config['esc50']['audio_path']}") |
|
|
logger.info(f"ESC-50 metadata path: {config['esc50']['metadata_path']}") |
|
|
|
|
|
|
|
|
task_map = { |
|
|
'count': run_count_task, |
|
|
'duration': run_duration_task, |
|
|
'order': run_order_task, |
|
|
'volume': run_volume_task |
|
|
} |
|
|
|
|
|
if tasks: |
|
|
tasks_to_run = {k: v for k, v in task_map.items() if k in tasks} |
|
|
logger.info(f"Running specific tasks: {', '.join(tasks)}") |
|
|
else: |
|
|
tasks_to_run = task_map |
|
|
logger.info("Running all enabled tasks") |
|
|
|
|
|
|
|
|
for task_name, task_func in tasks_to_run.items(): |
|
|
try: |
|
|
task_func(config, logger) |
|
|
except Exception as e: |
|
|
logger.error(f"Error running {task_name} task: {e}", exc_info=True) |
|
|
raise |
|
|
|
|
|
logger.info("=" * 80) |
|
|
logger.info("PIPELINE COMPLETED SUCCESSFULLY!") |
|
|
logger.info("=" * 80) |
|
|
logger.info(f"All outputs saved to: {output_base}") |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main entry point with argument parsing.""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Temporal Reasoning Audio Dataset Generation Pipeline", |
|
|
formatter_class=argparse.RawDescriptionHelpFormatter, |
|
|
epilog=""" |
|
|
Examples: |
|
|
# Run all tasks with default config |
|
|
python main.py |
|
|
|
|
|
# Run with custom config |
|
|
python main.py --config my_config.yaml |
|
|
|
|
|
# Run specific tasks only |
|
|
python main.py --tasks count duration |
|
|
|
|
|
# Use custom output directory |
|
|
python main.py --output /path/to/output |
|
|
|
|
|
# Combine options |
|
|
python main.py --config custom.yaml --tasks count order --output ./my_dataset |
|
|
""" |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--config', '-c', |
|
|
type=str, |
|
|
default='config.yaml', |
|
|
help='Path to configuration YAML file (default: config.yaml)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--tasks', '-t', |
|
|
nargs='+', |
|
|
choices=['count', 'duration', 'order', 'volume'], |
|
|
help='Specific tasks to run (default: all enabled tasks)' |
|
|
) |
|
|
|
|
|
parser.add_argument( |
|
|
'--output', '-o', |
|
|
type=str, |
|
|
help='Custom output directory (overrides config)' |
|
|
) |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
config_path = Path(args.config) |
|
|
if not config_path.exists(): |
|
|
|
|
|
script_dir = Path(__file__).parent |
|
|
config_path = script_dir / args.config |
|
|
if not config_path.exists(): |
|
|
print(f"Error: Config file not found: {args.config}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
try: |
|
|
run_pipeline( |
|
|
config_path=str(config_path), |
|
|
tasks=args.tasks, |
|
|
output_path=args.output |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Pipeline failed with error: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|