Spaces:
Runtime error
Runtime error
| from transformers import T5Tokenizer, T5ForConditionalGeneration | |
| # Loading the fine-tuned model | |
| model_name = "./T5base_Question_Generation" | |
| tokenizer = T5Tokenizer.from_pretrained(model_name) | |
| model = T5ForConditionalGeneration.from_pretrained(model_name) | |
| def get_question(tag, difficulty, context, answer="", num_questions=1, use_beam_search=False, num_beams=3, max_length=150): | |
| """ | |
| Generate questions using the fine-tuned T5 model | |
| Parameters: | |
| - tag: Type of question (e.g., "short answer", "multiple choice question", "true or false question") | |
| - difficulty: "easy", "medium", "hard" | |
| - context: Supporting context or passage | |
| - answer: Optional β if you want targeted question generation | |
| - num_questions: Number of diverse questions to generate | |
| - max_length: Max token length of generated output | |
| Returns: | |
| - List of generated questions as strings | |
| """ | |
| # Format input text based on whether answer is provided | |
| answer_part = f"[{answer}]" if answer else "" | |
| input_text = f"<extra_id_97>{tag} <extra_id_98>{difficulty} <extra_id_99>{answer_part} {context}" | |
| # Tokenize input | |
| features = tokenizer([input_text], return_tensors='pt', truncation=True, padding=True) | |
| # Decide generation strategy | |
| if num_questions == 1: | |
| if use_beam_search: | |
| output = model.generate( | |
| input_ids=features['input_ids'], | |
| attention_mask=features['attention_mask'], | |
| max_length=max_length, | |
| num_beams=num_beams, | |
| early_stopping=False | |
| ) | |
| else: | |
| output = model.generate( | |
| input_ids=features['input_ids'], | |
| attention_mask=features['attention_mask'], | |
| max_length=max_length, | |
| do_sample=False | |
| ) | |
| else: | |
| output = model.generate( | |
| input_ids=features['input_ids'], | |
| attention_mask=features['attention_mask'], | |
| max_length=max_length, | |
| num_return_sequences=num_questions, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=50 | |
| ) | |
| # Decode questions | |
| questions = [tokenizer.decode(out, skip_special_tokens=True) for out in output] | |
| return questions | |