Spaces:
Runtime error
Runtime error
| import wikipedia | |
| import transformers | |
| import spacy | |
| from transformers import AutoModelWithLMHead, AutoTokenizer | |
| import random | |
| import gradio as gr | |
| tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
| model = AutoModelWithLMHead.from_pretrained("mrm8488/t5-base-finetuned-question-generation-ap") | |
| nlp = spacy.load("en_core_web_sm") | |
| def get_question(answer, context, max_length=64): | |
| input_text = "answer: %s context: %s </s>" % (answer, context) | |
| features = tokenizer([input_text], return_tensors='pt') | |
| output = model.generate(input_ids=features['input_ids'], | |
| attention_mask=features['attention_mask'], | |
| max_length=max_length) | |
| return tokenizer.decode(output[0], skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| import gradio as gr | |
| def greet(entered_topic): | |
| print("Entered topic: ", entered_topic) | |
| topics = wikipedia.search(entered_topic) | |
| topics = topics[:3] | |
| random.shuffle(topics) | |
| for topic in topics: | |
| try: | |
| summary = wikipedia.summary(topic) | |
| except wikipedia.DisambiguationError as e: | |
| # print(e.options) | |
| s = random.choice(e.options) | |
| summary = wikipedia.summary(s) | |
| except wikipedia.PageError as e: | |
| continue | |
| break | |
| if(len(topics) == 0): | |
| return ["Please Type a Different Topic", gr.update(visible=True), gr.update(value="", visible=False)] | |
| print("Selected topic: ", topic) | |
| print("Summary: ", summary) | |
| summary = summary.replace("\n", "") | |
| doc = nlp(summary) | |
| answers = doc.ents | |
| filtered_answers = [] | |
| for answer in answers: | |
| if(answer.text.lower() in entered_topic.lower() or entered_topic.lower() in answer.text.lower()): | |
| pass | |
| else: | |
| filtered_answers.append(answer) | |
| answer_1 = random.choice(filtered_answers) | |
| question_1 = get_question(answer_1, summary) | |
| question_1 = question_1[9:] | |
| print("Question: ", question_1) | |
| print("Answer: ", answer_1) | |
| return [question_1, gr.update(visible=True), gr.update(value=answer_1, visible=False)] | |
| def get_answer(input_answer, gold_answer): | |
| print("Entered Answer: ", input_answer) | |
| return gr.update(value=gold_answer, visible=True) | |
| with gr.Blocks() as demo: | |
| # with gr.Row(): | |
| topic = gr.Textbox(label="Topic") | |
| greet_btn = gr.Button("Ask a Question") | |
| question = gr.Textbox(label="Question") | |
| input_answer = gr.Textbox(label="Your Answer", visible=False) | |
| answer_btn = gr.Button("Show Answer") | |
| gold_answer = gr.Textbox(label="Correct Answer", visible=False) | |
| greet_btn.click(fn=greet, inputs=topic, outputs=[question, input_answer, gold_answer]) | |
| # with gr.Row(): | |
| answer_btn.click(fn=get_answer, inputs=[input_answer,gold_answer], outputs=gold_answer) | |
| demo.launch() | |
| # demo.launch(share=True) |