| import streamlit as st |
| from transformers import T5TokenizerFast, T5ForConditionalGeneration |
| import nltk |
| import math |
| import torch |
|
|
| model_name = "abokbot/t5-end2end-questions-generation" |
|
|
| st.header("Generate questions for short Wikipedia-like articles") |
|
|
| st_model_load = st.text('Loading question generator model...') |
|
|
| @st.cache(allow_output_mutation=True) |
| def load_model(): |
| print("Loading model...") |
| tokenizer = T5TokenizerFast.from_pretrained("t5-base") |
| model = T5ForConditionalGeneration.from_pretrained(model_name) |
| nltk.download('punkt') |
| print("Model loaded!") |
| return tokenizer, model |
|
|
| tokenizer, model = load_model() |
| st.success('Model loaded!') |
| st_model_load.text("") |
|
|
| if 'text' not in st.session_state: |
| st.session_state.text = "" |
| st_text_area = st.text_area('Text to generate the questions for', value=st.session_state.text, height=500) |
|
|
| def generate_questions(): |
| st.session_state.text = st_text_area |
|
|
| generator_args = { |
| "max_length": 256, |
| "num_beams": 4, |
| "length_penalty": 1.5, |
| "no_repeat_ngram_size": 3, |
| "early_stopping": True, |
| } |
| input_string = "generate questions: " + st_text_area + " </s>" |
| input_ids = tokenizer.encode(input_string, return_tensors="pt") |
| res = model.generate(input_ids, **generator_args) |
| output = tokenizer.batch_decode(res, skip_special_tokens=True) |
| output = [question.strip() + "?" for question in output[0].split("?") if question != ""] |
|
|
| st.session_state.questions = output |
|
|
| |
| st_generate_button = st.button('Generate questions', on_click=generate_questions) |
|
|
| |
| if 'questions' not in st.session_state: |
| st.session_state.questions = [] |
|
|
| if len(st.session_state.questions) > 0: |
| with st.container(): |
| st.subheader("Generated questions") |
| for title in st.session_state.questions: |
| st.markdown("__" + title + "__") |
|
|