File size: 19,961 Bytes
fec9168
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
"""
Task 1: Count - Generate counting questions

This task joins multiple audio sources and asks questions about counting
the number of unique sound sources in the audio.
"""

import csv
import random
from pathlib import Path
from typing import Dict, List

import sys
sys.path.append(str(Path(__file__).parent.parent))

from utils import (
    AudioProcessor, ESC50Dataset, QuestionGenerator, LLMQuestionGenerator,
    setup_logger, set_random_seed, generate_sample_durations_for_task,
    generate_single_clip_duration, build_count_task_audio,
    get_max_clip_num_to_be_joined
)


class CountTaskGenerator:
    """Generator for counting task dataset."""
    
    def __init__(self, config: Dict, logger):
        """
        Initialize count task generator.
        
        Args:
            config: Configuration dictionary
            logger: Logger instance
        """
        self.config = config
        self.logger = logger
        self.task_config = config['tasks']['count']
        
        # Initialize components
        self.dataset = ESC50Dataset(
            config['esc50']['metadata_path'],
            config['esc50']['audio_path'],
            config  # Pass config for class subset loading
        )
        self.audio_processor = AudioProcessor(
            crossfade_duration=config['audio']['crossfade_duration'],
            silence_duration=config['audio']['silence_duration'],
            with_silence=config['audio']['with_silence'],
            normalize=config['audio']['normalize'],
            normalize_target_dBFS=config['audio']['normalize_target_dBFS'],
            synthetic_silence_path=config['synthetic_silence']['path']
        )
        self.question_generator = QuestionGenerator(
            num_options=config['mcq']['num_options'],
            option_labels=config['mcq']['option_labels'],
            distractor_strategy=config['mcq']['distractor_strategy']
        )
        
        # Initialize LLM question generator
        self.llm_enabled = config.get('llm', {}).get('enabled', False)
        self.llm_generator = LLMQuestionGenerator(
            enabled=self.llm_enabled,
            template_questions=self.task_config
        )
        if self.llm_enabled:
            logger.info("LLM question generation enabled (local Llama 3.1 8B)")
        else:
            logger.info("Using template-based question generation")
        
        # Duration settings from config
        self.min_clip_duration = config['audio']['min_clip_duration']
        self.max_clip_duration = config['audio']['max_clip_duration']
        self.source_clip_duration = config['audio'].get('source_clip_duration', 5.0)
        self.min_silence_ms = config['audio'].get('min_silence_duration', 100)
        self.max_extra_silence_per_gap_ms = config['audio'].get('max_extra_silence_per_gap', 500)
        # Small crossfade within same-source repetitions (for consecutive mode)
        self.crossfade_within_source_ms = config['audio'].get('crossfade_within_source', 50)
        self.task_duration_hours = self.task_config['task_duration_size']
        
        # Ordering mode: "random" or "consecutive"
        # random: Clips shuffled (A B A C B A C) - tests sound recognition
        # consecutive: Same-source grouped (AAA BBB CCC) - easier
        self.ordering_mode = self.task_config.get('ordering_mode', 'random')
        logger.info(f"Count task ordering mode: {self.ordering_mode}")
        
        # Set up output paths
        self.output_base = Path(config['output']['base_path']) / 'count'
        self.output_base.mkdir(parents=True, exist_ok=True)
        self.audio_output = self.output_base / 'audios'
        self.audio_output.mkdir(parents=True, exist_ok=True)
        
    def create_sampling_list(self, parent_list: List, n_sampling: int) -> List:
        """
        Sample elements from parent list with replacement.
        
        Args:
            parent_list: List to sample from
            n_sampling: Number of samples
            
        Returns:
            List of sampled elements
        """
        return [random.choice(parent_list) for _ in range(n_sampling)]
    
    def generate_sample(self, sample_id: int, target_unique_count: int = None, target_duration_seconds: float = None) -> Dict:
        """
        Generate a single count task sample.
        
        Pipeline for COUNT task:
        1. Use pre-generated target duration (or generate if not provided)
        2. Calculate max clips that can fit
        3. Pick N unique classes (N <= max_clips, since each source needs at least 1 clip)
        4. For each class, sample one audio clip
        5. Calculate repetitions to fill target duration
        6. Based on ordering_mode:
           - "random": Shuffle clips (A B A C B A C) - tests recognition
           - "consecutive": Group same-class (AAA BBB CCC) - easier
        7. Insert silences between clips
        8. Distribute remainder as random extra silences
        
        Args:
            sample_id: Sample ID number
            target_unique_count: Target number of unique sounds (for balanced distribution)
            target_duration_seconds: Pre-generated target duration (from generate_sample_durations_for_task)
            
        Returns:
            Dictionary with sample metadata
        """
        # Use pre-generated duration or generate one (backward compatibility)
        if target_duration_seconds is not None:
            clip_duration_seconds = target_duration_seconds
        else:
            clip_duration_seconds = generate_single_clip_duration(
                self.min_clip_duration,
                self.max_clip_duration
            )
        
        # Calculate max clips that can fit in target duration
        max_clips, remainder_seconds = get_max_clip_num_to_be_joined(
            clip_duration_seconds,
            self.source_clip_duration,
            self.min_silence_ms
        )
        
        # Ensure at least 1 clip
        max_clips = max(1, max_clips)
        
        max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10)
        
        # Calculate valid range: n_unique_audios can be 1 to max_clips_per_sample
        # but cannot exceed what physically fits or available categories
        max_unique_for_sample = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES))
        
        if max_unique_for_sample < 1:
            raise ValueError(
                f"Sample {sample_id}: Cannot generate sample - max_unique_for_sample={max_unique_for_sample}. "
                f"max_clips={max_clips}, max_clips_per_sample={max_clips_per_sample}, "
                f"available_categories={len(self.dataset.CATEGORIES)}, duration={clip_duration_seconds:.1f}s. "
                f"Increase min_clip_duration or reduce max_clips_per_sample."
            )
        
        # Determine n_unique_audios - use target from balanced distribution or random
        if target_unique_count is not None:
            # Clamp target to what this specific sample duration can fit
            # Short samples can't fit all possible answers, so we clamp down
            n_unique_audios = min(target_unique_count, max_unique_for_sample)
            
            if n_unique_audios != target_unique_count:
                self.logger.debug(
                    f"Sample {sample_id}: Clamped target from {target_unique_count} to {n_unique_audios} "
                    f"(duration={clip_duration_seconds:.1f}s can only fit {max_clips} clips)"
                )
        else:
            # No target specified - randomly select from valid range
            n_unique_audios = random.randint(1, max_unique_for_sample)
        
        self.logger.debug(
            f"Sample {sample_id}: target={clip_duration_seconds:.1f}s, max_clips={max_clips}, "
            f"n_unique_audios={n_unique_audios}"
        )
        
        # Sample unique categories - use least-used categories for balanced distribution
        selected_categories = self.dataset.get_least_used_categories(n_unique_audios)
        
        # Track usage of all selected categories
        for cat in selected_categories:
            self.dataset.category_usage_counts[cat] += 1
        
        # Sample one file from each unique category
        source_files = []
        source_paths = []
        source_categories = []
        
        for category in selected_categories:
            filename, filepath = self.dataset.sample_file_from_category(category)
            source_files.append(filename)
            source_paths.append(filepath)
            source_categories.append(category)
        
        # Load unique source audios
        source_audios = []
        for file_path in source_paths:
            audio = self.audio_processor.load_audio(file_path)
            source_audios.append(audio)
        
        # Build audio using configured ordering mode
        final_audio, clip_sequence, build_metadata = build_count_task_audio(
            source_audios,
            source_categories,
            clip_duration_seconds,
            ordering_mode=self.ordering_mode,
            source_clip_duration_seconds=self.source_clip_duration,
            min_silence_ms=self.min_silence_ms,
            max_extra_silence_per_gap_ms=self.max_extra_silence_per_gap_ms,
            crossfade_within_source_ms=self.crossfade_within_source_ms
        )
        
        # Save the audio
        output_audio_path = self.audio_output / f"{sample_id}.wav"
        final_audio.export(str(output_audio_path), format="wav")
        
        # Generate questions (using LLM if enabled)
        if self.llm_enabled and self.llm_generator:
            llm_questions = self.llm_generator.generate_count_questions(
                correct_count=n_unique_audios,
                categories_present=list(set(clip_sequence))
            )
            mcq_question_text = llm_questions.get('mcq_question')
            open_text_question_text = llm_questions.get('open_text_question')
        else:
            mcq_question_text = random.choice(self.task_config['mcq_questions'])
            open_text_question_text = random.choice(self.task_config['open_text_questions'])
        
        # Generate MCQ with options
        mcq_data = self.question_generator.generate_count_mcq(
            mcq_question_text,
            n_unique_audios,
            self.dataset.CATEGORIES
        )
        
        # Generate open-text answer
        open_text_data = self.question_generator.generate_count_open_text(
            open_text_question_text,
            n_unique_audios
        )
        
        # Create metadata
        metadata = {
            'id': sample_id,
            'audio_path': str(output_audio_path.relative_to(self.output_base.parent)),
            'n_unique_sounds': n_unique_audios,
            'total_clips': build_metadata['total_clips'],
            'repetitions_per_source': build_metadata['repetitions_per_source'],
            'ordering_mode': self.ordering_mode,
            'source_files': source_files,
            'source_categories': source_categories,
            'clip_sequence': clip_sequence,
            'unique_categories': sorted(list(set(source_categories))),
            'target_duration_seconds': clip_duration_seconds,
            'actual_duration_seconds': len(final_audio) / 1000.0,
            'mcq_question': mcq_data['question'],
            'mcq_options': mcq_data['options'],
            'mcq_correct_answer': mcq_data['correct_answer'],
            'open_text_question': open_text_data['question'],
            'open_text_answer': open_text_data['correct_answer'],
            'llm_generated': self.llm_enabled
        }
        
        self.logger.info(
            f"Generated count sample {sample_id}: {n_unique_audios} unique sounds, "
            f"{build_metadata['total_clips']} clips, {len(final_audio)/1000:.1f}s"
        )
        
        return metadata
    
    def generate_dataset(self) -> tuple:
        """
        Generate the complete count task dataset.
        
        Returns:
            Tuple of (mcq_csv_path, open_text_csv_path)
        """
        # Generate sample durations upfront to exactly fill target duration
        sample_durations = generate_sample_durations_for_task(
            self.task_duration_hours,
            self.min_clip_duration,
            self.max_clip_duration
        )
        num_samples = len(sample_durations)
        self.logger.info(f"Generating {num_samples} count task samples (target: {self.task_duration_hours}h, actual: {sum(sample_durations)/3600:.2f}h)...")
        
        # Calculate max clips each sample can fit based on duration
        max_clips_per_sample = self.task_config.get('max_clips_per_sample', 10)
        sample_max_clips = []
        for duration in sample_durations:
            max_clips, _ = get_max_clip_num_to_be_joined(
                duration,
                self.source_clip_duration,
                self.min_silence_ms
            )
            # Limit to config max and available categories
            max_for_sample = min(max_clips, max_clips_per_sample, len(self.dataset.CATEGORIES))
            sample_max_clips.append(max_for_sample)
        
        # Create balanced distribution by assigning targets based on sample capacity
        # Sort samples by capacity to assign higher targets to samples that can fit them
        possible_answers = list(range(1, max_clips_per_sample + 1))
        samples_per_answer = num_samples // len(possible_answers)
        remainder = num_samples % len(possible_answers)
        
        # Create list of (sample_idx, duration, max_clips_capacity)
        sample_info = [(i, sample_durations[i], sample_max_clips[i]) for i in range(num_samples)]
        
        # Sort by capacity (descending) - assign high targets to high-capacity samples
        sample_info.sort(key=lambda x: x[2], reverse=True)
        
        # Assign targets: distribute each answer count across samples
        balanced_assignments = [None] * num_samples
        assignment_pool = []
        
        for answer in possible_answers:
            count = samples_per_answer + (1 if remainder > 0 else 0)
            assignment_pool.extend([answer] * count)
            remainder = max(0, remainder - 1)
        
        # Reverse pool so we assign high targets first (to high-capacity samples)
        assignment_pool.sort(reverse=True)
        
        for idx, (sample_idx, duration, capacity) in enumerate(sample_info):
            # Assign target, clamped to sample's capacity
            target = min(assignment_pool[idx], capacity)
            balanced_assignments[sample_idx] = target
        
        # Log the actual distribution after capacity clamping
        from collections import Counter
        distribution = Counter(balanced_assignments)
        self.logger.info(f"Balanced answer distribution (after capacity-aware assignment): {dict(sorted(distribution.items()))}")
        
        all_metadata = []
        
        for i in range(num_samples):
            metadata = self.generate_sample(
                i, 
                target_unique_count=balanced_assignments[i],
                target_duration_seconds=sample_durations[i]
            )
            all_metadata.append(metadata)
        
        # Save MCQ CSV
        mcq_csv_path = self.output_base / 'count_mcq.csv'
        self._save_mcq_csv(all_metadata, mcq_csv_path)
        
        # Save open-text CSV
        open_text_csv_path = self.output_base / 'count_open_text.csv'
        self._save_open_text_csv(all_metadata, open_text_csv_path)
        
        # Save metadata CSV
        metadata_csv_path = self.output_base / 'count_metadata.csv'
        self._save_metadata_csv(all_metadata, metadata_csv_path)
        
        self.logger.info(f"Count task dataset generation complete!")
        self.logger.info(f"  - MCQ CSV: {mcq_csv_path}")
        self.logger.info(f"  - Open-text CSV: {open_text_csv_path}")
        self.logger.info(f"  - Metadata CSV: {metadata_csv_path}")
        self.logger.info(f"  - Audio files: {self.audio_output}")
        
        return mcq_csv_path, open_text_csv_path
    
    def _save_mcq_csv(self, metadata_list: List[Dict], output_path: Path):
        """Save MCQ format CSV."""
        with open(output_path, 'w', newline='') as f:
            writer = csv.writer(f)
            # Header
            writer.writerow([
                'question', 'id', 'audio_path',
                'optionA', 'optionB', 'optionC', 'optionD',
                'correct', 'source_wavs', 'source_categories'
            ])
            
            # Data rows
            for meta in metadata_list:
                writer.writerow([
                    meta['mcq_question'],
                    meta['id'],
                    meta['audio_path'],
                    meta['mcq_options']['A'],
                    meta['mcq_options']['B'],
                    meta['mcq_options']['C'],
                    meta['mcq_options']['D'],
                    meta['mcq_correct_answer'],
                    str(meta['source_files']),
                    str(meta['unique_categories'])
                ])
    
    def _save_open_text_csv(self, metadata_list: List[Dict], output_path: Path):
        """Save open-text format CSV."""
        with open(output_path, 'w', newline='') as f:
            writer = csv.writer(f)
            # Header
            writer.writerow([
                'question', 'id', 'audio_path', 'answer',
                'source_wavs', 'source_categories'
            ])
            
            # Data rows
            for meta in metadata_list:
                writer.writerow([
                    meta['open_text_question'],
                    meta['id'],
                    meta['audio_path'],
                    meta['open_text_answer'],
                    str(meta['source_files']),
                    str(meta['unique_categories'])
                ])
    
    def _save_metadata_csv(self, metadata_list: List[Dict], output_path: Path):
        """Save detailed metadata CSV."""
        with open(output_path, 'w', newline='') as f:
            writer = csv.writer(f)
            # Header
            writer.writerow([
                'id', 'audio_path', 'total_clips', 'n_unique_sounds',
                'source_files', 'source_categories', 'unique_categories',
                'ordering_mode', 'target_duration_s', 'actual_duration_s', 'llm_generated'
            ])
            
            # Data rows
            for meta in metadata_list:
                writer.writerow([
                    meta['id'],
                    meta['audio_path'],
                    meta['total_clips'],
                    meta['n_unique_sounds'],
                    str(meta['source_files']),
                    str(meta['source_categories']),
                    str(meta['unique_categories']),
                    meta.get('ordering_mode', 'random'),
                    meta.get('target_duration_seconds', 0),
                    meta.get('actual_duration_seconds', 0),
                    meta.get('llm_generated', False)
                ])


def main(config_path: str = None):
    """Main entry point for count task generation."""
    import yaml
    
    # Load configuration
    if config_path is None:
        config_path = Path(__file__).parent.parent / 'config.yaml'
    
    with open(config_path, 'r') as f:
        config = yaml.safe_load(f)
    
    # Set random seed
    set_random_seed(config['random_seed'])
    
    # Setup logger
    logger = setup_logger(
        'count_task',
        log_file=str(Path(config['output']['base_path']) / config['logging']['log_file']),
        level=config['logging']['level'],
        console_output=config['logging']['console_output']
    )
    
    # Generate dataset
    generator = CountTaskGenerator(config, logger)
    generator.generate_dataset()


if __name__ == '__main__':
    main()