jshah13 commited on
Commit
0fbcf4f
·
verified ·
1 Parent(s): b9ae496

Upload server/app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/app.py +340 -0
server/app.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RoboReplan server — OpenEnv HTTP protocol + metrics endpoint.
3
+ """
4
+ import os
5
+ import re
6
+ from pathlib import Path
7
+
8
+ from fastapi import FastAPI
9
+ from fastapi.middleware.cors import CORSMiddleware
10
+ from fastapi.responses import HTMLResponse, RedirectResponse
11
+ from pydantic import BaseModel
12
+ from openenv.core.env_server import create_fastapi_app
13
+
14
+ from .openenv_env import RoboReplanEnv, RoboAction, RoboObservation, RoboState
15
+ from .models import Action as EnvAction
16
+
17
+ difficulty = os.environ.get("DIFFICULTY", "easy")
18
+
19
+ # Shared env instance (metrics persist across requests)
20
+ _env_instance = RoboReplanEnv(difficulty=difficulty)
21
+
22
+ app = create_fastapi_app(
23
+ env=lambda: _env_instance,
24
+ action_cls=RoboAction,
25
+ observation_cls=RoboObservation,
26
+ )
27
+
28
+ app.add_middleware(CORSMiddleware, allow_origins=["*"],
29
+ allow_methods=["*"], allow_headers=["*"])
30
+
31
+ _VIZ_HTML = (Path(__file__).parent.parent / "viz_standalone.html").read_text()
32
+
33
+
34
+ @app.get("/")
35
+ def root():
36
+ return RedirectResponse(url="/viz")
37
+
38
+
39
+ @app.get("/viz", response_class=HTMLResponse)
40
+ def viz():
41
+ return _VIZ_HTML
42
+
43
+
44
+ @app.get("/metrics")
45
+ def metrics():
46
+ """Live training metrics: success rate, reward curve, failure breakdown, oracle agreement."""
47
+ return _env_instance.metrics
48
+
49
+
50
+ # ── Demo endpoints — judges can interact live ──────────────────────────
51
+
52
+ _demo_env = None
53
+ _policy_pipe = None
54
+ _POLICY_MODEL = os.environ.get("DEMO_POLICY_MODEL", "jshah13/robo-replan-grpo")
55
+ _VALID_ACTIONS = [a.value for a in EnvAction]
56
+
57
+
58
+ def _extract_action(text: str) -> str:
59
+ clean = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip().upper()
60
+ normalized = re.sub(r"[^A-Z_ ]+", " ", clean)
61
+ normalized = re.sub(r"\s+", " ", normalized).strip()
62
+ for action in sorted(_VALID_ACTIONS, key=len, reverse=True):
63
+ if action in clean:
64
+ return action
65
+ spaced_map = {a.replace("_", " "): a for a in _VALID_ACTIONS}
66
+ for spaced, action in spaced_map.items():
67
+ if spaced in normalized:
68
+ return action
69
+ return "SCAN_SCENE"
70
+
71
+
72
+ def _parse_valid_actions_from_prompt(prompt: str) -> list[str]:
73
+ m = re.search(r"Valid now:\s*(.*)", prompt, flags=re.IGNORECASE)
74
+ if not m:
75
+ return []
76
+ raw = m.group(1).strip()
77
+ if raw.lower() == "any":
78
+ return []
79
+ items = [x.strip().upper() for x in raw.split(",") if x.strip()]
80
+ return [a for a in items if a in _VALID_ACTIONS]
81
+
82
+
83
+ def _fallback_action(valid: list[str]) -> str:
84
+ if not valid:
85
+ return "SCAN_SCENE"
86
+ priority = [
87
+ "PLACE_BIN_A", "PLACE_BIN_B",
88
+ "PICK",
89
+ "CLEAR_BLOCKER",
90
+ "MOVE_TO_RED", "MOVE_TO_BLUE", "MOVE_TO_GREEN", "MOVE_TO_YELLOW", "MOVE_TO_PURPLE",
91
+ "MOVE_NORTH", "MOVE_SOUTH", "MOVE_EAST", "MOVE_WEST",
92
+ "ROTATE_LEFT", "ROTATE_RIGHT",
93
+ "SCAN_SCENE",
94
+ ]
95
+ for p in priority:
96
+ if p in valid:
97
+ return p
98
+ return valid[0]
99
+
100
+
101
+ def _prompt_line(prompt: str, key: str) -> str:
102
+ m = re.search(rf"{re.escape(key)}:\s*(.*)", prompt, flags=re.IGNORECASE)
103
+ return m.group(1).strip() if m else ""
104
+
105
+
106
+ def _smart_fallback_action(valid: list[str], prompt: str) -> str:
107
+ """
108
+ Use lightweight state cues from prompt to avoid repetitive bad actions.
109
+ """
110
+ if not valid:
111
+ return "SCAN_SCENE"
112
+ valid_set = set(valid)
113
+ last_line = _prompt_line(prompt, "Last")
114
+ holding = _prompt_line(prompt, "Holding").lower()
115
+
116
+ last_action, last_result = "", ""
117
+ m = re.match(r"\s*([A-Z_]+)\s*->\s*([A-Z_]+)", last_line.upper())
118
+ if m:
119
+ last_action, last_result = m.group(1), m.group(2)
120
+
121
+ # If holding something, prioritize placing over anything else.
122
+ if holding and holding not in ("nothing", "none", "null"):
123
+ if "PLACE_BIN_A" in valid_set:
124
+ return "PLACE_BIN_A"
125
+ if "PLACE_BIN_B" in valid_set:
126
+ return "PLACE_BIN_B"
127
+
128
+ # If we just moved to a target successfully, then pick.
129
+ if last_action.startswith("MOVE_TO_") and last_result == "SUCCESS" and "PICK" in valid_set:
130
+ return "PICK"
131
+
132
+ # If pick just failed/was invalid, move or clear first instead of repeating pick.
133
+ if last_action == "PICK" and last_result.startswith("FAILED"):
134
+ for a in valid:
135
+ if a.startswith("MOVE_TO_"):
136
+ return a
137
+ if "CLEAR_BLOCKER" in valid_set:
138
+ return "CLEAR_BLOCKER"
139
+
140
+ # In top-down mode, blind PICK tends to loop; prefer moving to a target first.
141
+ for a in valid:
142
+ if a.startswith("MOVE_TO_"):
143
+ return a
144
+
145
+ return _fallback_action(valid)
146
+
147
+
148
+ def _get_policy_pipe():
149
+ global _policy_pipe
150
+ if _policy_pipe is None:
151
+ import torch
152
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
153
+
154
+ tokenizer = AutoTokenizer.from_pretrained(_POLICY_MODEL)
155
+ tokenizer.padding_side = "left"
156
+ tokenizer.pad_token = tokenizer.eos_token
157
+
158
+ has_gpu = torch.cuda.is_available()
159
+ if has_gpu:
160
+ model = AutoModelForCausalLM.from_pretrained(
161
+ _POLICY_MODEL, torch_dtype=torch.float16, device_map="auto",
162
+ )
163
+ pipe_kwargs = {"device_map": "auto"}
164
+ else:
165
+ model = AutoModelForCausalLM.from_pretrained(
166
+ _POLICY_MODEL, torch_dtype=torch.float32,
167
+ )
168
+ pipe_kwargs = {"device": "cpu"}
169
+
170
+ model.generation_config.pad_token_id = tokenizer.pad_token_id
171
+ model.generation_config.max_length = None
172
+ _policy_pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **pipe_kwargs)
173
+ return _policy_pipe
174
+
175
+
176
+ class PolicyActionRequest(BaseModel):
177
+ prompt: str
178
+ valid_actions: list[str] = []
179
+
180
+
181
+ def _format_demo_step_response(step_out):
182
+ """
183
+ Compat for env implementations that return either:
184
+ - StepResult(observation, reward, done, info), or
185
+ - Observation(done, reward, ...)
186
+ """
187
+ if hasattr(step_out, "observation"):
188
+ obs = step_out.observation
189
+ return {
190
+ "observation": obs.model_dump(),
191
+ "reward": float(getattr(step_out, "reward", 0.0) or 0.0),
192
+ "done": bool(getattr(step_out, "done", False)),
193
+ "info": getattr(step_out, "info", {}) or {},
194
+ }
195
+ obs = step_out
196
+ return {
197
+ "observation": obs.model_dump(),
198
+ "reward": float(getattr(obs, "reward", 0.0) or 0.0),
199
+ "done": bool(getattr(obs, "done", False)),
200
+ "info": {},
201
+ }
202
+
203
+
204
+ @app.post("/demo/reset")
205
+ def demo_reset(difficulty: str = "easy", scenario_pack: str = "default"):
206
+ """Start a fresh demo episode. scenario_pack: default | pharmacy | warehouse | lab"""
207
+ global _demo_env
208
+ _demo_env = RoboReplanEnv(difficulty=difficulty)
209
+ # Apply scenario pack by patching the internal env config before reset
210
+ if scenario_pack != "default":
211
+ _demo_env._env.cfg.task.scenario_pack = scenario_pack
212
+ obs = _demo_env.reset()
213
+ return {"observation": obs.model_dump(), "done": False, "reward": 0.0}
214
+
215
+
216
+ @app.post("/demo/step")
217
+ def demo_step(action: str):
218
+ """Take one step in the demo episode."""
219
+ global _demo_env
220
+ if _demo_env is None:
221
+ _demo_env = RoboReplanEnv(difficulty="easy")
222
+ _demo_env.reset()
223
+ result = _demo_env.step(RoboAction(action=action))
224
+ return _format_demo_step_response(result)
225
+
226
+
227
+ def _oracle_reasoning(env, action: str) -> str:
228
+ """Generate a human-readable <think> narration for the oracle's chosen action."""
229
+ try:
230
+ inner = env._env
231
+ state = inner.sim.get_state()
232
+ holding = state.holding
233
+ placements = inner._required_placements
234
+ failures = inner._known_failures
235
+ constraints = inner._active_constraints
236
+ completed = set(inner._completed_subgoals)
237
+ instruction = inner._instruction
238
+
239
+ # Build context strings
240
+ blocked_objs = [n for n, o in state.objects.items() if not o.reachable and o.in_bin is None]
241
+ remaining = [n for n, b in placements.items()
242
+ if f"placed_{n}_in_bin_{b}" not in completed and
243
+ not (state.objects.get(n) and state.objects[n].in_bin == b)]
244
+ constraint_str = f" Constraint: {', '.join(constraints)}." if constraints else ""
245
+ deadline_status = inner._deadline_status()
246
+ deadline_str = f" Deadlines active: {deadline_status}." if deadline_status else ""
247
+
248
+ if action == "SCAN_SCENE":
249
+ return (f"I need to survey the scene to reveal hidden object traits "
250
+ f"before planning.{constraint_str}")
251
+ if action == "CLEAR_BLOCKER":
252
+ for n, o in state.objects.items():
253
+ if o.blocking and o.reachable:
254
+ return (f"Plan: CLEAR_BLOCKER → MOVE_TO_{o.blocking.replace('_block','').upper()} → PICK → PLACE. "
255
+ f"{n} is blocking {o.blocking}. Clearing it first.{constraint_str}")
256
+ if holding and action.startswith("PLACE_BIN_"):
257
+ bin_name = action.split("_")[-1]
258
+ target_bin = placements.get(holding, "?")
259
+ correct = target_bin == bin_name
260
+ return (f"I am holding {holding}. Target bin is {target_bin}. "
261
+ f"{'Placing correctly' if correct else 'Placing in available bin'}.{constraint_str}")
262
+ if action.startswith("MOVE_TO_"):
263
+ color = action.replace("MOVE_TO_", "").lower()
264
+ target = f"{color}_block"
265
+ if failures:
266
+ return (f"Previous failure: {failures[-1]}. Re-navigating to {target} to retry.{constraint_str}")
267
+ return (f"Plan: MOVE_TO_{color.upper()} → PICK → PLACE_BIN_{placements.get(target,'?')}. "
268
+ f"Moving to {target} now.{constraint_str}{deadline_str}")
269
+ if action == "PICK":
270
+ if failures and any("FAILED_SLIP" in f for f in failures):
271
+ return (f"Previous grasp slip detected ({failures[-1]}). "
272
+ f"Retrying PICK — repositioning and attempting again.")
273
+ return (f"Gripper is adjacent to target. Attempting to grasp.{constraint_str}")
274
+ if action in ("MOVE_NORTH", "MOVE_SOUTH", "MOVE_EAST", "MOVE_WEST"):
275
+ if remaining:
276
+ target = remaining[0]
277
+ return (f"Navigating toward {target} "
278
+ f"(target bin: {placements.get(target,'?')}).{constraint_str}{deadline_str}")
279
+ return f"Executing {action}.{constraint_str}"
280
+ except Exception:
281
+ return f"Executing {action}."
282
+
283
+
284
+ @app.get("/demo/oracle")
285
+ def demo_oracle():
286
+ """Step using the oracle policy — shows optimal behavior for demo."""
287
+ global _demo_env
288
+ if _demo_env is None:
289
+ _demo_env = RoboReplanEnv(difficulty="easy")
290
+ _demo_env.reset()
291
+ oracle = _demo_env._env._oracle_action() or "SCAN_SCENE"
292
+ reasoning = _oracle_reasoning(_demo_env, oracle)
293
+ result = _demo_env.step(RoboAction(action=oracle))
294
+ payload = _format_demo_step_response(result)
295
+ payload["action_taken"] = oracle
296
+ payload["reasoning"] = reasoning
297
+ return payload
298
+
299
+
300
+ @app.post("/demo/policy_action")
301
+ def demo_policy_action(req: PolicyActionRequest):
302
+ """
303
+ Returns one model-predicted action for a given prompt.
304
+ The visualization can use this to drive the environment step-by-step.
305
+ """
306
+ try:
307
+ pipe = _get_policy_pipe()
308
+ out = pipe(
309
+ req.prompt,
310
+ return_full_text=False,
311
+ max_new_tokens=128,
312
+ do_sample=True,
313
+ temperature=0.7,
314
+ top_p=0.9,
315
+ repetition_penalty=1.05,
316
+ )[0]["generated_text"]
317
+ raw_action = _extract_action(out)
318
+ action = raw_action
319
+ valid = [a for a in req.valid_actions if a in _VALID_ACTIONS] or _parse_valid_actions_from_prompt(req.prompt)
320
+ if valid and action not in valid:
321
+ action = _smart_fallback_action(valid, req.prompt)
322
+ # Avoid no-op scan loops when other valid actions exist.
323
+ if valid and action == "SCAN_SCENE" and any(v != "SCAN_SCENE" for v in valid):
324
+ action = _smart_fallback_action([v for v in valid if v != "SCAN_SCENE"], req.prompt)
325
+ # Extract <think>...</think> reasoning separately for display and env reward
326
+ import re as _re
327
+ _m = _re.search(r'<think>(.*?)</think>', out, _re.DOTALL)
328
+ reasoning = _m.group(1).strip() if _m else ""
329
+ return {
330
+ "action": action,
331
+ "reasoning": reasoning,
332
+ "raw_output": out,
333
+ "raw_action": raw_action,
334
+ "valid_actions_used": valid,
335
+ }
336
+ except Exception as exc:
337
+ # Fail soft so the UI can still run with manual/scripted controls.
338
+ valid = [a for a in req.valid_actions if a in _VALID_ACTIONS] or _parse_valid_actions_from_prompt(req.prompt)
339
+ action = _smart_fallback_action([v for v in valid if v != "SCAN_SCENE"], req.prompt) if valid else "SCAN_SCENE"
340
+ return {"action": action, "error": str(exc), "valid_actions_used": valid}