Spaces:
Sleeping
Sleeping
| """ | |
| RoboReplan server β OpenEnv HTTP protocol + metrics endpoint. | |
| """ | |
| import os | |
| import re | |
| from pathlib import Path | |
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import HTMLResponse, RedirectResponse | |
| from pydantic import BaseModel | |
| from openenv.core.env_server import create_fastapi_app | |
| from .openenv_env import RoboReplanEnv, RoboAction, RoboObservation, RoboState | |
| from .models import Action as EnvAction | |
| difficulty = os.environ.get("DIFFICULTY", "easy") | |
| # Shared env instance (metrics persist across requests) | |
| _env_instance = RoboReplanEnv(difficulty=difficulty) | |
| app = create_fastapi_app( | |
| env=lambda: _env_instance, | |
| action_cls=RoboAction, | |
| observation_cls=RoboObservation, | |
| ) | |
| app.add_middleware(CORSMiddleware, allow_origins=["*"], | |
| allow_methods=["*"], allow_headers=["*"]) | |
| _VIZ_HTML = (Path(__file__).parent.parent / "viz_standalone.html").read_text() | |
| def root(): | |
| return RedirectResponse(url="/viz") | |
| def viz(): | |
| return _VIZ_HTML | |
| def metrics(): | |
| """Live training metrics: success rate, reward curve, failure breakdown, oracle agreement.""" | |
| return _env_instance.metrics | |
| # ββ Demo endpoints β judges can interact live ββββββββββββββββββββββββββ | |
| _demo_env = None | |
| _policy_pipe = None | |
| _POLICY_MODEL = os.environ.get("DEMO_POLICY_MODEL", "jshah13/robo-replan-grpo") | |
| _VALID_ACTIONS = [a.value for a in EnvAction] | |
| def _extract_action(text: str) -> str: | |
| clean = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip().upper() | |
| normalized = re.sub(r"[^A-Z_ ]+", " ", clean) | |
| normalized = re.sub(r"\s+", " ", normalized).strip() | |
| for action in sorted(_VALID_ACTIONS, key=len, reverse=True): | |
| if action in clean: | |
| return action | |
| spaced_map = {a.replace("_", " "): a for a in _VALID_ACTIONS} | |
| for spaced, action in spaced_map.items(): | |
| if spaced in normalized: | |
| return action | |
| return "SCAN_SCENE" | |
| def _parse_valid_actions_from_prompt(prompt: str) -> list[str]: | |
| m = re.search(r"Valid now:\s*(.*)", prompt, flags=re.IGNORECASE) | |
| if not m: | |
| return [] | |
| raw = m.group(1).strip() | |
| if raw.lower() == "any": | |
| return [] | |
| items = [x.strip().upper() for x in raw.split(",") if x.strip()] | |
| return [a for a in items if a in _VALID_ACTIONS] | |
| def _fallback_action(valid: list[str]) -> str: | |
| if not valid: | |
| return "SCAN_SCENE" | |
| priority = [ | |
| "PLACE_BIN_A", "PLACE_BIN_B", | |
| "PICK", | |
| "CLEAR_BLOCKER", | |
| "MOVE_TO_RED", "MOVE_TO_BLUE", "MOVE_TO_GREEN", "MOVE_TO_YELLOW", "MOVE_TO_PURPLE", | |
| "MOVE_NORTH", "MOVE_SOUTH", "MOVE_EAST", "MOVE_WEST", | |
| "ROTATE_LEFT", "ROTATE_RIGHT", | |
| "SCAN_SCENE", | |
| ] | |
| for p in priority: | |
| if p in valid: | |
| return p | |
| return valid[0] | |
| def _prompt_line(prompt: str, key: str) -> str: | |
| m = re.search(rf"{re.escape(key)}:\s*(.*)", prompt, flags=re.IGNORECASE) | |
| return m.group(1).strip() if m else "" | |
| def _smart_fallback_action(valid: list[str], prompt: str) -> str: | |
| """ | |
| Use lightweight state cues from prompt to avoid repetitive bad actions. | |
| """ | |
| if not valid: | |
| return "SCAN_SCENE" | |
| valid_set = set(valid) | |
| last_line = _prompt_line(prompt, "Last") | |
| holding = _prompt_line(prompt, "Holding").lower() | |
| last_action, last_result = "", "" | |
| m = re.match(r"\s*([A-Z_]+)\s*->\s*([A-Z_]+)", last_line.upper()) | |
| if m: | |
| last_action, last_result = m.group(1), m.group(2) | |
| # If holding something, prioritize placing over anything else. | |
| if holding and holding not in ("nothing", "none", "null"): | |
| if "PLACE_BIN_A" in valid_set: | |
| return "PLACE_BIN_A" | |
| if "PLACE_BIN_B" in valid_set: | |
| return "PLACE_BIN_B" | |
| # If we just moved to a target successfully, then pick. | |
| if last_action.startswith("MOVE_TO_") and last_result == "SUCCESS" and "PICK" in valid_set: | |
| return "PICK" | |
| # If pick just failed/was invalid, move or clear first instead of repeating pick. | |
| if last_action == "PICK" and last_result.startswith("FAILED"): | |
| for a in valid: | |
| if a.startswith("MOVE_TO_"): | |
| return a | |
| if "CLEAR_BLOCKER" in valid_set: | |
| return "CLEAR_BLOCKER" | |
| # In top-down mode, blind PICK tends to loop; prefer moving to a target first. | |
| for a in valid: | |
| if a.startswith("MOVE_TO_"): | |
| return a | |
| return _fallback_action(valid) | |
| def _get_policy_pipe(): | |
| global _policy_pipe | |
| if _policy_pipe is None: | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline | |
| tokenizer = AutoTokenizer.from_pretrained(_POLICY_MODEL) | |
| tokenizer.padding_side = "left" | |
| tokenizer.pad_token = tokenizer.eos_token | |
| has_gpu = torch.cuda.is_available() | |
| if has_gpu: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| _POLICY_MODEL, torch_dtype=torch.float16, device_map="auto", | |
| ) | |
| pipe_kwargs = {"device_map": "auto"} | |
| else: | |
| model = AutoModelForCausalLM.from_pretrained( | |
| _POLICY_MODEL, torch_dtype=torch.float32, | |
| ) | |
| pipe_kwargs = {"device": "cpu"} | |
| model.generation_config.pad_token_id = tokenizer.pad_token_id | |
| model.generation_config.max_length = None | |
| _policy_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **pipe_kwargs) | |
| return _policy_pipe | |
| class PolicyActionRequest(BaseModel): | |
| prompt: str | |
| valid_actions: list[str] = [] | |
| def _format_demo_step_response(step_out): | |
| """ | |
| Compat for env implementations that return either: | |
| - StepResult(observation, reward, done, info), or | |
| - Observation(done, reward, ...) | |
| """ | |
| if hasattr(step_out, "observation"): | |
| obs = step_out.observation | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": float(getattr(step_out, "reward", 0.0) or 0.0), | |
| "done": bool(getattr(step_out, "done", False)), | |
| "info": getattr(step_out, "info", {}) or {}, | |
| } | |
| obs = step_out | |
| return { | |
| "observation": obs.model_dump(), | |
| "reward": float(getattr(obs, "reward", 0.0) or 0.0), | |
| "done": bool(getattr(obs, "done", False)), | |
| "info": {}, | |
| } | |
| def demo_reset(difficulty: str = "easy", scenario_pack: str = "default"): | |
| """Start a fresh demo episode. scenario_pack: default | pharmacy | warehouse | lab""" | |
| global _demo_env | |
| _demo_env = RoboReplanEnv(difficulty=difficulty) | |
| # Apply scenario pack by patching the internal env config before reset | |
| if scenario_pack != "default": | |
| _demo_env._env.cfg.task.scenario_pack = scenario_pack | |
| obs = _demo_env.reset() | |
| return {"observation": obs.model_dump(), "done": False, "reward": 0.0} | |
| def demo_step(action: str): | |
| """Take one step in the demo episode.""" | |
| global _demo_env | |
| if _demo_env is None: | |
| _demo_env = RoboReplanEnv(difficulty="easy") | |
| _demo_env.reset() | |
| result = _demo_env.step(RoboAction(action=action)) | |
| return _format_demo_step_response(result) | |
| def _oracle_reasoning(env, action: str) -> str: | |
| """Generate a human-readable <think> narration for the oracle's chosen action.""" | |
| try: | |
| inner = env._env | |
| state = inner.sim.get_state() | |
| holding = state.holding | |
| placements = inner._required_placements | |
| failures = inner._known_failures | |
| constraints = inner._active_constraints | |
| completed = set(inner._completed_subgoals) | |
| instruction = inner._instruction | |
| # Build context strings | |
| blocked_objs = [n for n, o in state.objects.items() if not o.reachable and o.in_bin is None] | |
| remaining = [n for n, b in placements.items() | |
| if f"placed_{n}_in_bin_{b}" not in completed and | |
| not (state.objects.get(n) and state.objects[n].in_bin == b)] | |
| constraint_str = f" Constraint: {', '.join(constraints)}." if constraints else "" | |
| deadline_status = inner._deadline_status() | |
| deadline_str = f" Deadlines active: {deadline_status}." if deadline_status else "" | |
| if action == "SCAN_SCENE": | |
| return (f"I need to survey the scene to reveal hidden object traits " | |
| f"before planning.{constraint_str}") | |
| if action == "CLEAR_BLOCKER": | |
| for n, o in state.objects.items(): | |
| if o.blocking and o.reachable: | |
| return (f"Plan: CLEAR_BLOCKER β MOVE_TO_{o.blocking.replace('_block','').upper()} β PICK β PLACE. " | |
| f"{n} is blocking {o.blocking}. Clearing it first.{constraint_str}") | |
| if holding and action.startswith("PLACE_BIN_"): | |
| bin_name = action.split("_")[-1] | |
| target_bin = placements.get(holding, "?") | |
| correct = target_bin == bin_name | |
| return (f"I am holding {holding}. Target bin is {target_bin}. " | |
| f"{'Placing correctly' if correct else 'Placing in available bin'}.{constraint_str}") | |
| if action.startswith("MOVE_TO_"): | |
| color = action.replace("MOVE_TO_", "").lower() | |
| target = f"{color}_block" | |
| if failures: | |
| return (f"Previous failure: {failures[-1]}. Re-navigating to {target} to retry.{constraint_str}") | |
| return (f"Plan: MOVE_TO_{color.upper()} β PICK β PLACE_BIN_{placements.get(target,'?')}. " | |
| f"Moving to {target} now.{constraint_str}{deadline_str}") | |
| if action == "PICK": | |
| if failures and any("FAILED_SLIP" in f for f in failures): | |
| return (f"Previous grasp slip detected ({failures[-1]}). " | |
| f"Retrying PICK β repositioning and attempting again.") | |
| return (f"Gripper is adjacent to target. Attempting to grasp.{constraint_str}") | |
| if action in ("MOVE_NORTH", "MOVE_SOUTH", "MOVE_EAST", "MOVE_WEST"): | |
| if remaining: | |
| target = remaining[0] | |
| return (f"Navigating toward {target} " | |
| f"(target bin: {placements.get(target,'?')}).{constraint_str}{deadline_str}") | |
| return f"Executing {action}.{constraint_str}" | |
| except Exception: | |
| return f"Executing {action}." | |
| def demo_oracle(): | |
| """Step using the oracle policy β shows optimal behavior for demo.""" | |
| global _demo_env | |
| if _demo_env is None: | |
| _demo_env = RoboReplanEnv(difficulty="easy") | |
| _demo_env.reset() | |
| oracle = _demo_env._env._oracle_action() or "SCAN_SCENE" | |
| reasoning = _oracle_reasoning(_demo_env, oracle) | |
| result = _demo_env.step(RoboAction(action=oracle)) | |
| payload = _format_demo_step_response(result) | |
| payload["action_taken"] = oracle | |
| payload["reasoning"] = reasoning | |
| return payload | |
| def demo_policy_action(req: PolicyActionRequest): | |
| """ | |
| Returns one model-predicted action for a given prompt. | |
| The visualization can use this to drive the environment step-by-step. | |
| """ | |
| try: | |
| pipe = _get_policy_pipe() | |
| out = pipe( | |
| req.prompt, | |
| return_full_text=False, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| temperature=0.7, | |
| top_p=0.9, | |
| repetition_penalty=1.05, | |
| )[0]["generated_text"] | |
| raw_action = _extract_action(out) | |
| action = raw_action | |
| valid = [a for a in req.valid_actions if a in _VALID_ACTIONS] or _parse_valid_actions_from_prompt(req.prompt) | |
| if valid and action not in valid: | |
| action = _smart_fallback_action(valid, req.prompt) | |
| # Avoid no-op scan loops when other valid actions exist. | |
| if valid and action == "SCAN_SCENE" and any(v != "SCAN_SCENE" for v in valid): | |
| action = _smart_fallback_action([v for v in valid if v != "SCAN_SCENE"], req.prompt) | |
| # Extract <think>...</think> reasoning separately for display and env reward | |
| import re as _re | |
| _m = _re.search(r'<think>(.*?)</think>', out, _re.DOTALL) | |
| reasoning = _m.group(1).strip() if _m else "" | |
| return { | |
| "action": action, | |
| "reasoning": reasoning, | |
| "raw_output": out, | |
| "raw_action": raw_action, | |
| "valid_actions_used": valid, | |
| } | |
| except Exception as exc: | |
| # Fail soft so the UI can still run with manual/scripted controls. | |
| valid = [a for a in req.valid_actions if a in _VALID_ACTIONS] or _parse_valid_actions_from_prompt(req.prompt) | |
| action = _smart_fallback_action([v for v in valid if v != "SCAN_SCENE"], req.prompt) if valid else "SCAN_SCENE" | |
| return {"action": action, "error": str(exc), "valid_actions_used": valid} | |