diabetesLLM / app.py
KS00Max's picture
updated
73fa6e6
import inspect
import logging
import uuid
from typing import List, Dict
import gradio as gr
from core.conversation import ConversationEngine
logging.basicConfig(level=logging.INFO)
try:
engine = ConversationEngine()
engine_error = None
except Exception as exc: # noqa: BLE001
engine = None
engine_error = str(exc)
logging.exception("Engine initialization failed: %s", exc)
def init_state():
sid = engine.session_manager.new_session_id() if engine else str(uuid.uuid4())
return {"session_id": sid, "pending": False}
def _normalize_choice(choice_value: str | None) -> str | None:
if not choice_value:
return None
value = choice_value.strip()
return value.split(":")[0].strip() if ":" in value else value
def _extract_text(item) -> str:
"""Recursively extract plain text from any Gradio message format."""
if item is None:
return ""
if isinstance(item, str):
return item
if isinstance(item, list):
for elem in item:
text = _extract_text(elem)
if text:
return text
return ""
if isinstance(item, dict):
for key in ("content", "text", "value"):
if key in item:
text = _extract_text(item[key])
if text:
return text
return ""
if hasattr(item, "content"):
return _extract_text(getattr(item, "content"))
return str(item)
def _ensure_messages_format(history: List) -> List[Dict[str, str]]:
"""Convert any history format to messages format: [{"role": "user", "content": "..."}, ...]"""
if not history:
return []
messages: List[Dict[str, str]] = []
for item in history:
if isinstance(item, dict) and "role" in item and "content" in item:
# Already messages format - extract clean text
content = _extract_text(item.get("content"))
if content:
messages.append({"role": str(item["role"]), "content": content})
elif hasattr(item, "role") and hasattr(item, "content"):
# ChatMessage object
content = _extract_text(getattr(item, "content"))
if content:
messages.append({"role": str(getattr(item, "role")), "content": content})
elif isinstance(item, (list, tuple)) and len(item) == 2:
# Tuple format (user, assistant)
user_text = _extract_text(item[0])
bot_text = _extract_text(item[1])
if user_text:
messages.append({"role": "user", "content": user_text})
if bot_text:
messages.append({"role": "assistant", "content": bot_text})
return messages
def _append_message(history: List[Dict], role: str, content: str) -> List[Dict]:
"""Append a message in messages format."""
new_history = list(history)
if content:
new_history.append({"role": role, "content": content})
return new_history
def respond(
user_message: str,
chat_history: List,
app_state: dict,
choice_value: str | None,
):
"""Main response handler using messages format."""
state = app_state or init_state()
session_id = state.get("session_id") or init_state()["session_id"]
state["session_id"] = session_id
# Normalize history to clean messages format
history = _ensure_messages_format(chat_history)
if engine is None:
reply = f"初期化エラー: {engine_error}. OPENAI_API_KEY を確認してください。"
history = _append_message(history, "user", user_message or "(入力なし)")
history = _append_message(history, "assistant", reply)
return history, state, "", gr.update(choices=[], value=None, visible=False), [], {}
pending = state.get("pending", False)
citations = []
trace = {}
if pending:
# Handle clarifying question response
choice_id = _normalize_choice(choice_value) or _normalize_choice(user_message)
if not choice_id:
return history, state, "", gr.update(visible=True), citations, trace
result = engine.handle_clarifying_answer(session_id, choice_id)
user_bubble = f"Clarifying回答: {choice_id}"
else:
if not user_message:
warn = "質問を入力してください。"
history = _append_message(history, "user", "(入力なし)")
history = _append_message(history, "assistant", warn)
return history, state, "", gr.update(choices=[], value=None, visible=False), citations, trace
result = engine.handle_user_message(session_id, user_message)
user_bubble = user_message
if result.get("type") == "clarify":
state["pending"] = True
options = [f"{c.id}: {c.text}" for c in result["question"].choices]
trace = result.get("trace", {})
history = _append_message(history, "user", user_bubble)
history = _append_message(history, "assistant", result["reply"])
return (
history,
state,
"",
gr.update(choices=options, value=None, visible=True),
citations,
trace,
)
if result.get("type") == "answer":
state["pending"] = False
citations = result.get("citations", [])
trace = result.get("trace", {})
history = _append_message(history, "user", user_bubble)
history = _append_message(history, "assistant", result["reply"])
return (
history,
state,
"",
gr.update(choices=[], value=None, visible=False),
citations,
trace,
)
# Error fallback
history = _append_message(history, "user", user_bubble)
history = _append_message(history, "assistant", result.get("reply", "エラーが発生しました。"))
state["pending"] = False
trace = result.get("trace", {})
return history, state, "", gr.update(choices=[], value=None, visible=False), citations, trace
def reset():
new_state = init_state()
return [], new_state, "", gr.update(choices=[], value=None, visible=False), [], {}
with gr.Blocks() as demo:
gr.HTML(
"""
<style>
:root { --font-main: "Noto Sans JP", "Segoe UI", "Helvetica Neue", Arial, sans-serif; }
.gradio-container { font-family: var(--font-main); }
#title-card { background: linear-gradient(135deg, #0f766e, #0ea5e9); color: #f8fafc; padding: 16px 18px; border-radius: 12px; margin-bottom: 8px; }
#title-card h1 { margin: 0; font-size: 1.4rem; }
#title-card p { margin: 4px 0 0 0; opacity: 0.9; }
</style>
<div id="title-card">
<h1>糖尿病ガイドライン RAG アシスタント</h1>
<p>シックデイ、低血糖、足病変などの場面別 Q&A を日本語で案内します。診断・治療を決める用途ではなく、受診判断やセルフケアの補助として利用してください。</p>
</div>
"""
)
with gr.Row():
with gr.Column(scale=3):
# Messages format for Gradio 4.44+
chatbot = gr.Chatbot(label="対話", height=520)
user_input = gr.Textbox(
label="質問を入力",
placeholder="例: 熱があって食欲がないときインスリンはどうする?",
lines=2,
)
choice_radio = gr.Radio(label="Clarifying 選択肢", choices=[], visible=False)
with gr.Row():
send_btn = gr.Button("送信", variant="primary")
reset_btn = gr.Button("新しいセッション", variant="secondary")
with gr.Column(scale=2):
gr.Markdown(
"**注意**: 個人の診断や治療方針は必ず主治医に確認してください。"
"救急が疑われる場合はただちに医療機関へ。"
)
sources = gr.JSON(label="根拠になった節・ページ", value=[])
trace_view = gr.JSON(label="推論トレース", value={})
state = gr.State(init_state())
send_btn.click(
respond,
inputs=[user_input, chatbot, state, choice_radio],
outputs=[chatbot, state, user_input, choice_radio, sources, trace_view],
)
user_input.submit(
respond,
inputs=[user_input, chatbot, state, choice_radio],
outputs=[chatbot, state, user_input, choice_radio, sources, trace_view],
)
choice_radio.change(
respond,
inputs=[user_input, chatbot, state, choice_radio],
outputs=[chatbot, state, user_input, choice_radio, sources, trace_view],
)
reset_btn.click(reset, outputs=[chatbot, state, user_input, choice_radio, sources, trace_view])
if __name__ == "__main__":
launch_opts = {}
try:
launch_sig = inspect.signature(gr.Blocks.launch)
if "ssr_mode" in launch_sig.parameters:
launch_opts["ssr_mode"] = False
except Exception:
pass
demo.queue().launch(**launch_opts)