|
|
""" |
|
|
Task 3: Order - Generate temporal ordering questions |
|
|
|
|
|
This task joins multiple audio sources and asks questions about their temporal order |
|
|
(first, last, what comes after, what comes before). |
|
|
""" |
|
|
|
|
|
import csv |
|
|
import random |
|
|
import math |
|
|
from pathlib import Path |
|
|
from typing import Dict, List |
|
|
|
|
|
import sys |
|
|
sys.path.append(str(Path(__file__).parent.parent)) |
|
|
|
|
|
from utils import ( |
|
|
AudioProcessor, ESC50Dataset, QuestionGenerator, LLMQuestionGenerator, |
|
|
setup_logger, set_random_seed, calculate_num_samples_for_task, |
|
|
generate_single_clip_duration, get_max_clip_num_to_be_joined, |
|
|
build_clip_sequence_with_silences, generate_sample_durations_for_task |
|
|
) |
|
|
|
|
|
|
|
|
class OrderTaskGenerator: |
|
|
"""Generator for temporal ordering task dataset.""" |
|
|
|
|
|
def __init__(self, config: Dict, logger): |
|
|
""" |
|
|
Initialize order task generator. |
|
|
|
|
|
Args: |
|
|
config: Configuration dictionary |
|
|
logger: Logger instance |
|
|
""" |
|
|
self.config = config |
|
|
self.logger = logger |
|
|
self.task_config = config['tasks']['order'] |
|
|
|
|
|
|
|
|
self.dataset = ESC50Dataset( |
|
|
config['esc50']['metadata_path'], |
|
|
config['esc50']['audio_path'], |
|
|
config |
|
|
) |
|
|
self.audio_processor = AudioProcessor( |
|
|
crossfade_duration=config['audio']['crossfade_duration'], |
|
|
silence_duration=config['audio']['silence_duration'], |
|
|
with_silence=config['audio']['with_silence'], |
|
|
normalize=config['audio']['normalize'], |
|
|
normalize_target_dBFS=config['audio']['normalize_target_dBFS'], |
|
|
synthetic_silence_path=config['synthetic_silence']['path'] |
|
|
) |
|
|
self.question_generator = QuestionGenerator( |
|
|
num_options=config['mcq']['num_options'], |
|
|
option_labels=config['mcq']['option_labels'], |
|
|
distractor_strategy=config['mcq']['distractor_strategy'] |
|
|
) |
|
|
|
|
|
|
|
|
self.llm_enabled = config.get('llm', {}).get('enabled', False) |
|
|
self.llm_generator = LLMQuestionGenerator( |
|
|
enabled=self.llm_enabled, |
|
|
template_questions=self.task_config |
|
|
) |
|
|
|
|
|
|
|
|
self.min_clip_duration = config['audio']['min_clip_duration'] |
|
|
self.max_clip_duration = config['audio']['max_clip_duration'] |
|
|
|
|
|
self.source_clip_duration = config['audio'].get('source_clip_duration', 5.0) |
|
|
self.min_silence_ms = config['audio'].get('min_silence_duration', 100) |
|
|
self.max_extra_silence_per_gap_ms = config['audio'].get('max_extra_silence_per_gap', 500) |
|
|
self.crossfade_ms = config['audio'].get('crossfade_duration', 0) |
|
|
self.task_duration_hours = self.task_config['task_duration_size'] |
|
|
|
|
|
|
|
|
self.allow_source_repetition = self.task_config.get('allow_source_repetition', False) |
|
|
self.min_clips_for_second = self.task_config.get('min_clips_for_second_questions', 4) |
|
|
|
|
|
|
|
|
self.output_base = Path(config['output']['base_path']) / 'order' |
|
|
self.output_base.mkdir(parents=True, exist_ok=True) |
|
|
self.audio_output = self.output_base / 'audios' |
|
|
self.audio_output.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
def _get_valid_question_types(self, n_clips: int) -> List[str]: |
|
|
""" |
|
|
Get question types valid for the given number of clips. |
|
|
|
|
|
"second" and "second_last" require at least min_clips_for_second clips. |
|
|
|
|
|
Args: |
|
|
n_clips: Number of clips in the sample |
|
|
|
|
|
Returns: |
|
|
List of valid question types |
|
|
""" |
|
|
all_types = self.task_config['question_types'] |
|
|
|
|
|
|
|
|
valid_types = [] |
|
|
for qtype in all_types: |
|
|
if qtype in ['second', 'second_last']: |
|
|
if n_clips >= self.min_clips_for_second: |
|
|
valid_types.append(qtype) |
|
|
elif qtype in ['after', 'before']: |
|
|
if n_clips >= 2: |
|
|
valid_types.append(qtype) |
|
|
else: |
|
|
valid_types.append(qtype) |
|
|
|
|
|
return valid_types if valid_types else ['first', 'last'] |
|
|
|
|
|
def generate_sample(self, sample_id: int, target_question_type: str = None, target_duration_seconds: float = None) -> Dict: |
|
|
""" |
|
|
Generate a single order task sample. |
|
|
|
|
|
Pipeline: pick dataset -> pick class -> pick audio clip -> get duration -> |
|
|
concatenate clips to reach target duration -> modulo to get num clips -> |
|
|
inserting silences randomly based on remainder. |
|
|
|
|
|
Args: |
|
|
sample_id: Sample ID number |
|
|
target_question_type: Target question type for balanced distribution |
|
|
target_duration_seconds: Pre-generated target duration (from generate_sample_durations_for_task) |
|
|
|
|
|
Returns: |
|
|
Dictionary with sample metadata |
|
|
""" |
|
|
|
|
|
if target_duration_seconds is not None: |
|
|
clip_duration_seconds = target_duration_seconds |
|
|
else: |
|
|
clip_duration_seconds = generate_single_clip_duration( |
|
|
self.min_clip_duration, |
|
|
self.max_clip_duration |
|
|
) |
|
|
|
|
|
|
|
|
max_clips, remainder_seconds = get_max_clip_num_to_be_joined( |
|
|
clip_duration_seconds, |
|
|
self.source_clip_duration, |
|
|
self.min_silence_ms |
|
|
) |
|
|
|
|
|
max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
min_clips_for_sample = max(2, max_clips - 3) |
|
|
max_clips_for_sample = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES)) |
|
|
|
|
|
|
|
|
if max_clips_for_sample < 2: |
|
|
raise ValueError( |
|
|
f"Sample {sample_id}: Cannot generate order task - need at least 2 clips. " |
|
|
f"max_clips={max_clips}, max_clips_per_sample={max_clips_per_sample}, " |
|
|
f"duration={clip_duration_seconds:.1f}s. Increase min_clip_duration." |
|
|
) |
|
|
|
|
|
if min_clips_for_sample > max_clips_for_sample: |
|
|
raise ValueError( |
|
|
f"Sample {sample_id}: Invalid clip range - min_clips ({min_clips_for_sample}) > max_clips ({max_clips_for_sample}). " |
|
|
f"max_clips={max_clips}, max_clips_per_sample={max_clips_per_sample}, duration={clip_duration_seconds:.1f}s" |
|
|
) |
|
|
|
|
|
|
|
|
n_clips = random.randint(min_clips_for_sample, max_clips_for_sample) |
|
|
|
|
|
|
|
|
valid_question_types = self._get_valid_question_types(n_clips) |
|
|
|
|
|
if not valid_question_types: |
|
|
raise ValueError( |
|
|
f"Sample {sample_id}: No valid question types for n_clips={n_clips}. " |
|
|
f"This should not happen - check _get_valid_question_types implementation." |
|
|
) |
|
|
|
|
|
|
|
|
if target_question_type is not None: |
|
|
if target_question_type not in valid_question_types: |
|
|
raise ValueError( |
|
|
f"Sample {sample_id}: target_question_type='{target_question_type}' not valid for n_clips={n_clips}. " |
|
|
f"Valid types: {valid_question_types}. Balanced distribution should only assign valid types." |
|
|
) |
|
|
question_type = target_question_type |
|
|
else: |
|
|
question_type = random.choice(valid_question_types) |
|
|
|
|
|
|
|
|
if question_type == 'first': |
|
|
answer_position = 0 |
|
|
elif question_type == 'last': |
|
|
answer_position = n_clips - 1 |
|
|
elif question_type == 'second': |
|
|
answer_position = 1 |
|
|
elif question_type == 'second_last': |
|
|
answer_position = n_clips - 2 |
|
|
elif question_type == 'after': |
|
|
|
|
|
answer_position = random.randint(1, n_clips - 1) if n_clips >= 2 else 0 |
|
|
else: |
|
|
|
|
|
answer_position = random.randint(0, n_clips - 2) if n_clips >= 2 else 0 |
|
|
|
|
|
|
|
|
answer_category = self.dataset.get_least_used_categories(1)[0] |
|
|
|
|
|
|
|
|
if n_clips <= len(self.dataset.CATEGORIES): |
|
|
other_categories = self.dataset.get_least_used_categories( |
|
|
n_clips - 1, |
|
|
exclude=[answer_category] |
|
|
) |
|
|
else: |
|
|
|
|
|
other_categories = self.dataset.get_least_used_categories( |
|
|
min(n_clips - 1, len(self.dataset.CATEGORIES) - 1), |
|
|
exclude=[answer_category] |
|
|
) |
|
|
|
|
|
while len(other_categories) < n_clips - 1: |
|
|
other_categories.append(random.choice(self.dataset.CATEGORIES)) |
|
|
|
|
|
|
|
|
selected_categories = [] |
|
|
other_idx = 0 |
|
|
for i in range(n_clips): |
|
|
if i == answer_position: |
|
|
selected_categories.append(answer_category) |
|
|
else: |
|
|
selected_categories.append(other_categories[other_idx]) |
|
|
other_idx += 1 |
|
|
|
|
|
|
|
|
self.dataset.category_usage_counts[answer_category] += 1 |
|
|
|
|
|
|
|
|
audio_segments = [] |
|
|
filenames_list = [] |
|
|
|
|
|
for category in selected_categories: |
|
|
filename, filepath = self.dataset.sample_file_from_category(category) |
|
|
audio = self.audio_processor.load_audio(filepath) |
|
|
audio_segments.append(audio) |
|
|
filenames_list.append(filename) |
|
|
|
|
|
|
|
|
output_audio_path = self.audio_output / f"{sample_id}.wav" |
|
|
final_audio = build_clip_sequence_with_silences( |
|
|
audio_segments, |
|
|
clip_duration_seconds, |
|
|
min_silence_ms=self.min_silence_ms, |
|
|
max_extra_silence_per_gap_ms=self.max_extra_silence_per_gap_ms, |
|
|
crossfade_ms=self.crossfade_ms |
|
|
) |
|
|
|
|
|
|
|
|
final_audio.export(str(output_audio_path), format="wav") |
|
|
|
|
|
|
|
|
|
|
|
if selected_categories[answer_position] != answer_category: |
|
|
self.logger.error(f"Sample {sample_id}: Answer mismatch! Expected {answer_category} at position {answer_position}, got {selected_categories[answer_position]}") |
|
|
|
|
|
correct_category = selected_categories[answer_position] |
|
|
else: |
|
|
correct_category = answer_category |
|
|
|
|
|
if question_type == 'first': |
|
|
mcq_question = self.task_config['mcq_questions']['first'] |
|
|
open_text_question = self.task_config['open_text_questions']['first'] |
|
|
|
|
|
elif question_type == 'last': |
|
|
mcq_question = self.task_config['mcq_questions']['last'] |
|
|
open_text_question = self.task_config['open_text_questions']['last'] |
|
|
|
|
|
elif question_type == 'second': |
|
|
mcq_question = self.task_config['mcq_questions']['second'] |
|
|
open_text_question = self.task_config['open_text_questions']['second'] |
|
|
|
|
|
elif question_type == 'second_last': |
|
|
mcq_question = self.task_config['mcq_questions']['second_last'] |
|
|
open_text_question = self.task_config['open_text_questions']['second_last'] |
|
|
|
|
|
elif question_type == 'after': |
|
|
|
|
|
if answer_position > 0: |
|
|
reference_category = selected_categories[answer_position - 1] |
|
|
mcq_question = self.task_config['mcq_questions']['after'].format(sound1=reference_category) |
|
|
open_text_question = self.task_config['open_text_questions']['after'].format(sound1=reference_category) |
|
|
else: |
|
|
|
|
|
mcq_question = self.task_config['mcq_questions']['first'] |
|
|
open_text_question = self.task_config['open_text_questions']['first'] |
|
|
|
|
|
else: |
|
|
|
|
|
if answer_position < n_clips - 1: |
|
|
reference_category = selected_categories[answer_position + 1] |
|
|
mcq_question = self.task_config['mcq_questions']['before'].format(sound2=reference_category) |
|
|
open_text_question = self.task_config['open_text_questions']['before'].format(sound2=reference_category) |
|
|
else: |
|
|
|
|
|
correct_category = selected_categories[0] |
|
|
mcq_question = self.task_config['mcq_questions']['first'] |
|
|
open_text_question = self.task_config['open_text_questions']['first'] |
|
|
question_type = 'first' |
|
|
|
|
|
|
|
|
mcq_data = self.question_generator.generate_category_mcq( |
|
|
mcq_question, |
|
|
correct_category, |
|
|
selected_categories, |
|
|
self.dataset.CATEGORIES |
|
|
) |
|
|
|
|
|
|
|
|
open_text_data = self.question_generator.generate_category_open_text( |
|
|
open_text_question, |
|
|
correct_category |
|
|
) |
|
|
|
|
|
|
|
|
sequence_question = self.task_config['open_text_questions']['sequence'] |
|
|
sequence_data = self.question_generator.generate_sequence_open_text( |
|
|
sequence_question, |
|
|
selected_categories |
|
|
) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
'id': sample_id, |
|
|
'audio_path': str(output_audio_path.relative_to(self.output_base.parent)), |
|
|
'n_clips': n_clips, |
|
|
'question_type': question_type, |
|
|
'audio_sequence': selected_categories, |
|
|
'correct_answer_category': correct_category, |
|
|
'source_files': filenames_list, |
|
|
'mcq_question': mcq_data['question'], |
|
|
'mcq_options': mcq_data['options'], |
|
|
'mcq_correct_answer': mcq_data['correct_answer'], |
|
|
'open_text_question': open_text_data['question'], |
|
|
'open_text_answer': open_text_data['correct_answer'], |
|
|
'sequence_question': sequence_data['question'], |
|
|
'sequence_answer': sequence_data['correct_answer'] |
|
|
} |
|
|
|
|
|
self.logger.info(f"Generated order sample {sample_id}: {question_type}, {n_clips} clips") |
|
|
|
|
|
return metadata |
|
|
|
|
|
def generate_dataset(self) -> tuple: |
|
|
""" |
|
|
Generate the complete order task dataset. |
|
|
|
|
|
Uses generate_sample_durations_for_task() to pre-generate exact sample durations |
|
|
that sum to exactly the target task duration. This guarantees: |
|
|
- Exact coverage of target duration |
|
|
- No estimation errors from average-based calculation |
|
|
|
|
|
Returns: |
|
|
Tuple of (mcq_csv_path, open_text_csv_path, sequence_csv_path) |
|
|
""" |
|
|
|
|
|
sample_durations = generate_sample_durations_for_task( |
|
|
self.task_duration_hours, |
|
|
self.min_clip_duration, |
|
|
self.max_clip_duration |
|
|
) |
|
|
num_samples = len(sample_durations) |
|
|
|
|
|
self.logger.info(f"Generating {num_samples} order task samples (target: {self.task_duration_hours}h, exact fill)...") |
|
|
|
|
|
|
|
|
|
|
|
max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10) |
|
|
sample_effective_max_clips = [] |
|
|
|
|
|
for duration in sample_durations: |
|
|
max_clips, _ = get_max_clip_num_to_be_joined( |
|
|
duration, |
|
|
self.source_clip_duration, |
|
|
self.min_silence_ms |
|
|
) |
|
|
|
|
|
effective_max = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES)) |
|
|
sample_effective_max_clips.append(effective_max) |
|
|
|
|
|
|
|
|
|
|
|
question_types = self.task_config['question_types'] |
|
|
|
|
|
|
|
|
basic_types = ['first', 'last', 'after', 'before'] |
|
|
advanced_types = ['second', 'second_last'] |
|
|
|
|
|
|
|
|
samples_for_basic = sum(1 for emc in sample_effective_max_clips if emc >= 2) |
|
|
samples_for_advanced = sum(1 for emc in sample_effective_max_clips if emc >= self.min_clips_for_second) |
|
|
|
|
|
|
|
|
sample_info = [(i, sample_durations[i], sample_effective_max_clips[i]) for i in range(num_samples)] |
|
|
|
|
|
|
|
|
sample_info.sort(key=lambda x: x[2], reverse=True) |
|
|
|
|
|
|
|
|
samples_per_type = num_samples // len(question_types) |
|
|
remainder = num_samples % len(question_types) |
|
|
|
|
|
|
|
|
assignment_pool = [] |
|
|
for qtype in advanced_types: |
|
|
count = samples_per_type + (1 if remainder > 0 else 0) |
|
|
assignment_pool.extend([qtype] * count) |
|
|
remainder = max(0, remainder - 1) |
|
|
|
|
|
for qtype in basic_types: |
|
|
count = samples_per_type + (1 if remainder > 0 else 0) |
|
|
assignment_pool.extend([qtype] * count) |
|
|
remainder = max(0, remainder - 1) |
|
|
|
|
|
|
|
|
balanced_assignments = [None] * num_samples |
|
|
|
|
|
for idx, (sample_idx, duration, capacity) in enumerate(sample_info): |
|
|
target_qtype = assignment_pool[idx] |
|
|
|
|
|
|
|
|
valid_types = self._get_valid_question_types(capacity) |
|
|
|
|
|
if target_qtype not in valid_types: |
|
|
|
|
|
if target_qtype in advanced_types and any(t in valid_types for t in basic_types): |
|
|
|
|
|
target_qtype = random.choice([t for t in basic_types if t in valid_types]) |
|
|
else: |
|
|
|
|
|
target_qtype = random.choice(valid_types) |
|
|
|
|
|
balanced_assignments[sample_idx] = target_qtype |
|
|
|
|
|
|
|
|
from collections import Counter |
|
|
type_dist = Counter(balanced_assignments) |
|
|
self.logger.info(f"Balanced question type distribution (after capacity-aware assignment): {dict(sorted(type_dist.items()))}") |
|
|
|
|
|
all_metadata = [] |
|
|
|
|
|
for i, target_duration in enumerate(sample_durations): |
|
|
metadata = self.generate_sample(i, target_question_type=balanced_assignments[i], target_duration_seconds=target_duration) |
|
|
all_metadata.append(metadata) |
|
|
mcq_csv_path = self.output_base / 'order_mcq.csv' |
|
|
self._save_mcq_csv(all_metadata, mcq_csv_path) |
|
|
|
|
|
|
|
|
open_text_csv_path = self.output_base / 'order_open_text.csv' |
|
|
self._save_open_text_csv(all_metadata, open_text_csv_path) |
|
|
|
|
|
|
|
|
sequence_csv_path = self.output_base / 'order_sequence.csv' |
|
|
self._save_sequence_csv(all_metadata, sequence_csv_path) |
|
|
|
|
|
|
|
|
metadata_csv_path = self.output_base / 'order_metadata.csv' |
|
|
self._save_metadata_csv(all_metadata, metadata_csv_path) |
|
|
|
|
|
self.logger.info(f"Order task dataset generation complete!") |
|
|
self.logger.info(f" - MCQ CSV: {mcq_csv_path}") |
|
|
self.logger.info(f" - Open-text CSV: {open_text_csv_path}") |
|
|
self.logger.info(f" - Sequence CSV: {sequence_csv_path}") |
|
|
self.logger.info(f" - Metadata CSV: {metadata_csv_path}") |
|
|
self.logger.info(f" - Audio files: {self.audio_output}") |
|
|
|
|
|
return mcq_csv_path, open_text_csv_path, sequence_csv_path |
|
|
|
|
|
def _save_mcq_csv(self, metadata_list: List[Dict], output_path: Path): |
|
|
"""Save MCQ format CSV.""" |
|
|
with open(output_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
|
|
|
writer.writerow([ |
|
|
'question', 'id', 'audio_path', |
|
|
'optionA', 'optionB', 'optionC', 'optionD', |
|
|
'correct', 'question_type', 'audio_sequence' |
|
|
]) |
|
|
|
|
|
|
|
|
for meta in metadata_list: |
|
|
writer.writerow([ |
|
|
meta['mcq_question'], |
|
|
meta['id'], |
|
|
meta['audio_path'], |
|
|
meta['mcq_options']['A'], |
|
|
meta['mcq_options']['B'], |
|
|
meta['mcq_options']['C'], |
|
|
meta['mcq_options']['D'], |
|
|
meta['mcq_correct_answer'], |
|
|
meta['question_type'], |
|
|
str(meta['audio_sequence']) |
|
|
]) |
|
|
|
|
|
def _save_open_text_csv(self, metadata_list: List[Dict], output_path: Path): |
|
|
"""Save open-text format CSV.""" |
|
|
with open(output_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
|
|
|
writer.writerow([ |
|
|
'question', 'id', 'audio_path', 'answer', |
|
|
'question_type', 'audio_sequence' |
|
|
]) |
|
|
|
|
|
|
|
|
for meta in metadata_list: |
|
|
writer.writerow([ |
|
|
meta['open_text_question'], |
|
|
meta['id'], |
|
|
meta['audio_path'], |
|
|
meta['open_text_answer'], |
|
|
meta['question_type'], |
|
|
str(meta['audio_sequence']) |
|
|
]) |
|
|
|
|
|
def _save_sequence_csv(self, metadata_list: List[Dict], output_path: Path): |
|
|
"""Save sequence question CSV.""" |
|
|
with open(output_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
|
|
|
writer.writerow([ |
|
|
'question', 'id', 'audio_path', 'answer', 'audio_sequence' |
|
|
]) |
|
|
|
|
|
|
|
|
for meta in metadata_list: |
|
|
writer.writerow([ |
|
|
meta['sequence_question'], |
|
|
meta['id'], |
|
|
meta['audio_path'], |
|
|
meta['sequence_answer'], |
|
|
str(meta['audio_sequence']) |
|
|
]) |
|
|
|
|
|
def _save_metadata_csv(self, metadata_list: List[Dict], output_path: Path): |
|
|
"""Save detailed metadata CSV.""" |
|
|
with open(output_path, 'w', newline='') as f: |
|
|
writer = csv.writer(f) |
|
|
|
|
|
writer.writerow([ |
|
|
'id', 'audio_path', 'n_clips', 'question_type', |
|
|
'audio_sequence', 'correct_answer', 'source_files' |
|
|
]) |
|
|
|
|
|
|
|
|
for meta in metadata_list: |
|
|
writer.writerow([ |
|
|
meta['id'], |
|
|
meta['audio_path'], |
|
|
meta['n_clips'], |
|
|
meta['question_type'], |
|
|
str(meta['audio_sequence']), |
|
|
meta['correct_answer_category'], |
|
|
str(meta['source_files']) |
|
|
]) |
|
|
|
|
|
|
|
|
def main(config_path: str = None): |
|
|
"""Main entry point for order task generation.""" |
|
|
import yaml |
|
|
|
|
|
|
|
|
if config_path is None: |
|
|
config_path = Path(__file__).parent.parent / 'config.yaml' |
|
|
|
|
|
with open(config_path, 'r') as f: |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
set_random_seed(config['random_seed']) |
|
|
|
|
|
|
|
|
logger = setup_logger( |
|
|
'order_task', |
|
|
log_file=str(Path(config['output']['base_path']) / config['logging']['log_file']), |
|
|
level=config['logging']['level'], |
|
|
console_output=config['logging']['console_output'] |
|
|
) |
|
|
|
|
|
|
|
|
generator = OrderTaskGenerator(config, logger) |
|
|
generator.generate_dataset() |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
main() |
|
|
|