Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import requests | |
| import streamlit as st | |
| import openai | |
| import json | |
| def main(): | |
| st.title("Scientific Question Generation") | |
| st.write("This application is designed to generate a question given a piece of scientific text.\ | |
| We include the output from four different models, the [BART-Large](https://huggingface.co/dhmeltzer/bart-large_askscience-qg) and [FLAN-T5-Base](https://huggingface.co/dhmeltzer/flan-t5-base_askscience-qg) models \ | |
| fine-tuned on the r/AskScience split of the [ELI5 dataset](https://huggingface.co/datasets/eli5) as well as the zero-shot output \ | |
| of the [FLAN-T5-XXL](https://huggingface.co/google/flan-t5-xxl) model and the [GPT-3.5-turbo](https://platform.openai.com/docs/models/gpt-3-5) model.\ | |
| For a more thorough discussion of question generation see this [report](https://wandb.ai/dmeltzer/Question_Generation/reports/Exploratory-Data-Analysis-for-r-AskScience--Vmlldzo0MjQwODg1?accessToken=fndbu2ar26mlbzqdphvb819847qqth2bxyi4hqhugbnv97607mj01qc7ed35v6w8) for EDA on the r/AskScience dataset and this \ | |
| [report](https://api.wandb.ai/links/dmeltzer/7an677es) for details on our training procedure.\ | |
| \n\nThe two fine-tuned models (BART-Large and FLAN-T5-Base) are hosted on AWS using a combination of AWS Sagemaker, Lambda, and API gateway.\ | |
| GPT-3.5 is called using the OpenAI API and the FLAN-T5-XXL model is hosted by HuggingFace and is called with their Inference API.\ | |
| \n \n **Disclaimer**: When first running this application it may take approximately 30 seconds for the first two responses to load because of the cold start problem with AWS Lambda.\ | |
| If this happens, please re-enter the input to call the model again and the models will respond quicker on any subsequent calls.") | |
| AWS_checkpoints = {} | |
| AWS_checkpoints['BART-Large']='https://8hlnvys7bh.execute-api.us-east-1.amazonaws.com/beta/' | |
| AWS_checkpoints['FLAN-T5-Base']='https://gnrxh05827.execute-api.us-east-1.amazonaws.com/beta/' | |
| # Right now HF_checkpoints just consists of FLAN-T5-XXL but we may add more models later. | |
| HF_checkpoints = ['google/flan-t5-xxl'] | |
| # Token to access HF inference API | |
| HF_headers = {"Authorization": f"Bearer {st.secrets['HF_token']}"} | |
| # Token to access OpenAI API | |
| openai.api_key = st.secrets['OpenAI_token'] | |
| # Used to query models hosted on Huggingface | |
| def query(checkpoint, payload): | |
| API_URL = f"https://api-inference.huggingface.co/models/{checkpoint}" | |
| response = requests.post(API_URL, | |
| headers=headers, | |
| json=payload) | |
| return response.json() | |
| # User search | |
| user_input = st.text_area("Question Generator", | |
| """Black holes can evaporate by emitting Hawking radiation.""") | |
| if user_input: | |
| for name, url in AWS_checkpoints.items(): | |
| headers={'x-api-key': st.secrets['aws-key']} | |
| input_data = json.dumps({'inputs':user_input}) | |
| r = requests.get(url,data=input_data,headers=headers) | |
| try: | |
| output = r.json()[0]['generated_text'] | |
| st.write(f'**{name}**: {output}') | |
| except: | |
| st.write(f'**{name}**: There was an error when calling the model. Please resubmit the question.') | |
| model_engine = "gpt-3.5-turbo" | |
| # Max tokens to produce | |
| max_tokens = 50 | |
| # Prompt GPT-3.5 with an explicit question | |
| prompt = f"generate a question: {user_input}" | |
| # We give GPT-3.5 a message so it knows to generate questions from text. | |
| response=openai.ChatCompletion.create( | |
| model=model_engine, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that generates questions from text."}, | |
| {"role": "user", "content": prompt}, | |
| ]) | |
| output = response['choices'][0]['message']['content'] | |
| st.write(f'**{model_engine}**: {output}') | |
| for checkpoint in HF_checkpoints: | |
| model_name = checkpoint.split('/')[1] | |
| # For FLAN models we need to give them instructions explicitly. | |
| if 'flan' in model_name.lower(): | |
| prompt = 'generate a question: ' + user_input | |
| else: | |
| prompt = user_input | |
| output = query(checkpoint,{ | |
| "inputs": prompt, | |
| "wait_for_model":True}) | |
| try: | |
| output=output[0]['generated_text'] | |
| except: | |
| st.write(output) | |
| return | |
| st.write(f'**{model_name}**: {output}') | |
| if __name__ == "__main__": | |
| main() | |
| #[0]['generated_text'] |