Spaces:
Running
Running
| """ | |
| Custom Gradio UI for the AWM environment. | |
| """ | |
| from __future__ import annotations | |
| import asyncio | |
| import inspect | |
| import json | |
| import os | |
| from typing import Any | |
| import gradio as gr | |
| from openenv.core.env_server.serialization import serialize_observation | |
| from .data_loader import AWMDataLoader | |
| from .prompts import DEFAULT_SYSTEM_PROMPT | |
| from .web_agent import AwmAgent | |
| from .config import DEFAULT_REWARD_CONFIG | |
| # Keep in sync with DEFAULT_REWARD_CONFIG in config.py. | |
| _DEFAULT_REWARD_JSON = json.dumps( | |
| DEFAULT_REWARD_CONFIG, indent=2 | |
| ) | |
| def _format_obs_md(payload: dict | None) -> str: | |
| if not payload: | |
| return "*No observation yet.*" | |
| obs = payload.get("observation") if isinstance(payload, dict) else None | |
| if obs is None: | |
| obs = payload | |
| reward = payload.get("reward") if isinstance(payload, dict) else None | |
| done = payload.get("done") if isinstance(payload, dict) else None | |
| lines: list[str] = [] | |
| if reward is not None: | |
| lines.append(f"**reward**: `{reward}`") | |
| if done is not None: | |
| lines.append(f"**done**: `{done}`") | |
| if isinstance(obs, dict): | |
| for key in ( | |
| "reward_type", | |
| "scenario", | |
| "task", | |
| "task_idx", | |
| "num_tools", | |
| "tool_name", | |
| "error", | |
| "warning", | |
| ): | |
| v = obs.get(key) | |
| if v is not None and v != "": | |
| lines.append(f"**{key}**: `{v}`") | |
| if obs.get("tool_result") is not None: | |
| tr = obs["tool_result"] | |
| tr_text = ( | |
| tr if isinstance(tr, str) else json.dumps(tr, indent=2, default=str) | |
| ) | |
| if len(tr_text) > 2000: | |
| tr_text = tr_text[:2000] + "\n... (truncated)" | |
| lines.append("\n**tool_result:**") | |
| lines.append(f"```\n{tr_text}\n```") | |
| if obs.get("verify_result"): | |
| vr = obs["verify_result"] | |
| vr_text = json.dumps(vr, indent=2, default=str) | |
| if len(vr_text) > 2000: | |
| vr_text = vr_text[:2000] + "\n... (truncated)" | |
| lines.append("\n**verify_result:**") | |
| lines.append(f"```json\n{vr_text}\n```") | |
| if obs.get("trajectory_path"): | |
| lines.append(f"\n**trajectory_path**: `{obs['trajectory_path']}`") | |
| return "\n\n".join(lines) if lines else "*Empty observation.*" | |
| def _make_args_template(input_schema: dict | None) -> str: | |
| if not input_schema or not isinstance(input_schema, dict): | |
| return "{}" | |
| props = input_schema.get("properties") or {} | |
| template: dict[str, Any] = {} | |
| for name, info in props.items(): | |
| ty = (info or {}).get("type", "string") | |
| template[name] = { | |
| "string": "", | |
| "integer": 0, | |
| "number": 0.0, | |
| "boolean": False, | |
| "array": [], | |
| "object": {}, | |
| }.get(ty, None) | |
| return json.dumps(template, indent=2) | |
| def build_awm_gradio_app( | |
| web_manager: Any, | |
| action_fields: list[dict] | None = None, | |
| metadata: Any = None, | |
| is_chat_env: bool = False, | |
| title: str = "AWM Environment", | |
| quick_start_md: str | None = None, | |
| ) -> gr.Blocks: | |
| data_loader = AWMDataLoader(cache_dir=os.environ.get("AWM_DATA_DIR")) | |
| readme_md = "" | |
| if metadata is not None and getattr(metadata, "readme_content", None): | |
| readme_md = metadata.readme_content | |
| # openenv-core 0.2.3 added a ``reset_kwargs`` parameter to | |
| # ``WebInterfaceManager.reset_environment``. PyPI's 0.2.1 takes no args, | |
| # which silently drops scenario/task_idx — fall back to calling env.reset | |
| # directly and replicate the episode-state updates the manager would do. | |
| _reset_env_supports_kwargs = ( | |
| len( | |
| [ | |
| p | |
| for p in inspect.signature( | |
| web_manager.reset_environment | |
| ).parameters.values() | |
| if p.name != "self" | |
| ] | |
| ) | |
| > 0 | |
| ) | |
| async def _safe_reset(reset_kwargs: dict[str, Any]) -> dict[str, Any]: | |
| if _reset_env_supports_kwargs: | |
| return await web_manager.reset_environment(reset_kwargs) | |
| env = web_manager.env | |
| params = inspect.signature(env.reset).parameters | |
| has_var_kw = any( | |
| p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() | |
| ) | |
| valid = {k: v for k, v in reset_kwargs.items() if has_var_kw or k in params} | |
| loop = asyncio.get_event_loop() | |
| observation = await loop.run_in_executor(None, lambda: env.reset(**valid)) | |
| serialized = serialize_observation(observation) | |
| es = web_manager.episode_state | |
| es.episode_id = env.state.episode_id | |
| es.step_count = 0 | |
| es.current_observation = serialized["observation"] | |
| es.action_logs = [] | |
| es.is_reset = True | |
| try: | |
| await web_manager._send_state_update() | |
| except Exception: | |
| pass | |
| return serialized | |
| async def _do_reset( | |
| scenario: str, | |
| task_idx: int, | |
| llm_base_url: str, | |
| llm_api_key: str, | |
| llm_model: str, | |
| reward_config_json: str, | |
| ): | |
| if not scenario: | |
| return ( | |
| "❌ Pick a scenario first.", | |
| "", | |
| "*Reset failed.*", | |
| "{}", | |
| gr.update(choices=[], value=None), | |
| "{}", | |
| ) | |
| reset_kwargs: dict[str, Any] = { | |
| "scenario": scenario, | |
| "task_idx": int(task_idx), | |
| } | |
| if llm_base_url: | |
| reset_kwargs["llm_base_url"] = llm_base_url | |
| if llm_api_key: | |
| reset_kwargs["llm_api_key"] = llm_api_key | |
| if llm_model: | |
| reset_kwargs["llm_model"] = llm_model | |
| if reward_config_json and reward_config_json.strip(): | |
| try: | |
| reset_kwargs["reward_config"] = json.loads(reward_config_json) | |
| except json.JSONDecodeError as e: | |
| return ( | |
| f"❌ Invalid reward_config JSON: {e}", | |
| "", | |
| "*Reset failed.*", | |
| "{}", | |
| gr.update(choices=[], value=None), | |
| "{}", | |
| ) | |
| try: | |
| result = await _safe_reset(reset_kwargs) | |
| except Exception as e: | |
| return ( | |
| f"❌ Reset error: {e}", | |
| "", | |
| f"*Reset failed: {e}*", | |
| "{}", | |
| gr.update(choices=[], value=None), | |
| "{}", | |
| ) | |
| obs = result.get("observation", {}) or {} | |
| rt = obs.get("reward_type") | |
| ok = rt in ("reset_ok", "reset_warning") | |
| status = "✅ Reset OK." if ok else f"❌ Reset failed: {obs.get('error') or rt}" | |
| task_md = f"**Task** (`{scenario}`, idx={task_idx}):\n\n{obs.get('task') or '*(no task description)*'}" | |
| tool_names: list[str] = [] | |
| tool_lookup: dict[str, dict] = {} | |
| if ok: | |
| try: | |
| tools_result = await web_manager.step_environment( | |
| {"type": "list_tools"} | |
| ) | |
| tools = (tools_result.get("observation", {}) or {}).get("tools") or [] | |
| for t in tools: | |
| if isinstance(t, dict): | |
| n = t.get("name", "") | |
| tool_names.append(n) | |
| tool_lookup[n] = t | |
| else: | |
| n = getattr(t, "name", "") | |
| tool_names.append(n) | |
| tool_lookup[n] = { | |
| "name": n, | |
| "description": getattr(t, "description", ""), | |
| "input_schema": getattr(t, "input_schema", {}), | |
| } | |
| except Exception as e: | |
| status += f" (list_tools warning: {e})" | |
| tool_choice = gr.update( | |
| choices=tool_names, | |
| value=(tool_names[0] if tool_names else None), | |
| ) | |
| return ( | |
| status, | |
| task_md, | |
| _format_obs_md(result), | |
| json.dumps(result, indent=2, default=str), | |
| tool_choice, | |
| json.dumps(tool_lookup), | |
| ) | |
| async def _refresh_scenarios(): | |
| try: | |
| scens = data_loader.list_scenarios() | |
| names = sorted(s["name"] for s in scens) | |
| return gr.update( | |
| choices=names, | |
| value=(names[0] if names else None), | |
| ), f"Loaded {len(names)} scenarios." | |
| except Exception as e: | |
| return gr.update(choices=[]), f"❌ Failed to load scenarios: {e}" | |
| async def _on_tool_change(tool_name: str, tool_lookup_json: str): | |
| try: | |
| lookup = json.loads(tool_lookup_json or "{}") | |
| except json.JSONDecodeError: | |
| lookup = {} | |
| if not tool_name or tool_name not in lookup: | |
| return "{}", "" | |
| meta = lookup[tool_name] | |
| schema = meta.get("input_schema") or meta.get("inputSchema") or {} | |
| return _make_args_template(schema), meta.get("description", "") | |
| async def _do_call_tool(tool_name: str, args_json: str): | |
| if not tool_name: | |
| return "*Pick a tool first.*", "{}" | |
| try: | |
| args = json.loads(args_json) if args_json.strip() else {} | |
| except json.JSONDecodeError as e: | |
| return f"❌ Invalid args JSON: {e}", "{}" | |
| try: | |
| result = await web_manager.step_environment( | |
| {"type": "call_tool", "tool_name": tool_name, "arguments": args} | |
| ) | |
| except Exception as e: | |
| return f"❌ step error: {e}", "{}" | |
| return _format_obs_md(result), json.dumps(result, indent=2, default=str) | |
| async def _do_list_tools(): | |
| try: | |
| result = await web_manager.step_environment({"type": "list_tools"}) | |
| except Exception as e: | |
| return f"❌ {e}", "{}" | |
| return _format_obs_md(result), json.dumps(result, indent=2, default=str) | |
| async def _do_verify(verifier_mode: Any, final_answer: str): | |
| if isinstance(verifier_mode, dict): | |
| verifier_mode = next(iter(verifier_mode.keys()), "code") | |
| verifier_mode = str(verifier_mode).strip().lower() or "code" | |
| if verifier_mode not in ("code", "sql"): | |
| verifier_mode = "code" | |
| args: dict[str, Any] = {"verifier_mode": verifier_mode} | |
| if final_answer: | |
| args["final_answer"] = final_answer | |
| try: | |
| result = await web_manager.step_environment( | |
| {"type": "call_tool", "tool_name": "verify", "arguments": args} | |
| ) | |
| except Exception as e: | |
| return f"❌ verify error: {e}", "{}" | |
| return _format_obs_md(result), json.dumps(result, indent=2, default=str) | |
| async def _do_done(keep_session: bool): | |
| try: | |
| result = await web_manager.step_environment( | |
| { | |
| "type": "call_tool", | |
| "tool_name": "done", | |
| "arguments": {"keep_session": bool(keep_session)}, | |
| } | |
| ) | |
| except Exception as e: | |
| return f"❌ done error: {e}", "{}", None | |
| obs = result.get("observation", {}) or {} | |
| traj_path = obs.get("trajectory_path") | |
| return ( | |
| _format_obs_md(result), | |
| json.dumps(result, indent=2, default=str), | |
| traj_path if traj_path and os.path.exists(traj_path) else None, | |
| ) | |
| async def _do_list_scenarios_via_tool(): | |
| try: | |
| result = await web_manager.step_environment( | |
| { | |
| "type": "call_tool", | |
| "tool_name": "__list_scenarios__", | |
| "arguments": {}, | |
| } | |
| ) | |
| except Exception as e: | |
| return f"❌ {e}", "{}" | |
| return _format_obs_md(result), json.dumps(result, indent=2, default=str) | |
| agent_state: dict[str, AwmAgent | None] = {"agent": None, "stop": False} | |
| async def _do_run_agent( | |
| scenario: str, | |
| task_idx: int, | |
| verifier_mode: Any, | |
| llm_base_url: str, | |
| llm_api_key: str, | |
| llm_model: str, | |
| system_prompt: str, | |
| max_iter: int, | |
| temperature: float, | |
| max_tokens: int, | |
| auto_verify: bool, | |
| auto_done: bool, | |
| ): | |
| if isinstance(verifier_mode, dict): | |
| verifier_mode = next(iter(verifier_mode.keys()), "code") | |
| verifier_mode = str(verifier_mode).strip().lower() or "code" | |
| if verifier_mode not in ("code", "sql"): | |
| verifier_mode = "code" | |
| if not scenario: | |
| yield "❌ Pick a scenario first.", None | |
| return | |
| if not (llm_base_url and llm_api_key and llm_model): | |
| yield ( | |
| "❌ LLM config required for Agent mode (base_url + api_key + model).", | |
| None, | |
| ) | |
| return | |
| try: | |
| reset_result = await _safe_reset( | |
| { | |
| "scenario": scenario, | |
| "task_idx": int(task_idx), | |
| "llm_base_url": llm_base_url, | |
| "llm_api_key": llm_api_key, | |
| "llm_model": llm_model, | |
| } | |
| ) | |
| except Exception as e: | |
| yield f"❌ Reset failed: {e}", None | |
| return | |
| obs = reset_result.get("observation", {}) or {} | |
| if obs.get("reward_type") not in ("reset_ok", "reset_warning"): | |
| yield ( | |
| f"❌ Reset returned reward_type={obs.get('reward_type')}, error={obs.get('error')}", | |
| None, | |
| ) | |
| return | |
| task_text = obs.get("task") or "" | |
| agent = AwmAgent( | |
| web_manager=web_manager, | |
| llm_base_url=llm_base_url, | |
| llm_api_key=llm_api_key, | |
| llm_model=llm_model, | |
| system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT, | |
| max_iterations=int(max_iter), | |
| temperature=float(temperature), | |
| max_tokens=int(max_tokens), | |
| ) | |
| agent_state["agent"] = agent | |
| agent_state["stop"] = False | |
| log_lines: list[str] = [ | |
| "### Agent run started", | |
| f"**scenario**: `{scenario}` **task_idx**: `{task_idx}`", | |
| f"**task**: {task_text}", | |
| "", | |
| ] | |
| # Two outputs: (log_markdown, trajectory_file_path_or_None) | |
| yield "\n".join(log_lines), None | |
| traj_path: str | None = None | |
| async for ev in agent.run( | |
| task=task_text, | |
| verifier_mode=verifier_mode, | |
| auto_verify=auto_verify, | |
| auto_done=auto_done, | |
| ): | |
| if agent_state["stop"]: | |
| agent.request_stop() | |
| if ev.kind == "info": | |
| log_lines.append(f"_ℹ️ {ev.text}_") | |
| elif ev.kind == "llm_response": | |
| step = ev.payload.get("step", "?") | |
| log_lines.append(f"\n**[Step {step}] LLM response:**") | |
| log_lines.append(f"```\n{ev.text[:4000]}\n```") | |
| elif ev.kind == "tool_call": | |
| log_lines.append(f"→ **tool_call**: {ev.text}") | |
| elif ev.kind == "tool_result": | |
| log_lines.append("← **tool_result:**") | |
| log_lines.append(f"```\n{ev.text[:1500]}\n```") | |
| elif ev.kind == "verify": | |
| log_lines.append(f"\n🧪 **Verify**: {ev.text}") | |
| if ev.payload.get("verify_result"): | |
| vr = json.dumps(ev.payload["verify_result"], indent=2, default=str)[ | |
| :1500 | |
| ] | |
| log_lines.append(f"```json\n{vr}\n```") | |
| elif ev.kind == "done": | |
| log_lines.append(f"\n🏁 **Done**: {ev.text}") | |
| # Capture the trajectory path for download | |
| p = ev.payload.get("trajectory_path") | |
| if p and os.path.exists(p): | |
| traj_path = p | |
| elif ev.kind == "error": | |
| log_lines.append(f"\n❌ **Error**: {ev.text}") | |
| yield "\n\n".join(log_lines), traj_path | |
| log_lines.append("\n_Agent run finished._") | |
| yield "\n\n".join(log_lines), traj_path | |
| def _do_stop_agent(): | |
| agent_state["stop"] = True | |
| a = agent_state["agent"] | |
| if a is not None: | |
| a.request_stop() | |
| return "🛑 Stop requested. The agent will exit before its next iteration." | |
| async def _load_trajectory_from_state(): | |
| try: | |
| logs = web_manager.episode_state.action_logs | |
| except Exception: | |
| return [], None | |
| rows: list[list[Any]] = [] | |
| for i, log in enumerate(logs, 1): | |
| action = getattr(log, "action", {}) or {} | |
| obs = getattr(log, "observation", {}) or {} | |
| tool_name = action.get("tool_name") if isinstance(action, dict) else "" | |
| atype = action.get("type") if isinstance(action, dict) else "" | |
| rt = obs.get("reward_type") if isinstance(obs, dict) else "" | |
| preview = "" | |
| if isinstance(obs, dict): | |
| tr = obs.get("tool_result") | |
| if isinstance(tr, str): | |
| preview = tr[:200] | |
| elif tr is not None: | |
| preview = json.dumps(tr, default=str)[:200] | |
| elif obs.get("error"): | |
| preview = f"ERROR: {obs['error']}"[:200] | |
| rows.append( | |
| [ | |
| i, | |
| atype or "", | |
| tool_name or "", | |
| rt or "", | |
| getattr(log, "reward", None), | |
| preview, | |
| ] | |
| ) | |
| return rows, None | |
| with gr.Blocks(title=f"AWM — {title}") as blocks: | |
| tools_state = gr.State("{}") | |
| gr.Markdown("# 🤖 Agent World Model — Web Console") | |
| gr.Markdown( | |
| "Pick a scenario, set LLM credentials (only needed for SQL verifier " | |
| "or Agent mode), then explore via Human or Agent mode." | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("## ⚙️ Setup") | |
| with gr.Row(): | |
| scenario_dd = gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label="Scenario", | |
| info="1,000 scenarios — click 'Load' first", | |
| elem_id="awm_scenario_dd", | |
| interactive=True, | |
| allow_custom_value=False, | |
| ) | |
| load_scen_btn = gr.Button( | |
| "🔄 Load scenarios", scale=0, elem_id="awm_load_scen" | |
| ) | |
| task_idx_slider = gr.Slider( | |
| minimum=0, | |
| maximum=9, | |
| step=1, | |
| value=0, | |
| label="Task idx (0-9)", | |
| elem_id="awm_task_idx", | |
| ) | |
| verifier_mode_radio = gr.Textbox( | |
| value="code", | |
| label="Verifier mode (code or sql)", | |
| elem_id="awm_verifier_mode", | |
| info="Type 'code' or 'sql'. SQL mode requires LLM config above.", | |
| ) | |
| with gr.Accordion( | |
| "LLM config (for SQL verifier and Agent mode)", open=False | |
| ): | |
| llm_base_url_in = gr.Textbox( | |
| label="LLM base_url", | |
| placeholder="https://...", | |
| value="", | |
| elem_id="awm_llm_base_url", | |
| ) | |
| llm_api_key_in = gr.Textbox( | |
| label="LLM api_key", | |
| type="password", | |
| value="", | |
| elem_id="awm_llm_api_key", | |
| ) | |
| llm_model_in = gr.Textbox( | |
| label="LLM model", | |
| placeholder="gpt-4.1, gpt-5, ...", | |
| value="", | |
| elem_id="awm_llm_model", | |
| ) | |
| with gr.Accordion("Reward config (advanced)", open=False): | |
| reward_cfg_in = gr.Code( | |
| language="json", | |
| value=_DEFAULT_REWARD_JSON, | |
| label="reward_config (JSON)", | |
| elem_id="awm_reward_cfg", | |
| ) | |
| with gr.Row(): | |
| reset_btn = gr.Button( | |
| "🔄 Reset", variant="primary", elem_id="awm_reset_btn" | |
| ) | |
| status_box = gr.Markdown("Status: *idle*", elem_id="awm_status_box") | |
| task_md = gr.Markdown("*No task loaded yet.*", elem_id="awm_task_md") | |
| with gr.Tabs(): | |
| with gr.Tab("👤 Human Mode"): | |
| with gr.Row(): | |
| list_tools_btn = gr.Button( | |
| "📋 List Tools", elem_id="awm_human_list_tools" | |
| ) | |
| list_scenarios_btn = gr.Button( | |
| "🌐 List Scenarios", elem_id="awm_human_list_scenarios" | |
| ) | |
| with gr.Row(): | |
| tool_dd = gr.Dropdown( | |
| choices=[], | |
| value=None, | |
| label="Tool", | |
| elem_id="awm_human_tool_dd", | |
| interactive=True, | |
| allow_custom_value=False, | |
| ) | |
| tool_desc_md = gr.Markdown("", elem_id="awm_human_tool_desc") | |
| tool_args_in = gr.Code( | |
| language="json", | |
| value="{}", | |
| label="Tool arguments (JSON)", | |
| elem_id="awm_human_tool_args", | |
| ) | |
| with gr.Row(): | |
| call_tool_btn = gr.Button( | |
| "▶️ Call Tool", variant="primary", elem_id="awm_human_call_tool" | |
| ) | |
| with gr.Group(): | |
| gr.Markdown("### Episode controls") | |
| final_answer_in = gr.Textbox( | |
| label="Final answer (optional, used in code-mode verify)", | |
| value="", | |
| elem_id="awm_human_final_answer", | |
| ) | |
| with gr.Row(): | |
| verify_btn = gr.Button("🧪 Verify", elem_id="awm_human_verify") | |
| keep_session_cb = gr.Checkbox( | |
| value=True, | |
| label="keep_session on done", | |
| elem_id="awm_human_keep_session", | |
| ) | |
| done_btn = gr.Button( | |
| "🏁 Done", variant="stop", elem_id="awm_human_done" | |
| ) | |
| gr.Markdown("### Latest observation") | |
| obs_md = gr.Markdown("*No action yet.*", elem_id="awm_human_obs_md") | |
| with gr.Accordion("Raw JSON", open=False): | |
| obs_json = gr.Code( | |
| language="json", | |
| value="{}", | |
| elem_id="awm_human_obs_json", | |
| ) | |
| trajectory_file_dl = gr.File( | |
| label="trajectory.json (after Done)", | |
| interactive=False, | |
| elem_id="awm_human_traj_dl", | |
| ) | |
| with gr.Tab("🤖 Agent Mode"): | |
| gr.Markdown( | |
| "Drives an LLM agent through the env. Reset is done" | |
| " automatically using the scenario/task_idx selected above." | |
| ) | |
| system_prompt_in = gr.Textbox( | |
| label="System prompt", | |
| value=DEFAULT_SYSTEM_PROMPT, | |
| lines=8, | |
| max_lines=20, | |
| elem_id="awm_agent_system_prompt", | |
| ) | |
| with gr.Row(): | |
| max_iter_slider = gr.Slider( | |
| minimum=1, | |
| maximum=30, | |
| step=1, | |
| value=10, | |
| label="Max iterations", | |
| elem_id="awm_agent_max_iter", | |
| ) | |
| temperature_slider = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| step=0.1, | |
| value=1.0, | |
| label="Temperature", | |
| elem_id="awm_agent_temperature", | |
| ) | |
| max_tokens_slider = gr.Slider( | |
| minimum=256, | |
| maximum=8192, | |
| step=128, | |
| value=2048, | |
| label="Max tokens / call", | |
| elem_id="awm_agent_max_tokens", | |
| ) | |
| with gr.Row(): | |
| auto_verify_cb = gr.Checkbox( | |
| value=True, | |
| label="Auto verify at end", | |
| elem_id="awm_agent_auto_verify", | |
| ) | |
| auto_done_cb = gr.Checkbox( | |
| value=True, | |
| label="Auto done at end (keep_session=True)", | |
| elem_id="awm_agent_auto_done", | |
| ) | |
| with gr.Row(): | |
| start_agent_btn = gr.Button( | |
| "▶️ Start Agent", variant="primary", elem_id="awm_agent_start" | |
| ) | |
| stop_agent_btn = gr.Button( | |
| "⏹ Stop", variant="stop", elem_id="awm_agent_stop" | |
| ) | |
| stop_status = gr.Markdown("", elem_id="awm_agent_stop_status") | |
| agent_log_md = gr.Markdown( | |
| "_Agent log will appear here._", elem_id="awm_agent_log_md" | |
| ) | |
| agent_traj_dl = gr.File( | |
| label="trajectory.json (auto-populated when agent finishes)", | |
| interactive=False, | |
| elem_id="awm_agent_traj_dl", | |
| ) | |
| with gr.Tab("📜 Trajectory"): | |
| gr.Markdown( | |
| "Step-by-step history of the current episode. " | |
| "Click **Refresh** after actions to update." | |
| ) | |
| refresh_traj_btn = gr.Button("🔄 Refresh", elem_id="awm_traj_refresh") | |
| traj_table = gr.Dataframe( | |
| headers=["#", "type", "tool", "reward_type", "reward", "preview"], | |
| datatype=["number", "str", "str", "str", "number", "str"], | |
| interactive=False, | |
| elem_id="awm_traj_table", | |
| ) | |
| if readme_md: | |
| with gr.Tab("📖 README"): | |
| gr.Markdown(readme_md) | |
| load_scen_btn.click( | |
| _refresh_scenarios, inputs=None, outputs=[scenario_dd, status_box] | |
| ) | |
| reset_btn.click( | |
| _do_reset, | |
| inputs=[ | |
| scenario_dd, | |
| task_idx_slider, | |
| llm_base_url_in, | |
| llm_api_key_in, | |
| llm_model_in, | |
| reward_cfg_in, | |
| ], | |
| outputs=[status_box, task_md, obs_md, obs_json, tool_dd, tools_state], | |
| ) | |
| list_tools_btn.click(_do_list_tools, inputs=None, outputs=[obs_md, obs_json]) | |
| list_scenarios_btn.click( | |
| _do_list_scenarios_via_tool, inputs=None, outputs=[obs_md, obs_json] | |
| ) | |
| tool_dd.change( | |
| _on_tool_change, | |
| inputs=[tool_dd, tools_state], | |
| outputs=[tool_args_in, tool_desc_md], | |
| ) | |
| call_tool_btn.click( | |
| _do_call_tool, | |
| inputs=[tool_dd, tool_args_in], | |
| outputs=[obs_md, obs_json], | |
| ) | |
| verify_btn.click( | |
| _do_verify, | |
| inputs=[verifier_mode_radio, final_answer_in], | |
| outputs=[obs_md, obs_json], | |
| ) | |
| done_btn.click( | |
| _do_done, | |
| inputs=[keep_session_cb], | |
| outputs=[obs_md, obs_json, trajectory_file_dl], | |
| ) | |
| start_agent_btn.click( | |
| _do_run_agent, | |
| inputs=[ | |
| scenario_dd, | |
| task_idx_slider, | |
| verifier_mode_radio, | |
| llm_base_url_in, | |
| llm_api_key_in, | |
| llm_model_in, | |
| system_prompt_in, | |
| max_iter_slider, | |
| temperature_slider, | |
| max_tokens_slider, | |
| auto_verify_cb, | |
| auto_done_cb, | |
| ], | |
| outputs=[agent_log_md, agent_traj_dl], | |
| ) | |
| stop_agent_btn.click(_do_stop_agent, inputs=None, outputs=[stop_status]) | |
| refresh_traj_btn.click( | |
| _load_trajectory_from_state, | |
| inputs=None, | |
| outputs=[traj_table, trajectory_file_dl], | |
| ) | |
| return blocks | |