jshah13 commited on
Commit
8c8ea52
·
verified ·
1 Parent(s): 2abbfbd

Upload server/robosim/randomizer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. server/robosim/randomizer.py +198 -0
server/robosim/randomizer.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Domain randomization for the tabletop planning environment.
3
+
4
+ Randomizes everything that can vary in a real tabletop scene:
5
+ - number of objects
6
+ - which object is the target
7
+ - which bin is the target
8
+ - how many blockers, and what they block
9
+ - object positions (within reachable workspace)
10
+ - task instruction (generated from the sampled scene)
11
+ - distractor objects (present but irrelevant to the task)
12
+ - constraint type (fragile first, heavy last, etc.)
13
+
14
+ The model must generalize across all of this — not memorize one layout.
15
+ """
16
+ import random
17
+ import string
18
+ from dataclasses import dataclass, field
19
+ from typing import Optional
20
+
21
+
22
+ OBJECT_NAMES = ["red_block", "blue_block", "green_block", "yellow_block", "purple_block"]
23
+ OBJECT_COLORS = {"red_block": "red", "blue_block": "blue", "green_block": "green",
24
+ "yellow_block": "yellow", "purple_block": "purple"}
25
+ BINS = ["A", "B"]
26
+ CONSTRAINTS = ["fragile_first", "heavy_last", "urgent_first", None, None, None] # None = no constraint
27
+
28
+
29
+ @dataclass
30
+ class ScenarioConfig:
31
+ # Objects actually present in the scene
32
+ objects: list[str] = field(default_factory=list)
33
+
34
+ # Which objects are targets (must be placed in a bin)
35
+ targets: dict[str, str] = field(default_factory=dict) # obj_name -> bin
36
+
37
+ # Blocking relationships: blocker -> blocked
38
+ blockers: dict[str, str] = field(default_factory=dict)
39
+
40
+ # Distractors: present but not part of the task
41
+ distractors: list[str] = field(default_factory=list)
42
+
43
+ # Active constraint
44
+ constraint: Optional[str] = None
45
+
46
+ # Generated instruction string
47
+ instruction: str = ""
48
+
49
+ # Object positions on the table (x, y) — workspace is roughly ±0.25
50
+ positions: dict[str, tuple] = field(default_factory=dict)
51
+ # Hidden traits, revealed via scan or proximity.
52
+ hidden_traits: dict[str, str] = field(default_factory=dict)
53
+ # Optional deadlines in steps for selected target objects.
54
+ deadlines: dict[str, int] = field(default_factory=dict)
55
+
56
+
57
+ def randomize_scenario(
58
+ n_objects: Optional[int] = None,
59
+ n_targets: Optional[int] = None,
60
+ n_blockers: Optional[int] = None,
61
+ force_blocked: bool = False,
62
+ scenario_pack: str = "default",
63
+ ) -> ScenarioConfig:
64
+ """
65
+ Generate a fully randomized scenario.
66
+
67
+ n_objects: total objects on table (default: random 2-5)
68
+ n_targets: how many must be placed in bins (default: random 1-2)
69
+ n_blockers: how many blocking relationships (default: random 0-2)
70
+ force_blocked: always have at least one blocker (good for training recovery)
71
+ """
72
+ # Sample object count
73
+ pack = SCENARIO_PACKS.get(scenario_pack, OBJECT_NAMES)
74
+ total = n_objects or random.randint(2, 5)
75
+ total = min(total, len(pack))
76
+
77
+ # Pick which objects appear
78
+ present = random.sample(pack, total)
79
+
80
+ # Pick targets (subset of present objects)
81
+ max_targets = min(n_targets or random.randint(1, 2), len(present))
82
+ targets_list = random.sample(present, max_targets)
83
+ target_bins = {obj: random.choice(BINS) for obj in targets_list}
84
+
85
+ # Distractors = present but not targets
86
+ distractors = [o for o in present if o not in target_bins]
87
+
88
+ # Build blocking relationships
89
+ n_block = n_blockers if n_blockers is not None else random.randint(0, min(2, len(distractors)))
90
+ if force_blocked:
91
+ n_block = max(1, n_block)
92
+
93
+ blockers = {}
94
+ # A blocker must be a non-target (distractor) blocking a target
95
+ available_blockers = list(distractors)
96
+ available_targets = list(targets_list)
97
+ random.shuffle(available_blockers)
98
+ random.shuffle(available_targets)
99
+ for i in range(min(n_block, len(available_blockers), len(available_targets))):
100
+ blockers[available_blockers[i]] = available_targets[i]
101
+
102
+ # Positions: place targets first, then put blockers in front of them
103
+ positions = {}
104
+ x_slots = [-0.15, 0.0, 0.15, -0.08, 0.08]
105
+ random.shuffle(x_slots)
106
+ slot_idx = 0
107
+ for obj in present:
108
+ if obj in blockers.values():
109
+ # target that gets blocked — place it further back
110
+ positions[obj] = (x_slots[slot_idx % len(x_slots)], -0.05)
111
+ else:
112
+ positions[obj] = (x_slots[slot_idx % len(x_slots)], 0.05)
113
+ slot_idx += 1
114
+
115
+ # Blocker slightly in front of what it blocks
116
+ for blocker, blocked in blockers.items():
117
+ tx, ty = positions[blocked]
118
+ positions[blocker] = (tx + random.uniform(-0.03, 0.03), ty + 0.08)
119
+
120
+ # Constraint
121
+ constraint = random.choice(CONSTRAINTS)
122
+ hidden_traits = {}
123
+ for obj in targets_list:
124
+ # Keep trait labels simple and interpretable for LLM reasoning.
125
+ hidden_traits[obj] = random.choice(["fragile", "heavy", "standard"])
126
+
127
+ deadlines = {}
128
+ if targets_list and random.random() < 0.6:
129
+ urgent_obj = random.choice(targets_list)
130
+ deadlines[urgent_obj] = random.randint(5, 10)
131
+
132
+ # Generate instruction
133
+ instruction = _build_instruction(target_bins, constraint, hidden_traits, deadlines)
134
+
135
+ return ScenarioConfig(
136
+ objects=present,
137
+ targets=target_bins,
138
+ blockers=blockers,
139
+ distractors=distractors,
140
+ constraint=constraint,
141
+ instruction=instruction,
142
+ positions=positions,
143
+ hidden_traits=hidden_traits,
144
+ deadlines=deadlines,
145
+ )
146
+
147
+
148
+ def _build_instruction(target_bins: dict[str, str], constraint: Optional[str],
149
+ hidden_traits: dict[str, str], deadlines: dict[str, int]) -> str:
150
+ parts = []
151
+ for obj, bin_ in target_bins.items():
152
+ display = OBJECT_COLORS.get(obj, obj.replace("_block", ""))
153
+ # Use bare display name for non-block objects (professional packs)
154
+ label = f"the {display} block" if obj.endswith("_block") else f"the {display}"
155
+ parts.append(f"{label} in bin {bin_}")
156
+
157
+ if len(parts) == 1:
158
+ base = f"Place {parts[0]}."
159
+ else:
160
+ base = "Place " + ", then ".join(parts) + "."
161
+
162
+ if constraint == "fragile_first":
163
+ base += " Handle fragile items first."
164
+ elif constraint == "heavy_last":
165
+ base += " Move heavy items last."
166
+ elif constraint == "urgent_first":
167
+ base += " Prioritize urgent items first."
168
+
169
+ if deadlines:
170
+ for obj, step in deadlines.items():
171
+ display = OBJECT_COLORS.get(obj, obj.replace("_block", ""))
172
+ label = f"the {display} block" if obj.endswith("_block") else f"the {display}"
173
+ base += f" Place {label} by step {step}."
174
+
175
+ if hidden_traits:
176
+ base += " Some object traits are hidden until you inspect the scene."
177
+
178
+ return base
179
+ SCENARIO_PACKS = {
180
+ "default": OBJECT_NAMES,
181
+ # Professional task skins — same mechanics, domain-appropriate names
182
+ "warehouse": ["fragile_package", "heavy_pallet", "urgent_parcel", "standard_box", "hazmat_drum"],
183
+ "pharmacy": ["morphine_vial", "saline_bag", "insulin_pen", "blood_sample", "contrast_agent"],
184
+ "lab": ["reagent_alpha", "catalyst_beta", "sample_gamma", "solvent_delta", "enzyme_epsilon"],
185
+ }
186
+
187
+ # Color/display name for each object in each pack
188
+ OBJECT_COLORS.update({
189
+ "fragile_package": "fragile package", "heavy_pallet": "heavy pallet",
190
+ "urgent_parcel": "urgent parcel", "standard_box": "standard box",
191
+ "hazmat_drum": "hazmat drum",
192
+ "morphine_vial": "morphine vial", "saline_bag": "saline bag",
193
+ "insulin_pen": "insulin pen", "blood_sample": "blood sample",
194
+ "contrast_agent": "contrast agent",
195
+ "reagent_alpha": "reagent-α", "catalyst_beta": "catalyst-β",
196
+ "sample_gamma": "sample-γ", "solvent_delta": "solvent-δ",
197
+ "enzyme_epsilon": "enzyme-ε",
198
+ })