""" Gradio UI for the Research → Interactive Explainer Environment. Two modes: 1. LLM Mode: LLM drives exploration + generation, human watches step-by-step 2. Human Mode: human types queries and code, sees rewards in real-time Environment service is the same OpenEnv server that hosts this UI. LLM configuration is resolved from API_URL, HF_TOKEN/API_KEY, and MODEL_NAME. """ import ast import json import os import re import uuid from pathlib import Path from typing import Any import gradio as gr from dotenv import load_dotenv # Load .env from project root PROJECT_ROOT = Path(__file__).parent load_dotenv(PROJECT_ROOT / ".env") try: from .client import ExplainerEnv from .constants import SUCCESS_SCORE_THRESHOLD, normalized_episode_score from .dashboard_prompts import ( SYSTEM_PROMPT, build_explore_prompt, build_generate_prompt, build_repair_prompt, parse_explore_response, parse_generate_response, ) from .models import ExplainerAction from .task_bank import ALL_TASKS except ImportError: # pragma: no cover - supports direct execution from env root from client import ExplainerEnv from constants import SUCCESS_SCORE_THRESHOLD, normalized_episode_score from dashboard_prompts import ( SYSTEM_PROMPT, build_explore_prompt, build_generate_prompt, build_repair_prompt, parse_explore_response, parse_generate_response, ) from models import ExplainerAction from task_bank import ALL_TASKS SELF_ENV_BASE_URL = f"http://127.0.0.1:{os.getenv('PORT', '8000')}" DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b" # --------------------------------------------------------------------------- # Task catalog (reference only) # --------------------------------------------------------------------------- TASK_CHOICES = ["(random)"] + [f"{t.topic} [{t.difficulty}, {t.tier}]" for t in ALL_TASKS] # Map dropdown label -> topic name for reset(topic=...) _TASK_LABEL_TO_TOPIC: dict[str, str] = {f"{t.topic} [{t.difficulty}, {t.tier}]": t.topic for t in ALL_TASKS} # --------------------------------------------------------------------------- # Session manager # --------------------------------------------------------------------------- class SessionManager: """Module-level registry mapping session_id -> connected ExplainerEnv client.""" def __init__(self): self._clients: dict[str, ExplainerEnv] = {} self._urls: dict[str, str] = {} async def get_or_create(self, session_id: str, base_url: str) -> ExplainerEnv: if session_id in self._clients and self._urls.get(session_id) != base_url: await self.close(session_id) if session_id not in self._clients: client = ExplainerEnv(base_url=base_url.rstrip("/")) await client.connect() self._clients[session_id] = client self._urls[session_id] = base_url return self._clients[session_id] async def close(self, session_id: str) -> None: client = self._clients.pop(session_id, None) self._urls.pop(session_id, None) if client: try: await client.disconnect() except Exception: pass SESSION_MGR = SessionManager() # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _resolve_env_url() -> str: return SELF_ENV_BASE_URL def _resolve_llm() -> tuple[str, str, str]: api_url = (os.getenv("API_URL") or os.getenv("API_BASE_URL") or "").rstrip("/") api_key = os.getenv("HF_TOKEN") or os.getenv("API_KEY") model = os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME return api_url, api_key, model def call_llm_or_raise(client: Any, user_prompt: str, *, model: str, max_tokens: int) -> str: """Call the LLM and preserve provider errors for the dashboard.""" completion = client.chat.completions.create( model=model, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_prompt}, ], temperature=0.7, max_tokens=max_tokens, stream=False, ) return (completion.choices[0].message.content or "").strip() def _format_llm_exception(exc: Exception, api_url: str, model: str) -> str: cause = getattr(exc, "__cause__", None) detail = str(cause or exc).strip() or exc.__class__.__name__ return f"{exc.__class__.__name__} from {api_url} using model {model}: {detail}" def empty_state() -> dict[str, Any]: return { "session_id": str(uuid.uuid4()), "obs": None, "step": 0, "rewards": [], "reward_details": [], "log": [], "done": False, "phase": "not_started", "explored_context": "", "topic": "", "tier": "", "keywords": "", "content": "", "data_available": False, "last_code": "", "last_format": "marimo", "generated_response": "", "parsed_response": "", "top_chunks": [], } def build_reward_matrix(reward_details: list[dict[str, Any]]) -> gr.update: """Build a reward matrix with reward names as rows and steps as columns.""" steps = sorted({entry["step"] for entry in reward_details}) reward_names: list[str] = [] cells: dict[tuple[str, int], Any] = {} for entry in reward_details: step = entry["step"] components = entry.get("components", {}) if not components: components = {"total": ""} for name, value in components.items(): if name not in reward_names: reward_names.append(name) cells[(name, step)] = value headers = ["Reward"] + [f"Step {step}" for step in steps] rows = [] for name in reward_names: row = [name] for step in steps: value = cells.get((name, step), "") row.append(_fmt_component(value) if value != "" else "") rows.append(row) return gr.update( headers=headers, value=rows, column_count=(len(headers), "fixed"), ) def build_reward_summary(reward_details: list[dict[str, Any]]) -> str: if not reward_details: return "*No rewards yet.*" sections = [] for entry in reward_details: components = entry.get("components", {}) total = _first_present( components, ("explore_total", "generate_total", "repair_total"), default="n/a", ) sections.append(f"**Step {entry['step']} · {entry['phase']} · total {_fmt_component(total)}**") return "\n\n".join(sections) def build_top_chunks_df(chunks: list[dict[str, Any]]) -> list[list[Any]]: rows = [] for chunk in chunks[:5]: rows.append([ chunk.get("rank", ""), chunk.get("source", ""), chunk.get("title", ""), chunk.get("score", ""), chunk.get("url", ""), _trim_display_text(str(chunk.get("snippet", "")), 700), ]) return rows def extract_top_chunks(obs_dict: dict[str, Any], search_results: str) -> list[dict[str, Any]]: metadata = obs_dict.get("metadata") or {} chunks = obs_dict.get("top_chunks") or metadata.get("top_chunks") or [] return chunks or parse_rendered_chunks(search_results) def parse_rendered_chunks(search_results: str) -> list[dict[str, Any]]: """Fallback parser for rendered research results if structured fields are absent.""" chunks = [] for part in re.split(r"\n\n---\n\n", search_results or ""): lines = [line for line in part.splitlines() if line.strip()] if not lines: continue match = re.match(r"\[(\d+)\]\s+([^:]+):\s+(.+)", lines[0]) if not match: continue url = "" body_start = 1 if len(lines) > 1 and lines[1].startswith("URL:"): url = lines[1].removeprefix("URL:").strip() body_start = 2 chunks.append({ "rank": int(match.group(1)), "source": match.group(2).strip(), "title": match.group(3).strip(), "url": url, "score": "", "snippet": "\n".join(lines[body_start:]).strip(), }) return chunks[:5] def _trim_display_text(text: str, max_chars: int) -> str: text = re.sub(r"\s+", " ", text).strip() return text if len(text) <= max_chars else text[:max_chars].rstrip() + "..." def _first_present(mapping: dict[str, Any], keys: tuple[str, ...], default: Any = None) -> Any: for key in keys: if key in mapping: return mapping[key] return default def _fmt_component(value: Any) -> str: return f"{value:.3f}" if isinstance(value, float) else str(value) _NON_REWARD_METADATA_KEYS = frozenset({ "step", "phase", "tool", "source_count", "error", "explore_steps_used", "repair_steps_used", "sandbox_message", "error_codes", }) _VISIBLE_REWARD_COMPONENTS = { "explore": ( "query_quality", "evidence_quality", "information_gain", "efficiency", "explore_total", ), "generate": ( "validity", "task_alignment", "structure", "research_usage", "generate_total", ), "repair": ( "repair_success", "fixed_prior_errors", "changed_code", "repair_total", ), } def parse_reward_components(feedback: str) -> dict[str, Any]: """Fallback parser for older observations that lack reward metadata.""" dict_match = re.search(r"Reward:\s*(\{.+\})", feedback) if dict_match: try: parsed = ast.literal_eval(dict_match.group(1)) except (SyntaxError, ValueError): pass else: if isinstance(parsed, dict): return {k: v for k, v in parsed.items() if k not in ("step", "phase")} kv_match = re.search(r"Reward:\s*(.+)", feedback) if kv_match: return _parse_key_value_components(kv_match.group(1)) return {} def _parse_key_value_components(text: str) -> dict[str, Any]: components: dict[str, Any] = {} for part in text.split(","): if "=" not in part: continue key, value = part.strip().split("=", 1) try: components[key.strip()] = float(value.strip()) except ValueError: components[key.strip()] = value.strip() return components def reward_components(obs_dict: dict[str, Any], feedback: str) -> dict[str, Any]: metadata = obs_dict.get("metadata") or {} components = { key: value for key, value in metadata.items() if key not in _NON_REWARD_METADATA_KEYS and isinstance(value, (int, float)) and not isinstance(value, bool) } phase = metadata.get("phase") or obs_dict.get("phase") allowed = _VISIBLE_REWARD_COMPONENTS.get(str(phase)) if allowed: visible = {key: components[key] for key in allowed if key in components} if visible: return visible return components or parse_reward_components(feedback) def to_obs_dict(obs: Any) -> dict[str, Any]: return obs.model_dump() if hasattr(obs, "model_dump") else vars(obs) def fmt_log(log_entries: list[str]) -> str: if not log_entries: return "*No events yet.*" return "```text\n" + "\n".join(log_entries) + "\n```" def obs_summary(obs: dict[str, Any]) -> str: return ( f"**Topic:** {obs.get('topic', '')}\n" f"**Tier:** {obs.get('tier', '')}\n" f"**Phase:** {obs.get('phase', '')}\n" f"**Explore steps left:** {obs.get('explore_steps_left', 0)}\n" f"**Keywords:** {obs.get('keywords', '')}\n" f"**Data available:** {obs.get('data_available', False)}" ) def fenced_json(data: dict[str, Any]) -> str: return "```json\n" + json.dumps(data, indent=2, ensure_ascii=False) + "\n```" def format_explore_action_md(tool: str, query: str, intent: str) -> str: return fenced_json({"tool": tool, "query": query, "intent": intent}) def format_code_text(code: str) -> str: return code or "" def common_outputs( state: dict[str, Any], status: str = "", obs_md: str = "", feedback: str = "", search: str = "", ) -> tuple[dict[str, Any], str, str, str, str, str, str, list[list[Any]], str, Any]: return ( state, fmt_log(state["log"]), obs_md, feedback, state.get("generated_response", ""), state.get("parsed_response", ""), search, build_top_chunks_df(state.get("top_chunks", [])), build_reward_summary(state["reward_details"]), build_reward_matrix(state["reward_details"]), ) def llm_outputs( state: dict[str, Any], status: str = "", obs_md: str = "", feedback: str = "", search: str = "", ) -> tuple[dict[str, Any], str, str, str, str, str, str, list[list[Any]], str, Any]: return common_outputs(state, status=status, obs_md=obs_md, feedback=feedback, search=search) async def do_reset(task_label, state): """Reset the environment and start a new episode.""" old_sid = state.get("session_id", "") if old_sid: await SESSION_MGR.close(old_sid) state = empty_state() sid = state["session_id"] env_url = _resolve_env_url() # Build reset kwargs — pass topic if a specific task was selected reset_kwargs: dict[str, Any] = {} topic = _TASK_LABEL_TO_TOPIC.get(task_label) if topic: reset_kwargs["topic"] = topic try: env = await SESSION_MGR.get_or_create(sid, env_url) result = await env.reset(**reset_kwargs) except Exception as e: state["log"].append(f"[ERROR] Connection/reset failed: {e}") return common_outputs(state, status=f"Error: {e}") obs = result.observation obs_dict = to_obs_dict(obs) state["obs"] = obs_dict state["phase"] = obs.phase state["topic"] = obs.topic state["tier"] = obs.tier state["keywords"] = obs.keywords state["content"] = obs.content state["data_available"] = obs.data_available state["generated_response"] = "" state["parsed_response"] = "" state["last_code"] = "" state["top_chunks"] = [] state["log"].append(f"[START] topic={obs.topic} tier={obs.tier} phase={obs.phase}") status = f"Reset OK — assigned: {obs.topic} [{obs.tier}]" return common_outputs( state, status=status, obs_md=obs_summary(obs_dict), feedback=obs.feedback, ) async def do_explore(tool, query, intent, state): """Execute an explore step.""" if state.get("done"): state["log"].append("[WARN] Episode already done.") return common_outputs(state, status="Episode already done.", feedback="Episode already done.") if not query.strip(): return common_outputs(state, status="Empty query — nothing sent.") sid = state.get("session_id", "") env_url = _resolve_env_url() try: env = await SESSION_MGR.get_or_create(sid, env_url) except Exception as e: state["log"].append(f"[ERROR] Connection failed: {e}") return common_outputs(state, status=f"Error: {e}") action = ExplainerAction( action_type="explore", tool=tool, query=query.strip(), intent=intent.strip(), ) result = await env.step(action) obs = result.observation reward = result.reward or 0.0 obs_dict = to_obs_dict(obs) state["step"] += 1 state["rewards"].append(reward) state["obs"] = obs_dict state["phase"] = obs.phase state["done"] = result.done state["explored_context"] = obs.explored_context state["parsed_response"] = format_explore_action_md(tool, query.strip(), intent.strip()) state["top_chunks"] = extract_top_chunks(obs_dict, obs.search_results) components = reward_components(obs_dict, obs.feedback) state["reward_details"].append({ "step": state["step"], "phase": "explore", "components": components, }) state["log"].append( f'[STEP] step={state["step"]} action=explore:{tool}:"{query[:60]}" reward={reward:.3f} done={result.done}' ) status = f"Step {state['step']} explore — reward: {reward:.3f}" return common_outputs( state, status=status, obs_md=obs_summary(obs_dict), feedback=obs.feedback, search=obs.search_results, ) async def do_generate(fmt, code, narration, state): """Execute a generate step.""" if state.get("done"): state["log"].append("[WARN] Episode already done.") return common_outputs(state, status="Episode already done.", feedback="Episode already done.") sid = state.get("session_id", "") env_url = _resolve_env_url() try: env = await SESSION_MGR.get_or_create(sid, env_url) except Exception as e: state["log"].append(f"[ERROR] Connection failed: {e}") return common_outputs(state, status=f"Error: {e}") action_type = "repair" if state.get("phase") == "repair" else "generate" action = ExplainerAction( action_type=action_type, format=fmt, code=code, narration=narration, ) result = await env.step(action) obs = result.observation reward = result.reward or 0.0 obs_dict = to_obs_dict(obs) state["step"] += 1 state["rewards"].append(reward) state["obs"] = obs_dict state["phase"] = obs.phase state["done"] = result.done state["last_code"] = code state["last_format"] = fmt state["generated_response"] = format_code_text(code) state["parsed_response"] = fenced_json({ "action_type": action_type, "format": fmt, "code_len": len(code), "narration_len": len(narration or ""), }) components = reward_components(obs_dict, obs.feedback) state["reward_details"].append({ "step": state["step"], "phase": action_type, "components": components, }) total_score = normalized_episode_score(sum(state["rewards"])) state["log"].append( f"[STEP] step={state['step']} action={action_type}:{fmt} reward={reward:.3f} done={result.done}" ) state["log"].append( f"[END] success={total_score >= SUCCESS_SCORE_THRESHOLD} steps={state['step']} " f"score={total_score:.3f} rewards={','.join(f'{r:.2f}' for r in state['rewards'])}" ) status = f"Episode done — score: {total_score:.3f} (generate reward: {reward:.3f})" return common_outputs( state, status=status, obs_md=obs_summary(obs_dict), feedback=obs.feedback, ) def _llm_error_outputs(state: dict[str, Any], message: str): state["log"].append(f"[ERROR] {message}") state["parsed_response"] = f"**LLM error:** {message}" return llm_outputs( state, obs_md=obs_summary(state.get("obs") or {}) if state.get("obs") else "", feedback=(state.get("obs") or {}).get("feedback", ""), ) async def do_llm_step(state): """Let the LLM take the next step (explore or generate).""" if state.get("done"): state["log"].append("[WARN] Episode already done.") return llm_outputs( state, feedback="Episode already done.", ) from openai import OpenAI api_url, api_key, model = _resolve_llm() if not api_url: return _llm_error_outputs(state, "API_URL is not configured.") if not api_key: return _llm_error_outputs(state, "HF_TOKEN or API_KEY is not configured.") if not model: return _llm_error_outputs(state, "MODEL_NAME is not configured.") client = OpenAI(base_url=api_url, api_key=api_key, timeout=60.0) obs_data = state.get("obs", {}) phase = state.get("phase", "explore") llm_response = "" if phase == "explore": prompt = build_explore_prompt( topic=state["topic"], content=state["content"], tier=state["tier"], keywords=state["keywords"], step=state["step"] + 1, steps_left=obs_data.get("explore_steps_left", 0), explored_context=state.get("explored_context", ""), feedback=obs_data.get("feedback", ""), ) try: llm_response = call_llm_or_raise(client, prompt, model=model, max_tokens=256) except Exception as exc: return _llm_error_outputs(state, _format_llm_exception(exc, api_url, model)) if not llm_response: return _llm_error_outputs( state, f"LLM call failed or returned an empty response from {api_url} using model {model}.", ) if llm_response.strip().upper() == "SKIP": state["log"].append("[LLM] Decided to skip exploration. Moving to generate.") state["phase"] = "generate" state["generated_response"] = llm_response state["parsed_response"] = "`SKIP`" return llm_outputs( state, obs_md=obs_summary(obs_data), feedback=obs_data.get("feedback", ""), ) tool, query, intent = parse_explore_response(llm_response, state["topic"]) state["generated_response"] = llm_response state["parsed_response"] = format_explore_action_md(tool, query, intent) state["log"].append(f'[LLM] Explore tool={tool} query="{query[:80]}"') ( s, log, obs_md, feedback, generated_response, parsed_response, search, top_chunks, reward_summary, rewards_table, ) = await do_explore( tool, query, intent, state, ) return ( s, log, obs_md, feedback, generated_response, parsed_response, search, top_chunks, reward_summary, rewards_table, ) elif phase in ("generate", "repair", "done"): if phase == "repair": prompt = build_repair_prompt( topic=state["topic"], tier=state["tier"], fmt=state.get("last_format", "marimo"), previous_code=state.get("last_code", ""), last_errors=obs_data.get("last_errors", ""), ) else: prompt = build_generate_prompt( topic=state["topic"], content=state["content"], tier=state["tier"], keywords=state["keywords"], data_available=state.get("data_available", False), explored_context=state.get("explored_context", ""), ) try: llm_response = call_llm_or_raise(client, prompt, model=model, max_tokens=4096) except Exception as exc: return _llm_error_outputs(state, _format_llm_exception(exc, api_url, model)) if not llm_response: return _llm_error_outputs( state, f"LLM call failed or returned an empty response from {api_url} using model {model}.", ) fmt, code, narration = parse_generate_response(llm_response) state["generated_response"] = format_code_text(code) state["parsed_response"] = fenced_json({ "format": fmt, "code_len": len(code), "narration_len": len(narration), }) state["log"].append(f"[LLM] Generate: format={fmt}, code_len={len(code)}") ( s, log, obs_md, feedback, generated_response, parsed_response, search, top_chunks, reward_summary, rewards_table, ) = await do_generate( fmt, code, narration, state, ) return ( s, log, obs_md, feedback, generated_response, parsed_response, search, top_chunks, reward_summary, rewards_table, ) return llm_outputs(state) async def do_llm_auto(state): """Run full episode automatically with LLM (explore + generate).""" outputs = None while not state.get("done"): outputs = await do_llm_step(state) state = outputs[0] if state.get("log") and str(state["log"][-1]).startswith("[ERROR]"): break return outputs if outputs else llm_outputs(state, status="No steps taken.") # --------------------------------------------------------------------------- # Gradio UI # --------------------------------------------------------------------------- def build_ui(): with gr.Blocks(title="Explainer Env — Interactive Runner") as demo: session_state = gr.State(empty_state()) # Header gr.Markdown("# Explainer Episode Inspector") # ===================================================================== # Controls # ===================================================================== with gr.Row(equal_height=True): task_dd = gr.Dropdown( choices=TASK_CHOICES, value="(random)", label="Task", scale=1, ) with gr.Row(equal_height=True): reset_btn = gr.Button("Reset Episode", variant="primary") llm_step_btn = gr.Button("Next Step", variant="secondary") llm_auto_btn = gr.Button("Auto Run", variant="primary") # ===================================================================== # Inspector panels # ===================================================================== with gr.Row(equal_height=False): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Observation") obs_md = gr.Markdown("*Click Reset Episode to begin.*") feedback_box = gr.Textbox( label="Latest feedback", lines=8, max_lines=8, interactive=False, ) with gr.Column(scale=1): with gr.Group(): gr.Markdown("### LLM") with gr.Tabs(): with gr.Tab("Parsed"): parsed_response_box = gr.Markdown("*No parsed response yet.*") with gr.Tab("Response / code"): generated_response_box = gr.Textbox( label="Raw response or generated code", value="No response yet.", lines=16, max_lines=16, interactive=False, buttons=["copy"], ) with gr.Row(equal_height=False): with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Research") search_box = gr.Textbox( label="Latest search results", lines=8, max_lines=8, interactive=False, ) top_chunks_table = gr.Dataframe( headers=["Rank", "Source", "Title", "Score", "URL", "Snippet"], interactive=False, column_count=(6, "fixed"), label="Top chunks", ) with gr.Column(scale=1): with gr.Group(): gr.Markdown("### Rewards") reward_summary = gr.Markdown("*No rewards yet.*") rewards_table = gr.Dataframe( headers=["Reward"], interactive=False, column_count=(1, "fixed"), label="Reward matrix", ) # ===================================================================== # Timeline # ===================================================================== with gr.Group(): gr.Markdown("### Timeline") log_box = gr.Markdown("*No events yet.*") # ===================================================================== # Wiring # ===================================================================== # Common outputs: state, log, obs, feedback, search, rewards common_output_components = [ session_state, log_box, obs_md, feedback_box, generated_response_box, parsed_response_box, search_box, top_chunks_table, reward_summary, rewards_table, ] reset_btn.click( fn=do_reset, inputs=[task_dd, session_state], outputs=common_output_components, ) llm_step_btn.click( fn=do_llm_step, inputs=[session_state], outputs=common_output_components, ) llm_auto_btn.click( fn=do_llm_auto, inputs=[session_state], outputs=common_output_components, ) return demo if __name__ == "__main__": demo = build_ui() demo.launch(server_name="0.0.0.0", server_port=7860)