""" T5-based Question Generation for Time-Aware RAG Generates questions from passages with temporal awareness """ import os import json import yaml import torch import logging from typing import List, Dict, Any from transformers import T5ForConditionalGeneration, T5Tokenizer, AutoTokenizer from datasets import Dataset, load_dataset from torch.utils.data import DataLoader from tqdm import tqdm import pandas as pd logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class T5QuestionGenerator: def __init__(self, config: Dict[str, Any]): self.config = config self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load T5 model and tokenizer model_name = config['models']['t5_generator']['name'] self.tokenizer = T5Tokenizer.from_pretrained(model_name) self.model = T5ForConditionalGeneration.from_pretrained(model_name) self.model.to(self.device) self.model.eval() logger.info(f"Loaded T5 model: {model_name} on {self.device}") def generate_temporal_questions_batch(self, passages: List[str], years: List[str], max_new_tokens: int = 64) -> List[str]: """Generate temporal questions using T5 - matches notebook implementation exactly""" # Use the exact same prompt format as in the notebook prompts = [f"generate question about {y}: {p}" for p, y in zip(passages, years)] # Batch tokenization exactly like the notebook inputs = self.tokenizer(prompts, padding="longest", truncation=True, max_length=512, return_tensors="pt") inputs = {k: v.to(self.device) for k, v in inputs.items()} # Generate questions with exact same parameters as notebook with torch.no_grad(): with torch.autocast("cuda", dtype=torch.float16, enabled=torch.cuda.is_available()): outputs = self.model.generate(**inputs, max_length=max_new_tokens, num_beams=4, early_stopping=True) # Decode exactly like the notebook decoded = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) # Clean up the outputs - remove the input prompts questions = [] for i, decoded_text in enumerate(decoded): # Remove the input prompt from the output prompt = prompts[i] if decoded_text.startswith(prompt): question = decoded_text[len(prompt):].strip() else: question = decoded_text.strip() # Ensure it ends with a question mark if question and not question.endswith('?'): question += '?' questions.append(question) return questions def get_years_from_text(self, text: str): """Extract years from text - matches notebook implementation""" import re # Use the same regex as in the notebook YEAR_REGEX = re.compile(r"\b(18[0-9]{2}|19[0-9]{2}|20[0-2][0-9])\b") return set(YEAR_REGEX.findall(text)) def generate_temporal_questions(self, passage: str, num_questions: int = 5) -> List[str]: """Generate temporal-aware questions from a passage - matches notebook workflow""" # Extract years exactly like the notebook years = self.get_years_from_text(passage) if not years: # If no years, skip this passage (like in notebook) return [] # Use first year like in the notebook first_year = sorted(list(years))[0] # Generate one question per passage (like notebook does) passages = [passage] year_batch = [first_year] # Generate questions using the batch function questions = self.generate_temporal_questions_batch(passages, year_batch) # If we need more questions, repeat with different years all_questions = [] years_list = sorted(list(years)) for i in range(num_questions): if i < len(questions) and questions[i]: all_questions.append(questions[i]) elif years_list: # Generate with different year year_to_use = years_list[i % len(years_list)] additional_q = self.generate_temporal_questions_batch([passage], [year_to_use]) if additional_q and additional_q[0]: all_questions.append(additional_q[0]) return all_questions[:num_questions] def process_dataset(self, dataset_path: str, output_path: str): """Process a dataset to generate questions""" logger.info(f"Processing dataset from {dataset_path}") # Load dataset (assuming it's in JSON format) if dataset_path.endswith('.json'): with open(dataset_path, 'r') as f: data = json.load(f) else: # Try to load as HuggingFace dataset data = load_dataset(dataset_path, split='train') generated_data = [] for idx, item in enumerate(tqdm(data, desc="Generating questions")): if isinstance(item, dict): passage = item.get('text', item.get('passage', '')) passage_id = item.get('id', f'passage_{idx}') else: passage = str(item) passage_id = f'passage_{idx}' if not passage: continue # Generate questions questions = self.generate_temporal_questions( passage, self.config['data']['generated_questions']['num_questions_per_passage'] ) for q_idx, question in enumerate(questions): generated_data.append({ 'passage_id': passage_id, 'passage': passage, 'question': question, 'question_id': f"{passage_id}_q{q_idx}", 'temporal_type': 'generated' }) # Save generated data os.makedirs(os.path.dirname(output_path), exist_ok=True) with open(output_path, 'w') as f: json.dump(generated_data, f, indent=2) logger.info(f"Generated {len(generated_data)} question-passage pairs") logger.info(f"Saved to {output_path}") return generated_data def main(): # Load configuration with open('configs/config.yaml', 'r') as f: config = yaml.safe_load(f) generator = T5QuestionGenerator(config) # Load FineWeb passages (matches notebook approach) fineweb_path = os.path.join(config['data']['fineweb']['output_path'], 'fineweb_passages.json') if os.path.exists(fineweb_path): print(f"Loading FineWeb passages from {fineweb_path}") with open(fineweb_path, 'r') as f: fineweb_data = json.load(f) # Convert to notebook format (id, text, title) - use all available passages all_passages = [(item['id'], item['text'], item['title']) for item in fineweb_data] # Sample for question generation like the notebook (NUM_QG_PASSAGES = 15000) num_qg_passages = min(15000, len(all_passages)) # notebook constant if len(all_passages) > num_qg_passages: import random random.seed(42) # For reproducibility sample_passages = random.sample(all_passages, num_qg_passages) print(f"Sampled {len(sample_passages)} passages from {len(all_passages)} total FineWeb passages for question generation") else: sample_passages = all_passages print(f"Using all {len(sample_passages)} FineWeb passages for question generation") else: print(f"FineWeb data not found at {fineweb_path}, using sample passages...") # Fallback to hardcoded samples if FineWeb not available sample_passages = [ (0, "The Renaissance began in Italy during the 14th century and lasted until the 17th century. It was characterized by a revival of classical learning and art.", "Renaissance"), (1, "World War II started on September 1, 1939, when Germany invaded Poland. The war lasted until September 2, 1945.", "WWII"), (2, "The Industrial Revolution began in Britain in the late 18th century and spread throughout Europe and North America during the 19th century.", "Industrial Revolution"), (3, "The American Civil War was fought from 1861 to 1865 between the northern and southern states.", "Civil War"), (4, "The Great Depression began in 1929 with the stock market crash and lasted throughout the 1930s.", "Great Depression"), (5, "The Cold War started in 1947 and lasted until 1991, representing geopolitical tension.", "Cold War") ] output_dir = config['data']['generated_questions']['output_path'] os.makedirs(output_dir, exist_ok=True) # Process exactly like the notebook with batching synthetic_pairs = [] # (question, passage_text, passage_id) passage_batch, passage_info, year_batch = [], [], [] QG_BATCH_SIZE = 8 # Match notebook batch size logger.info(f"Generating synthetic temporal questions from {len(sample_passages)} passages...") for (pid, text, title) in tqdm(sample_passages, desc="Processing passages"): years = generator.get_years_from_text(text) if not years: continue # Use first year like notebook first_year = sorted(list(years))[0] passage_batch.append(text) year_batch.append(first_year) passage_info.append((pid, text)) if len(passage_batch) >= QG_BATCH_SIZE: # Process batch generated_questions = generator.generate_temporal_questions_batch(passage_batch, year_batch) for i, q in enumerate(generated_questions): if q: p_id, p_text = passage_info[i] synthetic_pairs.append((q, p_text, p_id)) # Reset batch passage_batch, passage_info, year_batch = [], [], [] # Process remaining passages if passage_batch: generated_questions = generator.generate_temporal_questions_batch(passage_batch, year_batch) for i, q in enumerate(generated_questions): if q: p_id, p_text = passage_info[i] synthetic_pairs.append((q, p_text, p_id)) # Convert to the expected format all_generated = [] for i, (question, passage, passage_id) in enumerate(synthetic_pairs): all_generated.append({ 'passage_id': f'sample_{passage_id}', 'passage': passage, 'question': question, 'question_id': f'sample_{passage_id}_q{i}', 'temporal_type': 'generated' }) # Save sample generated data output_path = os.path.join(output_dir, 'sample_generated_questions.json') with open(output_path, 'w') as f: json.dump(all_generated, f, indent=2) logger.info(f"Generated sample questions saved to {output_path}") if __name__ == "__main__": main()