dsalwala commited on
Commit
d56db35
·
verified ·
1 Parent(s): 3707bf8

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +604 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,606 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
 
 
 
 
 
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import asyncio
2
+ import json
3
+ import os
4
+ import time
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ from typing import Dict, List
8
+
9
  import streamlit as st
10
+ from acp_sdk.client import Client
11
+ from acp_sdk.models import Message, MessagePart
12
+ from rich.console import Console
13
+
14
+ from gaf_guard.clients.stream_adaptors import get_adapter
15
+ from gaf_guard.core.models import WorkflowMessage
16
+ from gaf_guard.toolkit.enums import MessageType, Role, StreamStatus, UserInputType
17
+ from gaf_guard.toolkit.file_utils import resolve_file_paths
18
+
19
+
20
+ GAF_GUARD_ROOT = Path(__file__).parent.parent.absolute()
21
+
22
+ # Apply CSS to hide chat_input when app is running (processing)
23
+ st.markdown(
24
+ """
25
+ <style>
26
+ .header {
27
+ padding: 1rem;
28
+ background: linear-gradient(90deg, #667eea 0%, #764ba2 100%);
29
+ color: white;
30
+ text-align: center;
31
+ border-radius: 10px;
32
+ margin-bottom: 1rem;
33
+ }
34
+ .message-card {
35
+ padding: 1rem;
36
+ border-left: 4px solid #667eea;
37
+ background-color: #f8f9fa;
38
+ border-radius: 5px;
39
+ margin: 0.5rem 0;
40
+ }
41
+ .stApp[data-teststate=running] .stChatInput textarea,
42
+ .stApp[data-test-script-state=running] .stChatInput textarea {
43
+ display: none !important;
44
+ }
45
+ .stTextInput {{
46
+ position: fixed;
47
+ bottom: 3rem;
48
+ }}
49
+ .block-container {
50
+ padding-top: 1rem;
51
+ padding-bottom: 0rem;
52
+ padding-left: 5rem;
53
+ padding-right: 5rem;
54
+ }
55
+ </style>
56
+ """,
57
+ unsafe_allow_html=True,
58
+ )
59
+
60
+ # Declare global session variables
61
+ st.session_state.priority = ["low", "medium", "high"]
62
+ st.session_state.initial_risks_master = ["Toxic output", "Hallucination"]
63
+ st.set_page_config(
64
+ page_title="GAF Guard - A real-time monitoring system for risk assessment and drift monitoring.",
65
+ layout="wide", # This sets the app to wide mode
66
+ # initial_sidebar_state="expanded",
67
+ )
68
+ console = Console(log_time=True)
69
+ run_configs = {
70
+ "RiskGeneratorAgent": {
71
+ "risk_questionnaire_cot": os.path.join(
72
+ GAF_GUARD_ROOT, "chain_of_thought", "risk_questionnaire.json"
73
+ )
74
+ },
75
+ "DriftMonitoringAgent": {
76
+ "drift_monitoring_cot": os.path.join(
77
+ GAF_GUARD_ROOT, "chain_of_thought", "drift_monitoring.json"
78
+ ),
79
+ "drift_threshold": (
80
+ st.session_state.drift_threshold
81
+ if "drift_threshold" in st.session_state
82
+ else 8
83
+ ),
84
+ },
85
+ }
86
+ resolve_file_paths(run_configs)
87
+
88
+
89
+ def file_uploaded():
90
+ st.session_state.prompt_file = st.session_state.prompt_file_uploader.getvalue()
91
+ message = WorkflowMessage(
92
+ name="GAF Guard Client",
93
+ type=MessageType.GAF_GUARD_QUERY,
94
+ role=Role.SYSTEM,
95
+ content=f"**File uploaded successfully:** {st.session_state.prompt_file_uploader.name}",
96
+ accept=UserInputType.INPUT_PROMPT,
97
+ run_configs=run_configs,
98
+ )
99
+ st.session_state.messages.append(message)
100
+ render(message, simulate=True)
101
+
102
+
103
+ def play_button(adapter_type):
104
+ if st.session_state.setdefault(
105
+ "stream_adaptor",
106
+ get_adapter(
107
+ adapter_type,
108
+ config={"byte_data": st.session_state.prompt_file},
109
+ ),
110
+ ):
111
+ st.session_state.stream_status = StreamStatus.ACTIVE
112
+ else:
113
+ st.write("Selected adaptor is not available.")
114
+
115
+
116
+ def pause_button():
117
+ st.session_state.stream_status = StreamStatus.PAUSED
118
+ st.session_state.messages.append(
119
+ WorkflowMessage(
120
+ name="GAF Guard Client",
121
+ type=MessageType.GAF_GUARD_QUERY,
122
+ role=Role.SYSTEM,
123
+ content=f"**:red[Alert:]** Current input stream is paused. Please click on **Start** to resume.",
124
+ accept=UserInputType.INPUT_PROMPT,
125
+ run_configs=run_configs,
126
+ )
127
+ )
128
+
129
+
130
+ @st.fragment
131
+ def pause_fragment(adapter_type):
132
+ st.button(
133
+ "⏸️ Pause",
134
+ use_container_width=True,
135
+ disabled=(
136
+ adapter_type == "Select"
137
+ or st.session_state.stream_status
138
+ in [StreamStatus.STOPPED, StreamStatus.PAUSED]
139
+ ),
140
+ on_click=pause_button,
141
+ )
142
+
143
+
144
+ def add_sidebar():
145
+
146
+ with st.sidebar:
147
+ st.sidebar.title("⚙️ Settings")
148
+ if st.session_state.sidebar_display in ["settings_view", "input_prompt_source"]:
149
+ st.subheader(f":blue[Taxonomy:] {st.session_state.taxonomy}")
150
+ st.subheader(f":blue[Drift Threshold:] {st.session_state.drift_threshold}")
151
+ if st.session_state.sidebar_display == "settings_edit":
152
+ st.session_state.taxonomy = st.selectbox(
153
+ "Risk Taxonomy",
154
+ ("IBM Risk Atlas"),
155
+ )
156
+ st.session_state.drift_threshold = st.slider(
157
+ "Drift Threshold",
158
+ value=st.session_state.drift_threshold,
159
+ min_value=2,
160
+ max_value=10,
161
+ step=1,
162
+ )
163
+ if st.session_state.sidebar_display == "input_prompt_source":
164
+ st.sidebar.title("⚙️ Streaming Source")
165
+ adapter_type = st.selectbox(
166
+ "Select Input Prompt Source",
167
+ ["Select", "JSON"],
168
+ help="Choose your streaming source",
169
+ index=0,
170
+ disabled="stream_adaptor" in st.session_state,
171
+ )
172
+ if adapter_type == "JSON":
173
+ st.subheader("JSON File Source")
174
+ st.file_uploader(
175
+ "OK",
176
+ accept_multiple_files=False,
177
+ type="json",
178
+ label_visibility="collapsed",
179
+ on_change=file_uploaded,
180
+ key="prompt_file_uploader",
181
+ disabled="stream_adaptor" in st.session_state,
182
+ )
183
+
184
+ # Control buttons
185
+ col1, col2 = st.columns(2)
186
+ with col1:
187
+ st.button(
188
+ "▶️ Start",
189
+ use_container_width=True,
190
+ disabled=(
191
+ adapter_type == "Select"
192
+ or "prompt_file" not in st.session_state
193
+ or st.session_state.stream_status == StreamStatus.ACTIVE
194
+ ),
195
+ on_click=play_button,
196
+ args=(adapter_type,),
197
+ )
198
+ with col2:
199
+ pause_fragment(adapter_type)
200
+ st.markdown(
201
+ "**Note:** Pause button will temporarily halt the stream after processing the current prompt."
202
+ )
203
+
204
+ st.divider()
205
+ ai_atlas_button = st.container(
206
+ horizontal_alignment="center", vertical_alignment="bottom", height="stretch"
207
+ )
208
+ ai_atlas_button.markdown(":blue[Powered by:]", text_alignment="center")
209
+ ai_atlas_button.link_button(
210
+ "AI Atlas Nexus",
211
+ "https://github.com/IBM/ai-atlas-nexus",
212
+ icon=":material/thumb_up:",
213
+ type="secondary",
214
+ )
215
+ if hasattr(st.session_state, "client_session"):
216
+ ai_atlas_button.markdown(
217
+ f"Client Id: {str(st.session_state.client_session._session.id)[0:13]} \n :violet-badge[:material/rocket_launch: Connected to :yellow[GAF Guard] Server:] :orange-badge[:material/check: {st.session_state.host}:{st.session_state.port}]",
218
+ text_alignment="center",
219
+ )
220
+ else:
221
+ ai_atlas_button.markdown(
222
+ f":red-badge[:material/mimo_disconnect: Client Disconnected]",
223
+ text_alignment="center",
224
+ )
225
+
226
+
227
+ # render agent reponse from the server
228
+ def render(message: WorkflowMessage, simulate=False):
229
+
230
+ def simulate_agent_response(
231
+ role: Role,
232
+ message: str,
233
+ json_data: Dict = None,
234
+ simulate: bool = False,
235
+ accept: Dict = None,
236
+ ):
237
+ with st.chat_message(role):
238
+ if simulate:
239
+ message_placeholder = st.empty()
240
+ full_response = ""
241
+ for chunk in message.split():
242
+ full_response += chunk + " "
243
+ time.sleep(0.05)
244
+ message_placeholder.markdown(full_response + "▌")
245
+ message_placeholder.markdown(full_response)
246
+ else:
247
+ st.markdown(message)
248
+
249
+ if json_data:
250
+ st.json(json_data, expanded=4)
251
+ elif accept == UserInputType.INITIAL_RISKS:
252
+ st.button(
253
+ "Add Initial Risks",
254
+ on_click=initial_risks_selector,
255
+ disabled=hasattr(st.session_state, "initial_risks"),
256
+ )
257
+ st.session_state.disabled_input = False
258
+ elif accept == UserInputType.INPUT_PROMPT:
259
+ st.session_state.sidebar_display = "input_prompt_source"
260
+ st.session_state.disabled_input = True
261
+
262
+ if not message.display:
263
+ return False
264
+ if message.type == MessageType.GAF_GUARD_WF_STARTED:
265
+ return False
266
+ if message.type == MessageType.GAF_GUARD_WF_COMPLETED:
267
+ return False
268
+ elif message.type == MessageType.GAF_GUARD_STEP_STARTED:
269
+ simulate_agent_response(
270
+ role=message.role.value,
271
+ message=f"##### :blue[Workflow Step:] **{message.name}**",
272
+ simulate=simulate,
273
+ accept=message.accept,
274
+ )
275
+ elif message.type == MessageType.GAF_GUARD_STEP_COMPLETED:
276
+ # simulate_agent_response(
277
+ # role=message.role.value,
278
+ # message=f"##### :blue[Workflow Step:] **{message.name}** COMPLETED",
279
+ # simulate=simulate,
280
+ # accept=message.accept,
281
+ # )
282
+ return False
283
+ elif message.type == MessageType.GAF_GUARD_STEP_DATA:
284
+ if isinstance(message.content, dict):
285
+ if message.name == "Input Prompt":
286
+ simulate_agent_response(
287
+ role=message.role.value,
288
+ message=f"###### :yellow[**Prompt {message.content["prompt_index"]}**]: {message.content["prompt"]}",
289
+ simulate=simulate,
290
+ accept=message.accept,
291
+ )
292
+ else:
293
+ if len(message.content.items()) > 2:
294
+ data = []
295
+ for key, value in message.content.items():
296
+ data.append({key.title(): value})
297
+
298
+ simulate_agent_response(
299
+ role=message.role.value,
300
+ message="###### :yellow[Risk Report]",
301
+ json_data=data,
302
+ simulate=simulate,
303
+ accept=message.accept,
304
+ )
305
+ else:
306
+ for key, value in message.content.items():
307
+ if key == "identified_risks":
308
+ st.session_state.risks = value
309
+ if isinstance(value, List) or isinstance(value, Dict):
310
+ simulate_agent_response(
311
+ role=message.role.value,
312
+ message=f"###### :yellow[{key.replace('_', ' ').title()}]",
313
+ json_data=value,
314
+ simulate=simulate,
315
+ accept=message.accept,
316
+ )
317
+ elif isinstance(value, str) and key.endswith("alert"):
318
+ simulate_agent_response(
319
+ role=message.role.value,
320
+ message=f"###### :yellow[{key.replace('_', ' ').title()}]: :red[{value}]",
321
+ simulate=simulate,
322
+ accept=message.accept,
323
+ )
324
+ else:
325
+ simulate_agent_response(
326
+ role=message.role.value,
327
+ message=f"###### :yellow[{key.replace('_', ' ').title()}]: {value}",
328
+ simulate=simulate,
329
+ accept=message.accept,
330
+ )
331
+ elif message.type == MessageType.GAF_GUARD_QUERY:
332
+ simulate_agent_response(
333
+ role=message.role.value,
334
+ message=f":blue[{message.content}]",
335
+ simulate=simulate,
336
+ accept=message.accept,
337
+ )
338
+ else:
339
+ # raise Exception(f"Invalid message type: {message.type}")
340
+ if message.content:
341
+ simulate_agent_response(
342
+ role=message.role.value,
343
+ message=message.content,
344
+ simulate=simulate,
345
+ accept=message.accept,
346
+ )
347
+
348
+ return True
349
+
350
+
351
+ @st.dialog("Initial risks", width="medium")
352
+ def initial_risks_selector():
353
+
354
+ def add_row():
355
+ st.session_state.setdefault("initial_risks", {}).update(
356
+ {
357
+ str(len(st.session_state.initial_risks)): {
358
+ "risk": st.session_state.initial_risks_master[0],
359
+ "priority": "low",
360
+ "threshold": 0.01,
361
+ }
362
+ }
363
+ )
364
+
365
+ if "initial_risks" not in st.session_state:
366
+ add_row()
367
+
368
+ st.button("Add New Row", type="primary", on_click=add_row)
369
+ with st.form("input_form"):
370
+
371
+ # Create columns for the form inputs
372
+ col1, col2, col3 = st.columns(3)
373
+
374
+ for key, initial_risk in st.session_state.initial_risks.items():
375
+ with col1:
376
+ value = st.selectbox(
377
+ "Risk" if key == "0" else " ",
378
+ tuple(st.session_state.initial_risks_master),
379
+ key=f"col1{key}",
380
+ index=st.session_state.initial_risks_master.index(
381
+ initial_risk["risk"]
382
+ ),
383
+ )
384
+ st.session_state.initial_risks[key].update({"risk": value})
385
+ with col2:
386
+ value = st.selectbox(
387
+ "Priority" if key == "0" else " ",
388
+ tuple(st.session_state.priority),
389
+ key=f"col2{key}",
390
+ index=st.session_state.priority.index(initial_risk["priority"]),
391
+ )
392
+ st.session_state.initial_risks[key].update({"priority": value})
393
+ with col3:
394
+ threshold = st.number_input(
395
+ "Threshold" if key == "0" else " ",
396
+ key=f"col3{key}",
397
+ value=initial_risk["threshold"],
398
+ )
399
+ st.session_state.initial_risks[key].update({"threshold": threshold})
400
+
401
+ submitted = st.form_submit_button("Submit")
402
+
403
+ if submitted:
404
+ st.session_state.user_input = json.dumps(
405
+ list(st.session_state.initial_risks.values())
406
+ )
407
+ st.rerun()
408
+
409
+
410
+ @st.dialog(
411
+ "GAF Guard Connect",
412
+ width="medium",
413
+ dismissible=False,
414
+ icon=":material/login:",
415
+ )
416
+ def connect_screen_dialog():
417
+ if hasattr(st.session_state, "error"):
418
+ st.error(st.session_state.error, icon="🚨")
419
+ with st.form("login_form"):
420
+ input_host = st.text_input("GAF Guard Host", value="localhost")
421
+ input_port = st.number_input("GAF Guard Port", value=8000)
422
+ submitted = st.form_submit_button("Connect", type="primary")
423
+
424
+ if submitted:
425
+ if hasattr(st.session_state, "error"):
426
+ del st.session_state["error"]
427
+ st.session_state.host = input_host
428
+ st.session_state.port = input_port
429
+ st.rerun()
430
+
431
+
432
+ @st.dialog(
433
+ "GAF Guard Connect",
434
+ width="medium",
435
+ dismissible=False,
436
+ icon=":material/login:",
437
+ )
438
+ def connect():
439
+
440
+ async def ping_server(client):
441
+ await client.ping()
442
+
443
+ with st.status(
444
+ f"Connecting to GAF Guard using host: :blue[**{st.session_state.host}**] and port: :blue[**{st.session_state.port}**]",
445
+ expanded=True,
446
+ ) as status:
447
+ try:
448
+ client = Client(
449
+ base_url=f"http://{st.session_state.host}:{st.session_state.port}",
450
+ verify=True,
451
+ )
452
+ # asyncio.run(ping_server(client))
453
+ st.write("Client created...")
454
+ except Exception as e:
455
+ st.session_state.error = "Failed to connect. Check hostname and port."
456
+ st.rerun()
457
+
458
+ st.session_state.client_session = client.session()
459
+ st.write("Client session created...")
460
+
461
+ st.session_state.drift_threshold = 8
462
+ st.session_state.disabled_input = False
463
+ st.session_state.stream_status = StreamStatus.STOPPED
464
+ st.session_state.sidebar_display = "settings_edit"
465
+ st.session_state.messages = [
466
+ WorkflowMessage(
467
+ name="GAF Guard Client",
468
+ type=MessageType.CLIENT_INPUT,
469
+ role=Role.USER,
470
+ accept=UserInputType.USER_INTENT,
471
+ run_configs=run_configs,
472
+ )
473
+ ]
474
+ st.write("Client initialisation done...")
475
+
476
+ # print information in the client console window
477
+ console.print(
478
+ f"[[bold white]{datetime.now().strftime('%d-%m-%Y %H:%M:%S')}[/]] [italic bold white] :rocket: Connected to GAF Guard Server at[/italic bold white] [bold white]{st.session_state.host}:{st.session_state.port}[/bold white]"
479
+ )
480
+ console.print(
481
+ f"[[bold white]{datetime.now().strftime('%d-%m-%Y %H:%M:%S')}[/]] Client Id: {st.session_state.client_session._session.id}"
482
+ )
483
+ # console.print(
484
+ # f"""
485
+ # You can now view your Streamlit app in your browser.
486
+
487
+ # Local URL: http://{st.session_state.host}:{st.session_state.port}
488
+ # """
489
+ # )
490
+
491
+ status.update(
492
+ label=f":material/rocket_launch: Connected to :yellow[**GAF Guard**] Server: :orange-badge[:material/check: {st.session_state.host}:{st.session_state.port}]",
493
+ state="complete",
494
+ expanded=True,
495
+ )
496
+ time.sleep(1)
497
+
498
+ st.rerun()
499
+
500
+
501
+ def submit_input():
502
+ st.session_state.sidebar_display = "settings_view"
503
+
504
+
505
+ async def app():
506
+
507
+ st.title(f":yellow[GAF Guard]", text_alignment="center")
508
+ st.subheader(
509
+ "A real-time monitoring system for risk assessment and drift monitoring",
510
+ text_alignment="center",
511
+ divider=True,
512
+ )
513
+
514
+ # add sidebar and related components
515
+ add_sidebar()
516
+
517
+ # Display chat messages from history
518
+ for message in st.session_state.messages:
519
+ render(message)
520
+
521
+ last_message: WorkflowMessage = st.session_state.messages[-1]
522
+
523
+ if st.session_state.stream_status == StreamStatus.ACTIVE:
524
+ user_input = st.session_state.stream_adaptor.next()
525
+ if not user_input:
526
+ del st.session_state["stream_adaptor"]
527
+ st.session_state.stream_status = StreamStatus.STOPPED
528
+ st.session_state.messages.append(
529
+ WorkflowMessage(
530
+ name="GAF Guard Client",
531
+ type=MessageType.GAF_GUARD_QUERY,
532
+ role=Role.SYSTEM,
533
+ content=f"**The streaming input has ended. Please choose a streaming source and start again.**",
534
+ accept=UserInputType.INPUT_PROMPT,
535
+ run_configs=run_configs,
536
+ )
537
+ )
538
+ st.rerun()
539
+ else:
540
+ # Accept user input
541
+ user_input = st.chat_input(
542
+ placeholder="Enter your response here",
543
+ key="user_input",
544
+ disabled=st.session_state.disabled_input,
545
+ on_submit=submit_input,
546
+ )
547
+
548
+ if not user_input:
549
+ st.stop()
550
+ else:
551
+ COMPLETED = False
552
+ async for event in st.session_state.client_session.run_stream(
553
+ agent="orchestrator",
554
+ input=[
555
+ Message(
556
+ parts=[
557
+ MessagePart(
558
+ content=WorkflowMessage(
559
+ name="GAF Guard Client",
560
+ type=(
561
+ MessageType.CLIENT_RESPONSE
562
+ if last_message.type == MessageType.GAF_GUARD_QUERY
563
+ else MessageType.CLIENT_INPUT
564
+ ),
565
+ role=Role.USER,
566
+ content={last_message.accept: user_input},
567
+ run_configs=run_configs,
568
+ ).model_dump_json(),
569
+ content_type="text/plain",
570
+ )
571
+ ]
572
+ )
573
+ ],
574
+ ):
575
+ if event.type == "message.part":
576
+ message = WorkflowMessage(**json.loads(event.part.content))
577
+ if render(message, simulate=True):
578
+ st.session_state.messages.append(message)
579
+ elif event.type == "run.awaiting":
580
+ if hasattr(event, "run"):
581
+ message = WorkflowMessage(
582
+ **json.loads(event.run.await_request.message.parts[0].content)
583
+ )
584
+ if message.accept == UserInputType.INPUT_PROMPT:
585
+ if st.session_state.stream_status == StreamStatus.STOPPED:
586
+ render(message, simulate=True)
587
+ else:
588
+ message.display = False
589
+ else:
590
+ render(message, simulate=True)
591
+
592
+ st.session_state.messages.append(message)
593
+ st.session_state.disabled_input = True
594
+ st.rerun()
595
+
596
 
597
+ if hasattr(st.session_state, "client_session"):
598
+ asyncio.run(app())
599
+ elif (
600
+ not hasattr(st.session_state, "error")
601
+ and hasattr(st.session_state, "host")
602
+ and hasattr(st.session_state, "port")
603
+ ):
604
+ connect()
605
+ else:
606
+ connect_screen_dialog()