Alert_Demo / app.py
yeelou's picture
Upload 6 files
b281bf6 verified
# -*- coding: utf-8 -*-
"""risk_demo.ipynb
Automatically generated by Colaboratory.
Original file is located at
https://colab.research.google.com/drive/10O8RqzRNTUw5fZd-V7dCvS22oAFkTc1i
"""
import json
import random
import time
import gradio as gr
from prompts import load_dict, save_dict
from report import alert, company_analysis, risk, summary, title, translate
GPT_MODEL_DICT = {"4o-mini": "gpt-4o-mini", "4o": "gpt-4o"}
def gen_report(text, gpt, risk_dict, company_info, progress=gr.Progress()):
prompt_dict = load_dict()
if risk_dict == prompt_dict["risk_dict"]:
prompt_dict["risk_dict"] = risk_dict
save_dict(prompt_dict)
timestamp = time.time()
current_time = time.ctime(timestamp)
print("time:", current_time)
gpt_model = GPT_MODEL_DICT[gpt]
print("GPT:", gpt)
print("input:", text)
print("company_info:", company_info)
print("risk_dict:", risk_dict)
progress(0, desc="Starting")
ex_t_rnd = list(range(1, len(prompt_dict["risk_ex_t"]) + 1))
ex_f_rnd = list(range(1, len(prompt_dict["risk_ex_f"]) + 1))
random.shuffle(ex_t_rnd)
random.shuffle(ex_f_rnd)
print("seed:", ex_t_rnd, ex_f_rnd)
progress(0.05, desc="company analysis")
com_response = company_analysis(ex_t_rnd, risk_dict, company_info, gpt_model)
progress(0.2, desc="risk")
risk_response = risk(
text, company_info, com_response[0], ex_t_rnd, ex_f_rnd, gpt_model
)
if risk_response[0]["risk"] != "yes":
print(risk_response[0])
print("completion_tokens_num:", risk_response[1])
print("prompt_tokens_num:", risk_response[2])
print("*" * 20)
return "### ノーリスク"
progress(0.3, desc="title")
title_response = title(
text, company_info, ex_t_rnd, risk_response[0]["risk_key"], gpt_model
)
progress(0.5, desc="summary")
summary_response = summary(text, ex_t_rnd, gpt_model)
progress(0.7, desc="alert")
alert_response = alert(
text, company_info, ex_t_rnd, risk_response[0]["risk_key"], gpt_model
)
res_dict = {
"title": title_response[0]["title"],
"summary": summary_response[0]["summary"],
"alert": alert_response[0]["alert"],
"news": text,
"risk_key": risk_response[0]["risk_key"],
"reason": risk_response[0]["reason"],
"company_risk": json.dumps(com_response[0], ensure_ascii=False),
}
# save
progress(0.85, desc="check")
translate_response = translate(res_dict, gpt_model)
progress(0.90, desc="over")
res_msg = prompt_dict["report_msg"].format(
title=translate_response[0]["title"],
summary=translate_response[0]["summary"],
alert=translate_response[0]["alert"],
news=translate_response[0]["news"],
risk=translate_response[0]["risk_key"],
reason=translate_response[0]["reason"],
company_risk=translate_response[0]["company_risk"],
)
print(translate_response[0])
completion_tokens_num = [
i[1]
for i in [
com_response,
risk_response,
title_response,
summary_response,
alert_response,
translate_response,
]
]
prompt_tokens_num = [
i[2]
for i in [
com_response,
risk_response,
title_response,
summary_response,
alert_response,
translate_response,
]
]
print("completion_tokens_num:", completion_tokens_num)
print("prompt_tokens_num:", prompt_tokens_num)
print("*" * 20)
return res_msg
def example_f(input, choice, risk, info):
prompt_dict = load_dict()
if input == prompt_dict["risk_ex_t"]["ex1"]["news"]:
res = prompt_dict["risk_ex_t"]["ex1"]
elif input == prompt_dict["risk_ex_t"]["ex2"]["news"]:
res = prompt_dict["risk_ex_t"]["ex2"]
elif input == prompt_dict["risk_ex_t"]["ex3"]["news"]:
res = prompt_dict["risk_ex_t"]["ex3"]
else:
print("error")
res_msg = prompt_dict["report_msg"].format(
title=json.loads(res["title"])["title"],
summary=json.loads(res["summary"])["summary"],
alert=json.loads(res["alert"])["alert"],
news=res["news"],
risk=json.loads(res["risk"])["risk_key"],
reason=json.loads(res["risk"])["reason"],
company_risk=res["risk_list"],
)
return res_msg
with gr.Blocks(title="アラート生成POC", theme="bethecloud/storj_theme") as demo:
gr.Markdown("# アラート生成POC")
gr.Markdown("GPTを通じてアラートメッセージを生成する")
# get prompt dict
prompt_dict = load_dict()
with gr.Row():
with gr.Column():
choice = gr.Radio(
choices=["4o-mini", "4o"], value="4o-mini", label="GPTモデル"
)
input = gr.Textbox(label="ニュース", lines=7)
risk_dict = gr.Textbox(
label="リスクリスト", lines=10, value=prompt_dict["risk_dict"]
)
company_info = gr.Textbox(label="会社情報", lines=10)
with gr.Column():
gr.Markdown("アウトプット")
output = gr.Markdown(label="レポート")
gen_btn = gr.Button("生成")
gr.ClearButton([input, output, company_info], value="クリア")
gr.Examples(
examples=[
[
prompt_dict["risk_ex_t"]["ex1"]["news"],
"4o-mini",
prompt_dict["company_risk_list"],
prompt_dict["risk_ex_t"]["ex1"]["company_info"],
],
[
prompt_dict["risk_ex_t"]["ex2"]["news"],
"4o-mini",
prompt_dict["company_risk_list"],
prompt_dict["risk_ex_t"]["ex2"]["company_info"],
],
[
prompt_dict["risk_ex_t"]["ex3"]["news"],
"4o-mini",
prompt_dict["company_risk_list"],
prompt_dict["risk_ex_t"]["ex3"]["company_info"],
],
],
inputs=[input, choice, risk_dict, company_info],
outputs=[output],
fn=example_f,
cache_examples=True,
label="サンプルデータ",
)
gen_btn.click(
fn=gen_report, inputs=[input, choice, risk_dict, company_info], outputs=output
)
demo.launch(inline=False, share=True, debug=True)
# demo.launch()