Spaces:
Sleeping
Sleeping
| import os | |
| from knowledge_storm import STORMWikiRunnerArguments, STORMWikiRunner, STORMWikiLMConfigs | |
| from knowledge_storm.lm import OpenAIModel | |
| from knowledge_storm.rm import YouRM | |
| import spaces | |
| import gradio as gr | |
| import json | |
| import re | |
| def convert_references_to_links(text, json_data): | |
| url_mapping = json_data['url_to_unified_index'] | |
| # Function to replace references with markdown links | |
| def replace_reference(match): | |
| ref_num = match.group(1) | |
| url = next((url for url, index in url_mapping.items() if str(index) == ref_num), None) | |
| if url: | |
| return f'[{match.group(0)}]({url})' | |
| return match.group(0) | |
| # Replace references in the text | |
| processed_text = re.sub(r'\[(\d+)\]', replace_reference, text) | |
| # Generate reference list | |
| reference_list = [f"[{index}] {url}" for url, index in sorted(url_mapping.items(), key=lambda x: x[1])] | |
| # Combine processed text and reference list | |
| markdown_output = f"{processed_text}\n\n" + "\n".join(reference_list) | |
| return markdown_output | |
| lm_configs = STORMWikiLMConfigs() | |
| openai_kwargs = { | |
| 'api_key': os.getenv("OPENAI_API_KEY"), | |
| 'temperature': 1.0, | |
| 'top_p': 0.9, | |
| } | |
| # STORM is a LM system so different components can be powered by different models to reach a good balance between cost and quality. | |
| # For a good practice, choose a cheaper/faster model for `conv_simulator_lm` which is used to split queries, synthesize answers in the conversation. | |
| # Choose a more powerful model for `article_gen_lm` to generate verifiable text with citations. | |
| gpt_35 = OpenAIModel(model='gpt-3.5-turbo', max_tokens=500, **openai_kwargs) | |
| gpt_4 = OpenAIModel(model='gpt-4o', max_tokens=3000, **openai_kwargs) | |
| lm_configs.set_conv_simulator_lm(gpt_4) | |
| lm_configs.set_question_asker_lm(gpt_4) | |
| lm_configs.set_outline_gen_lm(gpt_4) | |
| lm_configs.set_article_gen_lm(gpt_4) | |
| lm_configs.set_article_polish_lm(gpt_4) | |
| # Check out the STORMWikiRunnerArguments class for more configurations. | |
| engine_args = STORMWikiRunnerArguments("outputs") | |
| rm = YouRM(ydc_api_key=os.getenv('YDC_API_KEY'), k=engine_args.search_top_k) | |
| runner = STORMWikiRunner(engine_args, lm_configs, rm) | |
| def generate_article(prompt, progress=gr.Progress(track_tqdm=True)): | |
| response = runner.run( | |
| topic=prompt, | |
| do_research=True, | |
| do_generate_outline=True, | |
| do_generate_article=True, | |
| do_polish_article=True, | |
| ) | |
| runner.post_run() | |
| runner.summary() | |
| print(os.listdir()) | |
| generated_folder = prompt.replace(" ", "_") | |
| with open(f'outputs/{generated_folder}/storm_gen_article.txt', 'r') as file: | |
| content = file.read() | |
| with open(f'outputs/{generated_folder}/url_to_info.json', 'r') as file: | |
| references_json = json.load(file) | |
| article_full = convert_references_to_links(f'# {prompt}\n\n'+content, references_json) | |
| return article_full | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Omnipedia article generation demo (Storm GPT-4 + You)") | |
| prompt = gr.Textbox(label="Prompt") | |
| output = gr.Markdown(label="Output") | |
| btn = gr.Button("Generate") | |
| btn.click(fn=generate_article, inputs=prompt, outputs=output) | |
| if __name__ == "__main__": | |
| demo.launch() |