jshah13 commited on
Commit
1da6b5e
Β·
verified Β·
1 Parent(s): 0fbcf4f

Upload server/environment.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/environment.py +798 -0
server/environment.py ADDED
@@ -0,0 +1,798 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TabletopPlanningEnv β€” fully instrumented RL training environment.
3
+
4
+ Every knob lives in EnvConfig. Every step is logged. Curriculum auto-advances.
5
+ The observation tells the model everything it needs to plan well.
6
+ """
7
+ import random
8
+ from typing import Optional
9
+
10
+ from .config import EnvConfig, RealismConfig
11
+ from .logger import EpisodeLogger
12
+ from .models import Action, ObjectInfo, Observation, StepResult
13
+ from .robosim import SimWrapper
14
+ from .robosim.randomizer import randomize_scenario
15
+
16
+
17
+ class TabletopPlanningEnv:
18
+ def __init__(self, config: EnvConfig = None, use_stub: bool = True):
19
+ self.cfg = config or EnvConfig.easy()
20
+ self.sim = SimWrapper(use_stub=use_stub)
21
+ self.logger = EpisodeLogger(
22
+ export_path=self.cfg.log.export_path,
23
+ max_history=self.cfg.log.max_episode_history,
24
+ )
25
+ self._episode_id = 0
26
+ self._cumulative_reward = 0.0
27
+ self._action_history: list[str] = []
28
+ self._last_action: Optional[str] = None
29
+ self._last_result: Optional[str] = None
30
+ self._mid_task_changed = False
31
+ self._reset_internal()
32
+
33
+ def _nav_enabled(self) -> bool:
34
+ return bool(getattr(self.cfg.task, "navigation_mode", False))
35
+
36
+ def _gripper_cell(self) -> tuple[int, int]:
37
+ p = self.sim.get_state().gripper_pos
38
+ return int(round(float(p[0]) / 0.1)), int(round(float(p[1]) / 0.1))
39
+
40
+ def _object_cell(self, obj_name: str) -> Optional[tuple[int, int]]:
41
+ obj = self.sim.get_state().objects.get(obj_name)
42
+ if obj is None:
43
+ return None
44
+ return int(round(float(obj.pos[0]) / 0.1)), int(round(float(obj.pos[1]) / 0.1))
45
+
46
+ def _is_adjacent_to(self, obj_name: str) -> bool:
47
+ oc = self._object_cell(obj_name)
48
+ if oc is None:
49
+ return False
50
+ gx, gy = self._gripper_cell()
51
+ ox, oy = oc
52
+ return abs(gx - ox) + abs(gy - oy) <= 2
53
+
54
+ def _is_facing_object(self, obj_name: str) -> bool:
55
+ oc = self._object_cell(obj_name)
56
+ if oc is None:
57
+ return False
58
+ gx, gy = self._gripper_cell()
59
+ ox, oy = oc
60
+ dx, dy = (ox - gx), (oy - gy)
61
+ facing = self.sim.get_facing()
62
+ forward = {
63
+ "N": (0, 1),
64
+ "S": (0, -1),
65
+ "E": (1, 0),
66
+ "W": (-1, 0),
67
+ }.get(facing, (0, 1))
68
+ return (dx, dy) == forward
69
+
70
+ def _can_pick_object(self, obj_name: str) -> bool:
71
+ obj = self.sim.get_state().objects.get(obj_name)
72
+ if obj is None or not obj.reachable or obj.is_held or obj.in_bin is not None:
73
+ return False
74
+ if self._nav_enabled():
75
+ return self._is_adjacent_to(obj_name)
76
+ gp = self.sim.get_state().gripper_pos
77
+ dx = float(gp[0]) - float(obj.pos[0])
78
+ dy = float(gp[1]) - float(obj.pos[1])
79
+ return (dx * dx + dy * dy) ** 0.5 < 0.15
80
+
81
+ def _next_goal_cell(self) -> Optional[tuple[int, int]]:
82
+ state = self.sim.get_state()
83
+ for obj_name, bin_name in self._required_placements.items():
84
+ obj = state.objects.get(obj_name)
85
+ if not obj or obj.in_bin == bin_name:
86
+ continue
87
+ if obj.reachable:
88
+ return self._object_cell(obj_name)
89
+ for blocker in state.objects.values():
90
+ if blocker.blocking == obj_name and blocker.reachable and blocker.in_bin is None:
91
+ return self._object_cell(blocker.name)
92
+ return None
93
+
94
+ def _distance_to_next_goal(self) -> Optional[int]:
95
+ cell = self._next_goal_cell()
96
+ if cell is None:
97
+ return None
98
+ gx, gy = self._gripper_cell()
99
+ tx, ty = cell
100
+ return abs(gx - tx) + abs(gy - ty)
101
+
102
+ def _valid_actions_with_reasons(self) -> dict[str, str]:
103
+ state = self.sim.get_state()
104
+ reasons = {"SCAN_SCENE": "refresh scene understanding"}
105
+ if self._nav_enabled():
106
+ reasons.update({
107
+ "MOVE_NORTH": "move gripper one cell north",
108
+ "MOVE_SOUTH": "move gripper one cell south",
109
+ "MOVE_EAST": "move gripper one cell east",
110
+ "MOVE_WEST": "move gripper one cell west",
111
+ "ROTATE_LEFT": "rotate gripper orientation left",
112
+ "ROTATE_RIGHT": "rotate gripper orientation right",
113
+ })
114
+ else:
115
+ for obj in state.objects.values():
116
+ if obj.reachable and not obj.is_held and obj.in_bin is None:
117
+ color = obj.name.replace("_block", "").upper()
118
+ reasons[f"MOVE_TO_{color}"] = f"navigate directly to {obj.name}"
119
+
120
+ if state.holding:
121
+ reasons["PLACE_BIN_A"] = "place held object in bin A"
122
+ reasons["PLACE_BIN_B"] = "place held object in bin B"
123
+ else:
124
+ for obj in state.objects.values():
125
+ if not self._can_pick_object(obj.name):
126
+ continue
127
+ reasons["PICK"] = f"pick reachable object ({obj.name})"
128
+ break
129
+
130
+ for obj in state.objects.values():
131
+ if not (obj.blocking and obj.reachable):
132
+ continue
133
+ if self._nav_enabled() and not self._is_adjacent_to(obj.name):
134
+ continue
135
+ reasons["CLEAR_BLOCKER"] = f"clear blocker ({obj.name})"
136
+ break
137
+ return reasons
138
+
139
+ def _deadline_status(self) -> dict[str, int]:
140
+ status = {}
141
+ deadlines = getattr(self._scenario_cfg, "deadlines", {}) or {}
142
+ for obj_name, deadline_step in deadlines.items():
143
+ obj = self.sim.get_state().objects.get(obj_name)
144
+ target_bin = self._required_placements.get(obj_name)
145
+ done = bool(obj and target_bin and obj.in_bin == target_bin)
146
+ if done:
147
+ continue
148
+ status[obj_name] = int(deadline_step - self._steps)
149
+ return status
150
+
151
+ def _observability_map(self) -> list[str]:
152
+ gx, gy = self._gripper_cell()
153
+ lines = []
154
+ for y in range(3, -4, -1):
155
+ row = []
156
+ for x in range(-3, 4):
157
+ if (x, y) == (gx, gy):
158
+ row.append("G")
159
+ else:
160
+ row.append(".")
161
+ lines.append("".join(row))
162
+ return lines
163
+
164
+ def _nav_step_toward(self, target: tuple[int, int]) -> str:
165
+ """Navigate one step toward target cell (navigates all the way onto the cell)."""
166
+ gx, gy = self._gripper_cell()
167
+ tx, ty = target
168
+ dx, dy = tx - gx, ty - gy
169
+ # Already at target cell β€” nothing to do
170
+ if dx == 0 and dy == 0:
171
+ return "SCAN_SCENE"
172
+ # Move along the longer axis first
173
+ if abs(dx) >= abs(dy):
174
+ return "MOVE_EAST" if dx > 0 else "MOVE_WEST"
175
+ return "MOVE_NORTH" if dy > 0 else "MOVE_SOUTH"
176
+
177
+ # ── Public interface ────────────────────────────────────────────────
178
+
179
+ def reset(self) -> Observation:
180
+ self._reset_internal()
181
+ return self._build_obs(last_action=None, last_result=None)
182
+
183
+ def step(self, action: str, reasoning: str = "") -> StepResult:
184
+ """
185
+ action: the high-level action string
186
+ reasoning: optional <think>...</think> chain-of-thought from the model.
187
+ Rewarded if it mentions the right objects and constraints.
188
+ """
189
+ if self._done:
190
+ raise RuntimeError("Episode is done. Call reset() first.")
191
+
192
+ # Inject mid-task instruction changes β€” can fire at multiple steps
193
+ change_steps = getattr(self.cfg.task, 'mid_task_change_steps', [self.cfg.task.mid_task_change_step])
194
+ if (self.cfg.task.mid_task_change_prob > 0
195
+ and self._steps in change_steps
196
+ and self._steps not in self._changes_applied
197
+ and random.random() < self.cfg.task.mid_task_change_prob
198
+ and not self._done):
199
+ self._apply_mid_task_change()
200
+ self._changes_applied.add(self._steps)
201
+
202
+ pre_holding = self.sim.get_state().holding
203
+ # Snapshot reachability BEFORE execution so reasoning bonus can check the
204
+ # pre-action state (e.g. "blue is blocking red" is true before CLEAR_BLOCKER fires).
205
+ pre_state_snapshot = {
206
+ name: {"reachable": obj.reachable, "blocking": obj.blocking}
207
+ for name, obj in self.sim.get_state().objects.items()
208
+ }
209
+ valid_now = self._valid_actions()
210
+ invalid_reason = None
211
+ if action not in valid_now:
212
+ raw_result = "FAILED_INVALID"
213
+ reasons = self._valid_actions_with_reasons()
214
+ if reasons:
215
+ invalid_reason = "invalid_now; choose one of: " + ", ".join(sorted(reasons.keys()))
216
+ else:
217
+ raw_result = self.sim.execute(action)
218
+ result = self._apply_noise(action, raw_result)
219
+
220
+ if result == "FAILED_SLIP" and raw_result == "SUCCESS" and action == "PICK":
221
+ state = self.sim.get_state()
222
+ if state.holding:
223
+ state.objects[state.holding].is_held = False
224
+ state.holding = None
225
+
226
+ # SCAN reveals hidden traits of all currently reachable objects
227
+ if action == "SCAN_SCENE" and result == "SUCCESS":
228
+ self._scanned = True
229
+ hidden = getattr(self._scenario_cfg, 'hidden_traits', {}) or {}
230
+ state = self.sim.get_state()
231
+ for obj_name, trait in hidden.items():
232
+ obj = state.objects.get(obj_name)
233
+ if obj and (obj.reachable or obj.in_bin is not None or obj.is_held):
234
+ self._revealed_traits[obj_name] = trait
235
+
236
+ # FAILED_FRAGILE: picking an unscanned fragile object damages it
237
+ if (result == "SUCCESS" and action == "PICK"
238
+ and getattr(self.cfg.task, 'require_scan_for_traits', False)):
239
+ state = self.sim.get_state()
240
+ picked = state.holding
241
+ hidden = getattr(self._scenario_cfg, 'hidden_traits', {}) or {}
242
+ if picked and hidden.get(picked) == "fragile" and picked not in self._revealed_traits:
243
+ # Object is fragile but agent never scanned β€” it breaks
244
+ state.objects[picked].is_held = False
245
+ state.holding = None
246
+ result = "FAILED_FRAGILE"
247
+
248
+ self._apply_world_drift()
249
+ self._action_history.append(action)
250
+ self._last_action = action
251
+ self._last_result = result
252
+
253
+ self._steps += 1
254
+ reward = self._compute_reward(action, result, pre_holding=pre_holding,
255
+ pre_state_snapshot=pre_state_snapshot)
256
+ reward += self._reasoning_bonus(reasoning, action, result,
257
+ pre_state_snapshot=pre_state_snapshot)
258
+ self._cumulative_reward += reward
259
+ self._update_planning_state(action, result)
260
+
261
+ # Oracle hint for logging / observation
262
+ oracle = self._oracle_action()
263
+
264
+ if self.cfg.log.log_every_step:
265
+ self.logger.log_step(
266
+ step=self._steps,
267
+ action=action,
268
+ result=result,
269
+ reward=reward,
270
+ cumulative_reward=self._cumulative_reward,
271
+ valid_actions=self._valid_actions(),
272
+ oracle_action=oracle if self.cfg.obs.include_oracle_hint else None,
273
+ holding=self.sim.get_state().holding,
274
+ n_failures=len(self._known_failures),
275
+ n_subgoals=len(self._completed_subgoals),
276
+ )
277
+
278
+ done = self._check_done()
279
+ if done:
280
+ ep = self.logger.end_episode(success=self._all_goals_complete())
281
+ self.logger.metrics._current_difficulty = self.cfg.log.export_path # track level
282
+
283
+ obs = self._build_obs(last_action=action, last_result=result)
284
+ return StepResult(
285
+ observation=obs,
286
+ reward=reward,
287
+ done=done,
288
+ info={
289
+ "result": result,
290
+ "step": self._steps,
291
+ "oracle_action": oracle,
292
+ "valid_actions": self._valid_actions(),
293
+ "action_preconditions": self._valid_actions_with_reasons(),
294
+ "distance_to_next_goal": self._distance_to_next_goal(),
295
+ "deadline_status": self._deadline_status(),
296
+ "invalid_reason": invalid_reason,
297
+ "goal_progress": self._goal_progress(),
298
+ "mid_task_changed": (self._steps - 1) in self._changes_applied,
299
+ "cumulative_reward": self._cumulative_reward,
300
+ },
301
+ )
302
+
303
+ @property
304
+ def metrics(self):
305
+ return self.logger.metrics.to_dict()
306
+
307
+ # ── Internal reset ──────────────────────────────────────────────────
308
+
309
+ def _reset_internal(self):
310
+ tc = self.cfg.task
311
+ force_blocked = random.random() < tc.force_blocked_prob
312
+ scenario_cfg = randomize_scenario(
313
+ n_objects=random.randint(tc.n_objects_min, tc.n_objects_max),
314
+ n_targets=random.randint(tc.n_targets_min, tc.n_targets_max),
315
+ n_blockers=random.randint(tc.n_blockers_min, tc.n_blockers_max),
316
+ force_blocked=force_blocked,
317
+ scenario_pack=getattr(tc, "scenario_pack", "default"),
318
+ )
319
+
320
+ self.sim._build_state_from_config(scenario_cfg)
321
+ self._scenario_cfg = scenario_cfg
322
+
323
+ self._steps = 0
324
+ self._done = False
325
+ self._scanned = False
326
+ self._mid_task_changed = False
327
+ self._changes_applied: set[int] = set() # which change-steps have fired
328
+ self._cumulative_reward = 0.0
329
+ self._action_history = []
330
+ self._last_action = None
331
+ self._last_result = None
332
+ self._completed_subgoals: list[str] = []
333
+ self._known_failures: list[str] = []
334
+ self._active_constraints: list[str] = ([scenario_cfg.constraint]
335
+ if scenario_cfg.constraint else [])
336
+ self._instruction = scenario_cfg.instruction
337
+ self._required_placements: dict[str, str] = dict(scenario_cfg.targets)
338
+ # Per-object trait reveal: populated by SCAN_SCENE, enforced in PICK
339
+ self._revealed_traits: dict[str, str] = {}
340
+
341
+ self._episode_id += 1
342
+ self.logger.begin_episode(
343
+ episode_id=self._episode_id,
344
+ instruction=self._instruction,
345
+ difficulty="custom",
346
+ n_objects=len(scenario_cfg.objects),
347
+ n_blockers=len(scenario_cfg.blockers),
348
+ n_targets=len(scenario_cfg.targets),
349
+ )
350
+
351
+ # ── Reward ──────────────────────────────────────────────────────────
352
+
353
+ def _reasoning_bonus(self, reasoning: str, action: str, result: str,
354
+ pre_state_snapshot: Optional[dict] = None) -> float:
355
+ """
356
+ Bonus for reasoning that mentions relevant objects, constraints, and plans.
357
+
358
+ Uses pre-action state snapshot so CLEAR_BLOCKER reasoning ("X is blocking Y")
359
+ is rewarded correctly even though the blocker is already gone post-execution.
360
+
361
+ The cap scales with reasoning length β€” longer, more detailed chain-of-thought
362
+ can earn proportionally more reward (up to a hard ceiling of 1.5).
363
+ """
364
+ if not reasoning or len(reasoning) < 10:
365
+ return 0.0
366
+ bonus = 0.0
367
+ r = reasoning.lower()
368
+
369
+ # Use pre-action state for blocked-object checks so CLEAR_BLOCKER reasoning
370
+ # ("blue is blocking red") is rewarded even though the blocker is now cleared.
371
+ blocked_before = set()
372
+ if pre_state_snapshot:
373
+ for name, snap in pre_state_snapshot.items():
374
+ if not snap["reachable"]:
375
+ blocked_before.add(name.replace("_block", "").lower())
376
+ else:
377
+ for obj in self.sim.get_state().objects.values():
378
+ if not obj.reachable:
379
+ blocked_before.add(obj.name.replace("_block", "").lower())
380
+
381
+ # Mentions blocked objects correctly
382
+ for color in blocked_before:
383
+ if color in r:
384
+ bonus += 0.1
385
+
386
+ # Mentions the target object and correct bin
387
+ for obj_name, bin_name in self._required_placements.items():
388
+ color = obj_name.replace("_block", "")
389
+ if color in r and f"bin {bin_name.lower()}" in r:
390
+ bonus += 0.2
391
+
392
+ # Mentions relevant constraint
393
+ for c in self._active_constraints:
394
+ if c.replace("_", " ") in r:
395
+ bonus += 0.1
396
+
397
+ # Mentions the chosen action or its intent
398
+ action_words = {
399
+ "CLEAR_BLOCKER": ["clear", "move", "push", "unblock"],
400
+ "PICK": ["pick", "grab", "grasp", "lift"],
401
+ "PLACE_BIN_A": ["place", "put", "bin a"],
402
+ "PLACE_BIN_B": ["place", "put", "bin b"],
403
+ "SCAN_SCENE": ["scan", "look", "inspect", "check"],
404
+ }
405
+ for word in action_words.get(action, []):
406
+ if word in r:
407
+ bonus += 0.1
408
+ break
409
+
410
+ # Bonus for explicit multi-step plan in reasoning ("plan:" or "β†’" sequence)
411
+ if "plan:" in r or (" β†’ " in reasoning):
412
+ bonus += 0.15
413
+
414
+ # Token-length scaling: longer reasoning unlocks a higher reward cap.
415
+ # Every 50 chars of reasoning raises the cap by 0.1, up to max 1.5.
416
+ # This rewards richer chain-of-thought without rewarding padding.
417
+ length_scale = min(1.5, 0.5 + 0.1 * (len(reasoning) // 50))
418
+ return min(bonus, length_scale)
419
+
420
+ def _compute_reward(self, action: str, result: str, pre_holding: Optional[str] = None,
421
+ pre_state_snapshot: Optional[dict] = None) -> float:
422
+ w = self.cfg.reward
423
+ r = w.step_cost
424
+
425
+ if self._nav_enabled():
426
+ if action in ("MOVE_NORTH", "MOVE_SOUTH", "MOVE_EAST", "MOVE_WEST"):
427
+ r -= 0.03
428
+ if action in ("ROTATE_LEFT", "ROTATE_RIGHT"):
429
+ r -= 0.02
430
+
431
+ if result not in ("SUCCESS", "PARTIAL_CLEAR"):
432
+ failure_key = f"{action}:{result}"
433
+ if result == "FAILED_FRAGILE":
434
+ # Larger specific penalty β€” agent should have scanned first
435
+ r += w.fragile_pick_penalty
436
+ r += w.repeated_failure if failure_key in self._known_failures else w.first_failure
437
+ else:
438
+ r += w.repeated_failure if failure_key in self._known_failures else w.first_failure
439
+ return r
440
+
441
+ if action == "CLEAR_BLOCKER":
442
+ r += w.blocker_cleared
443
+ if action == "PICK":
444
+ held = self.sim.get_state().holding
445
+ # Reward only picks that move a required-yet-unfinished target.
446
+ if held and held in self._required_placements:
447
+ target_bin = self._required_placements[held]
448
+ obj = self.sim.get_state().objects.get(held)
449
+ already_done = bool(obj and obj.in_bin == target_bin)
450
+ if not already_done:
451
+ r += w.successful_pick
452
+ else:
453
+ r += w.wrong_pick
454
+ else:
455
+ r += w.wrong_pick
456
+ if action in ("PLACE_BIN_A", "PLACE_BIN_B"):
457
+ bin_name = "A" if action == "PLACE_BIN_A" else "B"
458
+ placed_obj = pre_holding
459
+ correct = bool(placed_obj and self._required_placements.get(placed_obj) == bin_name)
460
+ r += w.correct_placement if correct else w.wrong_bin
461
+ if not correct and self._active_constraints:
462
+ r += w.constraint_violation # extra hit for constraint violation
463
+ if action == "SCAN_SCENE":
464
+ if not self._scanned:
465
+ r += w.useful_scan # first scan only
466
+ # Penalize avoidable scans β€” but NOT if scanning is currently needed
467
+ # to reveal a required hidden trait (fragile/heavy) before picking.
468
+ scan_is_needed = False
469
+ if getattr(self.cfg.task, 'require_scan_for_traits', False):
470
+ hidden = getattr(self._scenario_cfg, 'hidden_traits', {}) or {}
471
+ state = self.sim.get_state()
472
+ for obj_name in self._required_placements:
473
+ obj = state.objects.get(obj_name)
474
+ if (obj and obj.reachable and obj.in_bin is None
475
+ and obj_name in hidden
476
+ and obj_name not in self._revealed_traits):
477
+ scan_is_needed = True
478
+ break
479
+ if not scan_is_needed:
480
+ valid_now = self._valid_actions()
481
+ if any(a != "SCAN_SCENE" for a in valid_now):
482
+ r += w.useless_action
483
+ # Penalize scan loops with increasing severity regardless.
484
+ streak = 0
485
+ for a in reversed(self._action_history):
486
+ if a == "SCAN_SCENE":
487
+ streak += 1
488
+ else:
489
+ break
490
+ if streak > 0:
491
+ r -= min(1.5, 0.25 * streak)
492
+
493
+ # First recovery after failure
494
+ if self._known_failures and result == "SUCCESS" and action != "SCAN_SCENE":
495
+ if "recovery" not in self._completed_subgoals:
496
+ r += w.recovery_after_failure
497
+
498
+ # Terminal
499
+ if self._all_goals_complete():
500
+ r += w.task_complete
501
+ steps_saved = self.cfg.task.max_steps - self._steps
502
+ r += w.efficiency_bonus_max * (steps_saved / self.cfg.task.max_steps)
503
+ self._done = True
504
+ elif self._steps >= self.cfg.task.max_steps:
505
+ # Timeout: explicit penalty so the model learns completing > timing out.
506
+ r += w.timeout_failure
507
+
508
+ # Deadline pressure: penalize each overdue unfinished target.
509
+ for obj_name, remaining in self._deadline_status().items():
510
+ if remaining < 0:
511
+ r += (w.missed_deadline * 0.2)
512
+
513
+ return r
514
+
515
+ # ── Planning state ──────────────────────────────────────────────────
516
+
517
+ def _update_planning_state(self, action: str, result: str):
518
+ if result not in ("SUCCESS", "PARTIAL_CLEAR"):
519
+ key = f"{action}:{result}"
520
+ if key not in self._known_failures:
521
+ self._known_failures.append(key)
522
+ else:
523
+ if action == "CLEAR_BLOCKER" and "cleared_blocker" not in self._completed_subgoals:
524
+ self._completed_subgoals.append("cleared_blocker")
525
+ if (self._known_failures and result == "SUCCESS"
526
+ and "recovery" not in self._completed_subgoals):
527
+ self._completed_subgoals.append("recovery")
528
+
529
+ state = self.sim.get_state()
530
+ for obj_name, bin_name in self._required_placements.items():
531
+ key = f"placed_{obj_name}_in_bin_{bin_name}"
532
+ if key not in self._completed_subgoals:
533
+ obj = state.objects.get(obj_name)
534
+ if obj and obj.in_bin == bin_name:
535
+ self._completed_subgoals.append(key)
536
+
537
+ if self._steps >= self.cfg.task.max_steps:
538
+ self._done = True
539
+
540
+ def _check_done(self) -> bool:
541
+ return self._done
542
+
543
+ def _all_goals_complete(self) -> bool:
544
+ state = self.sim.get_state()
545
+ for name, bin_name in self._required_placements.items():
546
+ obj = state.objects.get(name)
547
+ if not obj or obj.in_bin != bin_name:
548
+ return False
549
+ return True
550
+
551
+ # ── Noise / dynamics ────────────────────────────────────────────────
552
+
553
+ def _apply_noise(self, action: str, result: str) -> str:
554
+ if result != "SUCCESS":
555
+ return result
556
+ rc = self.cfg.realism
557
+ if action == "PICK" and random.random() < rc.grasp_fail_prob:
558
+ return "FAILED_SLIP"
559
+ if action == "CLEAR_BLOCKER" and random.random() < rc.clear_partial_prob:
560
+ return "PARTIAL_CLEAR"
561
+ return result
562
+
563
+ def _apply_world_drift(self):
564
+ if random.random() < self.cfg.realism.object_drift_prob:
565
+ state = self.sim.get_state()
566
+ reachable = [o for o in state.objects.values()
567
+ if o.reachable and not o.is_held and o.in_bin is None]
568
+ if reachable:
569
+ obj = random.choice(reachable)
570
+ obj.reachable = False
571
+
572
+ # ── Mid-task instruction change ─────────────────────────────────────
573
+
574
+ def _apply_mid_task_change(self):
575
+ """Swap one target's bin. Agent must replan."""
576
+ from .robosim.randomizer import BINS
577
+ targets = list(self._required_placements.items())
578
+ if not targets:
579
+ return
580
+ obj_name, old_bin = random.choice(targets)
581
+ new_bin = [b for b in BINS if b != old_bin][0]
582
+ self._required_placements[obj_name] = new_bin
583
+ self._mid_task_changed = True
584
+ # Rebuild instruction to reflect change
585
+ from .robosim.randomizer import OBJECT_COLORS
586
+ color = OBJECT_COLORS.get(obj_name, obj_name.replace("_block", ""))
587
+ change_note = f" [UPDATE: place the {color} block in bin {new_bin} instead.]"
588
+ self._instruction = self._instruction + change_note
589
+ self._active_constraints.append("bin_change")
590
+
591
+ # ── Valid actions ────────────────────────────────────────────────────
592
+
593
+ def _valid_actions(self) -> list[str]:
594
+ """Which actions make sense right now given the current state."""
595
+ state = self.sim.get_state()
596
+ valid = ["SCAN_SCENE"]
597
+
598
+ if self._nav_enabled():
599
+ valid += ["MOVE_NORTH", "MOVE_SOUTH", "MOVE_EAST", "MOVE_WEST", "ROTATE_LEFT", "ROTATE_RIGHT"]
600
+ else:
601
+ for obj in state.objects.values():
602
+ if obj.reachable and not obj.is_held and obj.in_bin is None:
603
+ color = obj.name.replace("_block", "").upper()
604
+ valid.append(f"MOVE_TO_{color}")
605
+
606
+ if state.holding:
607
+ valid += ["PLACE_BIN_A", "PLACE_BIN_B"]
608
+ else:
609
+ has_pick = False
610
+ for obj in state.objects.values():
611
+ if self._can_pick_object(obj.name):
612
+ has_pick = True
613
+ break
614
+ if has_pick:
615
+ valid.append("PICK")
616
+
617
+ if not state.holding: # can't clear a blocker while holding something
618
+ for obj in state.objects.values():
619
+ if not (obj.blocking and obj.reachable):
620
+ continue
621
+ if self._nav_enabled() and not self._is_adjacent_to(obj.name):
622
+ continue
623
+ valid.append("CLEAR_BLOCKER")
624
+ break
625
+
626
+ return valid
627
+
628
+ # ── Goal progress ────────────────────────────────────────────────────
629
+
630
+ def _goal_progress(self) -> float:
631
+ if not self._required_placements:
632
+ return 1.0
633
+ state = self.sim.get_state()
634
+ done = sum(1 for name, bin_ in self._required_placements.items()
635
+ if state.objects.get(name) and state.objects[name].in_bin == bin_)
636
+ return done / len(self._required_placements)
637
+
638
+ # ── Oracle hint ──────────────────────────────────────────────────────
639
+
640
+ def _oracle_action(self) -> Optional[str]:
641
+ """Scripted optimal action for current state (teaching signal)."""
642
+ state = self.sim.get_state()
643
+ failures = set(self._known_failures)
644
+ completed = set(self._completed_subgoals)
645
+ last_action = self._last_action
646
+ last_result = self._last_result
647
+
648
+ def can_clear_now() -> bool:
649
+ for obj in state.objects.values():
650
+ if not (obj.blocking and obj.reachable):
651
+ continue
652
+ if self._nav_enabled() and not self._is_adjacent_to(obj.name):
653
+ continue
654
+ return True
655
+ return False
656
+
657
+ def blocker_for_target(target_name: str) -> Optional[str]:
658
+ for obj in state.objects.values():
659
+ if obj.blocking == target_name and obj.reachable and obj.in_bin is None:
660
+ return obj.name
661
+ return None
662
+
663
+ # If scan is required and next pick target is fragile+unscanned β†’ scan first
664
+ if getattr(self.cfg.task, 'require_scan_for_traits', False):
665
+ hidden = getattr(self._scenario_cfg, 'hidden_traits', {}) or {}
666
+ for obj_name in self._required_placements:
667
+ obj = state.objects.get(obj_name)
668
+ if (obj and obj.reachable and obj.in_bin is None
669
+ and hidden.get(obj_name) == "fragile"
670
+ and obj_name not in self._revealed_traits):
671
+ return "SCAN_SCENE"
672
+
673
+ # Just moved to something β†’ pick it
674
+ if last_action and last_action.startswith("MOVE_TO") and last_result == "SUCCESS":
675
+ return "PICK"
676
+
677
+ # Holding β†’ place correctly
678
+ if state.holding:
679
+ target_bin = self._required_placements.get(state.holding)
680
+ if target_bin:
681
+ return f"PLACE_BIN_{target_bin}"
682
+ return "PLACE_BIN_A"
683
+
684
+ # Failed to reach a target β†’ clear or re-navigate
685
+ if any(f.startswith("MOVE_TO") and "FAILED_BLOCKED" in f for f in failures) and can_clear_now():
686
+ return "CLEAR_BLOCKER"
687
+ # PICK:FAILED_EMPTY means gripper is not adjacent to anything pickable.
688
+ # In nav mode, re-navigate to the next target instead of looping on CLEAR_BLOCKER.
689
+ if "PICK:FAILED_EMPTY" in failures:
690
+ if self._nav_enabled():
691
+ # Fall through to the placement-order loop below which will nav correctly.
692
+ pass
693
+ elif can_clear_now():
694
+ return "CLEAR_BLOCKER"
695
+
696
+ # Work through required placements in order
697
+ for obj_name, bin_name in self._required_placements.items():
698
+ key = f"placed_{obj_name}_in_bin_{bin_name}"
699
+ if key in completed:
700
+ continue
701
+ obj = state.objects.get(obj_name)
702
+ if not obj or obj.in_bin:
703
+ continue
704
+ if obj.reachable:
705
+ if self._nav_enabled():
706
+ obj_cell = self._object_cell(obj_name)
707
+ gripper_cell = self._gripper_cell()
708
+ # Navigate all the way to the object's cell so PICK grabs
709
+ # the right object (not a closer distractor).
710
+ if obj_cell is not None and gripper_cell == obj_cell:
711
+ return "PICK"
712
+ if obj_cell is not None:
713
+ return self._nav_step_toward(obj_cell)
714
+ color = obj_name.replace("_block", "").upper()
715
+ return f"MOVE_TO_{color}"
716
+ blocker = blocker_for_target(obj_name)
717
+ if blocker is not None:
718
+ if self._nav_enabled():
719
+ if self._is_adjacent_to(blocker):
720
+ return "CLEAR_BLOCKER"
721
+ blocker_cell = self._object_cell(blocker)
722
+ if blocker_cell is not None:
723
+ return self._nav_step_toward(blocker_cell)
724
+ return "CLEAR_BLOCKER"
725
+ if can_clear_now():
726
+ return "CLEAR_BLOCKER"
727
+ return "SCAN_SCENE"
728
+
729
+ return "SCAN_SCENE"
730
+
731
+ # ── Observation ──────────────────────────────────────────────────────
732
+
733
+ def _build_obs(self, last_action: Optional[str], last_result: Optional[str]) -> Observation:
734
+ state = self.sim.get_state()
735
+ oc = self.cfg.obs
736
+
737
+ visible = []
738
+ for obj in state.objects.values():
739
+ # Apply observation noise
740
+ reachable = obj.reachable
741
+ if (not self._scanned and
742
+ random.random() < self.cfg.realism.hidden_object_prob):
743
+ reachable = False
744
+ elif (obj.reachable and
745
+ random.random() < self.cfg.realism.reachability_noise):
746
+ reachable = False
747
+
748
+ visible.append(ObjectInfo(
749
+ name=obj.name,
750
+ reachable=reachable,
751
+ location="unknown" if not reachable else "table",
752
+ blocking=obj.blocking,
753
+ in_bin=obj.in_bin,
754
+ is_held=obj.is_held,
755
+ ))
756
+
757
+ # Recent action history
758
+ history = (self._action_history[-oc.include_action_history:]
759
+ if oc.include_action_history > 0 else [])
760
+
761
+ extra = {}
762
+ if oc.include_valid_actions:
763
+ extra["valid_actions"] = self._valid_actions()
764
+ extra["action_preconditions"] = self._valid_actions_with_reasons()
765
+ if oc.include_goal_progress:
766
+ extra["goal_progress"] = round(self._goal_progress(), 2)
767
+ if oc.include_oracle_hint:
768
+ extra["oracle_hint"] = self._oracle_action()
769
+ if oc.include_distance_to_goal:
770
+ remaining = sum(1 for n, b in self._required_placements.items()
771
+ if not (state.objects.get(n) and state.objects[n].in_bin == b))
772
+ extra["goals_remaining"] = remaining
773
+ extra["distance_to_next_goal"] = self._distance_to_next_goal()
774
+ goal_cell = self._next_goal_cell()
775
+ if goal_cell is not None:
776
+ extra["next_target_cell"] = f"{goal_cell[0]},{goal_cell[1]}"
777
+ extra["deadline_status"] = self._deadline_status()
778
+ extra["object_deadlines"] = getattr(self._scenario_cfg, "deadlines", {}) or {}
779
+ extra["observability_map"] = self._observability_map()
780
+ # Show what traits have been revealed so far (empty until agent scans)
781
+ extra["discovered_traits"] = dict(self._revealed_traits)
782
+
783
+ return Observation(
784
+ instruction=self._instruction,
785
+ steps_remaining=self.cfg.task.max_steps - self._steps,
786
+ visible_objects=visible,
787
+ holding=state.holding,
788
+ completed_subgoals=list(self._completed_subgoals),
789
+ known_failures=list(self._known_failures),
790
+ active_constraints=list(self._active_constraints),
791
+ last_action=last_action,
792
+ last_result=last_result,
793
+ action_history=history,
794
+ nav_mode=self._nav_enabled(),
795
+ gripper_cell=f"{self._gripper_cell()[0]},{self._gripper_cell()[1]}",
796
+ gripper_facing=self.sim.get_facing(),
797
+ **extra,
798
+ )