|
|
""" |
|
|
Task 1: Count - Generate counting questions |
|
|
|
|
|
This task joins multiple audio sources and asks questions about counting |
|
|
the number of unique sound sources in the audio. |
|
|
""" |
|
|
|
|
|
import csv |
|
|
import random |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
|
|
|
import sys |
|
|
sys.path.append(str(Path(__file__).parent.parent)) |
|
|
|
|
|
from utils import ( |
|
|
AudioProcessor, ESC50Dataset, QuestionGenerator, LLMQuestionGenerator, |
|
|
setup_logger, set_random_seed, generate_sample_durations_for_task, |
|
|
generate_single_clip_duration, build_count_task_audio, |
|
|
get_max_clip_num_to_be_joined |
|
|
) |
|
|
|
|
|
|
|
|
class CountTaskGenerator: |
|
|
"""Generator for counting task dataset.""" |
|
|
|
|
|
def __init__(self, config: Dict, logger): |
|
|
""" |
|
|
Initialize count task generator. |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary |
|
|
logger: Logger instance |
|
|
""" |
|
|
self.config = config |
|
|
self.logger = logger |
|
|
self.task_config = config['tasks']['count'] |
|
|
|
|
|
|
|
|
self.dataset = ESC50Dataset( |
|
|
config['esc50']['metadata_path'], |
|
|
config['esc50']['audio_path'], |
|
|
config |
|
|
) |
|
|
self.audio_processor = AudioProcessor( |
|
|
crossfade_duration=config['audio']['crossfade_duration'], |
|
|
silence_duration=config['audio']['silence_duration'], |
|
|
with_silence=config['audio']['with_silence'], |
|
|
normalize=config['audio']['normalize'], |
|
|
normalize_target_dBFS=config['audio']['normalize_target_dBFS'], |
|
|
synthetic_silence_path=config['synthetic_silence']['path'] |
|
|
) |
|
|
self.question_generator = QuestionGenerator( |
|
|
num_options=config['mcq']['num_options'], |
|
|
option_labels=config['mcq']['option_labels'], |
|
|
distractor_strategy=config['mcq']['distractor_strategy'] |
|
|
) |
|
|
|
|
|
|
|
|
self.llm_enabled = config.get('llm', {}).get('enabled', False) |
|
|
self.llm_generator = LLMQuestionGenerator( |
|
|
enabled=self.llm_enabled, |
|
|
template_questions=self.task_config |
|
|
) |
|
|
if self.llm_enabled: |
|
|
logger.info("LLM question generation enabled (local Llama 3.1 8B)") |
|
|
else: |
|
|
logger.info("Using template-based question generation") |
|
|
|
|
|
|
|
|
self.min_clip_duration = config['audio']['min_clip_duration'] |
|
|
self.max_clip_duration = config['audio']['max_clip_duration'] |
|
|
self.source_clip_duration = config['audio'].get('source_clip_duration', 5.0) |
|
|
self.min_silence_ms = config['audio'].get('min_silence_duration', 100) |
|
|
self.max_extra_silence_per_gap_ms = config['audio'].get('max_extra_silence_per_gap', 500) |
|
|
|
|
|
self.crossfade_within_source_ms = config['audio'].get('crossfade_within_source', 50) |
|
|
self.task_duration_hours = self.task_config['task_duration_size'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.ordering_mode = self.task_config.get('ordering_mode', 'random') |
|
|
logger.info(f"Count task ordering mode: {self.ordering_mode}") |
|
|
|
|
|
|
|
|
self.output_base = Path(config['output']['base_path']) / 'count' |
|
|
self.output_base.mkdir(parents=True, exist_ok=True) |
|
|
self.audio_output = self.output_base / 'audios' |
|
|
self.audio_output.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def create_sampling_list(self, parent_list: List, n_sampling: int) -> List: |
|
|
""" |
|
|
Sample elements from parent list with replacement. |
|
|
|
|
|
Args: |
|
|
parent_list: List to sample from |
|
|
n_sampling: Number of samples |
|
|
|
|
|
Returns: |
|
|
List of sampled elements |
|
|
""" |
|
|
return [random.choice(parent_list) for _ in range(n_sampling)] |
|
|
|
|
|
def generate_sample(self, sample_id: int, target_unique_count: int = None, target_duration_seconds: float = None) -> Dict: |
|
|
""" |
|
|
Generate a single count task sample. |
|
|
|
|
|
Pipeline for COUNT task: |
|
|
1. Use pre-generated target duration (or generate if not provided) |
|
|
2. Calculate max clips that can fit |
|
|
3. Pick N unique classes (N <= max_clips, since each source needs at least 1 clip) |
|
|
4. For each class, sample one audio clip |
|
|
5. Calculate repetitions to fill target duration |
|
|
6. Based on ordering_mode: |
|
|
- "random": Shuffle clips (A B A C B A C) - tests recognition |
|
|
- "consecutive": Group same-class (AAA BBB CCC) - easier |
|
|
7. Insert silences between clips |
|
|
8. Distribute remainder as random extra silences |
|
|
|
|
|
Args: |
|
|
sample_id: Sample ID number |
|
|
target_unique_count: Target number of unique sounds (for balanced distribution) |
|
|
target_duration_seconds: Pre-generated target duration (from generate_sample_durations_for_task) |
|
|
|
|
|
Returns: |
|
|
Dictionary with sample metadata |
|
|
""" |
|
|
|
|
|
if target_duration_seconds is not None: |
|
|
clip_duration_seconds = target_duration_seconds |
|
|
else: |
|
|
clip_duration_seconds = generate_single_clip_duration( |
|
|
self.min_clip_duration, |
|
|
self.max_clip_duration |
|
|
) |
|
|
|
|
|
|
|
|
max_clips, remainder_seconds = get_max_clip_num_to_be_joined( |
|
|
clip_duration_seconds, |
|
|
self.source_clip_duration, |
|
|
self.min_silence_ms |
|
|
) |
|
|
|
|
|
|
|
|
max_clips = max(1, max_clips) |
|
|
|
|
|
max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10) |
|
|
|
|
|
|
|
|
|
|
|
max_unique_for_sample = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES)) |
|
|
|
|
|
if max_unique_for_sample < 1: |
|
|
raise ValueError( |
|
|
f"Sample {sample_id}: Cannot generate sample - max_unique_for_sample={max_unique_for_sample}. " |
|
|
f"max_clips={max_clips}, max_clips_per_sample={max_clips_per_sample}, " |
|
|
f"available_categories={len(self.dataset.CATEGORIES)}, duration={clip_duration_seconds:.1f}s. " |
|
|
f"Increase min_clip_duration or reduce max_clips_per_sample." |
|
|
) |
|
|
|
|
|
|
|
|
if target_unique_count is not None: |
|
|
|
|
|
|
|
|
n_unique_audios = min(target_unique_count, max_unique_for_sample) |
|
|
|
|
|
if n_unique_audios != target_unique_count: |
|
|
self.logger.debug( |
|
|
f"Sample {sample_id}: Clamped target from {target_unique_count} to {n_unique_audios} " |
|
|
f"(duration={clip_duration_seconds:.1f}s can only fit {max_clips} clips)" |
|
|
) |
|
|
else: |
|
|
|
|
|
n_unique_audios = random.randint(1, max_unique_for_sample) |
|
|
|
|
|
self.logger.debug( |
|
|
f"Sample {sample_id}: target={clip_duration_seconds:.1f}s, max_clips={max_clips}, " |
|
|
f"n_unique_audios={n_unique_audios}" |
|
|
) |
|
|
|
|
|
|
|
|
selected_categories = self.dataset.get_least_used_categories(n_unique_audios) |
|
|
|
|
|
|
|
|
for cat in selected_categories: |
|
|
self.dataset.category_usage_counts[cat] += 1 |
|
|
|
|
|
|
|
|
source_files = [] |
|
|
source_paths = [] |
|
|
source_categories = [] |
|
|
|
|
|
for category in selected_categories: |
|
|
filename, filepath = self.dataset.sample_file_from_category(category) |
|
|
source_files.append(filename) |
|
|
source_paths.append(filepath) |
|
|
source_categories.append(category) |
|
|
|
|
|
|
|
|
source_audios = [] |
|
|
for file_path in source_paths: |
|
|
audio = self.audio_processor.load_audio(file_path) |
|
|
source_audios.append(audio) |
|
|
|
|
|
|
|
|
final_audio, clip_sequence, build_metadata = build_count_task_audio( |
|
|
source_audios, |
|
|
source_categories, |
|
|
clip_duration_seconds, |
|
|
ordering_mode=self.ordering_mode, |
|
|
source_clip_duration_seconds=self.source_clip_duration, |
|
|
min_silence_ms=self.min_silence_ms, |
|
|
max_extra_silence_per_gap_ms=self.max_extra_silence_per_gap_ms, |
|
|
crossfade_within_source_ms=self.crossfade_within_source_ms |
|
|
) |
|
|
|
|
|
|
|
|
output_audio_path = self.audio_output / f"{sample_id}.wav" |
|
|
final_audio.export(str(output_audio_path), format="wav") |
|
|
|
|
|
|
|
|
if self.llm_enabled and self.llm_generator: |
|
|
llm_questions = self.llm_generator.generate_count_questions( |
|
|
correct_count=n_unique_audios, |
|
|
categories_present=list(set(clip_sequence)) |
|
|
) |
|
|
mcq_question_text = llm_questions.get('mcq_question') |
|
|
open_text_question_text = llm_questions.get('open_text_question') |
|
|
else: |
|
|
mcq_question_text = random.choice(self.task_config['mcq_questions']) |
|
|
open_text_question_text = random.choice(self.task_config['open_text_questions']) |
|
|
|
|
|
|
|
|
mcq_data = self.question_generator.generate_count_mcq( |
|
|
mcq_question_text, |
|
|
n_unique_audios, |
|
|
self.dataset.CATEGORIES |
|
|
) |
|
|
|
|
|
|
|
|
open_text_data = self.question_generator.generate_count_open_text( |
|
|
open_text_question_text, |
|
|
n_unique_audios |
|
|
) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'id': sample_id, |
|
|
'audio_path': str(output_audio_path.relative_to(self.output_base.parent)), |
|
|
'n_unique_sounds': n_unique_audios, |
|
|
'total_clips': build_metadata['total_clips'], |
|
|
'repetitions_per_source': build_metadata['repetitions_per_source'], |
|
|
'ordering_mode': self.ordering_mode, |
|
|
'source_files': source_files, |
|
|
'source_categories': source_categories, |
|
|
'clip_sequence': clip_sequence, |
|
|
'unique_categories': sorted(list(set(source_categories))), |
|
|
'target_duration_seconds': clip_duration_seconds, |
|
|
'actual_duration_seconds': len(final_audio) / 1000.0, |
|
|
'mcq_question': mcq_data['question'], |
|
|
'mcq_options': mcq_data['options'], |
|
|
'mcq_correct_answer': mcq_data['correct_answer'], |
|
|
'open_text_question': open_text_data['question'], |
|
|
'open_text_answer': open_text_data['correct_answer'], |
|
|
'llm_generated': self.llm_enabled |
|
|
} |
|
|
|
|
|
self.logger.info( |
|
|
f"Generated count sample {sample_id}: {n_unique_audios} unique sounds, " |
|
|
f"{build_metadata['total_clips']} clips, {len(final_audio)/1000:.1f}s" |
|
|
) |
|
|
|
|
|
return metadata |
|
|
|
|
|
def generate_dataset(self) -> tuple: |
|
|
""" |
|
|
Generate the complete count task dataset. |
|
|
|
|
|
Returns: |
|
|
Tuple of (mcq_csv_path, open_text_csv_path) |
|
|
""" |
|
|
|
|
|
sample_durations = generate_sample_durations_for_task( |
|
|
self.task_duration_hours, |
|
|
self.min_clip_duration, |
|
|
self.max_clip_duration |
|
|
) |
|
|
num_samples = len(sample_durations) |
|
|
self.logger.info(f"Generating {num_samples} count task samples (target: {self.task_duration_hours}h, actual: {sum(sample_durations)/3600:.2f}h)...") |
|
|
|
|
|
|
|
|
max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10) |
|
|
sample_max_clips = [] |
|
|
for duration in sample_durations: |
|
|
max_clips, _ = get_max_clip_num_to_be_joined( |
|
|
duration, |
|
|
self.source_clip_duration, |
|
|
self.min_silence_ms |
|
|
) |
|
|
|
|
|
max_for_sample = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES)) |
|
|
sample_max_clips.append(max_for_sample) |
|
|
|
|
|
|
|
|
|
|
|
possible_answers = list(range(1, max_clips_per_sample + 1)) |
|
|
samples_per_answer = num_samples // len(possible_answers) |
|
|
remainder = num_samples % len(possible_answers) |
|
|
|
|
|
|
|
|
sample_info = [(i, sample_durations[i], sample_max_clips[i]) for i in range(num_samples)] |
|
|
|
|
|
|
|
|
sample_info.sort(key=lambda x: x[2], reverse=True) |
|
|
|
|
|
|
|
|
balanced_assignments = [None] * num_samples |
|
|
assignment_pool = [] |
|
|
|
|
|
for answer in possible_answers: |
|
|
count = samples_per_answer + (1 if remainder > 0 else 0) |
|
|
assignment_pool.extend([answer] * count) |
|
|
remainder = max(0, remainder - 1) |
|
|
|
|
|
|
|
|
assignment_pool.sort(reverse=True) |
|
|
|
|
|
for idx, (sample_idx, duration, capacity) in enumerate(sample_info): |
|
|
|
|
|
target = min(assignment_pool[idx], capacity) |
|
|
balanced_assignments[sample_idx] = target |
|
|
|
|
|
|
|
|
from collections import Counter |
|
|
distribution = Counter(balanced_assignments) |
|
|
self.logger.info(f"Balanced answer distribution (after capacity-aware assignment): {dict(sorted(distribution.items()))}") |
|
|
|
|
|
all_metadata = [] |
|
|
|
|
|
for i in range(num_samples): |
|
|
metadata = self.generate_sample( |
|
|
i, |
|
|
target_unique_count=balanced_assignments[i], |
|
|
target_duration_seconds=sample_durations[i] |
|
|
) |
|
|
all_metadata.append(metadata) |
|
|
|
|
|
|
|
|
mcq_csv_path = self.output_base / 'count_mcq.csv' |
|
|
self._save_mcq_csv(all_metadata, mcq_csv_path) |
|
|
|
|
|
|
|
|
open_text_csv_path = self.output_base / 'count_open_text.csv' |
|
|
self._save_open_text_csv(all_metadata, open_text_csv_path) |
|
|
|
|
|
|
|
|
metadata_csv_path = self.output_base / 'count_metadata.csv' |
|
|
self._save_metadata_csv(all_metadata, metadata_csv_path) |
|
|
|
|
|
self.logger.info(f"Count task dataset generation complete!") |
|
|
self.logger.info(f" - MCQ CSV: {mcq_csv_path}") |
|
|
self.logger.info(f" - Open-text CSV: {open_text_csv_path}") |
|
|
self.logger.info(f" - Metadata CSV: {metadata_csv_path}") |
|
|
self.logger.info(f" - Audio files: {self.audio_output}") |
|
|
|
|
|
return mcq_csv_path, open_text_csv_path |
|
|
|
|
|
def _save_mcq_csv(self, metadata_list: List[Dict], output_path: Path): |
|
|
"""Save MCQ format CSV.""" |
|
|
with open(output_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
|
|
|
writer.writerow([ |
|
|
'question', 'id', 'audio_path', |
|
|
'optionA', 'optionB', 'optionC', 'optionD', |
|
|
'correct', 'source_wavs', 'source_categories' |
|
|
]) |
|
|
|
|
|
|
|
|
for meta in metadata_list: |
|
|
writer.writerow([ |
|
|
meta['mcq_question'], |
|
|
meta['id'], |
|
|
meta['audio_path'], |
|
|
meta['mcq_options']['A'], |
|
|
meta['mcq_options']['B'], |
|
|
meta['mcq_options']['C'], |
|
|
meta['mcq_options']['D'], |
|
|
meta['mcq_correct_answer'], |
|
|
str(meta['source_files']), |
|
|
str(meta['unique_categories']) |
|
|
]) |
|
|
|
|
|
def _save_open_text_csv(self, metadata_list: List[Dict], output_path: Path): |
|
|
"""Save open-text format CSV.""" |
|
|
with open(output_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
|
|
|
writer.writerow([ |
|
|
'question', 'id', 'audio_path', 'answer', |
|
|
'source_wavs', 'source_categories' |
|
|
]) |
|
|
|
|
|
|
|
|
for meta in metadata_list: |
|
|
writer.writerow([ |
|
|
meta['open_text_question'], |
|
|
meta['id'], |
|
|
meta['audio_path'], |
|
|
meta['open_text_answer'], |
|
|
str(meta['source_files']), |
|
|
str(meta['unique_categories']) |
|
|
]) |
|
|
|
|
|
def _save_metadata_csv(self, metadata_list: List[Dict], output_path: Path): |
|
|
"""Save detailed metadata CSV.""" |
|
|
with open(output_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
|
|
|
writer.writerow([ |
|
|
'id', 'audio_path', 'total_clips', 'n_unique_sounds', |
|
|
'source_files', 'source_categories', 'unique_categories', |
|
|
'ordering_mode', 'target_duration_s', 'actual_duration_s', 'llm_generated' |
|
|
]) |
|
|
|
|
|
|
|
|
for meta in metadata_list: |
|
|
writer.writerow([ |
|
|
meta['id'], |
|
|
meta['audio_path'], |
|
|
meta['total_clips'], |
|
|
meta['n_unique_sounds'], |
|
|
str(meta['source_files']), |
|
|
str(meta['source_categories']), |
|
|
str(meta['unique_categories']), |
|
|
meta.get('ordering_mode', 'random'), |
|
|
meta.get('target_duration_seconds', 0), |
|
|
meta.get('actual_duration_seconds', 0), |
|
|
meta.get('llm_generated', False) |
|
|
]) |
|
|
|
|
|
|
|
|
def main(config_path: str = None): |
|
|
"""Main entry point for count task generation.""" |
|
|
import yaml |
|
|
|
|
|
|
|
|
if config_path is None: |
|
|
config_path = Path(__file__).parent.parent / 'config.yaml' |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
set_random_seed(config['random_seed']) |
|
|
|
|
|
|
|
|
logger = setup_logger( |
|
|
'count_task', |
|
|
log_file=str(Path(config['output']['base_path']) / config['logging']['log_file']), |
|
|
level=config['logging']['level'], |
|
|
console_output=config['logging']['console_output'] |
|
|
) |
|
|
|
|
|
|
|
|
generator = CountTaskGenerator(config, logger) |
|
|
generator.generate_dataset() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|