Commit ·
b0dda0c
0
Parent(s):
Initial commit
Browse files- README.md +82 -0
- physics_world.py +314 -0
README.md
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Staticplay CurioDynamics
|
| 2 |
+
|
| 3 |
+
Curiosity-driven symbolic physics discovery from raw trajectories (position, velocity, time). The agent observes only state and learns a sparse symbolic model that recovers hidden dynamics such as gravity, drag, and wind.
|
| 4 |
+
|
| 5 |
+
## What It Does
|
| 6 |
+
- Generates a hidden physics world (projectile motion with drag + wind).
|
| 7 |
+
- Observes only `position`, `velocity`, and `time`.
|
| 8 |
+
- Learns a symbolic acceleration model via sparse regression.
|
| 9 |
+
- Recovers interpretable formulas like:
|
| 10 |
+
- `ax = wind − k * vx * |vx|`
|
| 11 |
+
- `ay = −g − k * vy * |vy|`
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
Requires Python 3.10+.
|
| 15 |
+
|
| 16 |
+
```powershell
|
| 17 |
+
python .\physics_world.py
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## How It Works
|
| 21 |
+
1. Simulates a hidden world with gravity + drag + wind.
|
| 22 |
+
2. Collects state transitions only (no equations given).
|
| 23 |
+
3. Fits a sparse symbolic model over a small feature library.
|
| 24 |
+
4. Interprets the recovered constants as wind and gravity.
|
| 25 |
+
|
| 26 |
+
## Latest Test Results
|
| 27 |
+
Hidden world: quadratic drag + wind
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
g_true = 9.81
|
| 31 |
+
drag_true = 0.12
|
| 32 |
+
wind_true = 0.4
|
| 33 |
+
|
| 34 |
+
model_ax: ax = 0.4 − 0.12 * vx|vx|
|
| 35 |
+
model_ay: ay = −9.81 − 0.12 * vy|vy|
|
| 36 |
+
wind_est = 0.4
|
| 37 |
+
g_est = 9.81
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
## Install
|
| 41 |
+
No external dependencies beyond standard library.
|
| 42 |
+
|
| 43 |
+
If you want a virtual environment:
|
| 44 |
+
```powershell
|
| 45 |
+
python -m venv .venv
|
| 46 |
+
.\.venv\Scripts\activate
|
| 47 |
+
python .\physics_world.py
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
## Files
|
| 51 |
+
- `physics_world.py`: hidden physics world + curiosity agent + symbolic regression
|
| 52 |
+
|
| 53 |
+
## How To Run Different Scenarios
|
| 54 |
+
Edit `run_stress_test(...)` in `physics_world.py`:
|
| 55 |
+
- `wind`
|
| 56 |
+
- `drag`
|
| 57 |
+
- `drag_power`
|
| 58 |
+
- `episodes`
|
| 59 |
+
- `steps`
|
| 60 |
+
|
| 61 |
+
Example:
|
| 62 |
+
```python
|
| 63 |
+
result = run_stress_test(
|
| 64 |
+
episodes=50,
|
| 65 |
+
steps=200,
|
| 66 |
+
g=9.81,
|
| 67 |
+
drag=0.12,
|
| 68 |
+
wind=0.4,
|
| 69 |
+
drag_power=2.0,
|
| 70 |
+
dt=0.1
|
| 71 |
+
)
|
| 72 |
+
print(result)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## Notes
|
| 76 |
+
This is intentionally minimal and deterministic to highlight discovery mechanics. The feature library is constrained to keep the model interpretable.
|
| 77 |
+
|
| 78 |
+
## Acknowledgements
|
| 79 |
+
- teliov/symcat-to-synthea
|
| 80 |
+
|
| 81 |
+
## Links
|
| 82 |
+
- https://staticplay.co.uk
|
physics_world.py
ADDED
|
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import random
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class State:
|
| 8 |
+
x: float
|
| 9 |
+
y: float
|
| 10 |
+
vx: float
|
| 11 |
+
vy: float
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class ProjectileWorld:
|
| 15 |
+
def __init__(self, g=9.81, drag=0.12, wind=0.4, dt=0.1, drag_power=2.0):
|
| 16 |
+
self.g = g
|
| 17 |
+
self.drag = drag
|
| 18 |
+
self.wind = wind
|
| 19 |
+
self.dt = dt
|
| 20 |
+
self.drag_power = drag_power
|
| 21 |
+
|
| 22 |
+
def step(self, s: State) -> State:
|
| 23 |
+
# Hidden physics: quadratic drag + wind
|
| 24 |
+
ax = -self.drag * s.vx * abs(s.vx) + self.wind
|
| 25 |
+
ay = -self.g - self.drag * s.vy * abs(s.vy)
|
| 26 |
+
|
| 27 |
+
vx = s.vx + ax * self.dt
|
| 28 |
+
vy = s.vy + ay * self.dt
|
| 29 |
+
x = s.x + vx * self.dt
|
| 30 |
+
y = s.y + vy * self.dt
|
| 31 |
+
return State(x=x, y=y, vx=vx, vy=vy)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CuriosityAgent:
|
| 35 |
+
def __init__(self, dt=0.1):
|
| 36 |
+
self.dt = dt
|
| 37 |
+
# learned symbolic model coefficients for ax and ay
|
| 38 |
+
self.model_ax = {}
|
| 39 |
+
self.model_ay = {}
|
| 40 |
+
# Invented concepts
|
| 41 |
+
self.invented = {}
|
| 42 |
+
self.surprise_window = []
|
| 43 |
+
self.err_vx_window = []
|
| 44 |
+
self.window_size = 20
|
| 45 |
+
self.surprise_threshold = 0.06
|
| 46 |
+
self.stable_threshold = 0.5
|
| 47 |
+
self.drag_power_window = []
|
| 48 |
+
# regression buffers per episode
|
| 49 |
+
self.samples = []
|
| 50 |
+
self.feature_means = {"vx_abs_vx": 0.0, "vy_abs_vy": 0.0}
|
| 51 |
+
|
| 52 |
+
def predict(self, s: State) -> State:
|
| 53 |
+
# If no model yet, predict no acceleration
|
| 54 |
+
if not self.model_ax and not self.model_ay:
|
| 55 |
+
ax = 0.0
|
| 56 |
+
ay = 0.0
|
| 57 |
+
else:
|
| 58 |
+
ax = self._eval_model(self.model_ax, s)
|
| 59 |
+
ay = self._eval_model(self.model_ay, s)
|
| 60 |
+
vx = s.vx + ax * self.dt
|
| 61 |
+
vy = s.vy + ay * self.dt
|
| 62 |
+
x = s.x + vx * self.dt
|
| 63 |
+
y = s.y + vy * self.dt
|
| 64 |
+
return State(x=x, y=y, vx=vx, vy=vy)
|
| 65 |
+
|
| 66 |
+
def update(self, s: State, s_next: State):
|
| 67 |
+
pred = self.predict(s)
|
| 68 |
+
# prediction error on velocity (scaled)
|
| 69 |
+
err_vx = s_next.vx - pred.vx
|
| 70 |
+
err_vy = s_next.vy - pred.vy
|
| 71 |
+
surprise = math.sqrt(err_vx * err_vx + err_vy * err_vy)
|
| 72 |
+
|
| 73 |
+
# keep surprise window
|
| 74 |
+
self.surprise_window.append(surprise)
|
| 75 |
+
if len(self.surprise_window) > self.window_size:
|
| 76 |
+
self.surprise_window.pop(0)
|
| 77 |
+
self.err_vx_window.append(err_vx)
|
| 78 |
+
if len(self.err_vx_window) > self.window_size:
|
| 79 |
+
self.err_vx_window.pop(0)
|
| 80 |
+
|
| 81 |
+
# store samples for regression (ax, ay from finite differences)
|
| 82 |
+
ax = (s_next.vx - s.vx) / self.dt
|
| 83 |
+
ay = (s_next.vy - s.vy) / self.dt
|
| 84 |
+
self.samples.append((s.vx, s.vy, ax, ay))
|
| 85 |
+
|
| 86 |
+
self._maybe_invent(surprise)
|
| 87 |
+
|
| 88 |
+
return surprise
|
| 89 |
+
|
| 90 |
+
def _maybe_invent(self, surprise):
|
| 91 |
+
if len(self.surprise_window) < self.window_size:
|
| 92 |
+
return
|
| 93 |
+
high = sum(1 for s in self.surprise_window if s > self.surprise_threshold)
|
| 94 |
+
ratio = high / self.window_size
|
| 95 |
+
if ratio >= self.stable_threshold and "drag" not in self.invented:
|
| 96 |
+
self.invented["drag"] = {
|
| 97 |
+
"confidence": round(ratio, 3),
|
| 98 |
+
"evidence_window": list(self.surprise_window),
|
| 99 |
+
}
|
| 100 |
+
# wind discovery: persistent bias in vx error
|
| 101 |
+
if "model_update" not in self.invented and ratio >= self.stable_threshold:
|
| 102 |
+
self.invented["model_update"] = {"confidence": round(ratio, 3)}
|
| 103 |
+
|
| 104 |
+
def fit_params(self):
|
| 105 |
+
if len(self.samples) < 20:
|
| 106 |
+
self.samples.clear()
|
| 107 |
+
return
|
| 108 |
+
|
| 109 |
+
# Symbolic regression via sparse linear model on feature library
|
| 110 |
+
features_ax_linear = []
|
| 111 |
+
features_ax_quad = []
|
| 112 |
+
features_ay_linear = []
|
| 113 |
+
features_ay_quad = []
|
| 114 |
+
targets_ax = []
|
| 115 |
+
targets_ay = []
|
| 116 |
+
mean_vx_abs_vx = sum((vx * abs(vx)) for vx, _, _, _ in self.samples) / len(self.samples)
|
| 117 |
+
mean_vy_abs_vy = sum((vy * abs(vy)) for _, vy, _, _ in self.samples) / len(self.samples)
|
| 118 |
+
self.feature_means["vx_abs_vx"] = mean_vx_abs_vx
|
| 119 |
+
self.feature_means["vy_abs_vy"] = mean_vy_abs_vy
|
| 120 |
+
|
| 121 |
+
for vx, vy, ax, ay in self.samples:
|
| 122 |
+
features_ax_linear.append({"1": 1.0, "vx": vx})
|
| 123 |
+
features_ax_quad.append({"1": 1.0, "vx_abs_vx": (vx * abs(vx)) - mean_vx_abs_vx})
|
| 124 |
+
features_ay_linear.append({"1": 1.0, "vy": vy})
|
| 125 |
+
features_ay_quad.append({"1": 1.0, "vy_abs_vy": (vy * abs(vy)) - mean_vy_abs_vy})
|
| 126 |
+
targets_ax.append(ax)
|
| 127 |
+
targets_ay.append(ay)
|
| 128 |
+
|
| 129 |
+
coeff_ax_lin, mse_ax_lin = self._fit_sparse(features_ax_linear, targets_ax, return_mse=True, center=True)
|
| 130 |
+
coeff_ax_quad, mse_ax_quad = self._fit_sparse(features_ax_quad, targets_ax, return_mse=True, center=True)
|
| 131 |
+
coeff_ay_lin, mse_ay_lin = self._fit_sparse(features_ay_linear, targets_ay, return_mse=True, center=True)
|
| 132 |
+
coeff_ay_quad, mse_ay_quad = self._fit_sparse(features_ay_quad, targets_ay, return_mse=True, center=True)
|
| 133 |
+
|
| 134 |
+
coeff_ax = coeff_ax_quad if mse_ax_quad < mse_ax_lin else coeff_ax_lin
|
| 135 |
+
coeff_ay = coeff_ay_quad if mse_ay_quad < mse_ay_lin else coeff_ay_lin
|
| 136 |
+
self.model_ax = coeff_ax
|
| 137 |
+
self.model_ay = coeff_ay
|
| 138 |
+
|
| 139 |
+
if self.model_ax or self.model_ay:
|
| 140 |
+
self.invented.setdefault(
|
| 141 |
+
"symbolic_model",
|
| 142 |
+
{"terms_ax": list(self.model_ax.keys()), "terms_ay": list(self.model_ay.keys())},
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
self.samples.clear()
|
| 146 |
+
|
| 147 |
+
def _fit_sparse(self, feature_rows, targets, return_mse=False, center=False):
|
| 148 |
+
# Sequential Thresholded Least Squares (SINDy-style)
|
| 149 |
+
keys = list(feature_rows[0].keys())
|
| 150 |
+
n = len(feature_rows)
|
| 151 |
+
|
| 152 |
+
# Build design matrix
|
| 153 |
+
X = [[row[k] for k in keys] for row in feature_rows]
|
| 154 |
+
y = targets[:]
|
| 155 |
+
y_mean = 0.0
|
| 156 |
+
if center:
|
| 157 |
+
y_mean = sum(y) / len(y)
|
| 158 |
+
y = [v - y_mean for v in y]
|
| 159 |
+
|
| 160 |
+
# Normalize columns (except constant)
|
| 161 |
+
means = [0.0] * len(keys)
|
| 162 |
+
stds = [1.0] * len(keys)
|
| 163 |
+
for j, k in enumerate(keys):
|
| 164 |
+
if k == "1":
|
| 165 |
+
means[j] = 0.0
|
| 166 |
+
stds[j] = 1.0
|
| 167 |
+
continue
|
| 168 |
+
col = [X[i][j] for i in range(n)]
|
| 169 |
+
m = sum(col) / n
|
| 170 |
+
v = sum((c - m) ** 2 for c in col) / n
|
| 171 |
+
s = math.sqrt(v) if v > 1e-12 else 1.0
|
| 172 |
+
means[j] = m
|
| 173 |
+
stds[j] = s
|
| 174 |
+
for i in range(n):
|
| 175 |
+
X[i][j] = (X[i][j] - m) / s
|
| 176 |
+
|
| 177 |
+
active = set(range(len(keys)))
|
| 178 |
+
coeff = [0.0] * len(keys)
|
| 179 |
+
|
| 180 |
+
def solve_least_squares(active_idx):
|
| 181 |
+
# Normal equations on active set
|
| 182 |
+
a_idx = sorted(active_idx)
|
| 183 |
+
m = len(a_idx)
|
| 184 |
+
if m == 0:
|
| 185 |
+
return [0.0] * len(keys)
|
| 186 |
+
xtx = [[0.0 for _ in range(m)] for _ in range(m)]
|
| 187 |
+
xty = [0.0 for _ in range(m)]
|
| 188 |
+
for i in range(n):
|
| 189 |
+
row = [X[i][j] for j in a_idx]
|
| 190 |
+
for r in range(m):
|
| 191 |
+
xty[r] += row[r] * y[i]
|
| 192 |
+
for c in range(m):
|
| 193 |
+
xtx[r][c] += row[r] * row[c]
|
| 194 |
+
# Gauss-Seidel
|
| 195 |
+
beta = [0.0] * m
|
| 196 |
+
for _ in range(30):
|
| 197 |
+
for r in range(m):
|
| 198 |
+
denom = xtx[r][r] if abs(xtx[r][r]) > 1e-8 else 1e-8
|
| 199 |
+
num = xty[r]
|
| 200 |
+
for c in range(m):
|
| 201 |
+
if c == r:
|
| 202 |
+
continue
|
| 203 |
+
num -= xtx[r][c] * beta[c]
|
| 204 |
+
beta[r] = num / denom
|
| 205 |
+
full = [0.0] * len(keys)
|
| 206 |
+
for r, j in enumerate(a_idx):
|
| 207 |
+
full[j] = beta[r]
|
| 208 |
+
return full
|
| 209 |
+
|
| 210 |
+
# Iterative thresholding
|
| 211 |
+
for _ in range(6):
|
| 212 |
+
coeff = solve_least_squares(active)
|
| 213 |
+
# Unnormalize coefficients
|
| 214 |
+
coeff_unnorm = coeff[:]
|
| 215 |
+
for j, k in enumerate(keys):
|
| 216 |
+
if k == "1":
|
| 217 |
+
continue
|
| 218 |
+
coeff_unnorm[j] = coeff[j] / stds[j]
|
| 219 |
+
# Threshold
|
| 220 |
+
new_active = set(i for i, v in enumerate(coeff_unnorm) if abs(v) >= 0.02)
|
| 221 |
+
new_active.add(keys.index("1"))
|
| 222 |
+
if new_active == active:
|
| 223 |
+
coeff = coeff_unnorm
|
| 224 |
+
break
|
| 225 |
+
active = new_active
|
| 226 |
+
coeff = coeff_unnorm
|
| 227 |
+
|
| 228 |
+
pruned = {k: round(v, 3) for k, v in zip(keys, coeff) if abs(v) >= 0.02}
|
| 229 |
+
if center:
|
| 230 |
+
pruned["1"] = round(pruned.get("1", 0.0) + y_mean, 3)
|
| 231 |
+
|
| 232 |
+
if not return_mse:
|
| 233 |
+
return pruned
|
| 234 |
+
|
| 235 |
+
# compute mse
|
| 236 |
+
mse = 0.0
|
| 237 |
+
for row, y in zip(feature_rows, targets):
|
| 238 |
+
y_hat = sum(pruned.get(k, 0.0) * row[k] for k in row)
|
| 239 |
+
mse += (y - y_hat) ** 2
|
| 240 |
+
mse /= len(feature_rows)
|
| 241 |
+
return pruned, mse
|
| 242 |
+
|
| 243 |
+
def _eval_model(self, model, s: State):
|
| 244 |
+
features = {
|
| 245 |
+
"1": 1.0,
|
| 246 |
+
"vx": s.vx,
|
| 247 |
+
"vy": s.vy,
|
| 248 |
+
"vx_abs_vx": (s.vx * abs(s.vx)) - self.feature_means["vx_abs_vx"],
|
| 249 |
+
"vy_abs_vy": (s.vy * abs(s.vy)) - self.feature_means["vy_abs_vy"],
|
| 250 |
+
}
|
| 251 |
+
return sum(model.get(k, 0.0) * features[k] for k in model)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def run_stress_test(
|
| 255 |
+
episodes=50,
|
| 256 |
+
steps=200,
|
| 257 |
+
g=9.81,
|
| 258 |
+
drag=0.12,
|
| 259 |
+
wind=0.4,
|
| 260 |
+
dt=0.1,
|
| 261 |
+
drag_power=2.0,
|
| 262 |
+
seed=123
|
| 263 |
+
):
|
| 264 |
+
random.seed(seed)
|
| 265 |
+
world = ProjectileWorld(g=g, drag=drag, wind=wind, dt=dt, drag_power=drag_power)
|
| 266 |
+
agent = CuriosityAgent(dt=dt)
|
| 267 |
+
|
| 268 |
+
surprises = []
|
| 269 |
+
for _ in range(episodes):
|
| 270 |
+
# random launch
|
| 271 |
+
speed = random.uniform(8, 20)
|
| 272 |
+
angle = random.uniform(20, 70) * math.pi / 180.0
|
| 273 |
+
s = State(
|
| 274 |
+
x=0.0,
|
| 275 |
+
y=0.0,
|
| 276 |
+
vx=speed * math.cos(angle),
|
| 277 |
+
vy=speed * math.sin(angle),
|
| 278 |
+
)
|
| 279 |
+
for _ in range(steps):
|
| 280 |
+
s_next = world.step(s)
|
| 281 |
+
surprise = agent.update(s, s_next)
|
| 282 |
+
surprises.append(surprise)
|
| 283 |
+
s = s_next
|
| 284 |
+
if s.y < 0.0:
|
| 285 |
+
break
|
| 286 |
+
agent.fit_params()
|
| 287 |
+
|
| 288 |
+
# Interpret constants as wind and gravity when quadratic terms exist
|
| 289 |
+
wind_est = None
|
| 290 |
+
g_est = None
|
| 291 |
+
if "vx_abs_vx" in agent.model_ax and "1" in agent.model_ax:
|
| 292 |
+
wind_est = agent.model_ax["1"] - agent.model_ax["vx_abs_vx"] * agent.feature_means["vx_abs_vx"]
|
| 293 |
+
if "vy_abs_vy" in agent.model_ay and "1" in agent.model_ay:
|
| 294 |
+
g_est = -(agent.model_ay["1"] - agent.model_ay["vy_abs_vy"] * agent.feature_means["vy_abs_vy"])
|
| 295 |
+
|
| 296 |
+
return {
|
| 297 |
+
"g_true": g,
|
| 298 |
+
"drag_true": drag,
|
| 299 |
+
"wind_true": wind,
|
| 300 |
+
"drag_power_true": drag_power,
|
| 301 |
+
"model_ax": agent.model_ax,
|
| 302 |
+
"model_ay": agent.model_ay,
|
| 303 |
+
"wind_est": round(wind_est, 3) if wind_est is not None else None,
|
| 304 |
+
"g_est": round(g_est, 3) if g_est is not None else None,
|
| 305 |
+
"invented": agent.invented,
|
| 306 |
+
"avg_surprise": round(sum(surprises) / len(surprises), 3),
|
| 307 |
+
"max_surprise": round(max(surprises), 3),
|
| 308 |
+
"samples": len(surprises),
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
if __name__ == "__main__":
|
| 313 |
+
result = run_stress_test()
|
| 314 |
+
print(result)
|