| | import gradio as gr |
| | import time |
| | import os |
| | from spinoza_project.source.backend.llm_utils import ( |
| | get_llm_api, |
| | get_vectorstore_api, |
| | ) |
| | from spinoza_project.source.frontend.utils import ( |
| | init_env, |
| | parse_output_llm_with_sources, |
| | ) |
| | from spinoza_project.source.frontend.gradio_utils import ( |
| | get_sources, |
| | set_prompts, |
| | get_config, |
| | get_prompts, |
| | get_assets, |
| | get_theme, |
| | get_init_prompt, |
| | get_synthesis_prompt, |
| | get_qdrants, |
| | get_qdrants_public, |
| | start_agents, |
| | end_agents, |
| | next_call, |
| | zip_longest_fill, |
| | reformulate, |
| | answer, |
| | ) |
| |
|
| | from assets.utils_javascript import ( |
| | accordion_trigger, |
| | accordion_trigger_end, |
| | accordion_trigger_spinoza, |
| | accordion_trigger_spinoza_end, |
| | update_footer, |
| | ) |
| |
|
| | init_env() |
| | config = get_config() |
| |
|
| | |
| | print("Loading Prompts") |
| | prompts = get_prompts(config) |
| | chat_qa_prompts, chat_reformulation_prompts = set_prompts(prompts, config) |
| | synthesis_prompt_template = get_synthesis_prompt(config) |
| |
|
| | |
| | print("Building LLM") |
| | groq_model_name = ( |
| | config["groq_model_name"] if not os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME") else "" |
| | ) |
| | llm = get_llm_api(groq_model_name) |
| |
|
| | |
| | print("Loading Databases") |
| | qdrants = get_qdrants(config) |
| |
|
| | if os.getenv("EKI_OPENAI_LLM_DEPLOYMENT_NAME"): |
| | bdd_presse = get_vectorstore_api("presse") |
| | bdd_afp = get_vectorstore_api("afp") |
| |
|
| | else: |
| | qdrants_public = get_qdrants_public(config) |
| | qdrants = {**qdrants, **qdrants_public} |
| | bdd_presse = None |
| | bdd_afp = None |
| |
|
| | |
| | css, source_information = get_assets() |
| | theme = get_theme() |
| | init_prompt = get_init_prompt() |
| |
|
| |
|
| | def reformulate_questions( |
| | question, |
| | llm=llm, |
| | chat_reformulation_prompts=chat_reformulation_prompts, |
| | config=config, |
| | ): |
| | for elt in zip_longest_fill( |
| | *[ |
| | reformulate(llm, chat_reformulation_prompts, question, tab, config=config) |
| | for tab in config["tabs"] |
| | ] |
| | ): |
| | time.sleep(0.02) |
| | yield elt |
| |
|
| |
|
| | def retrieve_sources( |
| | *questions, |
| | qdrants=qdrants, |
| | bdd_presse=bdd_presse, |
| | bdd_afp=bdd_afp, |
| | config=config, |
| | ): |
| | formated_sources, text_sources = get_sources( |
| | questions, qdrants, bdd_presse, bdd_afp, config |
| | ) |
| |
|
| | return (formated_sources, *text_sources) |
| |
|
| |
|
| | def answer_questions( |
| | *questions_sources, llm=llm, chat_qa_prompts=chat_qa_prompts, config=config |
| | ): |
| | questions = [elt for elt in questions_sources[: len(questions_sources) // 2]] |
| | sources = [elt for elt in questions_sources[len(questions_sources) // 2 :]] |
| |
|
| | for elt in zip_longest_fill( |
| | *[ |
| | answer(llm, chat_qa_prompts, question, source, tab, config) |
| | for question, source, tab in zip(questions, sources, config["tabs"]) |
| | ] |
| | ): |
| | time.sleep(0.02) |
| | yield [ |
| | [(question, parse_output_llm_with_sources(ans))] |
| | for question, ans in zip(questions, elt) |
| | ] |
| |
|
| |
|
| | def get_synthesis( |
| | question, |
| | *answers, |
| | llm=llm, |
| | synthesis_prompt_template=synthesis_prompt_template, |
| | config=config, |
| | ): |
| | answer = [] |
| | for i, tab in enumerate(config["tabs"]): |
| | if len(str(answers[i])) >= 100: |
| | answer.append( |
| | f"{tab}\n{answers[i]}".replace("<p>", "").replace("</p>\n", "") |
| | ) |
| |
|
| | if len(answer) == 0: |
| | return "Aucune source n'a pu être identifiée pour répondre, veuillez modifier votre question" |
| | else: |
| | for elt in llm.stream( |
| | synthesis_prompt_template, |
| | { |
| | "question": question.replace("<p>", "").replace("</p>\n", ""), |
| | "answers": "\n\n".join(answer), |
| | }, |
| | ): |
| | time.sleep(0.01) |
| | yield [(question, parse_output_llm_with_sources(elt))] |
| |
|
| |
|
| | with gr.Blocks( |
| | title=f"🔍 Spinoza", |
| | css=css, |
| | js=update_footer(), |
| | theme=theme, |
| | ) as demo: |
| | chatbots = {} |
| | question = gr.State("") |
| | docs_textbox = gr.State([""]) |
| | agent_questions = {elt: gr.State("") for elt in config["tabs"]} |
| | component_sources = {elt: gr.State("") for elt in config["tabs"]} |
| | text_sources = {elt: gr.State("") for elt in config["tabs"]} |
| | tab_states = {elt: gr.State(elt) for elt in config["tabs"]} |
| |
|
| | with gr.Tab("Q&A", elem_id="main-component"): |
| | with gr.Row(elem_id="chatbot-row"): |
| | with gr.Column(scale=2, elem_id="center-panel"): |
| | with gr.Group(elem_id="chatbot-group"): |
| | for tab in list(config["tabs"].keys()) + ["Spinoza"]: |
| | if tab == "Spinoza": |
| | agent_name = f"Spinoza" |
| | elem_id = f"accordion-{tab}" |
| | elem_classes = "accordion accordion-agent spinoza-agent" |
| | else: |
| | agent_name = f"Agent {config['source_mapping'][tab]}" |
| | elem_id = f"accordion-{config['source_mapping'][tab]}" |
| | elem_classes = "accordion accordion-agent" |
| |
|
| | with gr.Accordion( |
| | agent_name, |
| | open=True if agent_name == "Spinoza" else False, |
| | elem_id=elem_id, |
| | elem_classes=elem_classes, |
| | ): |
| | |
| | chatbots[tab] = gr.Chatbot( |
| | value=( |
| | [(None, init_prompt)] |
| | if agent_name == "Spinoza" |
| | else None |
| | ), |
| | show_copy_button=True, |
| | show_share_button=False, |
| | show_label=False, |
| | elem_id=f"chatbot-{agent_name.lower().replace(' ', '-')}", |
| | layout="panel", |
| | avatar_images=( |
| | "./assets/logos/help.png", |
| | ( |
| | "./assets/logos/spinoza.png" |
| | if agent_name == "Spinoza" |
| | else None |
| | ), |
| | ), |
| | ) |
| |
|
| | with gr.Row(elem_id="input-message"): |
| | ask = gr.Textbox( |
| | placeholder="Ask me anything here!", |
| | show_label=False, |
| | scale=7, |
| | lines=1, |
| | interactive=True, |
| | elem_id="input-textbox", |
| | ) |
| |
|
| | with gr.Column(scale=1, variant="panel", elem_id="right-panel"): |
| | with gr.TabItem("Sources", elem_id="tab-sources", id=0): |
| | sources_textbox = gr.HTML( |
| | show_label=False, elem_id="sources-textbox" |
| | ) |
| |
|
| | with gr.Tab("Source information", elem_id="source-component"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.Markdown(source_information) |
| |
|
| | with gr.Tab("Contact", elem_id="contact-component"): |
| | with gr.Row(): |
| | with gr.Column(scale=1): |
| | gr.Markdown("For any issue contact **spinoza.support@ekimetrics.com**.") |
| |
|
| | ask.submit( |
| | start_agents, inputs=[], outputs=[chatbots["Spinoza"]], js=accordion_trigger() |
| | ).then( |
| | fn=reformulate_questions, |
| | inputs=[ask], |
| | outputs=[agent_questions[tab] for tab in config["tabs"]], |
| | ).then( |
| | fn=retrieve_sources, |
| | inputs=[agent_questions[tab] for tab in config["tabs"]], |
| | outputs=[sources_textbox] + [text_sources[tab] for tab in config["tabs"]], |
| | ).then( |
| | fn=answer_questions, |
| | inputs=[agent_questions[tab] for tab in config["tabs"]] |
| | + [text_sources[tab] for tab in config["tabs"]], |
| | outputs=[chatbots[tab] for tab in config["tabs"]], |
| | ).then( |
| | fn=next_call, inputs=[], outputs=[], js=accordion_trigger_end() |
| | ).then( |
| | fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza() |
| | ).then( |
| | fn=get_synthesis, |
| | inputs=[agent_questions[list(config["tabs"].keys())[1]]] |
| | + [chatbots[tab] for tab in config["tabs"]], |
| | outputs=[chatbots["Spinoza"]], |
| | ).then( |
| | fn=next_call, inputs=[], outputs=[], js=accordion_trigger_spinoza_end() |
| | ).then( |
| | fn=end_agents, inputs=[], outputs=[] |
| | ) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | demo.queue().launch(debug=True) |
| |
|