""" Custom Gradio UI for the AWM environment. """ from __future__ import annotations import asyncio import inspect import json import os from typing import Any import gradio as gr from openenv.core.env_server.serialization import serialize_observation from .data_loader import AWMDataLoader from .prompts import DEFAULT_SYSTEM_PROMPT from .web_agent import AwmAgent from .config import DEFAULT_REWARD_CONFIG # Keep in sync with DEFAULT_REWARD_CONFIG in config.py. _DEFAULT_REWARD_JSON = json.dumps( DEFAULT_REWARD_CONFIG, indent=2 ) def _format_obs_md(payload: dict | None) -> str: if not payload: return "*No observation yet.*" obs = payload.get("observation") if isinstance(payload, dict) else None if obs is None: obs = payload reward = payload.get("reward") if isinstance(payload, dict) else None done = payload.get("done") if isinstance(payload, dict) else None lines: list[str] = [] if reward is not None: lines.append(f"**reward**: `{reward}`") if done is not None: lines.append(f"**done**: `{done}`") if isinstance(obs, dict): for key in ( "reward_type", "scenario", "task", "task_idx", "num_tools", "tool_name", "error", "warning", ): v = obs.get(key) if v is not None and v != "": lines.append(f"**{key}**: `{v}`") if obs.get("tool_result") is not None: tr = obs["tool_result"] tr_text = ( tr if isinstance(tr, str) else json.dumps(tr, indent=2, default=str) ) if len(tr_text) > 2000: tr_text = tr_text[:2000] + "\n... (truncated)" lines.append("\n**tool_result:**") lines.append(f"```\n{tr_text}\n```") if obs.get("verify_result"): vr = obs["verify_result"] vr_text = json.dumps(vr, indent=2, default=str) if len(vr_text) > 2000: vr_text = vr_text[:2000] + "\n... (truncated)" lines.append("\n**verify_result:**") lines.append(f"```json\n{vr_text}\n```") if obs.get("trajectory_path"): lines.append(f"\n**trajectory_path**: `{obs['trajectory_path']}`") return "\n\n".join(lines) if lines else "*Empty observation.*" def _make_args_template(input_schema: dict | None) -> str: if not input_schema or not isinstance(input_schema, dict): return "{}" props = input_schema.get("properties") or {} template: dict[str, Any] = {} for name, info in props.items(): ty = (info or {}).get("type", "string") template[name] = { "string": "", "integer": 0, "number": 0.0, "boolean": False, "array": [], "object": {}, }.get(ty, None) return json.dumps(template, indent=2) def build_awm_gradio_app( web_manager: Any, action_fields: list[dict] | None = None, metadata: Any = None, is_chat_env: bool = False, title: str = "AWM Environment", quick_start_md: str | None = None, ) -> gr.Blocks: data_loader = AWMDataLoader(cache_dir=os.environ.get("AWM_DATA_DIR")) readme_md = "" if metadata is not None and getattr(metadata, "readme_content", None): readme_md = metadata.readme_content # openenv-core 0.2.3 added a ``reset_kwargs`` parameter to # ``WebInterfaceManager.reset_environment``. PyPI's 0.2.1 takes no args, # which silently drops scenario/task_idx — fall back to calling env.reset # directly and replicate the episode-state updates the manager would do. _reset_env_supports_kwargs = ( len( [ p for p in inspect.signature( web_manager.reset_environment ).parameters.values() if p.name != "self" ] ) > 0 ) async def _safe_reset(reset_kwargs: dict[str, Any]) -> dict[str, Any]: if _reset_env_supports_kwargs: return await web_manager.reset_environment(reset_kwargs) env = web_manager.env params = inspect.signature(env.reset).parameters has_var_kw = any( p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values() ) valid = {k: v for k, v in reset_kwargs.items() if has_var_kw or k in params} loop = asyncio.get_event_loop() observation = await loop.run_in_executor(None, lambda: env.reset(**valid)) serialized = serialize_observation(observation) es = web_manager.episode_state es.episode_id = env.state.episode_id es.step_count = 0 es.current_observation = serialized["observation"] es.action_logs = [] es.is_reset = True try: await web_manager._send_state_update() except Exception: pass return serialized async def _do_reset( scenario: str, task_idx: int, llm_base_url: str, llm_api_key: str, llm_model: str, reward_config_json: str, ): if not scenario: return ( "❌ Pick a scenario first.", "", "*Reset failed.*", "{}", gr.update(choices=[], value=None), "{}", ) reset_kwargs: dict[str, Any] = { "scenario": scenario, "task_idx": int(task_idx), } if llm_base_url: reset_kwargs["llm_base_url"] = llm_base_url if llm_api_key: reset_kwargs["llm_api_key"] = llm_api_key if llm_model: reset_kwargs["llm_model"] = llm_model if reward_config_json and reward_config_json.strip(): try: reset_kwargs["reward_config"] = json.loads(reward_config_json) except json.JSONDecodeError as e: return ( f"❌ Invalid reward_config JSON: {e}", "", "*Reset failed.*", "{}", gr.update(choices=[], value=None), "{}", ) try: result = await _safe_reset(reset_kwargs) except Exception as e: return ( f"❌ Reset error: {e}", "", f"*Reset failed: {e}*", "{}", gr.update(choices=[], value=None), "{}", ) obs = result.get("observation", {}) or {} rt = obs.get("reward_type") ok = rt in ("reset_ok", "reset_warning") status = "✅ Reset OK." if ok else f"❌ Reset failed: {obs.get('error') or rt}" task_md = f"**Task** (`{scenario}`, idx={task_idx}):\n\n{obs.get('task') or '*(no task description)*'}" tool_names: list[str] = [] tool_lookup: dict[str, dict] = {} if ok: try: tools_result = await web_manager.step_environment( {"type": "list_tools"} ) tools = (tools_result.get("observation", {}) or {}).get("tools") or [] for t in tools: if isinstance(t, dict): n = t.get("name", "") tool_names.append(n) tool_lookup[n] = t else: n = getattr(t, "name", "") tool_names.append(n) tool_lookup[n] = { "name": n, "description": getattr(t, "description", ""), "input_schema": getattr(t, "input_schema", {}), } except Exception as e: status += f" (list_tools warning: {e})" tool_choice = gr.update( choices=tool_names, value=(tool_names[0] if tool_names else None), ) return ( status, task_md, _format_obs_md(result), json.dumps(result, indent=2, default=str), tool_choice, json.dumps(tool_lookup), ) async def _refresh_scenarios(): try: scens = data_loader.list_scenarios() names = sorted(s["name"] for s in scens) return gr.update( choices=names, value=(names[0] if names else None), ), f"Loaded {len(names)} scenarios." except Exception as e: return gr.update(choices=[]), f"❌ Failed to load scenarios: {e}" async def _on_tool_change(tool_name: str, tool_lookup_json: str): try: lookup = json.loads(tool_lookup_json or "{}") except json.JSONDecodeError: lookup = {} if not tool_name or tool_name not in lookup: return "{}", "" meta = lookup[tool_name] schema = meta.get("input_schema") or meta.get("inputSchema") or {} return _make_args_template(schema), meta.get("description", "") async def _do_call_tool(tool_name: str, args_json: str): if not tool_name: return "*Pick a tool first.*", "{}" try: args = json.loads(args_json) if args_json.strip() else {} except json.JSONDecodeError as e: return f"❌ Invalid args JSON: {e}", "{}" try: result = await web_manager.step_environment( {"type": "call_tool", "tool_name": tool_name, "arguments": args} ) except Exception as e: return f"❌ step error: {e}", "{}" return _format_obs_md(result), json.dumps(result, indent=2, default=str) async def _do_list_tools(): try: result = await web_manager.step_environment({"type": "list_tools"}) except Exception as e: return f"❌ {e}", "{}" return _format_obs_md(result), json.dumps(result, indent=2, default=str) async def _do_verify(verifier_mode: Any, final_answer: str): if isinstance(verifier_mode, dict): verifier_mode = next(iter(verifier_mode.keys()), "code") verifier_mode = str(verifier_mode).strip().lower() or "code" if verifier_mode not in ("code", "sql"): verifier_mode = "code" args: dict[str, Any] = {"verifier_mode": verifier_mode} if final_answer: args["final_answer"] = final_answer try: result = await web_manager.step_environment( {"type": "call_tool", "tool_name": "verify", "arguments": args} ) except Exception as e: return f"❌ verify error: {e}", "{}" return _format_obs_md(result), json.dumps(result, indent=2, default=str) async def _do_done(keep_session: bool): try: result = await web_manager.step_environment( { "type": "call_tool", "tool_name": "done", "arguments": {"keep_session": bool(keep_session)}, } ) except Exception as e: return f"❌ done error: {e}", "{}", None obs = result.get("observation", {}) or {} traj_path = obs.get("trajectory_path") return ( _format_obs_md(result), json.dumps(result, indent=2, default=str), traj_path if traj_path and os.path.exists(traj_path) else None, ) async def _do_list_scenarios_via_tool(): try: result = await web_manager.step_environment( { "type": "call_tool", "tool_name": "__list_scenarios__", "arguments": {}, } ) except Exception as e: return f"❌ {e}", "{}" return _format_obs_md(result), json.dumps(result, indent=2, default=str) agent_state: dict[str, AwmAgent | None] = {"agent": None, "stop": False} async def _do_run_agent( scenario: str, task_idx: int, verifier_mode: Any, llm_base_url: str, llm_api_key: str, llm_model: str, system_prompt: str, max_iter: int, temperature: float, max_tokens: int, auto_verify: bool, auto_done: bool, ): if isinstance(verifier_mode, dict): verifier_mode = next(iter(verifier_mode.keys()), "code") verifier_mode = str(verifier_mode).strip().lower() or "code" if verifier_mode not in ("code", "sql"): verifier_mode = "code" if not scenario: yield "❌ Pick a scenario first.", None return if not (llm_base_url and llm_api_key and llm_model): yield ( "❌ LLM config required for Agent mode (base_url + api_key + model).", None, ) return try: reset_result = await _safe_reset( { "scenario": scenario, "task_idx": int(task_idx), "llm_base_url": llm_base_url, "llm_api_key": llm_api_key, "llm_model": llm_model, } ) except Exception as e: yield f"❌ Reset failed: {e}", None return obs = reset_result.get("observation", {}) or {} if obs.get("reward_type") not in ("reset_ok", "reset_warning"): yield ( f"❌ Reset returned reward_type={obs.get('reward_type')}, error={obs.get('error')}", None, ) return task_text = obs.get("task") or "" agent = AwmAgent( web_manager=web_manager, llm_base_url=llm_base_url, llm_api_key=llm_api_key, llm_model=llm_model, system_prompt=system_prompt or DEFAULT_SYSTEM_PROMPT, max_iterations=int(max_iter), temperature=float(temperature), max_tokens=int(max_tokens), ) agent_state["agent"] = agent agent_state["stop"] = False log_lines: list[str] = [ "### Agent run started", f"**scenario**: `{scenario}`    **task_idx**: `{task_idx}`", f"**task**: {task_text}", "", ] # Two outputs: (log_markdown, trajectory_file_path_or_None) yield "\n".join(log_lines), None traj_path: str | None = None async for ev in agent.run( task=task_text, verifier_mode=verifier_mode, auto_verify=auto_verify, auto_done=auto_done, ): if agent_state["stop"]: agent.request_stop() if ev.kind == "info": log_lines.append(f"_ℹ️ {ev.text}_") elif ev.kind == "llm_response": step = ev.payload.get("step", "?") log_lines.append(f"\n**[Step {step}] LLM response:**") log_lines.append(f"```\n{ev.text[:4000]}\n```") elif ev.kind == "tool_call": log_lines.append(f"→ **tool_call**: {ev.text}") elif ev.kind == "tool_result": log_lines.append("← **tool_result:**") log_lines.append(f"```\n{ev.text[:1500]}\n```") elif ev.kind == "verify": log_lines.append(f"\n🧪 **Verify**: {ev.text}") if ev.payload.get("verify_result"): vr = json.dumps(ev.payload["verify_result"], indent=2, default=str)[ :1500 ] log_lines.append(f"```json\n{vr}\n```") elif ev.kind == "done": log_lines.append(f"\n🏁 **Done**: {ev.text}") # Capture the trajectory path for download p = ev.payload.get("trajectory_path") if p and os.path.exists(p): traj_path = p elif ev.kind == "error": log_lines.append(f"\n❌ **Error**: {ev.text}") yield "\n\n".join(log_lines), traj_path log_lines.append("\n_Agent run finished._") yield "\n\n".join(log_lines), traj_path def _do_stop_agent(): agent_state["stop"] = True a = agent_state["agent"] if a is not None: a.request_stop() return "🛑 Stop requested. The agent will exit before its next iteration." async def _load_trajectory_from_state(): try: logs = web_manager.episode_state.action_logs except Exception: return [], None rows: list[list[Any]] = [] for i, log in enumerate(logs, 1): action = getattr(log, "action", {}) or {} obs = getattr(log, "observation", {}) or {} tool_name = action.get("tool_name") if isinstance(action, dict) else "" atype = action.get("type") if isinstance(action, dict) else "" rt = obs.get("reward_type") if isinstance(obs, dict) else "" preview = "" if isinstance(obs, dict): tr = obs.get("tool_result") if isinstance(tr, str): preview = tr[:200] elif tr is not None: preview = json.dumps(tr, default=str)[:200] elif obs.get("error"): preview = f"ERROR: {obs['error']}"[:200] rows.append( [ i, atype or "", tool_name or "", rt or "", getattr(log, "reward", None), preview, ] ) return rows, None with gr.Blocks(title=f"AWM — {title}") as blocks: tools_state = gr.State("{}") gr.Markdown("# 🤖 Agent World Model — Web Console") gr.Markdown( "Pick a scenario, set LLM credentials (only needed for SQL verifier " "or Agent mode), then explore via Human or Agent mode." ) with gr.Group(): gr.Markdown("## ⚙️ Setup") with gr.Row(): scenario_dd = gr.Dropdown( choices=[], value=None, label="Scenario", info="1,000 scenarios — click 'Load' first", elem_id="awm_scenario_dd", interactive=True, allow_custom_value=False, ) load_scen_btn = gr.Button( "🔄 Load scenarios", scale=0, elem_id="awm_load_scen" ) task_idx_slider = gr.Slider( minimum=0, maximum=9, step=1, value=0, label="Task idx (0-9)", elem_id="awm_task_idx", ) verifier_mode_radio = gr.Textbox( value="code", label="Verifier mode (code or sql)", elem_id="awm_verifier_mode", info="Type 'code' or 'sql'. SQL mode requires LLM config above.", ) with gr.Accordion( "LLM config (for SQL verifier and Agent mode)", open=False ): llm_base_url_in = gr.Textbox( label="LLM base_url", placeholder="https://...", value="", elem_id="awm_llm_base_url", ) llm_api_key_in = gr.Textbox( label="LLM api_key", type="password", value="", elem_id="awm_llm_api_key", ) llm_model_in = gr.Textbox( label="LLM model", placeholder="gpt-4.1, gpt-5, ...", value="", elem_id="awm_llm_model", ) with gr.Accordion("Reward config (advanced)", open=False): reward_cfg_in = gr.Code( language="json", value=_DEFAULT_REWARD_JSON, label="reward_config (JSON)", elem_id="awm_reward_cfg", ) with gr.Row(): reset_btn = gr.Button( "🔄 Reset", variant="primary", elem_id="awm_reset_btn" ) status_box = gr.Markdown("Status: *idle*", elem_id="awm_status_box") task_md = gr.Markdown("*No task loaded yet.*", elem_id="awm_task_md") with gr.Tabs(): with gr.Tab("👤 Human Mode"): with gr.Row(): list_tools_btn = gr.Button( "📋 List Tools", elem_id="awm_human_list_tools" ) list_scenarios_btn = gr.Button( "🌐 List Scenarios", elem_id="awm_human_list_scenarios" ) with gr.Row(): tool_dd = gr.Dropdown( choices=[], value=None, label="Tool", elem_id="awm_human_tool_dd", interactive=True, allow_custom_value=False, ) tool_desc_md = gr.Markdown("", elem_id="awm_human_tool_desc") tool_args_in = gr.Code( language="json", value="{}", label="Tool arguments (JSON)", elem_id="awm_human_tool_args", ) with gr.Row(): call_tool_btn = gr.Button( "▶️ Call Tool", variant="primary", elem_id="awm_human_call_tool" ) with gr.Group(): gr.Markdown("### Episode controls") final_answer_in = gr.Textbox( label="Final answer (optional, used in code-mode verify)", value="", elem_id="awm_human_final_answer", ) with gr.Row(): verify_btn = gr.Button("🧪 Verify", elem_id="awm_human_verify") keep_session_cb = gr.Checkbox( value=True, label="keep_session on done", elem_id="awm_human_keep_session", ) done_btn = gr.Button( "🏁 Done", variant="stop", elem_id="awm_human_done" ) gr.Markdown("### Latest observation") obs_md = gr.Markdown("*No action yet.*", elem_id="awm_human_obs_md") with gr.Accordion("Raw JSON", open=False): obs_json = gr.Code( language="json", value="{}", elem_id="awm_human_obs_json", ) trajectory_file_dl = gr.File( label="trajectory.json (after Done)", interactive=False, elem_id="awm_human_traj_dl", ) with gr.Tab("🤖 Agent Mode"): gr.Markdown( "Drives an LLM agent through the env. Reset is done" " automatically using the scenario/task_idx selected above." ) system_prompt_in = gr.Textbox( label="System prompt", value=DEFAULT_SYSTEM_PROMPT, lines=8, max_lines=20, elem_id="awm_agent_system_prompt", ) with gr.Row(): max_iter_slider = gr.Slider( minimum=1, maximum=30, step=1, value=10, label="Max iterations", elem_id="awm_agent_max_iter", ) temperature_slider = gr.Slider( minimum=0.0, maximum=2.0, step=0.1, value=1.0, label="Temperature", elem_id="awm_agent_temperature", ) max_tokens_slider = gr.Slider( minimum=256, maximum=8192, step=128, value=2048, label="Max tokens / call", elem_id="awm_agent_max_tokens", ) with gr.Row(): auto_verify_cb = gr.Checkbox( value=True, label="Auto verify at end", elem_id="awm_agent_auto_verify", ) auto_done_cb = gr.Checkbox( value=True, label="Auto done at end (keep_session=True)", elem_id="awm_agent_auto_done", ) with gr.Row(): start_agent_btn = gr.Button( "▶️ Start Agent", variant="primary", elem_id="awm_agent_start" ) stop_agent_btn = gr.Button( "⏹ Stop", variant="stop", elem_id="awm_agent_stop" ) stop_status = gr.Markdown("", elem_id="awm_agent_stop_status") agent_log_md = gr.Markdown( "_Agent log will appear here._", elem_id="awm_agent_log_md" ) agent_traj_dl = gr.File( label="trajectory.json (auto-populated when agent finishes)", interactive=False, elem_id="awm_agent_traj_dl", ) with gr.Tab("📜 Trajectory"): gr.Markdown( "Step-by-step history of the current episode. " "Click **Refresh** after actions to update." ) refresh_traj_btn = gr.Button("🔄 Refresh", elem_id="awm_traj_refresh") traj_table = gr.Dataframe( headers=["#", "type", "tool", "reward_type", "reward", "preview"], datatype=["number", "str", "str", "str", "number", "str"], interactive=False, elem_id="awm_traj_table", ) if readme_md: with gr.Tab("📖 README"): gr.Markdown(readme_md) load_scen_btn.click( _refresh_scenarios, inputs=None, outputs=[scenario_dd, status_box] ) reset_btn.click( _do_reset, inputs=[ scenario_dd, task_idx_slider, llm_base_url_in, llm_api_key_in, llm_model_in, reward_cfg_in, ], outputs=[status_box, task_md, obs_md, obs_json, tool_dd, tools_state], ) list_tools_btn.click(_do_list_tools, inputs=None, outputs=[obs_md, obs_json]) list_scenarios_btn.click( _do_list_scenarios_via_tool, inputs=None, outputs=[obs_md, obs_json] ) tool_dd.change( _on_tool_change, inputs=[tool_dd, tools_state], outputs=[tool_args_in, tool_desc_md], ) call_tool_btn.click( _do_call_tool, inputs=[tool_dd, tool_args_in], outputs=[obs_md, obs_json], ) verify_btn.click( _do_verify, inputs=[verifier_mode_radio, final_answer_in], outputs=[obs_md, obs_json], ) done_btn.click( _do_done, inputs=[keep_session_cb], outputs=[obs_md, obs_json, trajectory_file_dl], ) start_agent_btn.click( _do_run_agent, inputs=[ scenario_dd, task_idx_slider, verifier_mode_radio, llm_base_url_in, llm_api_key_in, llm_model_in, system_prompt_in, max_iter_slider, temperature_slider, max_tokens_slider, auto_verify_cb, auto_done_cb, ], outputs=[agent_log_md, agent_traj_dl], ) stop_agent_btn.click(_do_stop_agent, inputs=None, outputs=[stop_status]) refresh_traj_btn.click( _load_trajectory_from_state, inputs=None, outputs=[traj_table, trajectory_file_dl], ) return blocks