Spaces:
Sleeping
Sleeping
| # -*- coding: utf-8 -*- | |
| """risk_demo.ipynb | |
| Automatically generated by Colaboratory. | |
| Original file is located at | |
| https://colab.research.google.com/drive/10O8RqzRNTUw5fZd-V7dCvS22oAFkTc1i | |
| """ | |
| import json | |
| import random | |
| import time | |
| import gradio as gr | |
| from prompts import load_dict, save_dict | |
| from report import alert, company_analysis, risk, summary, title, translate | |
| GPT_MODEL_DICT = {"4o-mini": "gpt-4o-mini", "4o": "gpt-4o"} | |
| def gen_report(text, gpt, risk_dict, company_info, progress=gr.Progress()): | |
| prompt_dict = load_dict() | |
| if risk_dict == prompt_dict["risk_dict"]: | |
| prompt_dict["risk_dict"] = risk_dict | |
| save_dict(prompt_dict) | |
| timestamp = time.time() | |
| current_time = time.ctime(timestamp) | |
| print("time:", current_time) | |
| gpt_model = GPT_MODEL_DICT[gpt] | |
| print("GPT:", gpt) | |
| print("input:", text) | |
| print("company_info:", company_info) | |
| print("risk_dict:", risk_dict) | |
| progress(0, desc="Starting") | |
| ex_t_rnd = list(range(1, len(prompt_dict["risk_ex_t"]) + 1)) | |
| ex_f_rnd = list(range(1, len(prompt_dict["risk_ex_f"]) + 1)) | |
| random.shuffle(ex_t_rnd) | |
| random.shuffle(ex_f_rnd) | |
| print("seed:", ex_t_rnd, ex_f_rnd) | |
| progress(0.05, desc="company analysis") | |
| com_response = company_analysis(ex_t_rnd, risk_dict, company_info, gpt_model) | |
| progress(0.2, desc="risk") | |
| risk_response = risk( | |
| text, company_info, com_response[0], ex_t_rnd, ex_f_rnd, gpt_model | |
| ) | |
| if risk_response[0]["risk"] != "yes": | |
| print(risk_response[0]) | |
| print("completion_tokens_num:", risk_response[1]) | |
| print("prompt_tokens_num:", risk_response[2]) | |
| print("*" * 20) | |
| return "### ノーリスク" | |
| progress(0.3, desc="title") | |
| title_response = title( | |
| text, company_info, ex_t_rnd, risk_response[0]["risk_key"], gpt_model | |
| ) | |
| progress(0.5, desc="summary") | |
| summary_response = summary(text, ex_t_rnd, gpt_model) | |
| progress(0.7, desc="alert") | |
| alert_response = alert( | |
| text, company_info, ex_t_rnd, risk_response[0]["risk_key"], gpt_model | |
| ) | |
| res_dict = { | |
| "title": title_response[0]["title"], | |
| "summary": summary_response[0]["summary"], | |
| "alert": alert_response[0]["alert"], | |
| "news": text, | |
| "risk_key": risk_response[0]["risk_key"], | |
| "reason": risk_response[0]["reason"], | |
| "company_risk": json.dumps(com_response[0], ensure_ascii=False), | |
| } | |
| # save | |
| progress(0.85, desc="check") | |
| translate_response = translate(res_dict, gpt_model) | |
| progress(0.90, desc="over") | |
| res_msg = prompt_dict["report_msg"].format( | |
| title=translate_response[0]["title"], | |
| summary=translate_response[0]["summary"], | |
| alert=translate_response[0]["alert"], | |
| news=translate_response[0]["news"], | |
| risk=translate_response[0]["risk_key"], | |
| reason=translate_response[0]["reason"], | |
| company_risk=translate_response[0]["company_risk"], | |
| ) | |
| print(translate_response[0]) | |
| completion_tokens_num = [ | |
| i[1] | |
| for i in [ | |
| com_response, | |
| risk_response, | |
| title_response, | |
| summary_response, | |
| alert_response, | |
| translate_response, | |
| ] | |
| ] | |
| prompt_tokens_num = [ | |
| i[2] | |
| for i in [ | |
| com_response, | |
| risk_response, | |
| title_response, | |
| summary_response, | |
| alert_response, | |
| translate_response, | |
| ] | |
| ] | |
| print("completion_tokens_num:", completion_tokens_num) | |
| print("prompt_tokens_num:", prompt_tokens_num) | |
| print("*" * 20) | |
| return res_msg | |
| def example_f(input, choice, risk, info): | |
| prompt_dict = load_dict() | |
| if input == prompt_dict["risk_ex_t"]["ex1"]["news"]: | |
| res = prompt_dict["risk_ex_t"]["ex1"] | |
| elif input == prompt_dict["risk_ex_t"]["ex2"]["news"]: | |
| res = prompt_dict["risk_ex_t"]["ex2"] | |
| elif input == prompt_dict["risk_ex_t"]["ex3"]["news"]: | |
| res = prompt_dict["risk_ex_t"]["ex3"] | |
| else: | |
| print("error") | |
| res_msg = prompt_dict["report_msg"].format( | |
| title=json.loads(res["title"])["title"], | |
| summary=json.loads(res["summary"])["summary"], | |
| alert=json.loads(res["alert"])["alert"], | |
| news=res["news"], | |
| risk=json.loads(res["risk"])["risk_key"], | |
| reason=json.loads(res["risk"])["reason"], | |
| company_risk=res["risk_list"], | |
| ) | |
| return res_msg | |
| with gr.Blocks(title="アラート生成POC", theme="bethecloud/storj_theme") as demo: | |
| gr.Markdown("# アラート生成POC") | |
| gr.Markdown("GPTを通じてアラートメッセージを生成する") | |
| # get prompt dict | |
| prompt_dict = load_dict() | |
| with gr.Row(): | |
| with gr.Column(): | |
| choice = gr.Radio( | |
| choices=["4o-mini", "4o"], value="4o-mini", label="GPTモデル" | |
| ) | |
| input = gr.Textbox(label="ニュース", lines=7) | |
| risk_dict = gr.Textbox( | |
| label="リスクリスト", lines=10, value=prompt_dict["risk_dict"] | |
| ) | |
| company_info = gr.Textbox(label="会社情報", lines=10) | |
| with gr.Column(): | |
| gr.Markdown("アウトプット") | |
| output = gr.Markdown(label="レポート") | |
| gen_btn = gr.Button("生成") | |
| gr.ClearButton([input, output, company_info], value="クリア") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| prompt_dict["risk_ex_t"]["ex1"]["news"], | |
| "4o-mini", | |
| prompt_dict["company_risk_list"], | |
| prompt_dict["risk_ex_t"]["ex1"]["company_info"], | |
| ], | |
| [ | |
| prompt_dict["risk_ex_t"]["ex2"]["news"], | |
| "4o-mini", | |
| prompt_dict["company_risk_list"], | |
| prompt_dict["risk_ex_t"]["ex2"]["company_info"], | |
| ], | |
| [ | |
| prompt_dict["risk_ex_t"]["ex3"]["news"], | |
| "4o-mini", | |
| prompt_dict["company_risk_list"], | |
| prompt_dict["risk_ex_t"]["ex3"]["company_info"], | |
| ], | |
| ], | |
| inputs=[input, choice, risk_dict, company_info], | |
| outputs=[output], | |
| fn=example_f, | |
| cache_examples=True, | |
| label="サンプルデータ", | |
| ) | |
| gen_btn.click( | |
| fn=gen_report, inputs=[input, choice, risk_dict, company_info], outputs=output | |
| ) | |
| demo.launch(inline=False, share=True, debug=True) | |
| # demo.launch() | |