Spaces:
Running
Running
| """ | |
| Gradio UI for the Research → Interactive Explainer Environment. | |
| Two modes: | |
| 1. LLM Mode: LLM drives exploration + generation, human watches step-by-step | |
| 2. Human Mode: human types queries and code, sees rewards in real-time | |
| Environment service is the same OpenEnv server that hosts this UI. | |
| LLM configuration is resolved from API_URL, HF_TOKEN/API_KEY, and MODEL_NAME. | |
| """ | |
| import ast | |
| import json | |
| import os | |
| import re | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from dotenv import load_dotenv | |
| # Load .env from project root | |
| PROJECT_ROOT = Path(__file__).parent | |
| load_dotenv(PROJECT_ROOT / ".env") | |
| try: | |
| from .client import ExplainerEnv | |
| from .constants import SUCCESS_SCORE_THRESHOLD, normalized_episode_score | |
| from .dashboard_prompts import ( | |
| SYSTEM_PROMPT, | |
| build_explore_prompt, | |
| build_generate_prompt, | |
| build_repair_prompt, | |
| parse_explore_response, | |
| parse_generate_response, | |
| ) | |
| from .models import ExplainerAction | |
| from .task_bank import ALL_TASKS | |
| except ImportError: # pragma: no cover - supports direct execution from env root | |
| from client import ExplainerEnv | |
| from constants import SUCCESS_SCORE_THRESHOLD, normalized_episode_score | |
| from dashboard_prompts import ( | |
| SYSTEM_PROMPT, | |
| build_explore_prompt, | |
| build_generate_prompt, | |
| build_repair_prompt, | |
| parse_explore_response, | |
| parse_generate_response, | |
| ) | |
| from models import ExplainerAction | |
| from task_bank import ALL_TASKS | |
| SELF_ENV_BASE_URL = f"http://127.0.0.1:{os.getenv('PORT', '8000')}" | |
| DEFAULT_MODEL_NAME = "bedrock-qwen3-coder-30b-a3b" | |
| # --------------------------------------------------------------------------- | |
| # Task catalog (reference only) | |
| # --------------------------------------------------------------------------- | |
| TASK_CHOICES = ["(random)"] + [f"{t.topic} [{t.difficulty}, {t.tier}]" for t in ALL_TASKS] | |
| # Map dropdown label -> topic name for reset(topic=...) | |
| _TASK_LABEL_TO_TOPIC: dict[str, str] = {f"{t.topic} [{t.difficulty}, {t.tier}]": t.topic for t in ALL_TASKS} | |
| # --------------------------------------------------------------------------- | |
| # Session manager | |
| # --------------------------------------------------------------------------- | |
| class SessionManager: | |
| """Module-level registry mapping session_id -> connected ExplainerEnv client.""" | |
| def __init__(self): | |
| self._clients: dict[str, ExplainerEnv] = {} | |
| self._urls: dict[str, str] = {} | |
| async def get_or_create(self, session_id: str, base_url: str) -> ExplainerEnv: | |
| if session_id in self._clients and self._urls.get(session_id) != base_url: | |
| await self.close(session_id) | |
| if session_id not in self._clients: | |
| client = ExplainerEnv(base_url=base_url.rstrip("/")) | |
| await client.connect() | |
| self._clients[session_id] = client | |
| self._urls[session_id] = base_url | |
| return self._clients[session_id] | |
| async def close(self, session_id: str) -> None: | |
| client = self._clients.pop(session_id, None) | |
| self._urls.pop(session_id, None) | |
| if client: | |
| try: | |
| await client.disconnect() | |
| except Exception: | |
| pass | |
| SESSION_MGR = SessionManager() | |
| # --------------------------------------------------------------------------- | |
| # Helpers | |
| # --------------------------------------------------------------------------- | |
| def _resolve_env_url() -> str: | |
| return SELF_ENV_BASE_URL | |
| def _resolve_llm() -> tuple[str, str, str]: | |
| api_url = (os.getenv("API_URL") or os.getenv("API_BASE_URL") or "").rstrip("/") | |
| api_key = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| model = os.getenv("MODEL_NAME") or DEFAULT_MODEL_NAME | |
| return api_url, api_key, model | |
| def call_llm_or_raise(client: Any, user_prompt: str, *, model: str, max_tokens: int) -> str: | |
| """Call the LLM and preserve provider errors for the dashboard.""" | |
| completion = client.chat.completions.create( | |
| model=model, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| temperature=0.7, | |
| max_tokens=max_tokens, | |
| stream=False, | |
| ) | |
| return (completion.choices[0].message.content or "").strip() | |
| def _format_llm_exception(exc: Exception, api_url: str, model: str) -> str: | |
| cause = getattr(exc, "__cause__", None) | |
| detail = str(cause or exc).strip() or exc.__class__.__name__ | |
| return f"{exc.__class__.__name__} from {api_url} using model {model}: {detail}" | |
| def empty_state() -> dict[str, Any]: | |
| return { | |
| "session_id": str(uuid.uuid4()), | |
| "obs": None, | |
| "step": 0, | |
| "rewards": [], | |
| "reward_details": [], | |
| "log": [], | |
| "done": False, | |
| "phase": "not_started", | |
| "explored_context": "", | |
| "topic": "", | |
| "tier": "", | |
| "keywords": "", | |
| "content": "", | |
| "data_available": False, | |
| "last_code": "", | |
| "last_format": "marimo", | |
| "generated_response": "", | |
| "parsed_response": "", | |
| "top_chunks": [], | |
| } | |
| def build_reward_matrix(reward_details: list[dict[str, Any]]) -> gr.update: | |
| """Build a reward matrix with reward names as rows and steps as columns.""" | |
| steps = sorted({entry["step"] for entry in reward_details}) | |
| reward_names: list[str] = [] | |
| cells: dict[tuple[str, int], Any] = {} | |
| for entry in reward_details: | |
| step = entry["step"] | |
| components = entry.get("components", {}) | |
| if not components: | |
| components = {"total": ""} | |
| for name, value in components.items(): | |
| if name not in reward_names: | |
| reward_names.append(name) | |
| cells[(name, step)] = value | |
| headers = ["Reward"] + [f"Step {step}" for step in steps] | |
| rows = [] | |
| for name in reward_names: | |
| row = [name] | |
| for step in steps: | |
| value = cells.get((name, step), "") | |
| row.append(_fmt_component(value) if value != "" else "") | |
| rows.append(row) | |
| return gr.update( | |
| headers=headers, | |
| value=rows, | |
| column_count=(len(headers), "fixed"), | |
| ) | |
| def build_reward_summary(reward_details: list[dict[str, Any]]) -> str: | |
| if not reward_details: | |
| return "*No rewards yet.*" | |
| sections = [] | |
| for entry in reward_details: | |
| components = entry.get("components", {}) | |
| total = _first_present( | |
| components, | |
| ("explore_total", "generate_total", "repair_total"), | |
| default="n/a", | |
| ) | |
| sections.append(f"**Step {entry['step']} · {entry['phase']} · total {_fmt_component(total)}**") | |
| return "\n\n".join(sections) | |
| def build_top_chunks_df(chunks: list[dict[str, Any]]) -> list[list[Any]]: | |
| rows = [] | |
| for chunk in chunks[:5]: | |
| rows.append([ | |
| chunk.get("rank", ""), | |
| chunk.get("source", ""), | |
| chunk.get("title", ""), | |
| chunk.get("score", ""), | |
| chunk.get("url", ""), | |
| _trim_display_text(str(chunk.get("snippet", "")), 700), | |
| ]) | |
| return rows | |
| def extract_top_chunks(obs_dict: dict[str, Any], search_results: str) -> list[dict[str, Any]]: | |
| metadata = obs_dict.get("metadata") or {} | |
| chunks = obs_dict.get("top_chunks") or metadata.get("top_chunks") or [] | |
| return chunks or parse_rendered_chunks(search_results) | |
| def parse_rendered_chunks(search_results: str) -> list[dict[str, Any]]: | |
| """Fallback parser for rendered research results if structured fields are absent.""" | |
| chunks = [] | |
| for part in re.split(r"\n\n---\n\n", search_results or ""): | |
| lines = [line for line in part.splitlines() if line.strip()] | |
| if not lines: | |
| continue | |
| match = re.match(r"\[(\d+)\]\s+([^:]+):\s+(.+)", lines[0]) | |
| if not match: | |
| continue | |
| url = "" | |
| body_start = 1 | |
| if len(lines) > 1 and lines[1].startswith("URL:"): | |
| url = lines[1].removeprefix("URL:").strip() | |
| body_start = 2 | |
| chunks.append({ | |
| "rank": int(match.group(1)), | |
| "source": match.group(2).strip(), | |
| "title": match.group(3).strip(), | |
| "url": url, | |
| "score": "", | |
| "snippet": "\n".join(lines[body_start:]).strip(), | |
| }) | |
| return chunks[:5] | |
| def _trim_display_text(text: str, max_chars: int) -> str: | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text if len(text) <= max_chars else text[:max_chars].rstrip() + "..." | |
| def _first_present(mapping: dict[str, Any], keys: tuple[str, ...], default: Any = None) -> Any: | |
| for key in keys: | |
| if key in mapping: | |
| return mapping[key] | |
| return default | |
| def _fmt_component(value: Any) -> str: | |
| return f"{value:.3f}" if isinstance(value, float) else str(value) | |
| _NON_REWARD_METADATA_KEYS = frozenset({ | |
| "step", | |
| "phase", | |
| "tool", | |
| "source_count", | |
| "error", | |
| "explore_steps_used", | |
| "repair_steps_used", | |
| "sandbox_message", | |
| "error_codes", | |
| }) | |
| _VISIBLE_REWARD_COMPONENTS = { | |
| "explore": ( | |
| "query_quality", | |
| "evidence_quality", | |
| "information_gain", | |
| "efficiency", | |
| "explore_total", | |
| ), | |
| "generate": ( | |
| "validity", | |
| "task_alignment", | |
| "structure", | |
| "research_usage", | |
| "generate_total", | |
| ), | |
| "repair": ( | |
| "repair_success", | |
| "fixed_prior_errors", | |
| "changed_code", | |
| "repair_total", | |
| ), | |
| } | |
| def parse_reward_components(feedback: str) -> dict[str, Any]: | |
| """Fallback parser for older observations that lack reward metadata.""" | |
| dict_match = re.search(r"Reward:\s*(\{.+\})", feedback) | |
| if dict_match: | |
| try: | |
| parsed = ast.literal_eval(dict_match.group(1)) | |
| except (SyntaxError, ValueError): | |
| pass | |
| else: | |
| if isinstance(parsed, dict): | |
| return {k: v for k, v in parsed.items() if k not in ("step", "phase")} | |
| kv_match = re.search(r"Reward:\s*(.+)", feedback) | |
| if kv_match: | |
| return _parse_key_value_components(kv_match.group(1)) | |
| return {} | |
| def _parse_key_value_components(text: str) -> dict[str, Any]: | |
| components: dict[str, Any] = {} | |
| for part in text.split(","): | |
| if "=" not in part: | |
| continue | |
| key, value = part.strip().split("=", 1) | |
| try: | |
| components[key.strip()] = float(value.strip()) | |
| except ValueError: | |
| components[key.strip()] = value.strip() | |
| return components | |
| def reward_components(obs_dict: dict[str, Any], feedback: str) -> dict[str, Any]: | |
| metadata = obs_dict.get("metadata") or {} | |
| components = { | |
| key: value | |
| for key, value in metadata.items() | |
| if key not in _NON_REWARD_METADATA_KEYS and isinstance(value, (int, float)) and not isinstance(value, bool) | |
| } | |
| phase = metadata.get("phase") or obs_dict.get("phase") | |
| allowed = _VISIBLE_REWARD_COMPONENTS.get(str(phase)) | |
| if allowed: | |
| visible = {key: components[key] for key in allowed if key in components} | |
| if visible: | |
| return visible | |
| return components or parse_reward_components(feedback) | |
| def to_obs_dict(obs: Any) -> dict[str, Any]: | |
| return obs.model_dump() if hasattr(obs, "model_dump") else vars(obs) | |
| def fmt_log(log_entries: list[str]) -> str: | |
| if not log_entries: | |
| return "*No events yet.*" | |
| return "```text\n" + "\n".join(log_entries) + "\n```" | |
| def obs_summary(obs: dict[str, Any]) -> str: | |
| return ( | |
| f"**Topic:** {obs.get('topic', '')}\n" | |
| f"**Tier:** {obs.get('tier', '')}\n" | |
| f"**Phase:** {obs.get('phase', '')}\n" | |
| f"**Explore steps left:** {obs.get('explore_steps_left', 0)}\n" | |
| f"**Keywords:** {obs.get('keywords', '')}\n" | |
| f"**Data available:** {obs.get('data_available', False)}" | |
| ) | |
| def fenced_json(data: dict[str, Any]) -> str: | |
| return "```json\n" + json.dumps(data, indent=2, ensure_ascii=False) + "\n```" | |
| def format_explore_action_md(tool: str, query: str, intent: str) -> str: | |
| return fenced_json({"tool": tool, "query": query, "intent": intent}) | |
| def format_code_text(code: str) -> str: | |
| return code or "" | |
| def common_outputs( | |
| state: dict[str, Any], | |
| status: str = "", | |
| obs_md: str = "", | |
| feedback: str = "", | |
| search: str = "", | |
| ) -> tuple[dict[str, Any], str, str, str, str, str, str, list[list[Any]], str, Any]: | |
| return ( | |
| state, | |
| fmt_log(state["log"]), | |
| obs_md, | |
| feedback, | |
| state.get("generated_response", ""), | |
| state.get("parsed_response", ""), | |
| search, | |
| build_top_chunks_df(state.get("top_chunks", [])), | |
| build_reward_summary(state["reward_details"]), | |
| build_reward_matrix(state["reward_details"]), | |
| ) | |
| def llm_outputs( | |
| state: dict[str, Any], | |
| status: str = "", | |
| obs_md: str = "", | |
| feedback: str = "", | |
| search: str = "", | |
| ) -> tuple[dict[str, Any], str, str, str, str, str, str, list[list[Any]], str, Any]: | |
| return common_outputs(state, status=status, obs_md=obs_md, feedback=feedback, search=search) | |
| async def do_reset(task_label, state): | |
| """Reset the environment and start a new episode.""" | |
| old_sid = state.get("session_id", "") | |
| if old_sid: | |
| await SESSION_MGR.close(old_sid) | |
| state = empty_state() | |
| sid = state["session_id"] | |
| env_url = _resolve_env_url() | |
| # Build reset kwargs — pass topic if a specific task was selected | |
| reset_kwargs: dict[str, Any] = {} | |
| topic = _TASK_LABEL_TO_TOPIC.get(task_label) | |
| if topic: | |
| reset_kwargs["topic"] = topic | |
| try: | |
| env = await SESSION_MGR.get_or_create(sid, env_url) | |
| result = await env.reset(**reset_kwargs) | |
| except Exception as e: | |
| state["log"].append(f"[ERROR] Connection/reset failed: {e}") | |
| return common_outputs(state, status=f"Error: {e}") | |
| obs = result.observation | |
| obs_dict = to_obs_dict(obs) | |
| state["obs"] = obs_dict | |
| state["phase"] = obs.phase | |
| state["topic"] = obs.topic | |
| state["tier"] = obs.tier | |
| state["keywords"] = obs.keywords | |
| state["content"] = obs.content | |
| state["data_available"] = obs.data_available | |
| state["generated_response"] = "" | |
| state["parsed_response"] = "" | |
| state["last_code"] = "" | |
| state["top_chunks"] = [] | |
| state["log"].append(f"[START] topic={obs.topic} tier={obs.tier} phase={obs.phase}") | |
| status = f"Reset OK — assigned: {obs.topic} [{obs.tier}]" | |
| return common_outputs( | |
| state, | |
| status=status, | |
| obs_md=obs_summary(obs_dict), | |
| feedback=obs.feedback, | |
| ) | |
| async def do_explore(tool, query, intent, state): | |
| """Execute an explore step.""" | |
| if state.get("done"): | |
| state["log"].append("[WARN] Episode already done.") | |
| return common_outputs(state, status="Episode already done.", feedback="Episode already done.") | |
| if not query.strip(): | |
| return common_outputs(state, status="Empty query — nothing sent.") | |
| sid = state.get("session_id", "") | |
| env_url = _resolve_env_url() | |
| try: | |
| env = await SESSION_MGR.get_or_create(sid, env_url) | |
| except Exception as e: | |
| state["log"].append(f"[ERROR] Connection failed: {e}") | |
| return common_outputs(state, status=f"Error: {e}") | |
| action = ExplainerAction( | |
| action_type="explore", | |
| tool=tool, | |
| query=query.strip(), | |
| intent=intent.strip(), | |
| ) | |
| result = await env.step(action) | |
| obs = result.observation | |
| reward = result.reward or 0.0 | |
| obs_dict = to_obs_dict(obs) | |
| state["step"] += 1 | |
| state["rewards"].append(reward) | |
| state["obs"] = obs_dict | |
| state["phase"] = obs.phase | |
| state["done"] = result.done | |
| state["explored_context"] = obs.explored_context | |
| state["parsed_response"] = format_explore_action_md(tool, query.strip(), intent.strip()) | |
| state["top_chunks"] = extract_top_chunks(obs_dict, obs.search_results) | |
| components = reward_components(obs_dict, obs.feedback) | |
| state["reward_details"].append({ | |
| "step": state["step"], | |
| "phase": "explore", | |
| "components": components, | |
| }) | |
| state["log"].append( | |
| f'[STEP] step={state["step"]} action=explore:{tool}:"{query[:60]}" reward={reward:.3f} done={result.done}' | |
| ) | |
| status = f"Step {state['step']} explore — reward: {reward:.3f}" | |
| return common_outputs( | |
| state, | |
| status=status, | |
| obs_md=obs_summary(obs_dict), | |
| feedback=obs.feedback, | |
| search=obs.search_results, | |
| ) | |
| async def do_generate(fmt, code, narration, state): | |
| """Execute a generate step.""" | |
| if state.get("done"): | |
| state["log"].append("[WARN] Episode already done.") | |
| return common_outputs(state, status="Episode already done.", feedback="Episode already done.") | |
| sid = state.get("session_id", "") | |
| env_url = _resolve_env_url() | |
| try: | |
| env = await SESSION_MGR.get_or_create(sid, env_url) | |
| except Exception as e: | |
| state["log"].append(f"[ERROR] Connection failed: {e}") | |
| return common_outputs(state, status=f"Error: {e}") | |
| action_type = "repair" if state.get("phase") == "repair" else "generate" | |
| action = ExplainerAction( | |
| action_type=action_type, | |
| format=fmt, | |
| code=code, | |
| narration=narration, | |
| ) | |
| result = await env.step(action) | |
| obs = result.observation | |
| reward = result.reward or 0.0 | |
| obs_dict = to_obs_dict(obs) | |
| state["step"] += 1 | |
| state["rewards"].append(reward) | |
| state["obs"] = obs_dict | |
| state["phase"] = obs.phase | |
| state["done"] = result.done | |
| state["last_code"] = code | |
| state["last_format"] = fmt | |
| state["generated_response"] = format_code_text(code) | |
| state["parsed_response"] = fenced_json({ | |
| "action_type": action_type, | |
| "format": fmt, | |
| "code_len": len(code), | |
| "narration_len": len(narration or ""), | |
| }) | |
| components = reward_components(obs_dict, obs.feedback) | |
| state["reward_details"].append({ | |
| "step": state["step"], | |
| "phase": action_type, | |
| "components": components, | |
| }) | |
| total_score = normalized_episode_score(sum(state["rewards"])) | |
| state["log"].append( | |
| f"[STEP] step={state['step']} action={action_type}:{fmt} reward={reward:.3f} done={result.done}" | |
| ) | |
| state["log"].append( | |
| f"[END] success={total_score >= SUCCESS_SCORE_THRESHOLD} steps={state['step']} " | |
| f"score={total_score:.3f} rewards={','.join(f'{r:.2f}' for r in state['rewards'])}" | |
| ) | |
| status = f"Episode done — score: {total_score:.3f} (generate reward: {reward:.3f})" | |
| return common_outputs( | |
| state, | |
| status=status, | |
| obs_md=obs_summary(obs_dict), | |
| feedback=obs.feedback, | |
| ) | |
| def _llm_error_outputs(state: dict[str, Any], message: str): | |
| state["log"].append(f"[ERROR] {message}") | |
| state["parsed_response"] = f"**LLM error:** {message}" | |
| return llm_outputs( | |
| state, | |
| obs_md=obs_summary(state.get("obs") or {}) if state.get("obs") else "", | |
| feedback=(state.get("obs") or {}).get("feedback", ""), | |
| ) | |
| async def do_llm_step(state): | |
| """Let the LLM take the next step (explore or generate).""" | |
| if state.get("done"): | |
| state["log"].append("[WARN] Episode already done.") | |
| return llm_outputs( | |
| state, | |
| feedback="Episode already done.", | |
| ) | |
| from openai import OpenAI | |
| api_url, api_key, model = _resolve_llm() | |
| if not api_url: | |
| return _llm_error_outputs(state, "API_URL is not configured.") | |
| if not api_key: | |
| return _llm_error_outputs(state, "HF_TOKEN or API_KEY is not configured.") | |
| if not model: | |
| return _llm_error_outputs(state, "MODEL_NAME is not configured.") | |
| client = OpenAI(base_url=api_url, api_key=api_key, timeout=60.0) | |
| obs_data = state.get("obs", {}) | |
| phase = state.get("phase", "explore") | |
| llm_response = "" | |
| if phase == "explore": | |
| prompt = build_explore_prompt( | |
| topic=state["topic"], | |
| content=state["content"], | |
| tier=state["tier"], | |
| keywords=state["keywords"], | |
| step=state["step"] + 1, | |
| steps_left=obs_data.get("explore_steps_left", 0), | |
| explored_context=state.get("explored_context", ""), | |
| feedback=obs_data.get("feedback", ""), | |
| ) | |
| try: | |
| llm_response = call_llm_or_raise(client, prompt, model=model, max_tokens=256) | |
| except Exception as exc: | |
| return _llm_error_outputs(state, _format_llm_exception(exc, api_url, model)) | |
| if not llm_response: | |
| return _llm_error_outputs( | |
| state, | |
| f"LLM call failed or returned an empty response from {api_url} using model {model}.", | |
| ) | |
| if llm_response.strip().upper() == "SKIP": | |
| state["log"].append("[LLM] Decided to skip exploration. Moving to generate.") | |
| state["phase"] = "generate" | |
| state["generated_response"] = llm_response | |
| state["parsed_response"] = "`SKIP`" | |
| return llm_outputs( | |
| state, | |
| obs_md=obs_summary(obs_data), | |
| feedback=obs_data.get("feedback", ""), | |
| ) | |
| tool, query, intent = parse_explore_response(llm_response, state["topic"]) | |
| state["generated_response"] = llm_response | |
| state["parsed_response"] = format_explore_action_md(tool, query, intent) | |
| state["log"].append(f'[LLM] Explore tool={tool} query="{query[:80]}"') | |
| ( | |
| s, | |
| log, | |
| obs_md, | |
| feedback, | |
| generated_response, | |
| parsed_response, | |
| search, | |
| top_chunks, | |
| reward_summary, | |
| rewards_table, | |
| ) = await do_explore( | |
| tool, | |
| query, | |
| intent, | |
| state, | |
| ) | |
| return ( | |
| s, | |
| log, | |
| obs_md, | |
| feedback, | |
| generated_response, | |
| parsed_response, | |
| search, | |
| top_chunks, | |
| reward_summary, | |
| rewards_table, | |
| ) | |
| elif phase in ("generate", "repair", "done"): | |
| if phase == "repair": | |
| prompt = build_repair_prompt( | |
| topic=state["topic"], | |
| tier=state["tier"], | |
| fmt=state.get("last_format", "marimo"), | |
| previous_code=state.get("last_code", ""), | |
| last_errors=obs_data.get("last_errors", ""), | |
| ) | |
| else: | |
| prompt = build_generate_prompt( | |
| topic=state["topic"], | |
| content=state["content"], | |
| tier=state["tier"], | |
| keywords=state["keywords"], | |
| data_available=state.get("data_available", False), | |
| explored_context=state.get("explored_context", ""), | |
| ) | |
| try: | |
| llm_response = call_llm_or_raise(client, prompt, model=model, max_tokens=4096) | |
| except Exception as exc: | |
| return _llm_error_outputs(state, _format_llm_exception(exc, api_url, model)) | |
| if not llm_response: | |
| return _llm_error_outputs( | |
| state, | |
| f"LLM call failed or returned an empty response from {api_url} using model {model}.", | |
| ) | |
| fmt, code, narration = parse_generate_response(llm_response) | |
| state["generated_response"] = format_code_text(code) | |
| state["parsed_response"] = fenced_json({ | |
| "format": fmt, | |
| "code_len": len(code), | |
| "narration_len": len(narration), | |
| }) | |
| state["log"].append(f"[LLM] Generate: format={fmt}, code_len={len(code)}") | |
| ( | |
| s, | |
| log, | |
| obs_md, | |
| feedback, | |
| generated_response, | |
| parsed_response, | |
| search, | |
| top_chunks, | |
| reward_summary, | |
| rewards_table, | |
| ) = await do_generate( | |
| fmt, | |
| code, | |
| narration, | |
| state, | |
| ) | |
| return ( | |
| s, | |
| log, | |
| obs_md, | |
| feedback, | |
| generated_response, | |
| parsed_response, | |
| search, | |
| top_chunks, | |
| reward_summary, | |
| rewards_table, | |
| ) | |
| return llm_outputs(state) | |
| async def do_llm_auto(state): | |
| """Run full episode automatically with LLM (explore + generate).""" | |
| outputs = None | |
| while not state.get("done"): | |
| outputs = await do_llm_step(state) | |
| state = outputs[0] | |
| if state.get("log") and str(state["log"][-1]).startswith("[ERROR]"): | |
| break | |
| return outputs if outputs else llm_outputs(state, status="No steps taken.") | |
| # --------------------------------------------------------------------------- | |
| # Gradio UI | |
| # --------------------------------------------------------------------------- | |
| def build_ui(): | |
| with gr.Blocks(title="Explainer Env — Interactive Runner") as demo: | |
| session_state = gr.State(empty_state()) | |
| # Header | |
| gr.Markdown("# Explainer Episode Inspector") | |
| # ===================================================================== | |
| # Controls | |
| # ===================================================================== | |
| with gr.Row(equal_height=True): | |
| task_dd = gr.Dropdown( | |
| choices=TASK_CHOICES, | |
| value="(random)", | |
| label="Task", | |
| scale=1, | |
| ) | |
| with gr.Row(equal_height=True): | |
| reset_btn = gr.Button("Reset Episode", variant="primary") | |
| llm_step_btn = gr.Button("Next Step", variant="secondary") | |
| llm_auto_btn = gr.Button("Auto Run", variant="primary") | |
| # ===================================================================== | |
| # Inspector panels | |
| # ===================================================================== | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Observation") | |
| obs_md = gr.Markdown("*Click Reset Episode to begin.*") | |
| feedback_box = gr.Textbox( | |
| label="Latest feedback", | |
| lines=8, | |
| max_lines=8, | |
| interactive=False, | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### LLM") | |
| with gr.Tabs(): | |
| with gr.Tab("Parsed"): | |
| parsed_response_box = gr.Markdown("*No parsed response yet.*") | |
| with gr.Tab("Response / code"): | |
| generated_response_box = gr.Textbox( | |
| label="Raw response or generated code", | |
| value="No response yet.", | |
| lines=16, | |
| max_lines=16, | |
| interactive=False, | |
| buttons=["copy"], | |
| ) | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Research") | |
| search_box = gr.Textbox( | |
| label="Latest search results", | |
| lines=8, | |
| max_lines=8, | |
| interactive=False, | |
| ) | |
| top_chunks_table = gr.Dataframe( | |
| headers=["Rank", "Source", "Title", "Score", "URL", "Snippet"], | |
| interactive=False, | |
| column_count=(6, "fixed"), | |
| label="Top chunks", | |
| ) | |
| with gr.Column(scale=1): | |
| with gr.Group(): | |
| gr.Markdown("### Rewards") | |
| reward_summary = gr.Markdown("*No rewards yet.*") | |
| rewards_table = gr.Dataframe( | |
| headers=["Reward"], | |
| interactive=False, | |
| column_count=(1, "fixed"), | |
| label="Reward matrix", | |
| ) | |
| # ===================================================================== | |
| # Timeline | |
| # ===================================================================== | |
| with gr.Group(): | |
| gr.Markdown("### Timeline") | |
| log_box = gr.Markdown("*No events yet.*") | |
| # ===================================================================== | |
| # Wiring | |
| # ===================================================================== | |
| # Common outputs: state, log, obs, feedback, search, rewards | |
| common_output_components = [ | |
| session_state, | |
| log_box, | |
| obs_md, | |
| feedback_box, | |
| generated_response_box, | |
| parsed_response_box, | |
| search_box, | |
| top_chunks_table, | |
| reward_summary, | |
| rewards_table, | |
| ] | |
| reset_btn.click( | |
| fn=do_reset, | |
| inputs=[task_dd, session_state], | |
| outputs=common_output_components, | |
| ) | |
| llm_step_btn.click( | |
| fn=do_llm_step, | |
| inputs=[session_state], | |
| outputs=common_output_components, | |
| ) | |
| llm_auto_btn.click( | |
| fn=do_llm_auto, | |
| inputs=[session_state], | |
| outputs=common_output_components, | |
| ) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |