KS00Max commited on
Commit
0aeb8c8
ยท
1 Parent(s): 706cf7c

fixed for demo

Browse files
Files changed (3) hide show
  1. app.py +150 -27
  2. core/conversation.py +23 -6
  3. core/state.py +1 -0
app.py CHANGED
@@ -1,3 +1,4 @@
 
1
  import logging
2
  import uuid
3
  from typing import List
@@ -8,6 +9,9 @@ from core.conversation import ConversationEngine
8
 
9
  logging.basicConfig(level=logging.INFO)
10
 
 
 
 
11
  try:
12
  engine = ConversationEngine()
13
  engine_error = None
@@ -19,7 +23,7 @@ except Exception as exc: # noqa: BLE001
19
 
20
  def init_state():
21
  sid = engine.session_manager.new_session_id() if engine else str(uuid.uuid4())
22
- return {"session_id": sid, "pending": False}
23
 
24
 
25
  def _normalize_choice(choice_value: str | None) -> str | None:
@@ -29,33 +33,55 @@ def _normalize_choice(choice_value: str | None) -> str | None:
29
  return value.split(":")[0].strip() if ":" in value else value
30
 
31
 
32
- def _ensure_message_history(history: List | None) -> list[dict[str, str]]:
33
- """Convert legacy tuple chat history into Gradio message format."""
34
  if not history:
35
  return []
36
 
37
- messages: list[dict[str, str]] = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  for item in history:
39
- if isinstance(item, dict) and "role" in item and "content" in item:
40
- messages.append({"role": str(item["role"]), "content": str(item["content"])})
41
- elif isinstance(item, (list, tuple)) and len(item) == 2:
42
- user, assistant = item
43
- if user is not None:
44
- messages.append({"role": "user", "content": str(user)})
45
- if assistant is not None:
46
- messages.append({"role": "assistant", "content": str(assistant)})
47
- return messages
48
-
49
-
50
- def _append_exchange(
51
- history: list[dict[str, str]], user_text: str | None, assistant_text: str | None
52
- ) -> list[dict[str, str]]:
53
- updated = list(history)
54
- if user_text is not None:
55
- updated.append({"role": "user", "content": user_text})
56
- if assistant_text is not None:
57
- updated.append({"role": "assistant", "content": assistant_text})
58
- return updated
 
 
 
 
 
 
 
59
 
60
 
61
  def respond(
@@ -68,7 +94,14 @@ def respond(
68
  session_id = state.get("session_id") or init_state()["session_id"]
69
  state["session_id"] = session_id
70
 
71
- history = _ensure_message_history(chat_history)
 
 
 
 
 
 
 
72
 
73
  if engine is None:
74
  reply = f"ๅˆๆœŸๅŒ–ใ‚จใƒฉใƒผ: {engine_error}. OPENAI_API_KEY ใ‚’็ขบ่ชใ—ใฆใใ ใ•ใ„ใ€‚"
@@ -95,8 +128,23 @@ def respond(
95
  result = engine.handle_user_message(session_id, user_message)
96
  user_bubble = user_message
97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
  if result.get("type") == "clarify":
99
  state["pending"] = True
 
100
  options = [f"{c.id}: {c.text}" for c in result["question"].choices]
101
  trace = result.get("trace", {})
102
  history = _append_exchange(history, user_bubble, result["reply"])
@@ -111,6 +159,7 @@ def respond(
111
 
112
  if result.get("type") == "answer":
113
  state["pending"] = False
 
114
  citations = result.get("citations", [])
115
  trace = result.get("trace", {})
116
  history = _append_exchange(history, user_bubble, result["reply"])
@@ -126,6 +175,7 @@ def respond(
126
  # error fallback
127
  history = _append_exchange(history, user_bubble, result.get("reply", "ใ‚จใƒฉใƒผใŒ็™บ็”Ÿใ—ใพใ—ใŸใ€‚"))
128
  state["pending"] = False
 
129
  trace = result.get("trace", {})
130
  return history, state, "", gr.update(choices=[], value=None, visible=False), citations, trace
131
 
@@ -135,6 +185,62 @@ def reset():
135
  return [], new_state, "", gr.update(choices=[], value=None, visible=False), [], {}
136
 
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  with gr.Blocks() as demo:
139
  gr.HTML(
140
  """
@@ -153,7 +259,10 @@ with gr.Blocks() as demo:
153
  )
154
  with gr.Row():
155
  with gr.Column(scale=3):
156
- chatbot = gr.Chatbot(label="ๅฏพ่ฉฑ", height=520)
 
 
 
157
  user_input = gr.Textbox(
158
  label="่ณชๅ•ใ‚’ๅ…ฅๅŠ›",
159
  placeholder="ไพ‹: ็†ฑใŒใ‚ใฃใฆ้ฃŸๆฌฒใŒใชใ„ใจใใ‚คใƒณใ‚นใƒชใƒณใฏใฉใ†ใ™ใ‚‹๏ผŸ",
@@ -162,6 +271,7 @@ with gr.Blocks() as demo:
162
  choice_radio = gr.Radio(label="Clarifying ้ธๆŠž่‚ข", choices=[], visible=False)
163
  with gr.Row():
164
  send_btn = gr.Button("้€ไฟก", variant="primary")
 
165
  reset_btn = gr.Button("ๆ–ฐใ—ใ„ใ‚ปใƒƒใ‚ทใƒงใƒณ", variant="secondary")
166
  with gr.Column(scale=2):
167
  gr.Markdown(
@@ -183,8 +293,21 @@ with gr.Blocks() as demo:
183
  inputs=[user_input, chatbot, state, choice_radio],
184
  outputs=[chatbot, state, user_input, choice_radio, sources, trace_view],
185
  )
 
 
 
 
 
186
  reset_btn.click(reset, outputs=[chatbot, state, user_input, choice_radio, sources, trace_view])
187
 
188
 
189
  if __name__ == "__main__":
190
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
  import logging
3
  import uuid
4
  from typing import List
 
9
 
10
  logging.basicConfig(level=logging.INFO)
11
 
12
+ # Gradio 3.x and 4.x differ in Chatbot payload formats.
13
+ _CHATBOT_SUPPORTS_MESSAGES = "type" in inspect.signature(gr.Chatbot.__init__).parameters
14
+
15
  try:
16
  engine = ConversationEngine()
17
  engine_error = None
 
23
 
24
  def init_state():
25
  sid = engine.session_manager.new_session_id() if engine else str(uuid.uuid4())
26
+ return {"session_id": sid, "pending": False, "triage_hold": False}
27
 
28
 
29
  def _normalize_choice(choice_value: str | None) -> str | None:
 
33
  return value.split(":")[0].strip() if ":" in value else value
34
 
35
 
36
+ def _ensure_chat_history(history: List | None):
37
+ """Normalize chat history for the running Gradio Chatbot."""
38
  if not history:
39
  return []
40
 
41
+ if _CHATBOT_SUPPORTS_MESSAGES:
42
+ messages: list[dict[str, str]] = []
43
+ for item in history:
44
+ if isinstance(item, dict) and "role" in item and "content" in item:
45
+ messages.append({"role": str(item["role"]), "content": str(item["content"])})
46
+ elif isinstance(item, (list, tuple)) and len(item) == 2:
47
+ user, assistant = item
48
+ if user is not None:
49
+ messages.append({"role": "user", "content": str(user)})
50
+ if assistant is not None:
51
+ messages.append({"role": "assistant", "content": str(assistant)})
52
+ return messages
53
+
54
+ # Gradio 3.x tuple format
55
+ pairs: list[tuple[str | None, str | None]] = []
56
+ pending_user: str | None = None
57
  for item in history:
58
+ if isinstance(item, (list, tuple)) and len(item) == 2:
59
+ pairs.append((item[0], item[1]))
60
+ elif isinstance(item, dict) and "role" in item and "content" in item:
61
+ role = item["role"]
62
+ if role == "user":
63
+ pending_user = str(item["content"])
64
+ elif role == "assistant":
65
+ pairs.append((pending_user, str(item["content"])))
66
+ pending_user = None
67
+ if pending_user is not None:
68
+ pairs.append((pending_user, None))
69
+ return pairs
70
+
71
+
72
+ def _append_exchange(history, user_text: str | None, assistant_text: str | None):
73
+ """Append one exchange in the correct Chatbot format."""
74
+ if _CHATBOT_SUPPORTS_MESSAGES:
75
+ updated = list(history)
76
+ if user_text is not None:
77
+ updated.append({"role": "user", "content": user_text})
78
+ if assistant_text is not None:
79
+ updated.append({"role": "assistant", "content": assistant_text})
80
+ return updated
81
+
82
+ updated_pairs = list(history)
83
+ updated_pairs.append((user_text, assistant_text))
84
+ return updated_pairs
85
 
86
 
87
  def respond(
 
94
  session_id = state.get("session_id") or init_state()["session_id"]
95
  state["session_id"] = session_id
96
 
97
+ history = _ensure_chat_history(chat_history)
98
+ triage_hold = bool(state.get("triage_hold", False))
99
+
100
+ if triage_hold:
101
+ warn = "็ทŠๆ€ฅๆ€งใฎ่ญฆๅ‘ŠใŒ่กจ็คบใ•ใ‚Œใฆใ„ใพใ™ใ€‚ใ€Œ่ญฆๅ‘Šใ‚’ๆ‰ฟ็Ÿฅใ—ใฆ็ถšใ‘ใ‚‹ใ€ใ‚’ๆŠผใ™ใจ็ถš่กŒใ—ใพใ™ใ€‚"
102
+ if user_message:
103
+ history = _append_exchange(history, user_message, warn)
104
+ return history, state, "", gr.update(choices=[], value=None, visible=False), [], {}
105
 
106
  if engine is None:
107
  reply = f"ๅˆๆœŸๅŒ–ใ‚จใƒฉใƒผ: {engine_error}. OPENAI_API_KEY ใ‚’็ขบ่ชใ—ใฆใใ ใ•ใ„ใ€‚"
 
128
  result = engine.handle_user_message(session_id, user_message)
129
  user_bubble = user_message
130
 
131
+ if result.get("type") == "triage_warning":
132
+ state["triage_hold"] = True
133
+ state["pending"] = False
134
+ trace = result.get("trace", {})
135
+ history = _append_exchange(history, user_bubble, result["reply"])
136
+ return (
137
+ history,
138
+ state,
139
+ "",
140
+ gr.update(choices=[], value=None, visible=False),
141
+ [],
142
+ trace,
143
+ )
144
+
145
  if result.get("type") == "clarify":
146
  state["pending"] = True
147
+ state["triage_hold"] = False
148
  options = [f"{c.id}: {c.text}" for c in result["question"].choices]
149
  trace = result.get("trace", {})
150
  history = _append_exchange(history, user_bubble, result["reply"])
 
159
 
160
  if result.get("type") == "answer":
161
  state["pending"] = False
162
+ state["triage_hold"] = False
163
  citations = result.get("citations", [])
164
  trace = result.get("trace", {})
165
  history = _append_exchange(history, user_bubble, result["reply"])
 
175
  # error fallback
176
  history = _append_exchange(history, user_bubble, result.get("reply", "ใ‚จใƒฉใƒผใŒ็™บ็”Ÿใ—ใพใ—ใŸใ€‚"))
177
  state["pending"] = False
178
+ state["triage_hold"] = False
179
  trace = result.get("trace", {})
180
  return history, state, "", gr.update(choices=[], value=None, visible=False), citations, trace
181
 
 
185
  return [], new_state, "", gr.update(choices=[], value=None, visible=False), [], {}
186
 
187
 
188
+ def continue_after_warning(
189
+ chat_history: List[dict[str, str]] | List[tuple[str, str]],
190
+ app_state: dict,
191
+ ):
192
+ state = app_state or init_state()
193
+ session_id = state.get("session_id") or init_state()["session_id"]
194
+ state["session_id"] = session_id
195
+
196
+ history = _ensure_chat_history(chat_history)
197
+
198
+ if engine is None:
199
+ reply = f"ๅˆๆœŸๅŒ–ใ‚จใƒฉใƒผ: {engine_error}. OPENAI_API_KEY ใ‚’็ขบ่ชใ—ใฆใใ ใ•ใ„ใ€‚"
200
+ history = _append_exchange(history, None, reply)
201
+ state["triage_hold"] = False
202
+ return history, state, "", gr.update(choices=[], value=None, visible=False), [], {}
203
+
204
+ result = engine.continue_after_warning(session_id)
205
+
206
+ if result.get("type") == "clarify":
207
+ state["pending"] = True
208
+ state["triage_hold"] = False
209
+ options = [f"{c.id}: {c.text}" for c in result["question"].choices]
210
+ trace = result.get("trace", {})
211
+ history = _append_exchange(history, None, result["reply"])
212
+ return (
213
+ history,
214
+ state,
215
+ "",
216
+ gr.update(choices=options, value=None, visible=True),
217
+ [],
218
+ trace,
219
+ )
220
+
221
+ if result.get("type") == "answer":
222
+ state["pending"] = False
223
+ state["triage_hold"] = False
224
+ citations = result.get("citations", [])
225
+ trace = result.get("trace", {})
226
+ history = _append_exchange(history, None, result["reply"])
227
+ return (
228
+ history,
229
+ state,
230
+ "",
231
+ gr.update(choices=[], value=None, visible=False),
232
+ citations,
233
+ trace,
234
+ )
235
+
236
+ # error or unexpected
237
+ history = _append_exchange(history, None, result.get("reply", "ใ‚จใƒฉใƒผใŒ็™บ็”Ÿใ—ใพใ—ใŸใ€‚"))
238
+ state["pending"] = False
239
+ state["triage_hold"] = False
240
+ trace = result.get("trace", {})
241
+ return history, state, "", gr.update(choices=[], value=None, visible=False), [], trace
242
+
243
+
244
  with gr.Blocks() as demo:
245
  gr.HTML(
246
  """
 
259
  )
260
  with gr.Row():
261
  with gr.Column(scale=3):
262
+ chatbot_kwargs = {"label": "ๅฏพ่ฉฑ", "height": 520}
263
+ if _CHATBOT_SUPPORTS_MESSAGES:
264
+ chatbot_kwargs["type"] = "messages"
265
+ chatbot = gr.Chatbot(**chatbot_kwargs)
266
  user_input = gr.Textbox(
267
  label="่ณชๅ•ใ‚’ๅ…ฅๅŠ›",
268
  placeholder="ไพ‹: ็†ฑใŒใ‚ใฃใฆ้ฃŸๆฌฒใŒใชใ„ใจใใ‚คใƒณใ‚นใƒชใƒณใฏใฉใ†ใ™ใ‚‹๏ผŸ",
 
271
  choice_radio = gr.Radio(label="Clarifying ้ธๆŠž่‚ข", choices=[], visible=False)
272
  with gr.Row():
273
  send_btn = gr.Button("้€ไฟก", variant="primary")
274
+ continue_btn = gr.Button("่ญฆๅ‘Šใ‚’ๆ‰ฟ็Ÿฅใ—ใฆ็ถšใ‘ใ‚‹", variant="secondary")
275
  reset_btn = gr.Button("ๆ–ฐใ—ใ„ใ‚ปใƒƒใ‚ทใƒงใƒณ", variant="secondary")
276
  with gr.Column(scale=2):
277
  gr.Markdown(
 
293
  inputs=[user_input, chatbot, state, choice_radio],
294
  outputs=[chatbot, state, user_input, choice_radio, sources, trace_view],
295
  )
296
+ continue_btn.click(
297
+ continue_after_warning,
298
+ inputs=[chatbot, state],
299
+ outputs=[chatbot, state, user_input, choice_radio, sources, trace_view],
300
+ )
301
  reset_btn.click(reset, outputs=[chatbot, state, user_input, choice_radio, sources, trace_view])
302
 
303
 
304
  if __name__ == "__main__":
305
+ launch_opts = {}
306
+ try:
307
+ launch_sig = inspect.signature(gr.Blocks.launch)
308
+ if "ssr_mode" in launch_sig.parameters:
309
+ launch_opts["ssr_mode"] = False # disable experimental SSR to avoid shutdown quirks
310
+ except Exception:
311
+ pass
312
+
313
+ demo.queue().launch(**launch_opts)
core/conversation.py CHANGED
@@ -82,7 +82,13 @@ class ConversationEngine:
82
  )
83
  return cites
84
 
85
- def _run_pipeline(self, session: SessionState, user_input: str, append_user: bool) -> Dict:
 
 
 
 
 
 
86
  trace: Dict[str, object] = {"state": dict(session.patient_state)}
87
  if append_user:
88
  session.messages.append({"role": "user", "content": user_input})
@@ -91,14 +97,17 @@ class ConversationEngine:
91
  # Step 3: emergency triage
92
  emergency, label = self.llm.triage_emergency(session.messages)
93
  trace["triage"] = label
94
- if emergency:
95
  reply = (
96
- "็ทŠๆ€ฅๆ€งใŒ็–‘ใ‚ใ‚Œใพใ™ใ€‚ๆ„่ญ˜ใŒใผใ‚“ใ‚„ใ‚Šใ™ใ‚‹ใ€ใ‘ใ„ใ‚Œใ‚“ใ€"
97
- "ๅผทใ„ๆฏๅˆ‡ใ‚Œใ‚„่ƒธ็—›ใŒใ‚ใ‚‹ๅ ดๅˆใฏใ€ใŸใ ใกใซๆ•‘ๆ€ฅ่ฆ่ซ‹ใ‚„ๅ—่จบใ‚’ใ—ใฆใใ ใ•ใ„ใ€‚"
98
- "ใ”ๆœฌไบบใŒๅฃใ‹ใ‚‰ๆ‘‚ๅ–ใงใใชใ„ๅ ดๅˆใ‚‚ๅŒๆง˜ใงใ™ใ€‚"
99
  )
100
  session.messages.append({"role": "assistant", "content": reply})
101
- return {"type": "answer", "reply": reply, "citations": [], "trace": trace}
 
 
 
102
 
103
  # Step 4: PageIndex search
104
  query_text = self._build_query(user_input, session.patient_state, None)
@@ -176,6 +185,14 @@ class ConversationEngine:
176
  session = self.session_manager.get(session_id)
177
  return self._run_pipeline(session, user_input, append_user=True)
178
 
 
 
 
 
 
 
 
 
179
  def handle_clarifying_answer(self, session_id: str, choice_id: str) -> Dict:
180
  session = self.session_manager.get(session_id)
181
  pending = session.pending_clarifying
 
82
  )
83
  return cites
84
 
85
+ def _run_pipeline(
86
+ self,
87
+ session: SessionState,
88
+ user_input: str,
89
+ append_user: bool,
90
+ force_continue: bool = False,
91
+ ) -> Dict:
92
  trace: Dict[str, object] = {"state": dict(session.patient_state)}
93
  if append_user:
94
  session.messages.append({"role": "user", "content": user_input})
 
97
  # Step 3: emergency triage
98
  emergency, label = self.llm.triage_emergency(session.messages)
99
  trace["triage"] = label
100
+ if emergency and not force_continue:
101
  reply = (
102
+ "็ทŠๆ€ฅๆ€งใŒ็–‘ใ‚ใ‚Œใพใ™ใ€‚ๆ„่ญ˜ใŒใผใ‚“ใ‚„ใ‚Šใ™ใ‚‹ใ€ใ‘ใ„ใ‚Œใ‚“ใ€ๅผทใ„ๆฏๅˆ‡ใ‚Œใ‚„่ƒธ็—›ใŒใ‚ใ‚‹ๅ ดๅˆใฏใ€"
103
+ "ใŸใ ใกใซๆ•‘ๆ€ฅ่ฆ่ซ‹ใ‚„ๅ—่จบใ‚’ใ—ใฆใใ ใ•ใ„ใ€‚ใ”ๆœฌไบบใŒๅฃใ‹ใ‚‰ๆ‘‚ๅ–ใงใใชใ„ๅ ดๅˆใ‚‚ๅŒๆง˜ใงใ™ใ€‚"
104
+ "\n\n่ญฆๅ‘Šใ‚’็ขบ่ชใ—ใ€็ถšใ‘ใฆ็ขบ่ชใ™ใ‚‹ๅ ดๅˆใฏ็”ป้ขไธ‹ใฎใ€Œ่ญฆๅ‘Šใ‚’ๆ‰ฟ็Ÿฅใ—ใฆ็ถšใ‘ใ‚‹ใ€ใ‚’ๆŠผใ—ใฆใใ ใ•ใ„ใ€‚"
105
  )
106
  session.messages.append({"role": "assistant", "content": reply})
107
+ session.pending_emergency = True
108
+ return {"type": "triage_warning", "reply": reply, "citations": [], "trace": trace}
109
+
110
+ session.pending_emergency = False
111
 
112
  # Step 4: PageIndex search
113
  query_text = self._build_query(user_input, session.patient_state, None)
 
185
  session = self.session_manager.get(session_id)
186
  return self._run_pipeline(session, user_input, append_user=True)
187
 
188
+ def continue_after_warning(self, session_id: str) -> Dict:
189
+ session = self.session_manager.get(session_id)
190
+ if not session.pending_emergency:
191
+ return {"type": "error", "reply": "็พๅœจ่กจ็คบไธญใฎ็ทŠๆ€ฅๆ€ง่ญฆๅ‘Šใฏใ‚ใ‚Šใพใ›ใ‚“ใ€‚"}
192
+ if not session.last_user_query:
193
+ return {"type": "error", "reply": "ๅ‰ใฎ่ณชๅ•ใŒ่ฆ‹ใคใ‹ใ‚Šใพใ›ใ‚“ใ€‚ๅ†ๅบฆ่ณชๅ•ใ—ใฆใใ ใ•ใ„ใ€‚"}
194
+ return self._run_pipeline(session, session.last_user_query, append_user=False, force_continue=True)
195
+
196
  def handle_clarifying_answer(self, session_id: str, choice_id: str) -> Dict:
197
  session = self.session_manager.get(session_id)
198
  pending = session.pending_clarifying
core/state.py CHANGED
@@ -14,6 +14,7 @@ class SessionState:
14
  patient_state: Dict[str, object] = field(default_factory=dict)
15
  last_user_query: str | None = None
16
  pending_clarifying: dict | None = None # {"question_id": str, "choice_map": {...}}
 
17
 
18
 
19
  class SessionManager:
 
14
  patient_state: Dict[str, object] = field(default_factory=dict)
15
  last_user_query: str | None = None
16
  pending_clarifying: dict | None = None # {"question_id": str, "choice_map": {...}}
17
+ pending_emergency: bool = False # triage warning state
18
 
19
 
20
  class SessionManager: