explainer-env / dashboard.py
kgdrathan's picture
Upload folder using huggingface_hub
ac7572a verified
"""
Gradio UI for the Research → Interactive Explainer Environment.
Two modes:
1. LLM Mode: LLM drives exploration + generation, human watches step-by-step
2. Human Mode: human types queries and code, sees rewards in real-time
Environment service is the same OpenEnv server that hosts this UI.
LLM configuration is resolved from API_URL, HF_TOKEN/API_KEY, and MODEL_NAME.
"""
import ast
import json
import os
import re
import uuid
from pathlib import Path
from typing import Any
import gradio as gr
from dotenv import load_dotenv
# Load .env from project root
PROJECT_ROOT = Path(__file__).parent
load_dotenv(PROJECT_ROOT / ".env")
try:
from .client import ExplainerEnv
from .constants import SUCCESS_SCORE_THRESHOLD, normalized_episode_score
from .dashboard_prompts import (
SYSTEM_PROMPT,
build_explore_prompt,
build_generate_prompt,
build_repair_prompt,
parse_explore_response,
parse_generate_response,
)
from .models import ExplainerAction
from .task_bank import ALL_TASKS
except ImportError: # pragma: no cover - supports direct execution from env root
from client import ExplainerEnv
from constants import SUCCESS_SCORE_THRESHOLD, normalized_episode_score
from dashboard_prompts import (
SYSTEM_PROMPT,
build_explore_prompt,
build_generate_prompt,
build_repair_prompt,
parse_explore_response,
parse_generate_response,
)
from models import ExplainerAction
from task_bank import ALL_TASKS
SELF_ENV_BASE_URL = f"http://127.0.0.1:{os.getenv('PORT', '8000')}"
DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b"
# ---------------------------------------------------------------------------
# Task catalog (reference only)
# ---------------------------------------------------------------------------
TASK_CHOICES = ["(random)"] + [f"{t.topic} [{t.difficulty}, {t.tier}]" for t in ALL_TASKS]
# Map dropdown label -> topic name for reset(topic=...)
_TASK_LABEL_TO_TOPIC: dict[str, str] = {f"{t.topic} [{t.difficulty}, {t.tier}]": t.topic for t in ALL_TASKS}
# ---------------------------------------------------------------------------
# Session manager
# ---------------------------------------------------------------------------
class SessionManager:
"""Module-level registry mapping session_id -> connected ExplainerEnv client."""
def __init__(self):
self._clients: dict[str, ExplainerEnv] = {}
self._urls: dict[str, str] = {}
async def get_or_create(self, session_id: str, base_url: str) -> ExplainerEnv:
if session_id in self._clients and self._urls.get(session_id) != base_url:
await self.close(session_id)
if session_id not in self._clients:
client = ExplainerEnv(base_url=base_url.rstrip("/"))
await client.connect()
self._clients[session_id] = client
self._urls[session_id] = base_url
return self._clients[session_id]
async def close(self, session_id: str) -> None:
client = self._clients.pop(session_id, None)
self._urls.pop(session_id, None)
if client:
try:
await client.disconnect()
except Exception:
pass
SESSION_MGR = SessionManager()
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _resolve_env_url() -> str:
return SELF_ENV_BASE_URL
def _resolve_llm() -> tuple[str, str, str]:
api_url = (os.getenv("API_URL") or os.getenv("API_BASE_URL") or "").rstrip("/")
api_key = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
model = os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME
return api_url, api_key, model
def call_llm_or_raise(client: Any, user_prompt: str, *, model: str, max_tokens: int) -> str:
"""Call the LLM and preserve provider errors for the dashboard."""
completion = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
],
temperature=0.7,
max_tokens=max_tokens,
stream=False,
)
return (completion.choices[0].message.content or "").strip()
def _format_llm_exception(exc: Exception, api_url: str, model: str) -> str:
cause = getattr(exc, "__cause__", None)
detail = str(cause or exc).strip() or exc.__class__.__name__
return f"{exc.__class__.__name__} from {api_url} using model {model}: {detail}"
def empty_state() -> dict[str, Any]:
return {
"session_id": str(uuid.uuid4()),
"obs": None,
"step": 0,
"rewards": [],
"reward_details": [],
"log": [],
"done": False,
"phase": "not_started",
"explored_context": "",
"topic": "",
"tier": "",
"keywords": "",
"content": "",
"data_available": False,
"last_code": "",
"last_format": "marimo",
"generated_response": "",
"parsed_response": "",
"top_chunks": [],
}
def build_reward_matrix(reward_details: list[dict[str, Any]]) -> gr.update:
"""Build a reward matrix with reward names as rows and steps as columns."""
steps = sorted({entry["step"] for entry in reward_details})
reward_names: list[str] = []
cells: dict[tuple[str, int], Any] = {}
for entry in reward_details:
step = entry["step"]
components = entry.get("components", {})
if not components:
components = {"total": ""}
for name, value in components.items():
if name not in reward_names:
reward_names.append(name)
cells[(name, step)] = value
headers = ["Reward"] + [f"Step {step}" for step in steps]
rows = []
for name in reward_names:
row = [name]
for step in steps:
value = cells.get((name, step), "")
row.append(_fmt_component(value) if value != "" else "")
rows.append(row)
return gr.update(
headers=headers,
value=rows,
column_count=(len(headers), "fixed"),
)
def build_reward_summary(reward_details: list[dict[str, Any]]) -> str:
if not reward_details:
return "*No rewards yet.*"
sections = []
for entry in reward_details:
components = entry.get("components", {})
total = _first_present(
components,
("explore_total", "generate_total", "repair_total"),
default="n/a",
)
sections.append(f"**Step {entry['step']} · {entry['phase']} · total {_fmt_component(total)}**")
return "\n\n".join(sections)
def build_top_chunks_df(chunks: list[dict[str, Any]]) -> list[list[Any]]:
rows = []
for chunk in chunks[:5]:
rows.append([
chunk.get("rank", ""),
chunk.get("source", ""),
chunk.get("title", ""),
chunk.get("score", ""),
chunk.get("url", ""),
_trim_display_text(str(chunk.get("snippet", "")), 700),
])
return rows
def extract_top_chunks(obs_dict: dict[str, Any], search_results: str) -> list[dict[str, Any]]:
metadata = obs_dict.get("metadata") or {}
chunks = obs_dict.get("top_chunks") or metadata.get("top_chunks") or []
return chunks or parse_rendered_chunks(search_results)
def parse_rendered_chunks(search_results: str) -> list[dict[str, Any]]:
"""Fallback parser for rendered research results if structured fields are absent."""
chunks = []
for part in re.split(r"\n\n---\n\n", search_results or ""):
lines = [line for line in part.splitlines() if line.strip()]
if not lines:
continue
match = re.match(r"\[(\d+)\]\s+([^:]+):\s+(.+)", lines[0])
if not match:
continue
url = ""
body_start = 1
if len(lines) > 1 and lines[1].startswith("URL:"):
url = lines[1].removeprefix("URL:").strip()
body_start = 2
chunks.append({
"rank": int(match.group(1)),
"source": match.group(2).strip(),
"title": match.group(3).strip(),
"url": url,
"score": "",
"snippet": "\n".join(lines[body_start:]).strip(),
})
return chunks[:5]
def _trim_display_text(text: str, max_chars: int) -> str:
text = re.sub(r"\s+", " ", text).strip()
return text if len(text) <= max_chars else text[:max_chars].rstrip() + "..."
def _first_present(mapping: dict[str, Any], keys: tuple[str, ...], default: Any = None) -> Any:
for key in keys:
if key in mapping:
return mapping[key]
return default
def _fmt_component(value: Any) -> str:
return f"{value:.3f}" if isinstance(value, float) else str(value)
_NON_REWARD_METADATA_KEYS = frozenset({
"step",
"phase",
"tool",
"source_count",
"error",
"explore_steps_used",
"repair_steps_used",
"sandbox_message",
"error_codes",
})
_VISIBLE_REWARD_COMPONENTS = {
"explore": (
"query_quality",
"evidence_quality",
"information_gain",
"efficiency",
"explore_total",
),
"generate": (
"validity",
"task_alignment",
"structure",
"research_usage",
"generate_total",
),
"repair": (
"repair_success",
"fixed_prior_errors",
"changed_code",
"repair_total",
),
}
def parse_reward_components(feedback: str) -> dict[str, Any]:
"""Fallback parser for older observations that lack reward metadata."""
dict_match = re.search(r"Reward:\s*(\{.+\})", feedback)
if dict_match:
try:
parsed = ast.literal_eval(dict_match.group(1))
except (SyntaxError, ValueError):
pass
else:
if isinstance(parsed, dict):
return {k: v for k, v in parsed.items() if k not in ("step", "phase")}
kv_match = re.search(r"Reward:\s*(.+)", feedback)
if kv_match:
return _parse_key_value_components(kv_match.group(1))
return {}
def _parse_key_value_components(text: str) -> dict[str, Any]:
components: dict[str, Any] = {}
for part in text.split(","):
if "=" not in part:
continue
key, value = part.strip().split("=", 1)
try:
components[key.strip()] = float(value.strip())
except ValueError:
components[key.strip()] = value.strip()
return components
def reward_components(obs_dict: dict[str, Any], feedback: str) -> dict[str, Any]:
metadata = obs_dict.get("metadata") or {}
components = {
key: value
for key, value in metadata.items()
if key not in _NON_REWARD_METADATA_KEYS and isinstance(value, (int, float)) and not isinstance(value, bool)
}
phase = metadata.get("phase") or obs_dict.get("phase")
allowed = _VISIBLE_REWARD_COMPONENTS.get(str(phase))
if allowed:
visible = {key: components[key] for key in allowed if key in components}
if visible:
return visible
return components or parse_reward_components(feedback)
def to_obs_dict(obs: Any) -> dict[str, Any]:
return obs.model_dump() if hasattr(obs, "model_dump") else vars(obs)
def fmt_log(log_entries: list[str]) -> str:
if not log_entries:
return "*No events yet.*"
return "```text\n" + "\n".join(log_entries) + "\n```"
def obs_summary(obs: dict[str, Any]) -> str:
return (
f"**Topic:** {obs.get('topic', '')}\n"
f"**Tier:** {obs.get('tier', '')}\n"
f"**Phase:** {obs.get('phase', '')}\n"
f"**Explore steps left:** {obs.get('explore_steps_left', 0)}\n"
f"**Keywords:** {obs.get('keywords', '')}\n"
f"**Data available:** {obs.get('data_available', False)}"
)
def fenced_json(data: dict[str, Any]) -> str:
return "```json\n" + json.dumps(data, indent=2, ensure_ascii=False) + "\n```"
def format_explore_action_md(tool: str, query: str, intent: str) -> str:
return fenced_json({"tool": tool, "query": query, "intent": intent})
def format_code_text(code: str) -> str:
return code or ""
def common_outputs(
state: dict[str, Any],
status: str = "",
obs_md: str = "",
feedback: str = "",
search: str = "",
) -> tuple[dict[str, Any], str, str, str, str, str, str, list[list[Any]], str, Any]:
return (
state,
fmt_log(state["log"]),
obs_md,
feedback,
state.get("generated_response", ""),
state.get("parsed_response", ""),
search,
build_top_chunks_df(state.get("top_chunks", [])),
build_reward_summary(state["reward_details"]),
build_reward_matrix(state["reward_details"]),
)
def llm_outputs(
state: dict[str, Any],
status: str = "",
obs_md: str = "",
feedback: str = "",
search: str = "",
) -> tuple[dict[str, Any], str, str, str, str, str, str, list[list[Any]], str, Any]:
return common_outputs(state, status=status, obs_md=obs_md, feedback=feedback, search=search)
async def do_reset(task_label, state):
"""Reset the environment and start a new episode."""
old_sid = state.get("session_id", "")
if old_sid:
await SESSION_MGR.close(old_sid)
state = empty_state()
sid = state["session_id"]
env_url = _resolve_env_url()
# Build reset kwargs — pass topic if a specific task was selected
reset_kwargs: dict[str, Any] = {}
topic = _TASK_LABEL_TO_TOPIC.get(task_label)
if topic:
reset_kwargs["topic"] = topic
try:
env = await SESSION_MGR.get_or_create(sid, env_url)
result = await env.reset(**reset_kwargs)
except Exception as e:
state["log"].append(f"[ERROR] Connection/reset failed: {e}")
return common_outputs(state, status=f"Error: {e}")
obs = result.observation
obs_dict = to_obs_dict(obs)
state["obs"] = obs_dict
state["phase"] = obs.phase
state["topic"] = obs.topic
state["tier"] = obs.tier
state["keywords"] = obs.keywords
state["content"] = obs.content
state["data_available"] = obs.data_available
state["generated_response"] = ""
state["parsed_response"] = ""
state["last_code"] = ""
state["top_chunks"] = []
state["log"].append(f"[START] topic={obs.topic} tier={obs.tier} phase={obs.phase}")
status = f"Reset OK — assigned: {obs.topic} [{obs.tier}]"
return common_outputs(
state,
status=status,
obs_md=obs_summary(obs_dict),
feedback=obs.feedback,
)
async def do_explore(tool, query, intent, state):
"""Execute an explore step."""
if state.get("done"):
state["log"].append("[WARN] Episode already done.")
return common_outputs(state, status="Episode already done.", feedback="Episode already done.")
if not query.strip():
return common_outputs(state, status="Empty query — nothing sent.")
sid = state.get("session_id", "")
env_url = _resolve_env_url()
try:
env = await SESSION_MGR.get_or_create(sid, env_url)
except Exception as e:
state["log"].append(f"[ERROR] Connection failed: {e}")
return common_outputs(state, status=f"Error: {e}")
action = ExplainerAction(
action_type="explore",
tool=tool,
query=query.strip(),
intent=intent.strip(),
)
result = await env.step(action)
obs = result.observation
reward = result.reward or 0.0
obs_dict = to_obs_dict(obs)
state["step"] += 1
state["rewards"].append(reward)
state["obs"] = obs_dict
state["phase"] = obs.phase
state["done"] = result.done
state["explored_context"] = obs.explored_context
state["parsed_response"] = format_explore_action_md(tool, query.strip(), intent.strip())
state["top_chunks"] = extract_top_chunks(obs_dict, obs.search_results)
components = reward_components(obs_dict, obs.feedback)
state["reward_details"].append({
"step": state["step"],
"phase": "explore",
"components": components,
})
state["log"].append(
f'[STEP] step={state["step"]} action=explore:{tool}:"{query[:60]}" reward={reward:.3f} done={result.done}'
)
status = f"Step {state['step']} explore — reward: {reward:.3f}"
return common_outputs(
state,
status=status,
obs_md=obs_summary(obs_dict),
feedback=obs.feedback,
search=obs.search_results,
)
async def do_generate(fmt, code, narration, state):
"""Execute a generate step."""
if state.get("done"):
state["log"].append("[WARN] Episode already done.")
return common_outputs(state, status="Episode already done.", feedback="Episode already done.")
sid = state.get("session_id", "")
env_url = _resolve_env_url()
try:
env = await SESSION_MGR.get_or_create(sid, env_url)
except Exception as e:
state["log"].append(f"[ERROR] Connection failed: {e}")
return common_outputs(state, status=f"Error: {e}")
action_type = "repair" if state.get("phase") == "repair" else "generate"
action = ExplainerAction(
action_type=action_type,
format=fmt,
code=code,
narration=narration,
)
result = await env.step(action)
obs = result.observation
reward = result.reward or 0.0
obs_dict = to_obs_dict(obs)
state["step"] += 1
state["rewards"].append(reward)
state["obs"] = obs_dict
state["phase"] = obs.phase
state["done"] = result.done
state["last_code"] = code
state["last_format"] = fmt
state["generated_response"] = format_code_text(code)
state["parsed_response"] = fenced_json({
"action_type": action_type,
"format": fmt,
"code_len": len(code),
"narration_len": len(narration or ""),
})
components = reward_components(obs_dict, obs.feedback)
state["reward_details"].append({
"step": state["step"],
"phase": action_type,
"components": components,
})
total_score = normalized_episode_score(sum(state["rewards"]))
state["log"].append(
f"[STEP] step={state['step']} action={action_type}:{fmt} reward={reward:.3f} done={result.done}"
)
state["log"].append(
f"[END] success={total_score >= SUCCESS_SCORE_THRESHOLD} steps={state['step']} "
f"score={total_score:.3f} rewards={','.join(f'{r:.2f}' for r in state['rewards'])}"
)
status = f"Episode done — score: {total_score:.3f} (generate reward: {reward:.3f})"
return common_outputs(
state,
status=status,
obs_md=obs_summary(obs_dict),
feedback=obs.feedback,
)
def _llm_error_outputs(state: dict[str, Any], message: str):
state["log"].append(f"[ERROR] {message}")
state["parsed_response"] = f"**LLM error:** {message}"
return llm_outputs(
state,
obs_md=obs_summary(state.get("obs") or {}) if state.get("obs") else "",
feedback=(state.get("obs") or {}).get("feedback", ""),
)
async def do_llm_step(state):
"""Let the LLM take the next step (explore or generate)."""
if state.get("done"):
state["log"].append("[WARN] Episode already done.")
return llm_outputs(
state,
feedback="Episode already done.",
)
from openai import OpenAI
api_url, api_key, model = _resolve_llm()
if not api_url:
return _llm_error_outputs(state, "API_URL is not configured.")
if not api_key:
return _llm_error_outputs(state, "HF_TOKEN or API_KEY is not configured.")
if not model:
return _llm_error_outputs(state, "MODEL_NAME is not configured.")
client = OpenAI(base_url=api_url, api_key=api_key, timeout=60.0)
obs_data = state.get("obs", {})
phase = state.get("phase", "explore")
llm_response = ""
if phase == "explore":
prompt = build_explore_prompt(
topic=state["topic"],
content=state["content"],
tier=state["tier"],
keywords=state["keywords"],
step=state["step"] + 1,
steps_left=obs_data.get("explore_steps_left", 0),
explored_context=state.get("explored_context", ""),
feedback=obs_data.get("feedback", ""),
)
try:
llm_response = call_llm_or_raise(client, prompt, model=model, max_tokens=256)
except Exception as exc:
return _llm_error_outputs(state, _format_llm_exception(exc, api_url, model))
if not llm_response:
return _llm_error_outputs(
state,
f"LLM call failed or returned an empty response from {api_url} using model {model}.",
)
if llm_response.strip().upper() == "SKIP":
state["log"].append("[LLM] Decided to skip exploration. Moving to generate.")
state["phase"] = "generate"
state["generated_response"] = llm_response
state["parsed_response"] = "`SKIP`"
return llm_outputs(
state,
obs_md=obs_summary(obs_data),
feedback=obs_data.get("feedback", ""),
)
tool, query, intent = parse_explore_response(llm_response, state["topic"])
state["generated_response"] = llm_response
state["parsed_response"] = format_explore_action_md(tool, query, intent)
state["log"].append(f'[LLM] Explore tool={tool} query="{query[:80]}"')
(
s,
log,
obs_md,
feedback,
generated_response,
parsed_response,
search,
top_chunks,
reward_summary,
rewards_table,
) = await do_explore(
tool,
query,
intent,
state,
)
return (
s,
log,
obs_md,
feedback,
generated_response,
parsed_response,
search,
top_chunks,
reward_summary,
rewards_table,
)
elif phase in ("generate", "repair", "done"):
if phase == "repair":
prompt = build_repair_prompt(
topic=state["topic"],
tier=state["tier"],
fmt=state.get("last_format", "marimo"),
previous_code=state.get("last_code", ""),
last_errors=obs_data.get("last_errors", ""),
)
else:
prompt = build_generate_prompt(
topic=state["topic"],
content=state["content"],
tier=state["tier"],
keywords=state["keywords"],
data_available=state.get("data_available", False),
explored_context=state.get("explored_context", ""),
)
try:
llm_response = call_llm_or_raise(client, prompt, model=model, max_tokens=4096)
except Exception as exc:
return _llm_error_outputs(state, _format_llm_exception(exc, api_url, model))
if not llm_response:
return _llm_error_outputs(
state,
f"LLM call failed or returned an empty response from {api_url} using model {model}.",
)
fmt, code, narration = parse_generate_response(llm_response)
state["generated_response"] = format_code_text(code)
state["parsed_response"] = fenced_json({
"format": fmt,
"code_len": len(code),
"narration_len": len(narration),
})
state["log"].append(f"[LLM] Generate: format={fmt}, code_len={len(code)}")
(
s,
log,
obs_md,
feedback,
generated_response,
parsed_response,
search,
top_chunks,
reward_summary,
rewards_table,
) = await do_generate(
fmt,
code,
narration,
state,
)
return (
s,
log,
obs_md,
feedback,
generated_response,
parsed_response,
search,
top_chunks,
reward_summary,
rewards_table,
)
return llm_outputs(state)
async def do_llm_auto(state):
"""Run full episode automatically with LLM (explore + generate)."""
outputs = None
while not state.get("done"):
outputs = await do_llm_step(state)
state = outputs[0]
if state.get("log") and str(state["log"][-1]).startswith("[ERROR]"):
break
return outputs if outputs else llm_outputs(state, status="No steps taken.")
# ---------------------------------------------------------------------------
# Gradio UI
# ---------------------------------------------------------------------------
def build_ui():
with gr.Blocks(title="Explainer Env — Interactive Runner") as demo:
session_state = gr.State(empty_state())
# Header
gr.Markdown("# Explainer Episode Inspector")
# =====================================================================
# Controls
# =====================================================================
with gr.Row(equal_height=True):
task_dd = gr.Dropdown(
choices=TASK_CHOICES,
value="(random)",
label="Task",
scale=1,
)
with gr.Row(equal_height=True):
reset_btn = gr.Button("Reset Episode", variant="primary")
llm_step_btn = gr.Button("Next Step", variant="secondary")
llm_auto_btn = gr.Button("Auto Run", variant="primary")
# =====================================================================
# Inspector panels
# =====================================================================
with gr.Row(equal_height=False):
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Observation")
obs_md = gr.Markdown("*Click Reset Episode to begin.*")
feedback_box = gr.Textbox(
label="Latest feedback",
lines=8,
max_lines=8,
interactive=False,
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### LLM")
with gr.Tabs():
with gr.Tab("Parsed"):
parsed_response_box = gr.Markdown("*No parsed response yet.*")
with gr.Tab("Response / code"):
generated_response_box = gr.Textbox(
label="Raw response or generated code",
value="No response yet.",
lines=16,
max_lines=16,
interactive=False,
buttons=["copy"],
)
with gr.Row(equal_height=False):
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Research")
search_box = gr.Textbox(
label="Latest search results",
lines=8,
max_lines=8,
interactive=False,
)
top_chunks_table = gr.Dataframe(
headers=["Rank", "Source", "Title", "Score", "URL", "Snippet"],
interactive=False,
column_count=(6, "fixed"),
label="Top chunks",
)
with gr.Column(scale=1):
with gr.Group():
gr.Markdown("### Rewards")
reward_summary = gr.Markdown("*No rewards yet.*")
rewards_table = gr.Dataframe(
headers=["Reward"],
interactive=False,
column_count=(1, "fixed"),
label="Reward matrix",
)
# =====================================================================
# Timeline
# =====================================================================
with gr.Group():
gr.Markdown("### Timeline")
log_box = gr.Markdown("*No events yet.*")
# =====================================================================
# Wiring
# =====================================================================
# Common outputs: state, log, obs, feedback, search, rewards
common_output_components = [
session_state,
log_box,
obs_md,
feedback_box,
generated_response_box,
parsed_response_box,
search_box,
top_chunks_table,
reward_summary,
rewards_table,
]
reset_btn.click(
fn=do_reset,
inputs=[task_dd, session_state],
outputs=common_output_components,
)
llm_step_btn.click(
fn=do_llm_step,
inputs=[session_state],
outputs=common_output_components,
)
llm_auto_btn.click(
fn=do_llm_auto,
inputs=[session_state],
outputs=common_output_components,
)
return demo
if __name__ == "__main__":
demo = build_ui()
demo.launch(server_name="0.0.0.0", server_port=7860)