Spaces:
Runtime error
Runtime error
| import os | |
| from typing import Dict | |
| import gradio as gr | |
| import pandas as pd | |
| from chat_task.chat import generate_chat | |
| from doc_qa_task.doc_qa import generate_doc_qa | |
| from examples import ( | |
| load_examples, | |
| preprocess_docqa_examples, | |
| preprocess_extraction_examples, | |
| preprocess_qa_generator_examples, | |
| ) | |
| from extract_data_task.extract import extract_slots | |
| from plugin_task.api import api_plugin_chat | |
| from qa_generator_task.generate_qa import generate_qa_pairs | |
| from plugin_task.plugins import PLUGIN_JSON_SCHEMA | |
| abs_path = os.path.abspath(__file__) | |
| current_dir = os.path.dirname(abs_path) | |
| statistic_path = os.path.join(current_dir, "images") | |
| load_examples() | |
| def clear_session(): | |
| """Clears the chat session.""" | |
| return "", None | |
| def clear_plugin_session(session: Dict): | |
| """Clears the plugin session.""" | |
| session.clear() | |
| return session, None, None | |
| def show_custom_fallback_textbox(x): | |
| if x == "自定义话术": | |
| return [gr.Row(visible=True), gr.Textbox()] | |
| else: | |
| return [gr.Row(visible=False), gr.Textbox()] | |
| def validate_field_word_count( | |
| input_text: str, description: str, max_word_count: int = 3000 | |
| ): | |
| """ | |
| Validate the input text for word count | |
| :param input_text: | |
| :return: | |
| """ | |
| if len(input_text) == 0: | |
| raise gr.Error(f"{description}不能为空") | |
| if len(input_text) > max_word_count: | |
| raise gr.Error(f"{description}字数不能超过{max_word_count}字") | |
| def validate_chat(input_text: str): | |
| """ | |
| Validate the input text | |
| :param input_text: | |
| :return: | |
| """ | |
| validate_field_word_count(input_text, "输入", 500) | |
| def validate_doc_qa( | |
| input_text: str, | |
| doc_df: "pd.DataFrame", | |
| fallback_ratio: str, | |
| fallback_text_input: str, | |
| ): | |
| """ | |
| Validate fields of doc_qa | |
| :param input_text: | |
| :param doc_df: | |
| :param fallback_ratio: | |
| :param fallback_text_input: | |
| :return: | |
| """ | |
| # add all the doc ids to the input text | |
| if fallback_ratio == "自定义话术": | |
| validate_field_word_count(fallback_text_input, "自定义话术", 100) | |
| validate_field_word_count(input_text, "输入", 500) | |
| page_content_full_text = ( | |
| " ".join(doc_df["文档片段名称"].tolist()) | |
| + " " | |
| + " ".join(doc_df["文档片段内容"].tolist()) | |
| ) | |
| validate_field_word_count(page_content_full_text, "文档信息", 2500) | |
| def validate_qa_pair_generator(input_text: str): | |
| """ | |
| Validate the input text | |
| :param input_text: | |
| :return: | |
| """ | |
| return validate_field_word_count(input_text, "输入") | |
| def validate_extraction( | |
| input_text: str, | |
| extraction_df: "pd.DataFrame", | |
| ): | |
| """ | |
| Validate fields of extraction | |
| """ | |
| extraction_full_text = ( | |
| " ".join(extraction_df["字段名称"].tolist()) | |
| + " " | |
| + " ".join(extraction_df["字段描述"].tolist()) | |
| ) | |
| validate_field_word_count(input_text, "输入", 1500) | |
| validate_field_word_count(extraction_full_text, "待抽取字段描述", 1500) | |
| def validate_plugin(input_text: str): | |
| """ | |
| Validate the input text | |
| :param input_text: | |
| :return: | |
| """ | |
| validate_field_word_count(input_text, "输入", 500) | |
| with gr.Blocks( | |
| title="Orion-14B", | |
| theme="shivi/calm_seafoam@>=0.0.1,<1.0.0", | |
| ) as demo: | |
| def user(user_message, history): | |
| return user_message, (history or []) + [[user_message, ""]] | |
| gr.Markdown( | |
| """ | |
| <div style="overflow: hidden;color:#fff;display: flex;flex-direction: column;align-items: center; position: relative; width: 100%; height: 180px;background-size: cover; background-image: url(https://www.orionstar.com/res/orics/down/ow001_20240119_8369eca9013416109a2303bf4e329140.png);"> | |
| <img style="width: 130px;height: 60px;position: absolute;top:10px;left:10px" src="https://www.orionstar.com/res/orics/down/ow001_20240119_1236eba7ea0ac15931f4518d7f211d47.png"/> | |
| <img style="min-width: 1416px; width: 1416px;height: 100px;margin-top: 30px;" src="https://www.orionstar.com/res/orics/down/ow001_20240119_10c5ca12a57116bda0e35916a28b247f.png"/> | |
| <span style="margin-top: 10px;font-size: 12px;">请在<a href="https://github.com/OrionStarAI/Orion" style="color: white;">Github</a>点击Star支持我们,加入<a href="https://www.orionstar.com/res/orics/down/ow001_20240122_d87e5b4ea66a31493c38fcffe7bdb453.png" style="color: white;">官方微信交流群</a></span> | |
| </div> | |
| """ | |
| ) | |
| with gr.Tab("基础能力"): | |
| chatbot = gr.Chatbot( | |
| label="Orion-14B-Chat", | |
| elem_classes="control-height", | |
| show_copy_button=True, | |
| min_width=1368, | |
| height=416, | |
| ) | |
| chat_text_input = gr.Textbox(label="输入", min_width=1368) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Examples( | |
| [ | |
| "可以给我讲个笑话吗?", | |
| "什么是伟大的诗歌?", | |
| "你知道李白吗?", | |
| "黑洞是如何工作的?", | |
| "在表中插入一条数据,id为1,name为张三,age为18,请问SQL语句是什么?", | |
| ], | |
| chat_text_input, | |
| label="试试问", | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(variant="compact"): | |
| clear_history = gr.Button( | |
| "清除历史", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "clear.png"), | |
| ) | |
| submit = gr.Button( | |
| "发送", | |
| variant="primary", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "send.svg"), | |
| ) | |
| chat_text_input.submit( | |
| fn=validate_chat, inputs=[chat_text_input], outputs=[], queue=False | |
| ).success( | |
| user, [chat_text_input, chatbot], [chat_text_input, chatbot], queue=False | |
| ).success( | |
| fn=generate_chat, | |
| inputs=[chat_text_input, chatbot], | |
| outputs=[chat_text_input, chatbot], | |
| ) | |
| submit.click( | |
| fn=validate_chat, inputs=[chat_text_input], outputs=[], queue=False | |
| ).success( | |
| user, [chat_text_input, chatbot], [chat_text_input, chatbot], queue=False | |
| ).success( | |
| fn=generate_chat, | |
| inputs=[chat_text_input, chatbot], | |
| outputs=[chat_text_input, chatbot], | |
| api_name="chat", | |
| ) | |
| clear_history.click( | |
| fn=clear_session, inputs=[], outputs=[chat_text_input, chatbot], queue=False | |
| ) | |
| with gr.Tab("基于文档问答"): | |
| with gr.Row(): | |
| with gr.Column(scale=3, min_width=357, variant="panel"): | |
| gr.Markdown( | |
| '<span style="color:rgba(0, 0, 0, 0.5); font-size: 14px; font-weight: 400; line-height: 28px; letter-spacing: 0em; text-align: left; width: 42px; height: 14px; left: 36px; top: 255px;">配置项</span>' | |
| ) | |
| citations_radio = gr.Radio( | |
| ["开启引用", "关闭引用"], label="引用", value="关闭引用" | |
| ) | |
| fallback_radio = gr.Radio( | |
| ["使用大模型知识", "自定义话术"], | |
| label="超纲问题回复", | |
| value="自定义话术", | |
| ) | |
| fallback_text_input = gr.Textbox( | |
| label="自定义话术", | |
| value="抱歉,我还在学习中,暂时无法回答您的问题。", | |
| ) | |
| gr.Markdown( | |
| '<span style="color:rgba(0, 0, 0, 0.5); font-size: 14px; font-weight: 400; line-height: 28px; letter-spacing: 0em; text-align: left; width: 42px; height: 14px; left: 36px; top: 255px;">文档信息</span>' | |
| ) | |
| doc_df = gr.Dataframe( | |
| headers=["文档片段内容", "文档片段名称"], | |
| datatype=["str", "str"], | |
| row_count=6, | |
| col_count=(2, "fixed"), | |
| label="", | |
| interactive=True, | |
| wrap=True, | |
| elem_classes="control-height", | |
| height=300, | |
| ) | |
| with gr.Column(scale=2, min_width=430): | |
| chatbot = gr.Chatbot( | |
| label="适用场景:预期LLM通过自由知识回答", | |
| elem_classes="control-height", | |
| show_copy_button=True, | |
| min_width=999, | |
| height=419, | |
| ) | |
| doc_qa_input = gr.Textbox(label="输入", min_width=999, max_lines=10) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Examples( | |
| [ | |
| "哪些情况下不能超车?", | |
| "参观须知", | |
| "青岛啤酒酒精含量是多少?", | |
| ], | |
| doc_qa_input, | |
| label="试试问", | |
| cache_examples=True, | |
| fn=preprocess_docqa_examples, | |
| outputs=[doc_df], | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(variant="compact"): | |
| clear_history = gr.Button( | |
| "清除历史", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "clear.png"), | |
| ) | |
| submit = gr.Button( | |
| "发送", | |
| variant="primary", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "send.svg"), | |
| ) | |
| doc_qa_input.submit( | |
| fn=validate_doc_qa, | |
| inputs=[ | |
| doc_qa_input, | |
| doc_df, | |
| fallback_radio, | |
| fallback_text_input, | |
| ], | |
| outputs=[], | |
| queue=False, | |
| ).success( | |
| user, [doc_qa_input, chatbot], [doc_qa_input, chatbot], queue=False | |
| ).success( | |
| fn=generate_doc_qa, | |
| inputs=[ | |
| doc_qa_input, | |
| chatbot, | |
| doc_df, | |
| fallback_radio, | |
| fallback_text_input, | |
| citations_radio, | |
| ], | |
| outputs=[doc_qa_input, chatbot], | |
| scroll_to_output=True, | |
| api_name="doc_qa", | |
| ) | |
| submit.click( | |
| fn=validate_doc_qa, | |
| inputs=[ | |
| doc_qa_input, | |
| doc_df, | |
| fallback_radio, | |
| fallback_text_input, | |
| ], | |
| outputs=[], | |
| queue=False, | |
| ).success( | |
| user, [doc_qa_input, chatbot], [doc_qa_input, chatbot], queue=False | |
| ).success( | |
| fn=generate_doc_qa, | |
| inputs=[ | |
| doc_qa_input, | |
| chatbot, | |
| doc_df, | |
| fallback_radio, | |
| fallback_text_input, | |
| citations_radio, | |
| ], | |
| outputs=[doc_qa_input, chatbot], | |
| scroll_to_output=True, | |
| ) | |
| clear_history.click( | |
| fn=lambda x: (None, None, None), | |
| inputs=[], | |
| outputs=[doc_df, doc_qa_input, chatbot], | |
| queue=False, | |
| ) | |
| with gr.Tab("插件能力"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown( | |
| '<span style="color:rgba(0, 0, 0, 0.5); font-size: 14px; font-weight: 400; line-height: 28px; letter-spacing: 0em; text-align: left; width: 42px; height: 14px; left: 36px; top: 255px;">配置项</span>' | |
| ) | |
| radio_plugins = [ | |
| gr.Radio( | |
| ["开启", "关闭"], | |
| label=plugin_json["name_for_human"], | |
| value="开启", | |
| ) | |
| for plugin_json in PLUGIN_JSON_SCHEMA | |
| ] | |
| with gr.Column(scale=3): | |
| session = gr.State(value=dict()) | |
| chatbot = gr.Chatbot( | |
| label="适用场景:需要LLM调用API解决问题", | |
| elem_classes="control-height", | |
| show_copy_button=True, | |
| ) | |
| plugin_text_input = gr.Textbox(label="输入") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Examples( | |
| [ | |
| "北京天气怎么样?", | |
| "查询物流信息", | |
| "每日壁纸", | |
| "bing今天的壁纸是什么", | |
| "查询手机号码归属地", | |
| ], | |
| plugin_text_input, | |
| label="试试问", | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(variant="compact"): | |
| clear_history = gr.Button( | |
| "清除历史", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "clear.png"), | |
| ) | |
| submit = gr.Button( | |
| "发送", | |
| variant="primary", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "send.svg"), | |
| ) | |
| plugin_text_input.submit( | |
| fn=validate_plugin, | |
| inputs=[ | |
| plugin_text_input, | |
| ], | |
| outputs=[], | |
| queue=False, | |
| ).success( | |
| user, | |
| [plugin_text_input, chatbot], | |
| [plugin_text_input, chatbot], | |
| scroll_to_output=True, | |
| ).success( | |
| fn=api_plugin_chat, | |
| inputs=[session, plugin_text_input, chatbot, *radio_plugins], | |
| outputs=[session, plugin_text_input, chatbot], | |
| scroll_to_output=True, | |
| ) | |
| submit.click( | |
| fn=validate_plugin, | |
| inputs=[ | |
| plugin_text_input, | |
| ], | |
| outputs=[], | |
| queue=False, | |
| ).success( | |
| user, | |
| [plugin_text_input, chatbot], | |
| [plugin_text_input, chatbot], | |
| scroll_to_output=True, | |
| ).success( | |
| fn=api_plugin_chat, | |
| inputs=[session, plugin_text_input, chatbot, *radio_plugins], | |
| outputs=[session, plugin_text_input, chatbot], | |
| api_name="plugin", | |
| scroll_to_output=True, | |
| ) | |
| clear_history.click( | |
| fn=clear_plugin_session, | |
| inputs=[session], | |
| outputs=[session, plugin_text_input, chatbot], | |
| queue=False, | |
| ) | |
| with gr.Tab("生成QA对"): | |
| with gr.Row(equal_height=True): | |
| qa_generator_output = gr.Code( | |
| language="json", | |
| show_label=False, | |
| min_width=1368, | |
| ) | |
| with gr.Row(): | |
| qa_generator_input = gr.Textbox( | |
| label="输入", | |
| show_label=True, | |
| info="", | |
| min_width=1368, | |
| lines=5, | |
| max_lines=10, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Examples( | |
| [ | |
| "第一章 总 则 \n第...", | |
| "金字塔,在建筑学上是...", | |
| "山西老陈醋是以高粱、...", | |
| "室内装饰构造虚拟仿真...", | |
| "猎户星空(Orion...", | |
| ], | |
| qa_generator_input, | |
| label="试试问", | |
| cache_examples=True, | |
| fn=preprocess_qa_generator_examples, | |
| outputs=[qa_generator_input], | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(variant="compact"): | |
| clear = gr.Button( | |
| "清除", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "clear.png"), | |
| ) | |
| submit = gr.Button( | |
| "发送", | |
| variant="primary", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "send.svg"), | |
| ) | |
| submit.click( | |
| fn=validate_qa_pair_generator, | |
| inputs=[qa_generator_input], | |
| outputs=[], | |
| ).success( | |
| fn=generate_qa_pairs, | |
| inputs=[qa_generator_input], | |
| outputs=[qa_generator_output, qa_generator_input], | |
| scroll_to_output=True, | |
| api_name="qa_generator", | |
| ) | |
| clear.click( | |
| fn=lambda x: ("", ""), | |
| inputs=[], | |
| outputs=[qa_generator_input, qa_generator_output], | |
| queue=False, | |
| ) | |
| with gr.Tab("抽取数据"): | |
| extract_outpu_df = gr.Dataframe( | |
| label="", | |
| headers=["字段名称", "字段抽取结果"], | |
| datatype=["str", "str"], | |
| col_count=(2, "fixed"), | |
| wrap=True, | |
| elem_classes="control-height", | |
| height=234, | |
| row_count=5, | |
| ) | |
| extract_input = gr.Textbox(label="输入", lines=5, min_width=1368, max_lines=10) | |
| extraction_df = gr.Dataframe( | |
| headers=["字段名称", "字段描述"], | |
| datatype=["str", "str"], | |
| row_count=3, | |
| col_count=(2, "fixed"), | |
| label="", | |
| interactive=True, | |
| wrap=True, | |
| elem_classes="control-height", | |
| height=180, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| gr.Examples( | |
| ["第一条合同当...", "发票编号: IN...", "发件人:John..."], | |
| extract_input, | |
| label="试试问", | |
| cache_examples=True, | |
| fn=preprocess_extraction_examples, | |
| outputs=[extract_input, extraction_df], | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Row(variant="compact"): | |
| clear = gr.Button( | |
| "清除历史", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "clear.png"), | |
| ) | |
| submit = gr.Button( | |
| "发送", | |
| variant="primary", | |
| min_width="17", | |
| size="sm", | |
| scale=1, | |
| icon=os.path.join(statistic_path, "send.svg"), | |
| ) | |
| submit.click( | |
| fn=validate_extraction, | |
| inputs=[extract_input, extraction_df], | |
| outputs=[], | |
| ).success( | |
| fn=extract_slots, | |
| inputs=[extract_input, extraction_df], | |
| outputs=[extract_outpu_df], | |
| scroll_to_output=True, | |
| api_name="extract", | |
| ) | |
| clear.click( | |
| fn=lambda x: ("", None, None), | |
| inputs=[], | |
| outputs=[ | |
| extract_input, | |
| extraction_df, | |
| extract_outpu_df, | |
| ], | |
| queue=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue(api_open=False, max_size=40).launch( | |
| height=800, | |
| share=False, | |
| server_name="0.0.0.0", | |
| show_api=False, | |
| max_threads=4, | |
| ) | |