cccat6 commited on
Commit
ccf9f1b
·
verified ·
1 Parent(s): dd3258b

Update FlowMo-WM code and static flow protocol

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +36 -0
  3. .hfignore +15 -0
  4. README.md +3 -4
  5. data/paper/dataset_card.md +33 -14
  6. data/paper/diagnostic_seen_flow.npz +0 -3
  7. data/paper/generation_config.json +9 -14
  8. data/paper/test_unseen_boat_params.npz +0 -3
  9. data/paper/test_unseen_flow.npz +0 -3
  10. data/paper/train.npz +0 -3
  11. driftwm/data/generate.py +11 -19
  12. driftwm/sim/env.py +0 -1
  13. driftwm/sim/flow.py +195 -123
  14. driftwm/sim/sanity.py +1 -1
  15. experiments/EXPERIMENT_MATRIX.md +25 -10
  16. experiments/README.md +4 -2
  17. experiments/TASK_PLAN.md +22 -9
  18. experiments/docs/EXPERIMENT_PROTOCOL.md +31 -13
  19. experiments/evaluate_flowmo_latent_probes.py +1 -3
  20. experiments/evaluate_image_planning.py +13 -27
  21. experiments/evaluate_image_world_models.py +1 -1
  22. experiments/flowmo/checkpoint/paper.pt +0 -3
  23. experiments/flowmo/checkpoint/paper_step_002000.pt +0 -3
  24. experiments/flowmo/checkpoint/paper_step_004000.pt +0 -3
  25. experiments/flowmo/checkpoint/paper_step_006000.pt +0 -3
  26. experiments/flowmo/checkpoint/paper_step_008000.pt +0 -3
  27. experiments/flowmo/checkpoint/paper_step_010000.pt +0 -3
  28. experiments/flowmo/checkpoint/paper_step_012000.pt +0 -3
  29. experiments/flowmo/checkpoint/paper_step_014000.pt +0 -3
  30. experiments/flowmo/checkpoint/paper_step_016000.pt +0 -3
  31. experiments/flowmo/checkpoint/paper_step_018000.pt +0 -3
  32. experiments/flowmo/checkpoint/paper_step_020000.pt +0 -3
  33. experiments/flowmo/result/paper_training.json +0 -43
  34. experiments/flowmo/result/paper_training_trace.jsonl +0 -100
  35. experiments/flowmo/result/parameter_count.json +0 -11
  36. experiments/leworldmodel/checkpoint/paper.pt +0 -3
  37. experiments/leworldmodel/checkpoint/paper_step_002000.pt +0 -3
  38. experiments/leworldmodel/checkpoint/paper_step_004000.pt +0 -3
  39. experiments/leworldmodel/checkpoint/paper_step_006000.pt +0 -3
  40. experiments/leworldmodel/checkpoint/paper_step_008000.pt +0 -3
  41. experiments/leworldmodel/checkpoint/paper_step_010000.pt +0 -3
  42. experiments/leworldmodel/checkpoint/paper_step_012000.pt +0 -3
  43. experiments/leworldmodel/checkpoint/paper_step_014000.pt +0 -3
  44. experiments/leworldmodel/checkpoint/paper_step_016000.pt +0 -3
  45. experiments/leworldmodel/checkpoint/paper_step_018000.pt +0 -3
  46. experiments/leworldmodel/checkpoint/paper_step_020000.pt +0 -3
  47. experiments/leworldmodel/result/paper_training.json +0 -43
  48. experiments/leworldmodel/result/paper_training_trace.jsonl +0 -100
  49. experiments/leworldmodel/result/parameter_count.json +0 -7
  50. experiments/planet/checkpoint/paper.pt +0 -3
.gitattributes CHANGED
@@ -259,3 +259,6 @@ experiments/reports/paper_planning/gifs/image_planning_flowmo_inferred_triangle_
259
  experiments/reports/paper_planning/gifs/image_planning_tdmpc2_inferred_triangle_passive_to_active_ep000.gif filter=lfs diff=lfs merge=lfs -text
260
  experiments/reports/paper_planning/gifs/image_planning_pid_los_controller_inferred_triangle_reach_uniform_ep000.gif filter=lfs diff=lfs merge=lfs -text
261
  experiments/reports/paper_planning/gifs/image_planning_tdmpc2_inferred_triangle_counterflow_ep002.gif filter=lfs diff=lfs merge=lfs -text
 
 
 
 
259
  experiments/reports/paper_planning/gifs/image_planning_tdmpc2_inferred_triangle_passive_to_active_ep000.gif filter=lfs diff=lfs merge=lfs -text
260
  experiments/reports/paper_planning/gifs/image_planning_pid_los_controller_inferred_triangle_reach_uniform_ep000.gif filter=lfs diff=lfs merge=lfs -text
261
  experiments/reports/paper_planning/gifs/image_planning_tdmpc2_inferred_triangle_counterflow_ep002.gif filter=lfs diff=lfs merge=lfs -text
262
+ experiments/reports/figures/flow_family_atlas.png filter=lfs diff=lfs merge=lfs -text
263
+ experiments/reports/figures/flow_family_panels/clean/random_fourier.png filter=lfs diff=lfs merge=lfs -text
264
+ experiments/reports/figures/flow_family_panels/labeled/random_fourier.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ .pytest_cache/
4
+
5
+ # Legacy and generated experiment artifacts
6
+ outputs/
7
+ logs/
8
+ external/
9
+ configs/
10
+ scripts/
11
+ docs/
12
+
13
+ # Large generated datasets and image caches
14
+ data/*.npz
15
+ !data/paper/
16
+ !data/paper/*.npz
17
+ experiments/shared/result/image_cache*/
18
+ experiments/shared/result/image_observation_sweep/
19
+ experiments/shared/result/image_scale_sweep/
20
+
21
+ # Checkpoints are generated by smoke and full training runs
22
+ experiments/*/checkpoint/*.pt
23
+ experiments/*/result/*.json
24
+ experiments/*/result/*.jsonl
25
+
26
+ # Per-run planning GIF/JSON outputs are regenerated by the pipeline
27
+ experiments/reports/image_paper_160_v2p5_planning/
28
+ experiments/reports/image_paper_160_v2p5_planning_smoke/
29
+ experiments/reports/*.json
30
+ experiments/reports/*.md
31
+ !experiments/reports/README.md
32
+ experiments/gifs/*.gif
33
+ experiments/figures/*.png
34
+ experiments/figures/*.pdf
35
+ experiments/tables/*.md
36
+ !experiments/tables/README.md
.hfignore CHANGED
@@ -8,6 +8,21 @@ outputs/
8
  external/
9
  flowmo_remote_outputs/
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  experiments/shared/result/image_cache*/
12
  experiments/shared/result/image_observation_sweep/
13
  experiments/shared/result/image_scale_sweep/
 
8
  external/
9
  flowmo_remote_outputs/
10
 
11
+ data/*.npz
12
+ data/paper/*.npz
13
+ experiments/*/checkpoint/*.pt
14
+ experiments/*/result/*.json
15
+ experiments/*/result/*.jsonl
16
+ experiments/reports/*.json
17
+ experiments/reports/*.md
18
+ !experiments/reports/README.md
19
+ experiments/reports/paper_planning/
20
+ experiments/gifs/*.gif
21
+ experiments/figures/*.png
22
+ experiments/figures/*.pdf
23
+ experiments/tables/*.md
24
+ !experiments/tables/README.md
25
+
26
  experiments/shared/result/image_cache*/
27
  experiments/shared/result/image_observation_sweep/
28
  experiments/shared/result/image_scale_sweep/
README.md CHANGED
@@ -13,7 +13,7 @@ tags:
13
 
14
  FlowMo is a clean-image world-model benchmark for surface vehicles under hidden water drift. The proposed model separates short-history endogenous state and momentum from long-history exogenous drift context, then evaluates whether that factorization improves rollout prediction and closed-loop planning.
15
 
16
- This repository currently contains the public code, tests, configuration, and canonical paper datasets. Official checkpoints, generated GIFs, tables, and full experiment reports will be uploaded after the paper-scale training and evaluation runs finish.
17
 
18
  ## Paper Pipeline
19
 
@@ -26,15 +26,14 @@ python -m experiments.run_paper_image_pipeline
26
  The default command trains all learned world models, evaluates prediction, runs FlowMo latent probes, evaluates planning on all configured tasks and boat morphologies, generates GIFs, and writes:
27
 
28
  ```text
29
- experiments/reports/paper_prediction_seen_flow_diagnostic.json
30
- experiments/reports/paper_prediction_unseen_flow.json
31
- experiments/reports/paper_prediction_unseen_boat_params.json
32
  experiments/reports/paper_flowmo_latent_probes.json
33
  experiments/reports/paper_planning/
34
  experiments/reports/paper_report.md
35
  ```
36
 
37
  Images are rendered online from simulator states. Model inputs are clean top-down RGB frames with no flow arrows, no goal markers, no velocity vectors, and no trajectory overlays.
 
38
 
39
  ## Compared Methods
40
 
 
13
 
14
  FlowMo is a clean-image world-model benchmark for surface vehicles under hidden water drift. The proposed model separates short-history endogenous state and momentum from long-history exogenous drift context, then evaluates whether that factorization improves rollout prediction and closed-loop planning.
15
 
16
+ This repository contains the public code, tests, configuration, canonical paper datasets, checkpoints, generated GIFs, tables, and experiment reports.
17
 
18
  ## Paper Pipeline
19
 
 
26
  The default command trains all learned world models, evaluates prediction, runs FlowMo latent probes, evaluates planning on all configured tasks and boat morphologies, generates GIFs, and writes:
27
 
28
  ```text
29
+ experiments/reports/paper_prediction.json
 
 
30
  experiments/reports/paper_flowmo_latent_probes.json
31
  experiments/reports/paper_planning/
32
  experiments/reports/paper_report.md
33
  ```
34
 
35
  Images are rendered online from simulator states. Model inputs are clean top-down RGB frames with no flow arrows, no goal markers, no velocity vectors, and no trajectory overlays.
36
+ The train split, test split, and final planning evaluation use the same paper flow-family set.
37
 
38
  ## Compared Methods
39
 
data/paper/dataset_card.md CHANGED
@@ -9,39 +9,58 @@ suffixes; when the dataset is regenerated, these files are replaced in place.
9
  | File | Role |
10
  | --- | --- |
11
  | `train.npz` | Training split shared by all learned world models. |
12
- | `test_unseen_flow.npz` | Primary split with unseen flow families. |
13
- | `test_unseen_boat_params.npz` | Primary split with unseen boat dynamics. |
14
- | `diagnostic_seen_flow.npz` | Seen-flow-family diagnostic split used only for optimization sanity checks. |
15
 
16
  ## Sizes
17
 
18
  | File | Episodes | Steps per episode |
19
  | --- | ---: | ---: |
20
  | `train.npz` | 2400 | 300 |
21
- | `diagnostic_seen_flow.npz` | 480 | 300 |
22
- | `test_unseen_flow.npz` | 480 | 300 |
23
- | `test_unseen_boat_params.npz` | 480 | 300 |
24
 
25
  ## Stored Arrays
26
 
27
  Each `.npz` stores low-dimensional simulator state and metadata. Image-input
28
  models receive clean rendered images generated online from the same states.
29
  The image observation contains only the boat and clean workspace; flow vectors,
30
- velocity arrows, and visualization overlays are not part of the model input.
 
31
 
32
  All learned world models use the same split files, the same window sampling
33
  rules, the same image renderer, and the same train/evaluation budgets.
34
 
35
  ## Flow Families
36
 
37
- The training split and seen-flow-family diagnostic split use `noflow`, `uniform`,
38
- `slowly_varying`, `vortex_center`, `gradient`, and `turbulent_patch` flows.
39
- The unseen-flow split uses `noflow`, `shear`, `moving_vortex`, and
40
- `random_fourier` flows. The unseen-boat-dynamics split uses the training flow
41
- families with held-out boat mass, drag, inertia, and actuator-delay ranges.
42
 
43
- All splits use the fixed paper flow-strength constants in `driftwm/sim/flow.py`,
44
- both boat morphologies, and clean image observations without flow overlays.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  `flow_pool_size=80` means that each nonzero flow family is represented by 80
46
  hidden flow conditions before trajectories are sampled; it is a dataset
47
  diversity constant, not a model input or experiment mode.
 
 
 
 
 
 
 
 
 
 
 
 
9
  | File | Role |
10
  | --- | --- |
11
  | `train.npz` | Training split shared by all learned world models. |
12
+ | `test.npz` | Evaluation split shared by prediction, probes, and downstream planning protocol design. |
 
 
13
 
14
  ## Sizes
15
 
16
  | File | Episodes | Steps per episode |
17
  | --- | ---: | ---: |
18
  | `train.npz` | 2400 | 300 |
19
+ | `test.npz` | 480 | 300 |
 
 
20
 
21
  ## Stored Arrays
22
 
23
  Each `.npz` stores low-dimensional simulator state and metadata. Image-input
24
  models receive clean rendered images generated online from the same states.
25
  The image observation contains only the boat and clean workspace; flow vectors,
26
+ velocity arrows, goal markers, and visualization overlays are not part of the
27
+ model input.
28
 
29
  All learned world models use the same split files, the same window sampling
30
  rules, the same image renderer, and the same train/evaluation budgets.
31
 
32
  ## Flow Families
33
 
34
+ The train split, test split, and final planning evaluation use the same paper
35
+ flow-family set:
 
 
 
36
 
37
+ ```text
38
+ noflow
39
+ uniform
40
+ vortex_center
41
+ double_gyre
42
+ source_sink
43
+ source_sink_pair
44
+ gradient
45
+ shear
46
+ turbulent_patch
47
+ random_fourier
48
+ ```
49
+
50
+ The test split uses independently sampled episodes and hidden flow conditions,
51
+ but the paper reports a single flow-regime evaluation rather than separate
52
+ flow-family categories.
53
  `flow_pool_size=80` means that each nonzero flow family is represented by 80
54
  hidden flow conditions before trajectories are sampled; it is a dataset
55
  diversity constant, not a model input or experiment mode.
56
+
57
+ `uniform` is spatially constant. The other nonzero families contain spatial
58
+ structure: shear, affine gradients, fixed vortices, explicit double-gyre
59
+ recirculation, source/sink radial currents, localized turbulent patches, and
60
+ divergence-free random Fourier currents. All paper flow fields are static:
61
+ their velocity at a fixed position is independent of time.
62
+
63
+ Localized structures are sampled near common task routes and waypoint corridors
64
+ rather than uniformly over the whole workspace. This keeps hidden-flow
65
+ variation relevant to the boat trajectories instead of placing most structure
66
+ in regions the vehicle rarely visits.
data/paper/diagnostic_seen_flow.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e8366dd25e98b81b3b159a600e4e158b8b719e0ad41a0e3fbc915b6af34ce924
3
- size 6601980
 
 
 
 
data/paper/generation_config.json CHANGED
@@ -3,13 +3,17 @@
3
  "twin",
4
  "triangle"
5
  ],
6
- "train_flow_types": [
7
  "noflow",
8
  "uniform",
9
- "slowly_varying",
10
  "vortex_center",
 
 
 
11
  "gradient",
12
- "turbulent_patch"
 
 
13
  ],
14
  "trajectory_types": [
15
  "noflow_random_action",
@@ -20,18 +24,14 @@
20
  ],
21
  "episodes": {
22
  "train": 2400,
23
- "diagnostic_seen_flow": 480,
24
- "test_unseen_flow": 480,
25
- "test_unseen_boat_params": 480
26
  },
27
  "steps": 300,
28
  "flow_pool_size": 80,
29
  "boundary": "terminate",
30
  "seeds": {
31
  "train": 4301,
32
- "diagnostic_seen_flow": 4302,
33
- "test_unseen_flow": 4303,
34
- "test_unseen_boat_params": 4304
35
  },
36
  "image_size": 160,
37
  "visual_scale": 2.5,
@@ -40,10 +40,5 @@
40
  10.0,
41
  0.0,
42
  10.0
43
- ],
44
- "unseen_flow_types": [
45
- "shear",
46
- "moving_vortex",
47
- "random_fourier"
48
  ]
49
  }
 
3
  "twin",
4
  "triangle"
5
  ],
6
+ "flow_families": [
7
  "noflow",
8
  "uniform",
 
9
  "vortex_center",
10
+ "double_gyre",
11
+ "source_sink",
12
+ "source_sink_pair",
13
  "gradient",
14
+ "shear",
15
+ "turbulent_patch",
16
+ "random_fourier"
17
  ],
18
  "trajectory_types": [
19
  "noflow_random_action",
 
24
  ],
25
  "episodes": {
26
  "train": 2400,
27
+ "test": 480
 
 
28
  },
29
  "steps": 300,
30
  "flow_pool_size": 80,
31
  "boundary": "terminate",
32
  "seeds": {
33
  "train": 4301,
34
+ "test": 4302
 
 
35
  },
36
  "image_size": 160,
37
  "visual_scale": 2.5,
 
40
  10.0,
41
  0.0,
42
  10.0
 
 
 
 
 
43
  ]
44
  }
data/paper/test_unseen_boat_params.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fa4161b8917935a7a159f3e1604d4551c455f55160d8dd83b234bd726fc154a4
3
- size 6400836
 
 
 
 
data/paper/test_unseen_flow.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3beab06303edd2a81c802972d1055fed3ceffa54ac90f4686b11b08479682b87
3
- size 6493086
 
 
 
 
data/paper/train.npz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:3632528ec850fd526e028f49bf1641d8f9f9399dc6c07e495b9351df29f76f2a
3
- size 32273806
 
 
 
 
driftwm/data/generate.py CHANGED
@@ -17,13 +17,13 @@ ID_TO_BOAT = {v: k for k, v in BOAT_TO_ID.items()}
17
  FLOW_TO_ID = {
18
  "noflow": 0,
19
  "uniform": 1,
20
- "slowly_varying": 2,
21
- "vortex": 3,
22
- "vortex_center": 4,
23
- "gradient": 5,
24
- "turbulent_patch": 6,
25
  "shear": 7,
26
- "moving_vortex": 8,
27
  "random_fourier": 9,
28
  }
29
  ID_TO_FLOW = {v: k for k, v in FLOW_TO_ID.items()}
@@ -103,7 +103,6 @@ def generate_dataset(
103
  boundary: str = "terminate",
104
  randomize_params: bool = True,
105
  flow_pool_size: int = PAPER_FLOW_POOL_SIZE,
106
- unseen_boat_params: bool = False,
107
  ) -> None:
108
  rng = np.random.default_rng(seed)
109
  out = Path(out)
@@ -134,7 +133,9 @@ def generate_dataset(
134
  if traj_type.startswith("noflow"):
135
  flow_type = "noflow"
136
  else:
137
- available = [ft for ft in flow_types if ft != "noflow"] or ["uniform"]
 
 
138
  flow_type = available[int(rng.integers(0, len(available)))]
139
  flow_template = flow_pool[flow_type][int(rng.integers(0, len(flow_pool[flow_type])))]
140
  flow = copy.deepcopy(flow_template)
@@ -146,14 +147,6 @@ def generate_dataset(
146
  random_velocity=random_velocity,
147
  randomize_params=randomize_params,
148
  )
149
- if unseen_boat_params:
150
- for key in list(env.params):
151
- if key in {"mass", "inertia", "actuator_tau"}:
152
- factor = rng.choice([rng.uniform(0.55, 0.72), rng.uniform(1.45, 1.85)])
153
- else:
154
- factor = rng.choice([rng.uniform(0.45, 0.70), rng.uniform(1.45, 1.90)])
155
- env.params[key] = float(env.params[key] * factor)
156
-
157
  if traj_type == "noflow_random_action":
158
  planned_actions = smooth_random_actions(rng, steps, env.action_dim, scale=1.0)
159
  elif traj_type == "noflow_action_then_zero":
@@ -214,7 +207,6 @@ def generate_dataset(
214
  "seed": seed,
215
  "max_action_dim": 3,
216
  "flow_pool_size": flow_pool_size,
217
- "unseen_boat_params": bool(unseen_boat_params),
218
  }
219
  np.savez_compressed(
220
  out,
@@ -240,8 +232,8 @@ def main() -> None:
240
  parser.add_argument("--out", required=True)
241
  parser.add_argument("--seed", type=int, default=0)
242
  parser.add_argument("--boundary", choices=["terminate", "bounce", "clip"], default="terminate")
 
243
  parser.add_argument("--no-randomize-params", action="store_true")
244
- parser.add_argument("--unseen-boat-params", action="store_true")
245
  args = parser.parse_args()
246
  generate_dataset(
247
  boats=args.boats,
@@ -251,8 +243,8 @@ def main() -> None:
251
  out=args.out,
252
  seed=args.seed,
253
  boundary=args.boundary,
 
254
  randomize_params=not args.no_randomize_params,
255
- unseen_boat_params=args.unseen_boat_params,
256
  )
257
 
258
 
 
17
  FLOW_TO_ID = {
18
  "noflow": 0,
19
  "uniform": 1,
20
+ "vortex_center": 2,
21
+ "double_gyre": 3,
22
+ "source_sink": 4,
23
+ "source_sink_pair": 5,
24
+ "gradient": 6,
25
  "shear": 7,
26
+ "turbulent_patch": 8,
27
  "random_fourier": 9,
28
  }
29
  ID_TO_FLOW = {v: k for k, v in FLOW_TO_ID.items()}
 
103
  boundary: str = "terminate",
104
  randomize_params: bool = True,
105
  flow_pool_size: int = PAPER_FLOW_POOL_SIZE,
 
106
  ) -> None:
107
  rng = np.random.default_rng(seed)
108
  out = Path(out)
 
133
  if traj_type.startswith("noflow"):
134
  flow_type = "noflow"
135
  else:
136
+ available = [ft for ft in flow_types if ft != "noflow"]
137
+ if not available:
138
+ raise ValueError("flow_types must include at least one nonzero flow family")
139
  flow_type = available[int(rng.integers(0, len(available)))]
140
  flow_template = flow_pool[flow_type][int(rng.integers(0, len(flow_pool[flow_type])))]
141
  flow = copy.deepcopy(flow_template)
 
147
  random_velocity=random_velocity,
148
  randomize_params=randomize_params,
149
  )
 
 
 
 
 
 
 
 
150
  if traj_type == "noflow_random_action":
151
  planned_actions = smooth_random_actions(rng, steps, env.action_dim, scale=1.0)
152
  elif traj_type == "noflow_action_then_zero":
 
207
  "seed": seed,
208
  "max_action_dim": 3,
209
  "flow_pool_size": flow_pool_size,
 
210
  }
211
  np.savez_compressed(
212
  out,
 
232
  parser.add_argument("--out", required=True)
233
  parser.add_argument("--seed", type=int, default=0)
234
  parser.add_argument("--boundary", choices=["terminate", "bounce", "clip"], default="terminate")
235
+ parser.add_argument("--flow-pool-size", type=int, default=PAPER_FLOW_POOL_SIZE)
236
  parser.add_argument("--no-randomize-params", action="store_true")
 
237
  args = parser.parse_args()
238
  generate_dataset(
239
  boats=args.boats,
 
243
  out=args.out,
244
  seed=args.seed,
245
  boundary=args.boundary,
246
+ flow_pool_size=args.flow_pool_size,
247
  randomize_params=not args.no_randomize_params,
 
248
  )
249
 
250
 
driftwm/sim/env.py CHANGED
@@ -121,7 +121,6 @@ class SurfaceBoatEnv:
121
  self.config.workspace,
122
  self.config.boundary,
123
  )
124
- self.flow.step(self.config.dt, self.rng)
125
  self.t += 1
126
  self.time += self.config.dt
127
  timeout = self.t >= self.config.episode_steps
 
121
  self.config.workspace,
122
  self.config.boundary,
123
  )
 
124
  self.t += 1
125
  self.time += self.config.dt
126
  timeout = self.t >= self.config.episode_steps
driftwm/sim/flow.py CHANGED
@@ -9,20 +9,29 @@ import numpy as np
9
  PAPER_FLOW = {
10
  "uniform_min": 0.03,
11
  "uniform_max": 0.24,
12
- "slow_max": 0.26,
13
- "slow_noise": 0.0035,
14
- "vortex_base_max": 0.12,
15
- "vortex_gamma": 0.14,
16
- "vortex_max": 0.34,
17
- "gradient_base_min": 0.01,
18
- "gradient_base_max": 0.16,
19
- "gradient_matrix_std": 0.022,
20
  "gradient_max": 0.34,
21
- "turbulent_base_max": 0.12,
22
- "turbulent_vector_std": 0.075,
23
  "turbulent_max": 0.34,
 
 
24
  "shear_max": 0.38,
25
- "moving_vortex_max": 0.38,
 
 
 
 
 
 
 
26
  "random_fourier_max": 0.38,
27
  }
28
 
@@ -35,9 +44,6 @@ class Flow:
35
  def velocity(self, pos: np.ndarray, t: float = 0.0) -> np.ndarray:
36
  raise NotImplementedError
37
 
38
- def step(self, dt: float, rng: np.random.Generator) -> None:
39
- return None
40
-
41
  def metadata(self) -> dict[str, Any]:
42
  return {"flow_type": self.name, "flow_id": int(self.flow_id)}
43
 
@@ -68,45 +74,6 @@ class UniformFlow(Flow):
68
  out["vector"] = self.vector.astype(float).tolist()
69
  return out
70
 
71
-
72
- @dataclass
73
- class SlowlyVaryingFlow(Flow):
74
- vector: np.ndarray
75
- rho: float
76
- noise_std: float
77
- max_speed: float
78
-
79
- def __init__(
80
- self,
81
- vector: np.ndarray,
82
- flow_id: int,
83
- rho: float = 0.995,
84
- noise_std: float = 0.005,
85
- max_speed: float = 0.35,
86
- ):
87
- super().__init__("slowly_varying", flow_id)
88
- self.vector = np.asarray(vector, dtype=np.float32)
89
- self.rho = float(rho)
90
- self.noise_std = float(noise_std)
91
- self.max_speed = float(max_speed)
92
-
93
- def velocity(self, pos: np.ndarray, t: float = 0.0) -> np.ndarray:
94
- pos = np.asarray(pos, dtype=np.float32)
95
- return np.broadcast_to(self.vector, pos.shape).astype(np.float32)
96
-
97
- def step(self, dt: float, rng: np.random.Generator) -> None:
98
- noise = rng.normal(0.0, self.noise_std, size=2).astype(np.float32)
99
- self.vector = self.rho * self.vector + np.sqrt(max(0.0, 1.0 - self.rho**2)) * noise
100
- speed = float(np.linalg.norm(self.vector))
101
- if speed > self.max_speed:
102
- self.vector = self.vector / speed * self.max_speed
103
-
104
- def metadata(self) -> dict[str, Any]:
105
- out = super().metadata()
106
- out.update({"vector": self.vector.astype(float).tolist(), "rho": self.rho})
107
- return out
108
-
109
-
110
  @dataclass
111
  class VortexFlow(Flow):
112
  base: np.ndarray
@@ -178,7 +145,13 @@ class GradientFlow(Flow):
178
 
179
  def metadata(self) -> dict[str, Any]:
180
  out = super().metadata()
181
- out.update({"base": self.base.astype(float).tolist(), "matrix": self.matrix.astype(float).tolist()})
 
 
 
 
 
 
182
  return out
183
 
184
 
@@ -251,39 +224,91 @@ class ShearFlow(Flow):
251
 
252
 
253
  @dataclass
254
- class MovingVortexFlow(VortexFlow):
255
- center_velocity: np.ndarray
 
256
  workspace: tuple[float, float, float, float]
 
257
 
258
  def __init__(
259
  self,
260
- base: np.ndarray,
261
- center: np.ndarray,
262
- gamma: float,
263
- center_velocity: np.ndarray,
264
  flow_id: int,
265
  workspace: tuple[float, float, float, float],
266
- radius_eps: float = 0.35,
267
- max_speed: float = 0.60,
268
  ):
269
- super().__init__(base=base, center=center, gamma=gamma, flow_id=flow_id, radius_eps=radius_eps, max_speed=max_speed, name="moving_vortex")
270
- self.center_velocity = np.asarray(center_velocity, dtype=np.float32)
 
271
  self.workspace = workspace
 
272
 
273
- def step(self, dt: float, rng: np.random.Generator) -> None:
274
- del rng
275
- self.center = self.center + dt * self.center_velocity
276
  xmin, xmax, ymin, ymax = self.workspace
277
- if self.center[0] < xmin + 1.0 or self.center[0] > xmax - 1.0:
278
- self.center_velocity[0] *= -1.0
279
- if self.center[1] < ymin + 1.0 or self.center[1] > ymax - 1.0:
280
- self.center_velocity[1] *= -1.0
281
- self.center[0] = np.clip(self.center[0], xmin + 1.0, xmax - 1.0)
282
- self.center[1] = np.clip(self.center[1], ymin + 1.0, ymax - 1.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
  def metadata(self) -> dict[str, Any]:
285
  out = super().metadata()
286
- out["center_velocity"] = self.center_velocity.astype(float).tolist()
 
 
 
 
 
 
287
  return out
288
 
289
 
@@ -293,7 +318,6 @@ class RandomFourierFlow(Flow):
293
  k: np.ndarray
294
  amp: np.ndarray
295
  phase: np.ndarray
296
- temporal: np.ndarray
297
  max_speed: float
298
 
299
  def __init__(
@@ -302,7 +326,6 @@ class RandomFourierFlow(Flow):
302
  k: np.ndarray,
303
  amp: np.ndarray,
304
  phase: np.ndarray,
305
- temporal: np.ndarray,
306
  flow_id: int,
307
  max_speed: float = 0.60,
308
  ):
@@ -311,13 +334,12 @@ class RandomFourierFlow(Flow):
311
  self.k = np.asarray(k, dtype=np.float32)
312
  self.amp = np.asarray(amp, dtype=np.float32)
313
  self.phase = np.asarray(phase, dtype=np.float32)
314
- self.temporal = np.asarray(temporal, dtype=np.float32)
315
  self.max_speed = float(max_speed)
316
 
317
  def velocity(self, pos: np.ndarray, t: float = 0.0) -> np.ndarray:
318
  pos = np.asarray(pos, dtype=np.float32)
319
  flat = pos.reshape(-1, 2)
320
- arg = flat @ self.k.T + self.phase[None, :] + float(t) * self.temporal[None, :]
321
  # Divergence-free field via stream function psi: v=(dpsi/dy, -dpsi/dx).
322
  coeff = self.amp[None, :] * np.cos(arg)
323
  vx = np.sum(coeff * self.k[None, :, 1], axis=1)
@@ -339,6 +361,59 @@ def _sample_uniform_vector(rng: np.random.Generator, min_speed: float = 0.05, ma
339
  return np.array([speed * np.cos(direction), speed * np.sin(direction)], dtype=np.float32)
340
 
341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
342
  def sample_flow(
343
  flow_type: str,
344
  rng: np.random.Generator,
@@ -351,77 +426,74 @@ def sample_flow(
351
  return NoFlow(flow_id=0)
352
  if flow_type == "uniform":
353
  return UniformFlow(_sample_uniform_vector(rng, profile["uniform_min"], profile["uniform_max"]), flow_id=flow_id)
354
- if flow_type in {"slow", "slowly_varying", "ou"}:
355
- return SlowlyVaryingFlow(
356
- _sample_uniform_vector(rng, profile["uniform_min"], profile["uniform_max"]),
357
- flow_id=flow_id,
358
- noise_std=profile["slow_noise"],
359
- max_speed=profile["slow_max"],
360
- )
361
- if flow_type in {"vortex", "vortex_center"}:
362
- xmin, xmax, ymin, ymax = workspace
363
- if flow_type == "vortex_center":
364
- center = np.array([(xmin + xmax) / 2.0, (ymin + ymax) / 2.0], dtype=np.float32)
365
- else:
366
- center = np.array([rng.uniform(xmin + 2.0, xmax - 2.0), rng.uniform(ymin + 2.0, ymax - 2.0)], dtype=np.float32)
367
  base = _sample_uniform_vector(rng, 0.0, profile["vortex_base_max"])
368
- gamma = float(rng.uniform(-profile["vortex_gamma"], profile["vortex_gamma"]))
369
  return VortexFlow(base=base, center=center, gamma=gamma, flow_id=flow_id, max_speed=profile["vortex_max"], name=flow_type)
370
- if flow_type in {"gradient", "gradient_flow"}:
371
- xmin, xmax, ymin, ymax = workspace
372
- center = np.array([(xmin + xmax) / 2.0, (ymin + ymax) / 2.0], dtype=np.float32)
373
  base = _sample_uniform_vector(rng, profile["gradient_base_min"], profile["gradient_base_max"])
374
- mat = rng.normal(0.0, profile["gradient_matrix_std"], size=(2, 2)).astype(np.float32)
 
 
375
  return GradientFlow(base=base, center=center, matrix=mat, flow_id=flow_id, max_speed=profile["gradient_max"])
376
- if flow_type in {"turbulent", "turbulent_patch", "patch"}:
377
- xmin, xmax, ymin, ymax = workspace
378
  base = _sample_uniform_vector(rng, 0.0, profile["turbulent_base_max"])
379
- centers = np.stack(
380
- [
381
- rng.uniform([xmin + 1.0, ymin + 1.0], [xmax - 1.0, ymax - 1.0])
382
- for _ in range(5)
383
- ],
384
- axis=0,
385
- ).astype(np.float32)
386
  vectors = rng.normal(0.0, profile["turbulent_vector_std"], size=(5, 2)).astype(np.float32)
387
  return TurbulentPatchFlow(base=base, centers=centers, vectors=vectors, flow_id=flow_id, max_speed=profile["turbulent_max"])
388
- if flow_type in {"shear", "shear_flow"}:
389
  xmin, xmax, ymin, ymax = workspace
390
  base = _sample_uniform_vector(rng, 0.0, profile["turbulent_base_max"])
391
  center_y = 0.5 * (ymin + ymax)
392
- shear = float(rng.uniform(-0.08, 0.08))
393
  return ShearFlow(base=base, center_y=center_y, shear=shear, flow_id=flow_id, max_speed=profile["shear_max"])
394
- if flow_type in {"moving_vortex", "moving-vortex"}:
395
- xmin, xmax, ymin, ymax = workspace
396
- center = np.array([rng.uniform(xmin + 2.0, xmax - 2.0), rng.uniform(ymin + 2.0, ymax - 2.0)], dtype=np.float32)
397
- base = _sample_uniform_vector(rng, 0.0, profile["vortex_base_max"])
398
- gamma = float(rng.uniform(-profile["vortex_gamma"], profile["vortex_gamma"]))
399
- center_velocity = _sample_uniform_vector(rng, 0.02, 0.08)
400
- return MovingVortexFlow(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
401
  base=base,
402
- center=center,
403
- gamma=gamma,
404
- center_velocity=center_velocity,
405
  flow_id=flow_id,
406
- workspace=workspace,
407
- max_speed=profile["moving_vortex_max"],
408
  )
409
- if flow_type in {"random_fourier", "fourier", "divfree"}:
410
  base = _sample_uniform_vector(rng, 0.0, profile["turbulent_base_max"])
411
  modes = 8
412
  k = rng.integers(1, 5, size=(modes, 2)).astype(np.float32)
413
  signs = rng.choice([-1.0, 1.0], size=(modes, 2)).astype(np.float32)
414
  k = signs * k * (2.0 * np.pi / 10.0)
415
- amp_std = 0.028
416
  amp = rng.normal(0.0, amp_std, size=(modes,)).astype(np.float32)
417
  phase = rng.uniform(0.0, 2.0 * np.pi, size=(modes,)).astype(np.float32)
418
- temporal = rng.normal(0.0, 0.12, size=(modes,)).astype(np.float32)
419
  return RandomFourierFlow(
420
  base=base,
421
  k=k,
422
  amp=amp,
423
  phase=phase,
424
- temporal=temporal,
425
  flow_id=flow_id,
426
  max_speed=profile["random_fourier_max"],
427
  )
 
9
  PAPER_FLOW = {
10
  "uniform_min": 0.03,
11
  "uniform_max": 0.24,
12
+ "vortex_base_max": 0.05,
13
+ "vortex_gamma_min": 0.12,
14
+ "vortex_gamma_max": 0.24,
15
+ "vortex_max": 0.36,
16
+ "gradient_base_min": 0.00,
17
+ "gradient_base_max": 0.08,
18
+ "gradient_matrix_min": 0.018,
19
+ "gradient_matrix_max": 0.040,
20
  "gradient_max": 0.34,
21
+ "turbulent_base_max": 0.05,
22
+ "turbulent_vector_std": 0.105,
23
  "turbulent_max": 0.34,
24
+ "shear_min": 0.035,
25
+ "shear_max_rate": 0.085,
26
  "shear_max": 0.38,
27
+ "double_gyre_amp_min": 0.20,
28
+ "double_gyre_amp_max": 0.34,
29
+ "double_gyre_max": 0.36,
30
+ "source_base_max": 0.03,
31
+ "source_strength_min": 0.16,
32
+ "source_strength_max": 0.30,
33
+ "source_max": 0.36,
34
+ "random_fourier_amp_std": 0.045,
35
  "random_fourier_max": 0.38,
36
  }
37
 
 
44
  def velocity(self, pos: np.ndarray, t: float = 0.0) -> np.ndarray:
45
  raise NotImplementedError
46
 
 
 
 
47
  def metadata(self) -> dict[str, Any]:
48
  return {"flow_type": self.name, "flow_id": int(self.flow_id)}
49
 
 
74
  out["vector"] = self.vector.astype(float).tolist()
75
  return out
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  @dataclass
78
  class VortexFlow(Flow):
79
  base: np.ndarray
 
145
 
146
  def metadata(self) -> dict[str, Any]:
147
  out = super().metadata()
148
+ out.update(
149
+ {
150
+ "base": self.base.astype(float).tolist(),
151
+ "center": self.center.astype(float).tolist(),
152
+ "matrix": self.matrix.astype(float).tolist(),
153
+ }
154
+ )
155
  return out
156
 
157
 
 
224
 
225
 
226
  @dataclass
227
+ class DoubleGyreFlow(Flow):
228
+ amp: float
229
+ phase: float
230
  workspace: tuple[float, float, float, float]
231
+ max_speed: float
232
 
233
  def __init__(
234
  self,
235
+ amp: float,
236
+ phase: float,
 
 
237
  flow_id: int,
238
  workspace: tuple[float, float, float, float],
239
+ max_speed: float = 0.50,
 
240
  ):
241
+ super().__init__("double_gyre", flow_id)
242
+ self.amp = float(amp)
243
+ self.phase = float(phase)
244
  self.workspace = workspace
245
+ self.max_speed = float(max_speed)
246
 
247
+ def velocity(self, pos: np.ndarray, t: float = 0.0) -> np.ndarray:
248
+ pos = np.asarray(pos, dtype=np.float32)
 
249
  xmin, xmax, ymin, ymax = self.workspace
250
+ x = (pos[..., 0] - xmin) / max(xmax - xmin, 1e-6)
251
+ y = (pos[..., 1] - ymin) / max(ymax - ymin, 1e-6)
252
+ sx = np.sin(np.pi * x + self.phase)
253
+ cx = np.cos(np.pi * x + self.phase)
254
+ sy = np.sin(2.0 * np.pi * y)
255
+ cy = np.cos(2.0 * np.pi * y)
256
+ vel = np.stack([self.amp * sx * cy, -0.5 * self.amp * cx * sy], axis=-1)
257
+ speed = np.linalg.norm(vel, axis=-1, keepdims=True)
258
+ scale = np.minimum(1.0, self.max_speed / np.maximum(speed, 1e-6))
259
+ return (vel * scale).astype(np.float32)
260
+
261
+ def metadata(self) -> dict[str, Any]:
262
+ out = super().metadata()
263
+ out.update({"amp": self.amp, "phase": self.phase})
264
+ return out
265
+
266
+
267
+ @dataclass
268
+ class SourceSinkFlow(Flow):
269
+ base: np.ndarray
270
+ centers: np.ndarray
271
+ strengths: np.ndarray
272
+ radius_eps: float
273
+ max_speed: float
274
+
275
+ def __init__(
276
+ self,
277
+ name: str,
278
+ base: np.ndarray,
279
+ centers: np.ndarray,
280
+ strengths: np.ndarray,
281
+ flow_id: int,
282
+ radius_eps: float = 0.45,
283
+ max_speed: float = 0.50,
284
+ ):
285
+ super().__init__(name, flow_id)
286
+ self.base = np.asarray(base, dtype=np.float32)
287
+ self.centers = np.asarray(centers, dtype=np.float32)
288
+ self.strengths = np.asarray(strengths, dtype=np.float32)
289
+ self.radius_eps = float(radius_eps)
290
+ self.max_speed = float(max_speed)
291
+
292
+ def velocity(self, pos: np.ndarray, t: float = 0.0) -> np.ndarray:
293
+ pos = np.asarray(pos, dtype=np.float32)
294
+ rel = pos[..., None, :] - self.centers
295
+ denom = np.sum(rel * rel, axis=-1, keepdims=True) + self.radius_eps**2
296
+ strength_shape = (1,) * (rel.ndim - 2) + (self.strengths.shape[0], 1)
297
+ radial = np.sum(self.strengths.reshape(strength_shape) * rel / denom, axis=-2)
298
+ vel = self.base + radial
299
+ speed = np.linalg.norm(vel, axis=-1, keepdims=True)
300
+ scale = np.minimum(1.0, self.max_speed / np.maximum(speed, 1e-6))
301
+ return (vel * scale).astype(np.float32)
302
 
303
  def metadata(self) -> dict[str, Any]:
304
  out = super().metadata()
305
+ out.update(
306
+ {
307
+ "base": self.base.astype(float).tolist(),
308
+ "centers": self.centers.astype(float).tolist(),
309
+ "strengths": self.strengths.astype(float).tolist(),
310
+ }
311
+ )
312
  return out
313
 
314
 
 
318
  k: np.ndarray
319
  amp: np.ndarray
320
  phase: np.ndarray
 
321
  max_speed: float
322
 
323
  def __init__(
 
326
  k: np.ndarray,
327
  amp: np.ndarray,
328
  phase: np.ndarray,
 
329
  flow_id: int,
330
  max_speed: float = 0.60,
331
  ):
 
334
  self.k = np.asarray(k, dtype=np.float32)
335
  self.amp = np.asarray(amp, dtype=np.float32)
336
  self.phase = np.asarray(phase, dtype=np.float32)
 
337
  self.max_speed = float(max_speed)
338
 
339
  def velocity(self, pos: np.ndarray, t: float = 0.0) -> np.ndarray:
340
  pos = np.asarray(pos, dtype=np.float32)
341
  flat = pos.reshape(-1, 2)
342
+ arg = flat @ self.k.T + self.phase[None, :]
343
  # Divergence-free field via stream function psi: v=(dpsi/dy, -dpsi/dx).
344
  coeff = self.amp[None, :] * np.cos(arg)
345
  vx = np.sum(coeff * self.k[None, :, 1], axis=1)
 
361
  return np.array([speed * np.cos(direction), speed * np.sin(direction)], dtype=np.float32)
362
 
363
 
364
+ def _sample_signed_uniform(rng: np.random.Generator, min_abs: float, max_abs: float) -> float:
365
+ sign = -1.0 if rng.random() < 0.5 else 1.0
366
+ return float(sign * rng.uniform(min_abs, max_abs))
367
+
368
+
369
+ def _route_anchors(workspace: tuple[float, float, float, float]) -> np.ndarray:
370
+ xmin, xmax, ymin, ymax = workspace
371
+ w = xmax - xmin
372
+ h = ymax - ymin
373
+ points = np.array(
374
+ [
375
+ [0.20, 0.20],
376
+ [0.35, 0.35],
377
+ [0.50, 0.50],
378
+ [0.65, 0.65],
379
+ [0.80, 0.80],
380
+ [0.20, 0.80],
381
+ [0.35, 0.65],
382
+ [0.50, 0.50],
383
+ [0.65, 0.35],
384
+ [0.80, 0.20],
385
+ [0.50, 0.25],
386
+ [0.50, 0.75],
387
+ [0.25, 0.50],
388
+ [0.75, 0.50],
389
+ ],
390
+ dtype=np.float32,
391
+ )
392
+ points[:, 0] = xmin + points[:, 0] * w
393
+ points[:, 1] = ymin + points[:, 1] * h
394
+ return points
395
+
396
+
397
+ def _sample_route_center(
398
+ rng: np.random.Generator,
399
+ workspace: tuple[float, float, float, float],
400
+ jitter: float = 0.65,
401
+ ) -> np.ndarray:
402
+ xmin, xmax, ymin, ymax = workspace
403
+ anchors = _route_anchors(workspace)
404
+ center = anchors[int(rng.integers(0, len(anchors)))] + rng.normal(0.0, jitter, size=2).astype(np.float32)
405
+ return np.array([np.clip(center[0], xmin + 1.2, xmax - 1.2), np.clip(center[1], ymin + 1.2, ymax - 1.2)], dtype=np.float32)
406
+
407
+
408
+ def _sample_route_centers(
409
+ rng: np.random.Generator,
410
+ workspace: tuple[float, float, float, float],
411
+ count: int,
412
+ jitter: float = 0.65,
413
+ ) -> np.ndarray:
414
+ return np.stack([_sample_route_center(rng, workspace, jitter=jitter) for _ in range(count)], axis=0).astype(np.float32)
415
+
416
+
417
  def sample_flow(
418
  flow_type: str,
419
  rng: np.random.Generator,
 
426
  return NoFlow(flow_id=0)
427
  if flow_type == "uniform":
428
  return UniformFlow(_sample_uniform_vector(rng, profile["uniform_min"], profile["uniform_max"]), flow_id=flow_id)
429
+ if flow_type == "vortex_center":
430
+ center = _sample_route_center(rng, workspace, jitter=0.45)
 
 
 
 
 
 
 
 
 
 
 
431
  base = _sample_uniform_vector(rng, 0.0, profile["vortex_base_max"])
432
+ gamma = _sample_signed_uniform(rng, profile["vortex_gamma_min"], profile["vortex_gamma_max"])
433
  return VortexFlow(base=base, center=center, gamma=gamma, flow_id=flow_id, max_speed=profile["vortex_max"], name=flow_type)
434
+ if flow_type == "gradient":
435
+ center = _sample_route_center(rng, workspace, jitter=0.35)
 
436
  base = _sample_uniform_vector(rng, profile["gradient_base_min"], profile["gradient_base_max"])
437
+ scale = rng.uniform(profile["gradient_matrix_min"], profile["gradient_matrix_max"])
438
+ mat = rng.normal(0.0, 1.0, size=(2, 2)).astype(np.float32)
439
+ mat = mat / max(float(np.linalg.norm(mat)), 1e-6) * scale
440
  return GradientFlow(base=base, center=center, matrix=mat, flow_id=flow_id, max_speed=profile["gradient_max"])
441
+ if flow_type == "turbulent_patch":
 
442
  base = _sample_uniform_vector(rng, 0.0, profile["turbulent_base_max"])
443
+ centers = _sample_route_centers(rng, workspace, count=5, jitter=0.85)
 
 
 
 
 
 
444
  vectors = rng.normal(0.0, profile["turbulent_vector_std"], size=(5, 2)).astype(np.float32)
445
  return TurbulentPatchFlow(base=base, centers=centers, vectors=vectors, flow_id=flow_id, max_speed=profile["turbulent_max"])
446
+ if flow_type == "shear":
447
  xmin, xmax, ymin, ymax = workspace
448
  base = _sample_uniform_vector(rng, 0.0, profile["turbulent_base_max"])
449
  center_y = 0.5 * (ymin + ymax)
450
+ shear = _sample_signed_uniform(rng, profile["shear_min"], profile["shear_max_rate"])
451
  return ShearFlow(base=base, center_y=center_y, shear=shear, flow_id=flow_id, max_speed=profile["shear_max"])
452
+ if flow_type == "double_gyre":
453
+ amp = float(rng.uniform(profile["double_gyre_amp_min"], profile["double_gyre_amp_max"]))
454
+ phase = float(rng.uniform(0.0, 2.0 * np.pi))
455
+ return DoubleGyreFlow(amp=amp, phase=phase, flow_id=flow_id, workspace=workspace, max_speed=profile["double_gyre_max"])
456
+ if flow_type == "source_sink":
457
+ base = _sample_uniform_vector(rng, 0.0, profile["source_base_max"])
458
+ center = _sample_route_centers(rng, workspace, count=1, jitter=0.65)
459
+ strength = np.array([_sample_signed_uniform(rng, profile["source_strength_min"], profile["source_strength_max"])], dtype=np.float32)
460
+ return SourceSinkFlow(
461
+ name="source_sink",
462
+ base=base,
463
+ centers=center,
464
+ strengths=strength,
465
+ flow_id=flow_id,
466
+ max_speed=profile["source_max"],
467
+ )
468
+ if flow_type == "source_sink_pair":
469
+ base = _sample_uniform_vector(rng, 0.0, profile["source_base_max"])
470
+ centers = _sample_route_centers(rng, workspace, count=2, jitter=0.75)
471
+ strength = float(rng.uniform(profile["source_strength_min"], profile["source_strength_max"]))
472
+ if rng.random() < 0.5:
473
+ strength = -strength
474
+ strengths = np.array([strength, -strength], dtype=np.float32)
475
+ return SourceSinkFlow(
476
+ name="source_sink_pair",
477
  base=base,
478
+ centers=centers,
479
+ strengths=strengths,
 
480
  flow_id=flow_id,
481
+ max_speed=profile["source_max"],
 
482
  )
483
+ if flow_type == "random_fourier":
484
  base = _sample_uniform_vector(rng, 0.0, profile["turbulent_base_max"])
485
  modes = 8
486
  k = rng.integers(1, 5, size=(modes, 2)).astype(np.float32)
487
  signs = rng.choice([-1.0, 1.0], size=(modes, 2)).astype(np.float32)
488
  k = signs * k * (2.0 * np.pi / 10.0)
489
+ amp_std = profile["random_fourier_amp_std"]
490
  amp = rng.normal(0.0, amp_std, size=(modes,)).astype(np.float32)
491
  phase = rng.uniform(0.0, 2.0 * np.pi, size=(modes,)).astype(np.float32)
 
492
  return RandomFourierFlow(
493
  base=base,
494
  k=k,
495
  amp=amp,
496
  phase=phase,
 
497
  flow_id=flow_id,
498
  max_speed=profile["random_fourier_max"],
499
  )
driftwm/sim/sanity.py CHANGED
@@ -68,7 +68,7 @@ def run_sanity(
68
  def main() -> None:
69
  parser = argparse.ArgumentParser()
70
  parser.add_argument("--boat", choices=["twin", "triangle"], default="twin")
71
- parser.add_argument("--flow", choices=["noflow", "uniform", "slowly_varying", "vortex"], default="noflow")
72
  parser.add_argument("--scenario", choices=["auto", "slide", "drift", "thruster", "random"], default="auto")
73
  parser.add_argument("--steps", type=int, default=200)
74
  parser.add_argument("--seed", type=int, default=0)
 
68
  def main() -> None:
69
  parser = argparse.ArgumentParser()
70
  parser.add_argument("--boat", choices=["twin", "triangle"], default="twin")
71
+ parser.add_argument("--flow", choices=["noflow", "uniform", "vortex_center", "double_gyre", "source_sink", "source_sink_pair", "gradient", "shear", "turbulent_patch", "random_fourier"], default="noflow")
72
  parser.add_argument("--scenario", choices=["auto", "slide", "drift", "thruster", "random"], default="auto")
73
  parser.add_argument("--steps", type=int, default=200)
74
  parser.add_argument("--seed", type=int, default=0)
experiments/EXPERIMENT_MATRIX.md CHANGED
@@ -12,14 +12,16 @@ Image size: 160 x 160
12
  Visual scale: 2.5
13
  Forbidden image cues: flow arrows, velocity vectors, trajectory overlays, goal marker
14
  Train split: data/paper/train.npz
15
- Primary unseen-flow split: data/paper/test_unseen_flow.npz
16
- Primary unseen-boat-dynamics split: data/paper/test_unseen_boat_params.npz
17
- Diagnostic seen-flow-family split: data/paper/diagnostic_seen_flow.npz
18
  Config: experiments/shared/config/paper_image.json
19
  Checkpoint: paper.pt
20
  Intermediate checkpoints: paper_step_XXXXXX.pt
21
  ```
22
 
 
 
 
23
  Formal training budget:
24
 
25
  ```text
@@ -32,6 +34,8 @@ steps: 20000
32
  checkpoint_interval: 2000
33
  num_workers: 4
34
  render_mode: device
 
 
35
  ```
36
 
37
  Precision policy:
@@ -53,12 +57,10 @@ Purpose: measure world-model quality directly. The key question is whether FlowM
53
  | `planet` | RSSM WM baseline | Whether generic recurrent latent memory can represent momentum and drift without a separate context factor. |
54
  | `tdmpc2` | Compact latent-dynamics WM baseline | Whether a compact action-conditioned latent transition matches FlowMo under equal supervision. |
55
 
56
- Prediction datasets:
57
 
58
  ```text
59
- test_unseen_flow
60
- test_unseen_boat_params
61
- diagnostic_seen_flow
62
  ```
63
 
64
  Prediction metrics:
@@ -115,10 +117,8 @@ Traditional non-WM controllers:
115
  Planning tasks:
116
 
117
  ```text
118
- reach_uniform
119
- counterflow
120
  station_keeping
121
- passive_to_active
122
  waypoint_square
123
  waypoint_zigzag
124
  ```
@@ -130,6 +130,21 @@ twin
130
  triangle
131
  ```
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  Planning metrics:
134
 
135
  ```text
 
12
  Visual scale: 2.5
13
  Forbidden image cues: flow arrows, velocity vectors, trajectory overlays, goal marker
14
  Train split: data/paper/train.npz
15
+ Test split: data/paper/test.npz
16
+ Flow families: noflow, uniform, vortex_center, double_gyre, source_sink, source_sink_pair, gradient, shear, turbulent_patch, random_fourier
 
17
  Config: experiments/shared/config/paper_image.json
18
  Checkpoint: paper.pt
19
  Intermediate checkpoints: paper_step_XXXXXX.pt
20
  ```
21
 
22
+ All flow fields are static. Localized flow structures are sampled near the
23
+ route corridors used by the training controllers and final planning tasks.
24
+
25
  Formal training budget:
26
 
27
  ```text
 
34
  checkpoint_interval: 2000
35
  num_workers: 4
36
  render_mode: device
37
+ training_parallel_jobs: 2
38
+ planning_parallel_jobs: 3
39
  ```
40
 
41
  Precision policy:
 
57
  | `planet` | RSSM WM baseline | Whether generic recurrent latent memory can represent momentum and drift without a separate context factor. |
58
  | `tdmpc2` | Compact latent-dynamics WM baseline | Whether a compact action-conditioned latent transition matches FlowMo under equal supervision. |
59
 
60
+ Prediction dataset:
61
 
62
  ```text
63
+ test
 
 
64
  ```
65
 
66
  Prediction metrics:
 
117
  Planning tasks:
118
 
119
  ```text
120
+ reach_target
 
121
  station_keeping
 
122
  waypoint_square
123
  waypoint_zigzag
124
  ```
 
130
  triangle
131
  ```
132
 
133
+ Flow families:
134
+
135
+ ```text
136
+ noflow
137
+ uniform
138
+ vortex_center
139
+ double_gyre
140
+ source_sink
141
+ source_sink_pair
142
+ gradient
143
+ shear
144
+ turbulent_patch
145
+ random_fourier
146
+ ```
147
+
148
  Planning metrics:
149
 
150
  ```text
experiments/README.md CHANGED
@@ -62,7 +62,8 @@ Formal clean-image configuration:
62
  image_size=160
63
  visual_scale=2.5
64
  train=data/paper/train.npz
65
- test=data/paper/test_unseen_flow.npz and data/paper/test_unseen_boat_params.npz
 
66
  ```
67
 
68
  Full paper-facing image pipeline:
@@ -72,6 +73,7 @@ python -m experiments.run_paper_image_pipeline
72
  ```
73
 
74
  The default command runs the paper configuration end to end: train all learned world models, evaluate long rollout prediction, run FlowMo latent probes, evaluate closed-loop planning against traditional controllers, generate GIFs, and write the final report. Images are rendered online from simulator states, so no separate image-cache preparation step is required.
 
75
 
76
  Manual image training:
77
 
@@ -79,6 +81,6 @@ Manual image training:
79
  python -m experiments.train_image_world_models
80
  python -m experiments.evaluate_image_world_models
81
  python -m experiments.evaluate_flowmo_latent_probes
82
- python -m experiments.evaluate_image_planning --task reach_uniform --boat twin
83
  python -m experiments.summarize_paper_image_results
84
  ```
 
62
  image_size=160
63
  visual_scale=2.5
64
  train=data/paper/train.npz
65
+ test=data/paper/test.npz
66
+ flow_families=noflow, uniform, vortex_center, double_gyre, source_sink, source_sink_pair, gradient, shear, turbulent_patch, random_fourier
67
  ```
68
 
69
  Full paper-facing image pipeline:
 
73
  ```
74
 
75
  The default command runs the paper configuration end to end: train all learned world models, evaluate long rollout prediction, run FlowMo latent probes, evaluate closed-loop planning against traditional controllers, generate GIFs, and write the final report. Images are rendered online from simulator states, so no separate image-cache preparation step is required.
76
+ All flow fields are static. Localized flow structures are sampled near task routes so that boat trajectories encounter non-uniform current in the shared train/test/final protocol.
77
 
78
  Manual image training:
79
 
 
81
  python -m experiments.train_image_world_models
82
  python -m experiments.evaluate_image_world_models
83
  python -m experiments.evaluate_flowmo_latent_probes
84
+ python -m experiments.evaluate_image_planning --task reach_target --boat twin
85
  python -m experiments.summarize_paper_image_results
86
  ```
experiments/TASK_PLAN.md CHANGED
@@ -12,8 +12,10 @@ Shared setup:
12
  Input: clean top-down boat images plus action history
13
  No image cues: no flow arrows, no velocity vector, no goal marker
14
  Training data: data/paper/train.npz
15
- Primary evaluation data: data/paper/test_unseen_flow.npz, data/paper/test_unseen_boat_params.npz
16
- Diagnostic data: data/paper/diagnostic_seen_flow.npz
 
 
17
  Training budget: shared optimizer, batch size, rollout horizon, step count, and checkpoint schedule
18
  Training precision: BF16 model autocast, FP32 losses and metrics
19
  Prediction precision: BF16 model autocast, FP32 metrics
@@ -46,9 +48,7 @@ experiments/<method>/checkpoint/paper.pt
46
  experiments/<method>/checkpoint/paper_step_*.pt
47
  experiments/<method>/result/parameter_count.json
48
  experiments/<method>/result/paper_training.json
49
- experiments/reports/paper_prediction_seen_flow_diagnostic.json
50
- experiments/reports/paper_prediction_unseen_flow.json
51
- experiments/reports/paper_prediction_unseen_boat_params.json
52
  experiments/reports/paper_flowmo_latent_probes.json
53
  ```
54
 
@@ -56,7 +56,7 @@ Core A conclusions:
56
 
57
  ```text
58
  1. Whether FlowMo has lower long-horizon rollout error.
59
- 2. Whether the gain is strongest under unseen flow families and unseen boat dynamics.
60
  3. Whether explicit drift context helps beyond ordinary recurrent history.
61
  4. Whether the same architecture works for both twin and triangle boats.
62
  5. Whether frozen linear probes recover object momentum from `z_t` and ambient drift from `c_t`.
@@ -87,10 +87,8 @@ Compared methods:
87
  Planning tasks:
88
 
89
  ```text
90
- reach_uniform
91
- counterflow
92
  station_keeping
93
- passive_to_active
94
  waypoint_square
95
  waypoint_zigzag
96
  ```
@@ -102,6 +100,21 @@ twin
102
  triangle
103
  ```
104
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  Required B outputs:
106
 
107
  ```text
 
12
  Input: clean top-down boat images plus action history
13
  No image cues: no flow arrows, no velocity vector, no goal marker
14
  Training data: data/paper/train.npz
15
+ Evaluation data: data/paper/test.npz
16
+ Flow families: noflow, uniform, vortex_center, double_gyre, source_sink, source_sink_pair, gradient, shear, turbulent_patch, random_fourier
17
+ All flow fields are static. Localized flow structures are sampled near common
18
+ task routes so the boat encounters non-uniform current during rollout.
19
  Training budget: shared optimizer, batch size, rollout horizon, step count, and checkpoint schedule
20
  Training precision: BF16 model autocast, FP32 losses and metrics
21
  Prediction precision: BF16 model autocast, FP32 metrics
 
48
  experiments/<method>/checkpoint/paper_step_*.pt
49
  experiments/<method>/result/parameter_count.json
50
  experiments/<method>/result/paper_training.json
51
+ experiments/reports/paper_prediction.json
 
 
52
  experiments/reports/paper_flowmo_latent_probes.json
53
  ```
54
 
 
56
 
57
  ```text
58
  1. Whether FlowMo has lower long-horizon rollout error.
59
+ 2. Whether the gain holds across the full paper flow-family set.
60
  3. Whether explicit drift context helps beyond ordinary recurrent history.
61
  4. Whether the same architecture works for both twin and triangle boats.
62
  5. Whether frozen linear probes recover object momentum from `z_t` and ambient drift from `c_t`.
 
87
  Planning tasks:
88
 
89
  ```text
90
+ reach_target
 
91
  station_keeping
 
92
  waypoint_square
93
  waypoint_zigzag
94
  ```
 
100
  triangle
101
  ```
102
 
103
+ Flow families:
104
+
105
+ ```text
106
+ noflow
107
+ uniform
108
+ vortex_center
109
+ double_gyre
110
+ source_sink
111
+ source_sink_pair
112
+ gradient
113
+ shear
114
+ turbulent_patch
115
+ random_fourier
116
+ ```
117
+
118
  Required B outputs:
119
 
120
  ```text
experiments/docs/EXPERIMENT_PROTOCOL.md CHANGED
@@ -36,13 +36,20 @@ All methods use the same splits:
36
 
37
  ```text
38
  train: data/paper/train.npz
39
- unseen_flow_test: data/paper/test_unseen_flow.npz
40
- unseen_boat_dynamics_test: data/paper/test_unseen_boat_params.npz
41
- seen_flow_diagnostic: data/paper/diagnostic_seen_flow.npz
42
  dataset_card: data/paper/dataset_card.md
43
  generation_config: data/paper/generation_config.json
44
  ```
45
 
 
 
 
 
 
 
 
 
 
46
  Observation protocol:
47
 
48
  ```text
@@ -64,6 +71,8 @@ steps: 20000
64
  checkpoint_interval: 2000
65
  num_workers: 4
66
  render_mode: device
 
 
67
  ```
68
 
69
  Precision policy:
@@ -78,12 +87,10 @@ The precision split is intentional: BF16 speeds up image encoding and latent rol
78
 
79
  ## Prediction Evaluation
80
 
81
- Datasets:
82
 
83
  ```text
84
- test_unseen_flow
85
- test_unseen_boat_params
86
- diagnostic_seen_flow
87
  ```
88
 
89
  Metrics:
@@ -139,10 +146,8 @@ oracle_flow_mpc
139
  Tasks:
140
 
141
  ```text
142
- reach_uniform
143
- counterflow
144
  station_keeping
145
- passive_to_active
146
  waypoint_square
147
  waypoint_zigzag
148
  ```
@@ -154,6 +159,21 @@ twin
154
  triangle
155
  ```
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  Metrics:
158
 
159
  ```text
@@ -179,9 +199,7 @@ experiments/<method>/result/paper_training_trace.jsonl
179
  Evaluation outputs:
180
 
181
  ```text
182
- experiments/reports/paper_prediction_unseen_flow.json
183
- experiments/reports/paper_prediction_unseen_boat_params.json
184
- experiments/reports/paper_prediction_seen_flow_diagnostic.json
185
  experiments/reports/paper_flowmo_latent_probes.json
186
  experiments/reports/paper_planning/*.json
187
  experiments/reports/paper_planning/gifs/*.gif
 
36
 
37
  ```text
38
  train: data/paper/train.npz
39
+ test: data/paper/test.npz
 
 
40
  dataset_card: data/paper/dataset_card.md
41
  generation_config: data/paper/generation_config.json
42
  ```
43
 
44
+ The train split, test split, and final planning evaluation use the same paper
45
+ flow-family set: `noflow`, `uniform`, `vortex_center`, `double_gyre`,
46
+ `source_sink`, `source_sink_pair`, `gradient`, `shear`, `turbulent_patch`, and
47
+ `random_fourier`.
48
+
49
+ All paper flow fields are static. Localized structures are sampled near common
50
+ task routes and waypoint corridors so that non-uniform flow is encountered by
51
+ the boat during both training trajectories and final planning tasks.
52
+
53
  Observation protocol:
54
 
55
  ```text
 
71
  checkpoint_interval: 2000
72
  num_workers: 4
73
  render_mode: device
74
+ training_parallel_jobs: 2
75
+ planning_parallel_jobs: 3
76
  ```
77
 
78
  Precision policy:
 
87
 
88
  ## Prediction Evaluation
89
 
90
+ Dataset:
91
 
92
  ```text
93
+ test
 
 
94
  ```
95
 
96
  Metrics:
 
146
  Tasks:
147
 
148
  ```text
149
+ reach_target
 
150
  station_keeping
 
151
  waypoint_square
152
  waypoint_zigzag
153
  ```
 
159
  triangle
160
  ```
161
 
162
+ Flow families:
163
+
164
+ ```text
165
+ noflow
166
+ uniform
167
+ vortex_center
168
+ double_gyre
169
+ source_sink
170
+ source_sink_pair
171
+ gradient
172
+ shear
173
+ turbulent_patch
174
+ random_fourier
175
+ ```
176
+
177
  Metrics:
178
 
179
  ```text
 
199
  Evaluation outputs:
200
 
201
  ```text
202
+ experiments/reports/paper_prediction.json
 
 
203
  experiments/reports/paper_flowmo_latent_probes.json
204
  experiments/reports/paper_planning/*.json
205
  experiments/reports/paper_planning/gifs/*.gif
experiments/evaluate_flowmo_latent_probes.py CHANGED
@@ -203,9 +203,7 @@ def main() -> None:
203
  parser.add_argument("--train-episodes", type=int, default=2400)
204
  parser.add_argument("--train-windows", type=int, default=32768)
205
  parser.add_argument("--eval-splits", nargs="+", default=[
206
- "unseen_flow:data/paper/test_unseen_flow.npz:480",
207
- "unseen_boat_params:data/paper/test_unseen_boat_params.npz:480",
208
- "seen_flow_diagnostic:data/paper/diagnostic_seen_flow.npz:480",
209
  ])
210
  parser.add_argument("--eval-windows", type=int, default=8192)
211
  parser.add_argument("--history-len", type=int, default=32)
 
203
  parser.add_argument("--train-episodes", type=int, default=2400)
204
  parser.add_argument("--train-windows", type=int, default=32768)
205
  parser.add_argument("--eval-splits", nargs="+", default=[
206
+ "test:data/paper/test.npz:480",
 
 
207
  ])
208
  parser.add_argument("--eval-windows", type=int, default=8192)
209
  parser.add_argument("--history-len", type=int, default=32)
experiments/evaluate_image_planning.py CHANGED
@@ -13,7 +13,7 @@ import torch
13
  import torch.nn.functional as F
14
 
15
  from driftwm.sim.env import SurfaceBoatEnv
16
- from driftwm.sim.flow import UniformFlow, sample_flow
17
  from driftwm.sim.render import render_frame, save_gif
18
  from experiments.shared.src.methods import PAPER_LEARNED_METHODS, TRADITIONAL_METHODS
19
  from experiments.shared.src.vision.clean_renderer import render_clean_boat_array
@@ -56,31 +56,23 @@ def task_goals(task: str, rng: np.random.Generator) -> np.ndarray:
56
  return np.array([[2.5, 7.0], [4.2, 3.0], [5.8, 7.0], [7.5, 3.0]], dtype=np.float32)
57
  if task == "station_keeping":
58
  return np.array([[5.0, 5.0]], dtype=np.float32)
59
- if task == "counterflow":
60
- return np.array([[8.4, 5.0]], dtype=np.float32)
61
  return np.array([[8.0, 8.0]], dtype=np.float32)
62
 
63
 
 
 
 
 
 
64
  def reset_task(env: SurfaceBoatEnv, task: str, flow_type: str, rng: np.random.Generator) -> None:
65
- if task == "counterflow":
66
- env.reset(
67
- flow_type="uniform",
68
- flow=UniformFlow(np.array([-0.22, 0.0], dtype=np.float32), flow_id=7001),
69
- random_velocity=False,
70
- )
71
- env.state[:6] = np.array([2.0, 5.0, 0.0, 0.0, 0.0, 0.0], dtype=np.float32)
72
- return
73
  if task == "station_keeping":
74
- env.reset(
75
- flow_type="uniform",
76
- flow=UniformFlow(np.array([0.16, 0.10], dtype=np.float32), flow_id=7002),
77
- random_velocity=False,
78
- )
79
- env.state[:6] = np.array([5.0, 5.0, 0.3, 0.0, 0.0, 0.0], dtype=np.float32)
80
  return
81
  flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace)
82
  env.reset(flow_type=flow_type, flow=flow, random_velocity=False)
83
- env.state[:6] = np.array([2.0, 2.0, float(rng.uniform(-np.pi, np.pi)), 0.0, 0.0, 0.0], dtype=np.float32)
84
 
85
 
86
  def rollout_latent(model, z: torch.Tensor, c: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
@@ -420,16 +412,11 @@ def evaluate_one_method(method: str, args) -> dict:
420
  energy = 0.0
421
  reached_times: list[int] = []
422
  min_goal_dists = np.full((len(goals),), np.inf, dtype=np.float32)
423
- passive_steps = args.passive_steps if args.task == "passive_to_active" else 0
424
  planned = None
425
  learned_plan_mean = None
426
  for t in range(args.max_steps):
427
  goal = goals[goal_idx]
428
- if t < passive_steps:
429
- action = np.zeros((env.action_dim,), dtype=np.float32)
430
- planned = None
431
- learned_plan_mean = None
432
- elif learned:
433
  action, planned, learned_plan_mean = learned_plan(
434
  model,
435
  image_history,
@@ -532,12 +519,11 @@ def summarize(method: str, args, results: list[dict]) -> dict:
532
  def main() -> None:
533
  parser = argparse.ArgumentParser()
534
  parser.add_argument("--methods", nargs="+", default=LEARNED_METHODS + TRADITIONAL_METHODS)
535
- parser.add_argument("--task", choices=["reach_uniform", "counterflow", "station_keeping", "passive_to_active", "waypoint_square", "waypoint_zigzag"], default="reach_uniform")
536
  parser.add_argument("--boat", choices=["twin", "triangle"], default="twin")
537
- parser.add_argument("--flow-type", choices=["uniform", "slowly_varying", "vortex_center", "gradient", "turbulent_patch"], default="uniform")
538
  parser.add_argument("--episodes", type=int, default=50)
539
  parser.add_argument("--max-steps", type=int, default=420)
540
- parser.add_argument("--passive-steps", type=int, default=25)
541
  parser.add_argument("--history-len", type=int, default=32)
542
  parser.add_argument("--image-size", type=int, default=160)
543
  parser.add_argument("--visual-scale", type=float, default=2.5)
 
13
  import torch.nn.functional as F
14
 
15
  from driftwm.sim.env import SurfaceBoatEnv
16
+ from driftwm.sim.flow import sample_flow
17
  from driftwm.sim.render import render_frame, save_gif
18
  from experiments.shared.src.methods import PAPER_LEARNED_METHODS, TRADITIONAL_METHODS
19
  from experiments.shared.src.vision.clean_renderer import render_clean_boat_array
 
56
  return np.array([[2.5, 7.0], [4.2, 3.0], [5.8, 7.0], [7.5, 3.0]], dtype=np.float32)
57
  if task == "station_keeping":
58
  return np.array([[5.0, 5.0]], dtype=np.float32)
 
 
59
  return np.array([[8.0, 8.0]], dtype=np.float32)
60
 
61
 
62
+ def set_task_state(env: SurfaceBoatEnv, state: np.ndarray) -> None:
63
+ env.state[:6] = np.asarray(state, dtype=np.float32)
64
+ env.last_flow_velocity = env.flow_at(env.state[:2]).astype(np.float32)
65
+
66
+
67
  def reset_task(env: SurfaceBoatEnv, task: str, flow_type: str, rng: np.random.Generator) -> None:
 
 
 
 
 
 
 
 
68
  if task == "station_keeping":
69
+ flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace)
70
+ env.reset(flow_type=flow_type, flow=flow, random_velocity=False)
71
+ set_task_state(env, np.array([5.0, 5.0, 0.3, 0.0, 0.0, 0.0], dtype=np.float32))
 
 
 
72
  return
73
  flow = sample_flow(flow_type, rng, flow_id=10_000 + int(rng.integers(1, 1_000_000)), workspace=env.workspace)
74
  env.reset(flow_type=flow_type, flow=flow, random_velocity=False)
75
+ set_task_state(env, np.array([2.0, 2.0, float(rng.uniform(-np.pi, np.pi)), 0.0, 0.0, 0.0], dtype=np.float32))
76
 
77
 
78
  def rollout_latent(model, z: torch.Tensor, c: torch.Tensor, actions: torch.Tensor) -> torch.Tensor:
 
412
  energy = 0.0
413
  reached_times: list[int] = []
414
  min_goal_dists = np.full((len(goals),), np.inf, dtype=np.float32)
 
415
  planned = None
416
  learned_plan_mean = None
417
  for t in range(args.max_steps):
418
  goal = goals[goal_idx]
419
+ if learned:
 
 
 
 
420
  action, planned, learned_plan_mean = learned_plan(
421
  model,
422
  image_history,
 
519
  def main() -> None:
520
  parser = argparse.ArgumentParser()
521
  parser.add_argument("--methods", nargs="+", default=LEARNED_METHODS + TRADITIONAL_METHODS)
522
+ parser.add_argument("--task", choices=["reach_target", "station_keeping", "waypoint_square", "waypoint_zigzag"], default="reach_target")
523
  parser.add_argument("--boat", choices=["twin", "triangle"], default="twin")
524
+ parser.add_argument("--flow-type", choices=["noflow", "uniform", "vortex_center", "double_gyre", "source_sink", "source_sink_pair", "gradient", "shear", "turbulent_patch", "random_fourier"], default="uniform")
525
  parser.add_argument("--episodes", type=int, default=50)
526
  parser.add_argument("--max-steps", type=int, default=420)
 
527
  parser.add_argument("--history-len", type=int, default=32)
528
  parser.add_argument("--image-size", type=int, default=160)
529
  parser.add_argument("--visual-scale", type=float, default=2.5)
experiments/evaluate_image_world_models.py CHANGED
@@ -199,7 +199,7 @@ def summarize(pos_mean: np.ndarray, heading_mean: np.ndarray, steps: list[int])
199
  def main() -> None:
200
  parser = argparse.ArgumentParser()
201
  parser.add_argument("--methods", nargs="+", default=METHODS)
202
- parser.add_argument("--test-source", default="data/paper/test_unseen_flow.npz")
203
  parser.add_argument("--test-episodes", type=int, default=256)
204
  parser.add_argument("--history-len", type=int, default=32)
205
  parser.add_argument("--horizon", type=int, default=60)
 
199
  def main() -> None:
200
  parser = argparse.ArgumentParser()
201
  parser.add_argument("--methods", nargs="+", default=METHODS)
202
+ parser.add_argument("--test-source", default="data/paper/test.npz")
203
  parser.add_argument("--test-episodes", type=int, default=256)
204
  parser.add_argument("--history-len", type=int, default=32)
205
  parser.add_argument("--horizon", type=int, default=60)
experiments/flowmo/checkpoint/paper.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:ee9afdebdb83b911c66c0c7fe11d04c710325b59699bf2315cb6429be3cb8048
3
- size 2668751
 
 
 
 
experiments/flowmo/checkpoint/paper_step_002000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:926b64053ac63c43dd5d5e8e814bcd11d3fb3144ed856bf348c21294c1d75891
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_004000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:81a1748bb45882ec85e6259e086b9bb99133bceadfcb506b79c7bafe581af7a8
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_006000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f2866a920b5f48b2278b62205707c40edfed673bfabe18a76163679cee286e27
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_008000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:24359235bc7b043c31206b1abadabd2359ac4c352ac4531dcc675239f5028e7b
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_010000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:4038685bebbe8eedac78242618e6da2d8a36c63f0287dd454e97568977e53b3d
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_012000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:0f5e3047f0a82d16701d743aeb76b8f9746e1ee3f1d74ea4c6a93b9d3f222c6e
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_014000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:091dd0aab8b72fc075279dfc5efce56965c8c9d998ec814042c115a8347a61b9
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_016000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:22f284b27e461ff1abde23bd22cd2f13c38fe3144dedeeb88614bf805679c085
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_018000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:797b2df1d7a20bf54d65059726f4f722cf878374a276c6239a9160a54bea0522
3
- size 2671607
 
 
 
 
experiments/flowmo/checkpoint/paper_step_020000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:78c1bd5c1e5d4663c12e105b243d406eecea5f2e65201e5ee2445ea477818e0f
3
- size 2671607
 
 
 
 
experiments/flowmo/result/paper_training.json DELETED
@@ -1,43 +0,0 @@
1
- {
2
- "method": "flowmo",
3
- "steps": 20000,
4
- "batch_size": 256,
5
- "train_samples": 5120000,
6
- "final_train_loss": 0.004372047260403633,
7
- "total_parameters": 663964,
8
- "target_mode": "absolute_normalized",
9
- "position_scale": 5.0,
10
- "heading_weight": 2.0,
11
- "current_pose_weight": 1.0,
12
- "motion_weight": 0.5,
13
- "precision": "bf16",
14
- "checkpoint_name": "paper.pt",
15
- "final_checkpoint": "paper.pt",
16
- "intermediate_checkpoints": [
17
- "paper_step_002000.pt",
18
- "paper_step_004000.pt",
19
- "paper_step_006000.pt",
20
- "paper_step_008000.pt",
21
- "paper_step_010000.pt",
22
- "paper_step_012000.pt",
23
- "paper_step_014000.pt",
24
- "paper_step_016000.pt",
25
- "paper_step_018000.pt",
26
- "paper_step_020000.pt"
27
- ],
28
- "checkpoint_interval": 2000,
29
- "prediction": {
30
- "pos1": 0.07179122391001631,
31
- "heading1": 0.045128354569897056,
32
- "pos3": 0.07811223280926545,
33
- "heading3": 0.04785821299689511,
34
- "pos6": 0.08915455282355349,
35
- "heading6": 0.05167978972895071,
36
- "pos8": 0.09605406736955047,
37
- "heading8": 0.054166459323217474,
38
- "pos10": 0.10336457837062578,
39
- "heading10": 0.05640091137805333,
40
- "pos20": 0.1460142444508771,
41
- "heading20": 0.06746077448284875
42
- }
43
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/flowmo/result/paper_training_trace.jsonl DELETED
@@ -1,100 +0,0 @@
1
- {"method": "flowmo", "step": 200, "loss": 2.5686612129211426}
2
- {"method": "flowmo", "step": 400, "loss": 2.5599124431610107}
3
- {"method": "flowmo", "step": 600, "loss": 2.1556336879730225}
4
- {"method": "flowmo", "step": 800, "loss": 1.8842406272888184}
5
- {"method": "flowmo", "step": 1000, "loss": 1.8033578395843506}
6
- {"method": "flowmo", "step": 1200, "loss": 1.6628687381744385}
7
- {"method": "flowmo", "step": 1400, "loss": 1.5624256134033203}
8
- {"method": "flowmo", "step": 1600, "loss": 1.4740279912948608}
9
- {"method": "flowmo", "step": 1800, "loss": 1.3885352611541748}
10
- {"method": "flowmo", "step": 2000, "loss": 1.324019193649292}
11
- {"method": "flowmo", "step": 2200, "loss": 1.0553849935531616}
12
- {"method": "flowmo", "step": 2400, "loss": 0.6030915975570679}
13
- {"method": "flowmo", "step": 2600, "loss": 0.20210689306259155}
14
- {"method": "flowmo", "step": 2800, "loss": 0.14605776965618134}
15
- {"method": "flowmo", "step": 3000, "loss": 0.09821392595767975}
16
- {"method": "flowmo", "step": 3200, "loss": 0.07479490339756012}
17
- {"method": "flowmo", "step": 3400, "loss": 0.059080999344587326}
18
- {"method": "flowmo", "step": 3600, "loss": 0.051076825708150864}
19
- {"method": "flowmo", "step": 3800, "loss": 0.042308710515499115}
20
- {"method": "flowmo", "step": 4000, "loss": 0.040146660059690475}
21
- {"method": "flowmo", "step": 4200, "loss": 0.03381121903657913}
22
- {"method": "flowmo", "step": 4400, "loss": 0.033231284469366074}
23
- {"method": "flowmo", "step": 4600, "loss": 0.028457675129175186}
24
- {"method": "flowmo", "step": 4800, "loss": 0.029077233746647835}
25
- {"method": "flowmo", "step": 5000, "loss": 0.02207356132566929}
26
- {"method": "flowmo", "step": 5200, "loss": 0.020034978166222572}
27
- {"method": "flowmo", "step": 5400, "loss": 0.019785162061452866}
28
- {"method": "flowmo", "step": 5600, "loss": 0.018391719087958336}
29
- {"method": "flowmo", "step": 5800, "loss": 0.02175654098391533}
30
- {"method": "flowmo", "step": 6000, "loss": 0.015171783976256847}
31
- {"method": "flowmo", "step": 6200, "loss": 0.01452728919684887}
32
- {"method": "flowmo", "step": 6400, "loss": 0.013214356265962124}
33
- {"method": "flowmo", "step": 6600, "loss": 0.051673468202352524}
34
- {"method": "flowmo", "step": 6800, "loss": 0.018827352672815323}
35
- {"method": "flowmo", "step": 7000, "loss": 0.012735347263514996}
36
- {"method": "flowmo", "step": 7200, "loss": 0.011451991274952888}
37
- {"method": "flowmo", "step": 7400, "loss": 0.010433687828481197}
38
- {"method": "flowmo", "step": 7600, "loss": 0.010923548601567745}
39
- {"method": "flowmo", "step": 7800, "loss": 0.010971073061227798}
40
- {"method": "flowmo", "step": 8000, "loss": 0.009853748604655266}
41
- {"method": "flowmo", "step": 8200, "loss": 0.09088479727506638}
42
- {"method": "flowmo", "step": 8400, "loss": 0.034223418682813644}
43
- {"method": "flowmo", "step": 8600, "loss": 0.014456425793468952}
44
- {"method": "flowmo", "step": 8800, "loss": 0.009422067552804947}
45
- {"method": "flowmo", "step": 9000, "loss": 0.00858109537512064}
46
- {"method": "flowmo", "step": 9200, "loss": 0.00857796985656023}
47
- {"method": "flowmo", "step": 9400, "loss": 0.008296442218124866}
48
- {"method": "flowmo", "step": 9600, "loss": 0.008247998543083668}
49
- {"method": "flowmo", "step": 9800, "loss": 0.008240980096161366}
50
- {"method": "flowmo", "step": 10000, "loss": 0.008153271861374378}
51
- {"method": "flowmo", "step": 10200, "loss": 0.012404488399624825}
52
- {"method": "flowmo", "step": 10400, "loss": 0.013864593580365181}
53
- {"method": "flowmo", "step": 10600, "loss": 0.01010044477880001}
54
- {"method": "flowmo", "step": 10800, "loss": 0.00767604261636734}
55
- {"method": "flowmo", "step": 11000, "loss": 0.007007307838648558}
56
- {"method": "flowmo", "step": 11200, "loss": 0.0070138657465577126}
57
- {"method": "flowmo", "step": 11400, "loss": 0.007243836764246225}
58
- {"method": "flowmo", "step": 11600, "loss": 0.006900576408952475}
59
- {"method": "flowmo", "step": 11800, "loss": 0.0068667978048324585}
60
- {"method": "flowmo", "step": 12000, "loss": 0.006599605083465576}
61
- {"method": "flowmo", "step": 12200, "loss": 0.007158435881137848}
62
- {"method": "flowmo", "step": 12400, "loss": 0.045721929520368576}
63
- {"method": "flowmo", "step": 12600, "loss": 0.006790271960198879}
64
- {"method": "flowmo", "step": 12800, "loss": 0.0060927667655050755}
65
- {"method": "flowmo", "step": 13000, "loss": 0.005786360241472721}
66
- {"method": "flowmo", "step": 13200, "loss": 0.00602421211078763}
67
- {"method": "flowmo", "step": 13400, "loss": 0.005942641757428646}
68
- {"method": "flowmo", "step": 13600, "loss": 0.006074435543268919}
69
- {"method": "flowmo", "step": 13800, "loss": 0.021174009889364243}
70
- {"method": "flowmo", "step": 14000, "loss": 0.006621338427066803}
71
- {"method": "flowmo", "step": 14200, "loss": 0.005491666030138731}
72
- {"method": "flowmo", "step": 14400, "loss": 0.0051383040845394135}
73
- {"method": "flowmo", "step": 14600, "loss": 0.005117133259773254}
74
- {"method": "flowmo", "step": 14800, "loss": 0.0053353263065218925}
75
- {"method": "flowmo", "step": 15000, "loss": 0.00533561734482646}
76
- {"method": "flowmo", "step": 15200, "loss": 0.005121554713696241}
77
- {"method": "flowmo", "step": 15400, "loss": 0.005291329696774483}
78
- {"method": "flowmo", "step": 15600, "loss": 0.00511613953858614}
79
- {"method": "flowmo", "step": 15800, "loss": 0.005213129799813032}
80
- {"method": "flowmo", "step": 16000, "loss": 0.005071689374744892}
81
- {"method": "flowmo", "step": 16200, "loss": 0.0057407282292842865}
82
- {"method": "flowmo", "step": 16400, "loss": 0.0054640620946884155}
83
- {"method": "flowmo", "step": 16600, "loss": 0.005037755239754915}
84
- {"method": "flowmo", "step": 16800, "loss": 0.004956530407071114}
85
- {"method": "flowmo", "step": 17000, "loss": 0.24728184938430786}
86
- {"method": "flowmo", "step": 17200, "loss": 0.0358046218752861}
87
- {"method": "flowmo", "step": 17400, "loss": 0.005404628813266754}
88
- {"method": "flowmo", "step": 17600, "loss": 0.004861537832766771}
89
- {"method": "flowmo", "step": 17800, "loss": 0.0046697030775249004}
90
- {"method": "flowmo", "step": 18000, "loss": 0.00479076337069273}
91
- {"method": "flowmo", "step": 18200, "loss": 0.0045429919846355915}
92
- {"method": "flowmo", "step": 18400, "loss": 0.004368708468973637}
93
- {"method": "flowmo", "step": 18600, "loss": 0.004182927776128054}
94
- {"method": "flowmo", "step": 18800, "loss": 0.004190036095678806}
95
- {"method": "flowmo", "step": 19000, "loss": 0.004384973086416721}
96
- {"method": "flowmo", "step": 19200, "loss": 0.006023879628628492}
97
- {"method": "flowmo", "step": 19400, "loss": 0.004392072558403015}
98
- {"method": "flowmo", "step": 19600, "loss": 0.004413294140249491}
99
- {"method": "flowmo", "step": 19800, "loss": 0.004305647686123848}
100
- {"method": "flowmo", "step": 20000, "loss": 0.004372047260403633}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/flowmo/result/parameter_count.json DELETED
@@ -1,11 +0,0 @@
1
- {
2
- "encoder": 340416,
3
- "state_history": 75648,
4
- "context_history": 75648,
5
- "to_z": 30960,
6
- "to_c": 17544,
7
- "base_delta": 45808,
8
- "residual_delta": 46448,
9
- "decoder": 31492,
10
- "total": 663964
11
- }
 
 
 
 
 
 
 
 
 
 
 
 
experiments/leworldmodel/checkpoint/paper.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8837aa8f630fb2cc87f31b47c4df3f0cf207202247e1fffd066d668206fb8b5f
3
- size 2667147
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_002000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8670c6b88de48a2f724e55062f7650b4d47f4481e198a47656555957749c18c3
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_004000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f09493171c08aa2ae742644344e12e1b8cc5515830cdead18bcf0684583d48cc
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_006000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:53fa0a4b771dc3d9443309f449306b28a6175006220ae5a7a7261bf0f8bd0678
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_008000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c7fca57df185e5a6cc9d03fb85c8d229a680a4dc6bbdf83d5a328b11bc073b0d
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_010000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:43d43615527e310c1e29b5c9856a129b6fcd8401e9d55a14d37dfa29c154f4ef
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_012000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8e6ee23f7ba9d61726ff278bef72b5e7ce578884d86acccb441599954dfaa927
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_014000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:86ef98d7c0a6b51735b0d7ab996380b365b910849d824deca60d52cd4138d4d6
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_016000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9670fe4385c12cb1a3411446d546ee1462fe57a1aea29810ee625e2d382a2564
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_018000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:fe55552dc42dc83bcb47872f43b05a1909ea1f69d59f6fe9ba977c1ca9d5140f
3
- size 2669155
 
 
 
 
experiments/leworldmodel/checkpoint/paper_step_020000.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:641ef4793bde88c7570d3811bf0d931adae7455cfc9d477dadb0b088878ea37b
3
- size 2669155
 
 
 
 
experiments/leworldmodel/result/paper_training.json DELETED
@@ -1,43 +0,0 @@
1
- {
2
- "method": "leworldmodel",
3
- "steps": 20000,
4
- "batch_size": 256,
5
- "train_samples": 5120000,
6
- "final_train_loss": 0.018198927864432335,
7
- "total_parameters": 664612,
8
- "target_mode": "absolute_normalized",
9
- "position_scale": 5.0,
10
- "heading_weight": 2.0,
11
- "current_pose_weight": 1.0,
12
- "motion_weight": 0.5,
13
- "precision": "bf16",
14
- "checkpoint_name": "paper.pt",
15
- "final_checkpoint": "paper.pt",
16
- "intermediate_checkpoints": [
17
- "paper_step_002000.pt",
18
- "paper_step_004000.pt",
19
- "paper_step_006000.pt",
20
- "paper_step_008000.pt",
21
- "paper_step_010000.pt",
22
- "paper_step_012000.pt",
23
- "paper_step_014000.pt",
24
- "paper_step_016000.pt",
25
- "paper_step_018000.pt",
26
- "paper_step_020000.pt"
27
- ],
28
- "checkpoint_interval": 2000,
29
- "prediction": {
30
- "pos1": 0.10107660254773994,
31
- "heading1": 0.0479962204505379,
32
- "pos3": 0.11733688049328823,
33
- "heading3": 0.054719595277371504,
34
- "pos6": 0.14297669551645717,
35
- "heading6": 0.0649993756475548,
36
- "pos8": 0.1607180543554326,
37
- "heading8": 0.07113518511566023,
38
- "pos10": 0.1783415130339563,
39
- "heading10": 0.07669965278667708,
40
- "pos20": 0.25807287978629273,
41
- "heading20": 0.0995667139844348
42
- }
43
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/leworldmodel/result/paper_training_trace.jsonl DELETED
@@ -1,100 +0,0 @@
1
- {"method": "leworldmodel", "step": 200, "loss": 2.569669246673584}
2
- {"method": "leworldmodel", "step": 400, "loss": 2.5372421741485596}
3
- {"method": "leworldmodel", "step": 600, "loss": 2.037346363067627}
4
- {"method": "leworldmodel", "step": 800, "loss": 1.8627333641052246}
5
- {"method": "leworldmodel", "step": 1000, "loss": 1.8265268802642822}
6
- {"method": "leworldmodel", "step": 1200, "loss": 1.7777857780456543}
7
- {"method": "leworldmodel", "step": 1400, "loss": 1.753949522972107}
8
- {"method": "leworldmodel", "step": 1600, "loss": 1.7200227975845337}
9
- {"method": "leworldmodel", "step": 1800, "loss": 1.6717208623886108}
10
- {"method": "leworldmodel", "step": 2000, "loss": 1.391864538192749}
11
- {"method": "leworldmodel", "step": 2200, "loss": 0.7442419528961182}
12
- {"method": "leworldmodel", "step": 2400, "loss": 0.2299843579530716}
13
- {"method": "leworldmodel", "step": 2600, "loss": 0.13865336775779724}
14
- {"method": "leworldmodel", "step": 2800, "loss": 0.09015313535928726}
15
- {"method": "leworldmodel", "step": 3000, "loss": 0.06993035972118378}
16
- {"method": "leworldmodel", "step": 3200, "loss": 0.05781794339418411}
17
- {"method": "leworldmodel", "step": 3400, "loss": 0.05077105388045311}
18
- {"method": "leworldmodel", "step": 3600, "loss": 0.043729424476623535}
19
- {"method": "leworldmodel", "step": 3800, "loss": 0.040191106498241425}
20
- {"method": "leworldmodel", "step": 4000, "loss": 0.03818749263882637}
21
- {"method": "leworldmodel", "step": 4200, "loss": 0.03584111109375954}
22
- {"method": "leworldmodel", "step": 4400, "loss": 0.03268589824438095}
23
- {"method": "leworldmodel", "step": 4600, "loss": 0.030512923374772072}
24
- {"method": "leworldmodel", "step": 4800, "loss": 0.028514141216874123}
25
- {"method": "leworldmodel", "step": 5000, "loss": 0.026582585647702217}
26
- {"method": "leworldmodel", "step": 5200, "loss": 0.02657574787735939}
27
- {"method": "leworldmodel", "step": 5400, "loss": 0.03360811248421669}
28
- {"method": "leworldmodel", "step": 5600, "loss": 0.024245228618383408}
29
- {"method": "leworldmodel", "step": 5800, "loss": 0.025151818990707397}
30
- {"method": "leworldmodel", "step": 6000, "loss": 0.02470393292605877}
31
- {"method": "leworldmodel", "step": 6200, "loss": 0.022554941475391388}
32
- {"method": "leworldmodel", "step": 6400, "loss": 0.021832682192325592}
33
- {"method": "leworldmodel", "step": 6600, "loss": 0.02138935960829258}
34
- {"method": "leworldmodel", "step": 6800, "loss": 0.021714694797992706}
35
- {"method": "leworldmodel", "step": 7000, "loss": 0.022361399605870247}
36
- {"method": "leworldmodel", "step": 7200, "loss": 0.021948281675577164}
37
- {"method": "leworldmodel", "step": 7400, "loss": 0.023297373205423355}
38
- {"method": "leworldmodel", "step": 7600, "loss": 0.01953437551856041}
39
- {"method": "leworldmodel", "step": 7800, "loss": 0.019761236384510994}
40
- {"method": "leworldmodel", "step": 8000, "loss": 0.018553584814071655}
41
- {"method": "leworldmodel", "step": 8200, "loss": 0.01823526993393898}
42
- {"method": "leworldmodel", "step": 8400, "loss": 0.01854352466762066}
43
- {"method": "leworldmodel", "step": 8600, "loss": 0.019154751673340797}
44
- {"method": "leworldmodel", "step": 8800, "loss": 0.018928799778223038}
45
- {"method": "leworldmodel", "step": 9000, "loss": 0.01909957453608513}
46
- {"method": "leworldmodel", "step": 9200, "loss": 0.018046768382191658}
47
- {"method": "leworldmodel", "step": 9400, "loss": 0.016244517639279366}
48
- {"method": "leworldmodel", "step": 9600, "loss": 0.016833283007144928}
49
- {"method": "leworldmodel", "step": 9800, "loss": 0.01727704517543316}
50
- {"method": "leworldmodel", "step": 10000, "loss": 0.017903970554471016}
51
- {"method": "leworldmodel", "step": 10200, "loss": 0.01650647632777691}
52
- {"method": "leworldmodel", "step": 10400, "loss": 0.015822188928723335}
53
- {"method": "leworldmodel", "step": 10600, "loss": 0.021761486306786537}
54
- {"method": "leworldmodel", "step": 10800, "loss": 0.01572641171514988}
55
- {"method": "leworldmodel", "step": 11000, "loss": 0.014744052663445473}
56
- {"method": "leworldmodel", "step": 11200, "loss": 0.014771764166653156}
57
- {"method": "leworldmodel", "step": 11400, "loss": 0.015161859802901745}
58
- {"method": "leworldmodel", "step": 11600, "loss": 0.015039228834211826}
59
- {"method": "leworldmodel", "step": 11800, "loss": 0.01450162474066019}
60
- {"method": "leworldmodel", "step": 12000, "loss": 0.014639433473348618}
61
- {"method": "leworldmodel", "step": 12200, "loss": 0.014432272873818874}
62
- {"method": "leworldmodel", "step": 12400, "loss": 0.05746567249298096}
63
- {"method": "leworldmodel", "step": 12600, "loss": 0.01567252166569233}
64
- {"method": "leworldmodel", "step": 12800, "loss": 0.013239766471087933}
65
- {"method": "leworldmodel", "step": 13000, "loss": 0.01337014976888895}
66
- {"method": "leworldmodel", "step": 13200, "loss": 0.013945686630904675}
67
- {"method": "leworldmodel", "step": 13400, "loss": 0.013215066865086555}
68
- {"method": "leworldmodel", "step": 13600, "loss": 0.013161891140043736}
69
- {"method": "leworldmodel", "step": 13800, "loss": 0.013161612674593925}
70
- {"method": "leworldmodel", "step": 14000, "loss": 0.013272494077682495}
71
- {"method": "leworldmodel", "step": 14200, "loss": 0.012501145713031292}
72
- {"method": "leworldmodel", "step": 14400, "loss": 0.01319703459739685}
73
- {"method": "leworldmodel", "step": 14600, "loss": 0.01253820862621069}
74
- {"method": "leworldmodel", "step": 14800, "loss": 0.013268169946968555}
75
- {"method": "leworldmodel", "step": 15000, "loss": 0.012286090292036533}
76
- {"method": "leworldmodel", "step": 15200, "loss": 0.012689301744103432}
77
- {"method": "leworldmodel", "step": 15400, "loss": 0.018598034977912903}
78
- {"method": "leworldmodel", "step": 15600, "loss": 0.011223368346691132}
79
- {"method": "leworldmodel", "step": 15800, "loss": 0.011263682506978512}
80
- {"method": "leworldmodel", "step": 16000, "loss": 0.011280846782028675}
81
- {"method": "leworldmodel", "step": 16200, "loss": 0.011259369552135468}
82
- {"method": "leworldmodel", "step": 16400, "loss": 0.012492901645600796}
83
- {"method": "leworldmodel", "step": 16600, "loss": 0.011446918360888958}
84
- {"method": "leworldmodel", "step": 16800, "loss": 0.01105540618300438}
85
- {"method": "leworldmodel", "step": 17000, "loss": 0.010708491317927837}
86
- {"method": "leworldmodel", "step": 17200, "loss": 0.01076575368642807}
87
- {"method": "leworldmodel", "step": 17400, "loss": 0.010710487142205238}
88
- {"method": "leworldmodel", "step": 17600, "loss": 0.010570250451564789}
89
- {"method": "leworldmodel", "step": 17800, "loss": 0.010576384142041206}
90
- {"method": "leworldmodel", "step": 18000, "loss": 0.011023864150047302}
91
- {"method": "leworldmodel", "step": 18200, "loss": 0.010448881424963474}
92
- {"method": "leworldmodel", "step": 18400, "loss": 0.010450177825987339}
93
- {"method": "leworldmodel", "step": 18600, "loss": 0.010127122513949871}
94
- {"method": "leworldmodel", "step": 18800, "loss": 0.010028455406427383}
95
- {"method": "leworldmodel", "step": 19000, "loss": 0.010146670043468475}
96
- {"method": "leworldmodel", "step": 19200, "loss": 0.010228784754872322}
97
- {"method": "leworldmodel", "step": 19400, "loss": 0.009929011575877666}
98
- {"method": "leworldmodel", "step": 19600, "loss": 0.009595287963747978}
99
- {"method": "leworldmodel", "step": 19800, "loss": 0.00958178285509348}
100
- {"method": "leworldmodel", "step": 20000, "loss": 0.018198927864432335}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
experiments/leworldmodel/result/parameter_count.json DELETED
@@ -1,7 +0,0 @@
1
- {
2
- "encoder": 471584,
3
- "to_z": 43328,
4
- "transition": 87104,
5
- "decoder": 62596,
6
- "total": 664612
7
- }
 
 
 
 
 
 
 
 
experiments/planet/checkpoint/paper.pt DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:6ad00b1fb32dbdf8bd8e3346f33809449b3e322b2d8c63215122b4d2b6037281
3
- size 2667969