Spaces:
Sleeping
Sleeping
edit gradio ui
Browse files
app.py
CHANGED
|
@@ -2,13 +2,13 @@ import gradio as gr
|
|
| 2 |
from inference import run_inference, reload_model # reload_model์ ๋ชจ๋ธ ์ฌ๋ก๋ฉ ํจ์
|
| 3 |
from utils_prompt import build_webtest_prompt
|
| 4 |
|
|
|
|
| 5 |
def gradio_infer(npc_id, npc_location, player_utt):
|
| 6 |
prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
|
| 7 |
result = run_inference(prompt)
|
| 8 |
return result["npc_output_text"], result["deltas"], result["flags_prob"]
|
| 9 |
|
| 10 |
-
|
| 11 |
-
# API ํธ์ถ์ฉ
|
| 12 |
def api_infer(session_id, npc_id, prompt, max_tokens=200):
|
| 13 |
result = run_inference(prompt)
|
| 14 |
return {
|
|
@@ -20,7 +20,7 @@ def api_infer(session_id, npc_id, prompt, max_tokens=200):
|
|
| 20 |
"thresholds": result["flags_thr"]
|
| 21 |
}
|
| 22 |
|
| 23 |
-
#
|
| 24 |
def ping_reload():
|
| 25 |
reload_model(branch="latest") # latest ๋ธ๋์น์์ ์ฌ๋ค์ด๋ก๋ & ๋ก๋
|
| 26 |
return {"status": "reloaded"}
|
|
@@ -36,10 +36,22 @@ with gr.Blocks() as demo:
|
|
| 36 |
deltas = gr.JSON(label="Deltas")
|
| 37 |
flags = gr.JSON(label="Flags Probabilities")
|
| 38 |
btn = gr.Button("Run Inference")
|
| 39 |
-
btn.click(fn=gradio_infer, inputs=[npc_id, npc_loc, player_utt], outputs=[npc_resp, deltas, flags])
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
|
| 44 |
if __name__ == "__main__":
|
| 45 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|
|
|
|
| 2 |
from inference import run_inference, reload_model # reload_model์ ๋ชจ๋ธ ์ฌ๋ก๋ฉ ํจ์
|
| 3 |
from utils_prompt import build_webtest_prompt
|
| 4 |
|
| 5 |
+
# UI์์ ํธ์ถํ ํจ์
|
| 6 |
def gradio_infer(npc_id, npc_location, player_utt):
|
| 7 |
prompt = build_webtest_prompt(npc_id, npc_location, player_utt)
|
| 8 |
result = run_inference(prompt)
|
| 9 |
return result["npc_output_text"], result["deltas"], result["flags_prob"]
|
| 10 |
|
| 11 |
+
# API ํธ์ถ์ฉ ํจ์
|
|
|
|
| 12 |
def api_infer(session_id, npc_id, prompt, max_tokens=200):
|
| 13 |
result = run_inference(prompt)
|
| 14 |
return {
|
|
|
|
| 20 |
"thresholds": result["flags_thr"]
|
| 21 |
}
|
| 22 |
|
| 23 |
+
# ๋ชจ๋ธ ์ฌ๋ก๋ฉ์ฉ ํจ์
|
| 24 |
def ping_reload():
|
| 25 |
reload_model(branch="latest") # latest ๋ธ๋์น์์ ์ฌ๋ค์ด๋ก๋ & ๋ก๋
|
| 26 |
return {"status": "reloaded"}
|
|
|
|
| 36 |
deltas = gr.JSON(label="Deltas")
|
| 37 |
flags = gr.JSON(label="Flags Probabilities")
|
| 38 |
btn = gr.Button("Run Inference")
|
|
|
|
| 39 |
|
| 40 |
+
# UI ๋ฒํผ ํด๋ฆญ ์ API ์๋ํฌ์ธํธ๋ ์๋ ์์ฑ
|
| 41 |
+
btn.click(
|
| 42 |
+
fn=gradio_infer,
|
| 43 |
+
inputs=[npc_id, npc_loc, player_utt],
|
| 44 |
+
outputs=[npc_resp, deltas, flags],
|
| 45 |
+
api_name="predict_main" # /api/predict_main ์๋ํฌ์ธํธ ์์ฑ
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# ๋ณ๋์ UI ์์ด API๋ง ์ ๊ณตํ๋ ์๋ํฌ์ธํธ
|
| 49 |
+
gr.Button("Reload Model").click(
|
| 50 |
+
fn=ping_reload,
|
| 51 |
+
inputs=[],
|
| 52 |
+
outputs=[],
|
| 53 |
+
api_name="ping_reload" # /api/ping_reload ์๋ํฌ์ธํธ ์์ฑ
|
| 54 |
+
)
|
| 55 |
|
| 56 |
if __name__ == "__main__":
|
| 57 |
demo.launch(server_name="0.0.0.0", server_port=7860)
|