TREA_2.0_codebase / utils /question_utils.py
malay-36's picture
Upload folder using huggingface_hub
fec9168 verified
"""
Question generation utilities for MCQ and open-text formats.
"""
import random
from typing import Dict, List, Optional, Tuple
from .logger import setup_logger
logger = setup_logger(__name__)
class QuestionGenerator:
"""Generates questions in MCQ and open-text formats."""
def __init__(
self,
num_options: int = 4,
option_labels: Optional[List[str]] = None,
distractor_strategy: str = "balanced"
):
"""
Initialize question generator.
Args:
num_options: Number of MCQ options
option_labels: Labels for options (e.g., ['A', 'B', 'C', 'D'])
distractor_strategy: Strategy for generating distractor options
- "present_only": only use sounds present in audio
- "mixed": mix of present and absent sounds
- "balanced": balanced distribution
"""
self.num_options = num_options
self.option_labels = option_labels or ["A", "B", "C", "D"]
self.distractor_strategy = distractor_strategy
if len(self.option_labels) != num_options:
raise ValueError(f"Number of option labels must match num_options ({num_options})")
def generate_count_mcq(
self,
question_template: str,
correct_count: int,
all_categories: List[str]
) -> Dict:
"""
Generate an MCQ for counting task.
Args:
question_template: Question text template
correct_count: Correct number of unique sounds
all_categories: List of all available categories
Returns:
Dictionary with question, options, and correct answer
"""
# Generate options (including the correct answer)
options = self._generate_count_options(correct_count)
# Shuffle options
random.shuffle(options)
# Find correct answer label
correct_label = self.option_labels[options.index(correct_count)]
# Create option mapping
option_map = {label: value for label, value in zip(self.option_labels, options)}
return {
"question": question_template,
"options": option_map,
"correct_answer": correct_label,
"correct_value": correct_count
}
def generate_count_open_text(
self,
question_template: str,
correct_count: int
) -> Dict:
"""
Generate an open-text question for counting task.
Args:
question_template: Question text template
correct_count: Correct number of unique sounds
Returns:
Dictionary with question and correct answer
"""
return {
"question": question_template,
"correct_answer": str(correct_count)
}
def generate_category_mcq(
self,
question_template: str,
correct_category: str,
present_categories: List[str],
all_categories: List[str]
) -> Dict:
"""
Generate an MCQ where answer is a sound category.
Args:
question_template: Question text template
correct_category: Correct category
present_categories: Categories present in the audio
all_categories: All available categories
Returns:
Dictionary with question, options, and correct answer
"""
# Generate distractor options
distractors = self._generate_category_distractors(
correct_category,
present_categories,
all_categories,
self.num_options - 1
)
# Combine with correct answer
options = [correct_category] + distractors
random.shuffle(options)
# Find correct answer label
correct_label = self.option_labels[options.index(correct_category)]
# Create option mapping
option_map = {label: value for label, value in zip(self.option_labels, options)}
return {
"question": question_template,
"options": option_map,
"correct_answer": correct_label,
"correct_value": correct_category
}
def generate_category_open_text(
self,
question_template: str,
correct_category: str
) -> Dict:
"""
Generate an open-text question where answer is a sound category.
Args:
question_template: Question text template
correct_category: Correct category
Returns:
Dictionary with question and correct answer
"""
return {
"question": question_template,
"correct_answer": correct_category
}
def generate_sequence_open_text(
self,
question_template: str,
sequence: List[str]
) -> Dict:
"""
Generate an open-text question for sequence/ordering.
Args:
question_template: Question text template
sequence: List of categories in order
Returns:
Dictionary with question and correct answer
"""
return {
"question": question_template,
"correct_answer": ", ".join(sequence)
}
def _generate_count_options(self, correct_count: int) -> List[int]:
"""
Generate count options including the correct count.
Args:
correct_count: Correct count value
Returns:
List of count options
"""
options = [correct_count]
# Generate distractors (minimum count is 1, not 0)
possible_values = list(range(1, max(correct_count + 3, 12)))
possible_values = [v for v in possible_values if v != correct_count]
distractors = random.sample(possible_values, min(self.num_options - 1, len(possible_values)))
options.extend(distractors)
return options[:self.num_options]
def _generate_category_distractors(
self,
correct_category: str,
present_categories: List[str],
all_categories: List[str],
num_distractors: int
) -> List[str]:
"""
Generate distractor categories based on strategy.
Args:
correct_category: Correct category
present_categories: Categories present in audio
all_categories: All available categories
num_distractors: Number of distractors to generate
Returns:
List of distractor categories
"""
present_non_answer = [c for c in present_categories if c != correct_category]
absent_categories = [c for c in all_categories if c not in present_categories]
distractors = []
if self.distractor_strategy == "present_only":
# Only use categories present in the audio
if len(present_non_answer) >= num_distractors:
distractors = random.sample(present_non_answer, num_distractors)
else:
distractors = present_non_answer.copy()
# Fill remaining with random absent categories
remaining = num_distractors - len(distractors)
distractors.extend(random.sample(absent_categories, min(remaining, len(absent_categories))))
elif self.distractor_strategy == "mixed":
# Mix of present and absent (random proportion)
num_present = random.randint(0, min(len(present_non_answer), num_distractors))
num_absent = num_distractors - num_present
if num_present > 0:
distractors.extend(random.sample(present_non_answer, min(num_present, len(present_non_answer))))
if num_absent > 0:
distractors.extend(random.sample(absent_categories, min(num_absent, len(absent_categories))))
else: # balanced
# Balanced distribution: 0, 1, or 2 present sounds as distractors
num_present_distractor = random.choice([0, 1, 2])
num_present_distractor = min(num_present_distractor, len(present_non_answer), num_distractors)
num_absent_distractor = num_distractors - num_present_distractor
if num_present_distractor > 0:
distractors.extend(random.sample(present_non_answer, num_present_distractor))
if num_absent_distractor > 0:
distractors.extend(random.sample(absent_categories, min(num_absent_distractor, len(absent_categories))))
# Fill remaining slots if needed
while len(distractors) < num_distractors:
remaining_options = [c for c in all_categories if c not in distractors and c != correct_category]
if not remaining_options:
break
distractors.append(random.choice(remaining_options))
return distractors[:num_distractors]