ianalin123 Claude Sonnet 4.6 commited on
Commit
9e670bb
Β·
1 Parent(s): 9221fb1

fix: prevent numpy.bool serialization error in FastAPI responses

Browse files

numpy.bool_ leaks from check_kawasaki when sector angles are numpy
float64 (from np.arctan2), making total_violation < threshold return
numpy.bool_ which pydantic v2 can't serialize.

- engine/validation.py: cast bool() in check_kawasaki return
- server/app.py + openenv_server/app.py: add NumpyJSONResponse that
uses json.dumps with a custom default encoder for all numpy scalars
and arrays, applied to /targets, /episode/demo, /episode/replay

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>

Files changed (3) hide show
  1. engine/validation.py +1 -1
  2. openenv_server/app.py +27 -8
  3. server/app.py +32 -13
engine/validation.py CHANGED
@@ -97,7 +97,7 @@ def check_kawasaki(paper: Paper) -> tuple[bool, float]:
97
  violation = abs(even_sum - odd_sum)
98
  total_violation += violation
99
 
100
- is_valid = total_violation < 1e-4
101
  return is_valid, float(total_violation)
102
 
103
 
 
97
  violation = abs(even_sum - odd_sum)
98
  total_violation += violation
99
 
100
+ is_valid = bool(total_violation < 1e-4)
101
  return is_valid, float(total_violation)
102
 
103
 
openenv_server/app.py CHANGED
@@ -1,9 +1,28 @@
1
  from __future__ import annotations
2
 
 
3
  from pathlib import Path
4
 
5
- from fastapi.responses import HTMLResponse
 
6
  from fastapi.staticfiles import StaticFiles
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from openenv.core.env_server.http_server import create_app
8
 
9
  from openenv_runtime.environment import OpenEnvOrigamiEnvironment
@@ -60,8 +79,8 @@ DEMO_SEQUENCES: dict[str, list[dict]] = {
60
  # API routes β€” must be registered BEFORE the StaticFiles catch-all mount
61
  # ---------------------------------------------------------------------------
62
 
63
- @app.get("/targets", include_in_schema=True)
64
- def get_targets() -> dict:
65
  """Return available task names and metadata for the frontend."""
66
  from server.tasks import get_task_by_name, available_task_names
67
 
@@ -76,11 +95,11 @@ def get_targets() -> dict:
76
  "difficulty": t.get("difficulty", 1),
77
  "material": t.get("material", "paper"),
78
  }
79
- return result
80
 
81
 
82
- @app.get("/episode/demo", include_in_schema=True)
83
- def demo_episode(target: str = "half_fold") -> dict:
84
  """Return a pre-solved demo episode for the given task."""
85
  from server.origami_environment import OrigamiEnvironment
86
  from server.models import OrigamiAction as NewOrigamiAction
@@ -119,12 +138,12 @@ def demo_episode(target: str = "half_fold") -> dict:
119
 
120
  task_def = get_task_by_name(target) if target else {}
121
 
122
- return {
123
  "task_name": target,
124
  "task": task_def,
125
  "steps": steps,
126
  "final_metrics": obs.metrics if steps else {},
127
- }
128
 
129
 
130
  # ---------------------------------------------------------------------------
 
1
  from __future__ import annotations
2
 
3
+ import json
4
  from pathlib import Path
5
 
6
+ import numpy as np
7
+ from fastapi.responses import HTMLResponse, JSONResponse
8
  from fastapi.staticfiles import StaticFiles
9
+
10
+
11
+ def _np_default(obj):
12
+ if isinstance(obj, np.bool_):
13
+ return bool(obj)
14
+ if isinstance(obj, np.integer):
15
+ return int(obj)
16
+ if isinstance(obj, np.floating):
17
+ return float(obj)
18
+ if isinstance(obj, np.ndarray):
19
+ return obj.tolist()
20
+ raise TypeError(f"Not serializable: {type(obj)}")
21
+
22
+
23
+ class NumpyJSONResponse(JSONResponse):
24
+ def render(self, content) -> bytes:
25
+ return json.dumps(content, default=_np_default).encode("utf-8")
26
  from openenv.core.env_server.http_server import create_app
27
 
28
  from openenv_runtime.environment import OpenEnvOrigamiEnvironment
 
79
  # API routes β€” must be registered BEFORE the StaticFiles catch-all mount
80
  # ---------------------------------------------------------------------------
81
 
82
+ @app.get("/targets", include_in_schema=True, response_class=NumpyJSONResponse)
83
+ def get_targets():
84
  """Return available task names and metadata for the frontend."""
85
  from server.tasks import get_task_by_name, available_task_names
86
 
 
95
  "difficulty": t.get("difficulty", 1),
96
  "material": t.get("material", "paper"),
97
  }
98
+ return NumpyJSONResponse(result)
99
 
100
 
101
+ @app.get("/episode/demo", include_in_schema=True, response_class=NumpyJSONResponse)
102
+ def demo_episode(target: str = "half_fold"):
103
  """Return a pre-solved demo episode for the given task."""
104
  from server.origami_environment import OrigamiEnvironment
105
  from server.models import OrigamiAction as NewOrigamiAction
 
138
 
139
  task_def = get_task_by_name(target) if target else {}
140
 
141
+ return NumpyJSONResponse({
142
  "task_name": target,
143
  "task": task_def,
144
  "steps": steps,
145
  "final_metrics": obs.metrics if steps else {},
146
+ })
147
 
148
 
149
  # ---------------------------------------------------------------------------
server/app.py CHANGED
@@ -10,14 +10,33 @@ Usage in training:
10
  """
11
  from __future__ import annotations
12
 
 
13
  from pathlib import Path
14
 
 
15
  import uvicorn
16
  from fastapi import FastAPI, HTTPException, WebSocket
17
  from fastapi.middleware.cors import CORSMiddleware
18
- from fastapi.responses import HTMLResponse
19
  from fastapi.staticfiles import StaticFiles
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  from server.training_broadcast import TrainingBroadcastServer
22
 
23
  app = FastAPI(title="Optigami Training Server", version="1.0")
@@ -59,10 +78,10 @@ def health() -> dict:
59
 
60
  # ── Demo endpoints (same as openenv_server/app.py so the React UI works) ──
61
 
62
- @app.get("/targets")
63
- def get_targets() -> dict:
64
  from server.tasks import available_task_names, get_task_by_name
65
- return {
66
  name: {
67
  "name": name,
68
  "level": t["difficulty"],
@@ -73,7 +92,7 @@ def get_targets() -> dict:
73
  }
74
  for name in available_task_names()
75
  if (t := get_task_by_name(name))
76
- }
77
 
78
 
79
  _DEMO_SEQUENCES: dict[str, list[dict]] = {
@@ -90,8 +109,8 @@ _DEMO_SEQUENCES: dict[str, list[dict]] = {
90
  }
91
 
92
 
93
- @app.get("/episode/demo")
94
- def demo_episode(target: str = "half_fold") -> dict:
95
  from server.origami_environment import OrigamiEnvironment
96
  from server.models import OrigamiAction as NewAction
97
  from server.tasks import get_task_by_name
@@ -114,23 +133,23 @@ def demo_episode(target: str = "half_fold") -> dict:
114
  if obs.done:
115
  break
116
 
117
- return {"task_name": target, "task": get_task_by_name(target) or {},
118
- "steps": steps, "final_metrics": obs.metrics if steps else {}}
119
 
120
 
121
- @app.get("/episode/replay/{ep_id}")
122
- def replay_episode(ep_id: str) -> dict:
123
  """Return a stored training episode in the same format as /episode/demo."""
124
  from server.tasks import get_task_by_name
125
  ep = broadcast._registry.get(ep_id)
126
  if not ep:
127
  raise HTTPException(status_code=404, detail=f"Episode '{ep_id}' not found in registry")
128
- return {
129
  "task_name": ep.task_name,
130
  "task": get_task_by_name(ep.task_name) or {},
131
  "steps": ep.steps,
132
  "final_metrics": ep.final_metrics or (ep.steps[-1]["metrics"] if ep.steps else {}),
133
- }
134
 
135
 
136
  # ── Static files β€” viewer first, then React app (LAST, catch-all) ──
 
10
  """
11
  from __future__ import annotations
12
 
13
+ import json
14
  from pathlib import Path
15
 
16
+ import numpy as np
17
  import uvicorn
18
  from fastapi import FastAPI, HTTPException, WebSocket
19
  from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import HTMLResponse, JSONResponse
21
  from fastapi.staticfiles import StaticFiles
22
 
23
+
24
+ def _np_default(obj):
25
+ if isinstance(obj, np.bool_):
26
+ return bool(obj)
27
+ if isinstance(obj, np.integer):
28
+ return int(obj)
29
+ if isinstance(obj, np.floating):
30
+ return float(obj)
31
+ if isinstance(obj, np.ndarray):
32
+ return obj.tolist()
33
+ raise TypeError(f"Not serializable: {type(obj)}")
34
+
35
+
36
+ class NumpyJSONResponse(JSONResponse):
37
+ def render(self, content) -> bytes:
38
+ return json.dumps(content, default=_np_default).encode("utf-8")
39
+
40
  from server.training_broadcast import TrainingBroadcastServer
41
 
42
  app = FastAPI(title="Optigami Training Server", version="1.0")
 
78
 
79
  # ── Demo endpoints (same as openenv_server/app.py so the React UI works) ──
80
 
81
+ @app.get("/targets", response_class=NumpyJSONResponse)
82
+ def get_targets():
83
  from server.tasks import available_task_names, get_task_by_name
84
+ return NumpyJSONResponse({
85
  name: {
86
  "name": name,
87
  "level": t["difficulty"],
 
92
  }
93
  for name in available_task_names()
94
  if (t := get_task_by_name(name))
95
+ })
96
 
97
 
98
  _DEMO_SEQUENCES: dict[str, list[dict]] = {
 
109
  }
110
 
111
 
112
+ @app.get("/episode/demo", response_class=NumpyJSONResponse)
113
+ def demo_episode(target: str = "half_fold"):
114
  from server.origami_environment import OrigamiEnvironment
115
  from server.models import OrigamiAction as NewAction
116
  from server.tasks import get_task_by_name
 
133
  if obs.done:
134
  break
135
 
136
+ return NumpyJSONResponse({"task_name": target, "task": get_task_by_name(target) or {},
137
+ "steps": steps, "final_metrics": obs.metrics if steps else {}})
138
 
139
 
140
+ @app.get("/episode/replay/{ep_id}", response_class=NumpyJSONResponse)
141
+ def replay_episode(ep_id: str):
142
  """Return a stored training episode in the same format as /episode/demo."""
143
  from server.tasks import get_task_by_name
144
  ep = broadcast._registry.get(ep_id)
145
  if not ep:
146
  raise HTTPException(status_code=404, detail=f"Episode '{ep_id}' not found in registry")
147
+ return NumpyJSONResponse({
148
  "task_name": ep.task_name,
149
  "task": get_task_by_name(ep.task_name) or {},
150
  "steps": ep.steps,
151
  "final_metrics": ep.final_metrics or (ep.steps[-1]["metrics"] if ep.steps else {}),
152
+ })
153
 
154
 
155
  # ── Static files β€” viewer first, then React app (LAST, catch-all) ──