brian4dwell commited on
Commit
1c5aca1
·
1 Parent(s): 08c2845

initi worker working

Browse files
.vscode/launch.json CHANGED
@@ -13,5 +13,17 @@
13
  "python": "/home/robot_op/miniconda3/envs/stream3r/bin/python",
14
  "justMyCode": true
15
  },
 
 
 
 
 
 
 
 
 
 
 
 
16
  ]
17
  }
 
13
  "python": "/home/robot_op/miniconda3/envs/stream3r/bin/python",
14
  "justMyCode": true
15
  },
16
+ {
17
+ "name": "Python: STream3R Worker",
18
+ "type": "debugpy",
19
+ "request": "launch",
20
+ "module": "stream3r.worker.main",
21
+ "args": ["--log-level", "INFO"],
22
+ "console": "integratedTerminal",
23
+ "cwd": "${workspaceFolder}",
24
+ "envFile": "${workspaceFolder}/.env",
25
+ "python": "/home/robot_op/miniconda3/envs/stream3r/bin/python",
26
+ "justMyCode": true
27
+ }
28
  ]
29
  }
design_docs/stream3r_api.md ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # STream3R API — Job Orchestration and Integration Plan
2
+
3
+ ## Executive Summary
4
+
5
+ This document proposes a lightweight STream3R API service that wraps the async job system implemented by the RQ worker. The API exposes RESTful endpoints and Redis Stream subscriptions that allow upstream applications to submit reconstruction jobs, track progress, and retrieve completed artifacts. Responsibilities are split cleanly: the API handles request validation, persistence, and orchestration, while the GPU worker (documented in `design_docs/worker.md`) performs heavy inference and storage of artifacts.
6
+
7
+ **Key outcomes**
8
+ - Unified interface for both short `pose_pointmap` jobs and long-running `model_build` jobs.
9
+ - Consistent job lifecycle backed by Postgres and Redis, mirroring the worker contract.
10
+ - Environment-agnostic integration so other services can enqueue jobs or consume progress events without GPU access.
11
+
12
+ ---
13
+
14
+ ## 1. Scope & Goals
15
+
16
+ ### Goals
17
+ - Provide a REST API for submitting jobs and querying job status or results.
18
+ - Offer optional server-sent events (SSE) or WebSocket feeds for near-real-time updates on `pose_pointmap` jobs.
19
+ - Enforce validation and idempotency for job submissions.
20
+ - Expose artifacts written by the worker (S3 URLs, local paths) without re-hosting the files.
21
+
22
+ ### Non-Goals
23
+ - Implement GPU inference or artifact generation (handled by RQ worker).
24
+ - Manage long-term artifact retention or CDN delivery.
25
+ - Provide fine-grained authorization beyond bearer token / API key patterns (left to integration).
26
+
27
+ ---
28
+
29
+ ## 2. Architecture Overview
30
+
31
+ ```
32
+ Client ──HTTP──► STream3R API ──RQ Enqueue──► Redis Queue ─► Worker
33
+ │ │ │
34
+ │ └──Postgres──────────────►│
35
+ │ └──Redis Stream (events)◄─┘
36
+ │ └──S3 (artifact URLs)◄───── Worker writes
37
+
38
+ └── Poll `/jobs/{id}` or Subscribe SSE/WebSocket for progress updates
39
+ ```
40
+
41
+ Components:
42
+ - **FastAPI** service (recommended) running behind an ASGI server.
43
+ - **Redis**: shared with worker for queues and events.
44
+ - **Postgres**: `stream3r_jobs` table is the canonical job record.
45
+ - **S3/Backblaze** (or local storage): artifact URLs returned by worker.
46
+ - **RQ worker** (implemented separately) executing jobs and updating state.
47
+
48
+ ---
49
+
50
+ ## 3. Endpoints
51
+
52
+ ### `POST /jobs`
53
+ Submit a job for either `pose_pointmap` or `model_build`.
54
+
55
+ **Request body**
56
+ ```json
57
+ {
58
+ "job_type": "pose_pointmap",
59
+ "scene_id": "SCENE123",
60
+ "mode": "causal",
61
+ "streaming": true,
62
+ "frames": [
63
+ {"url": "https://.../frame_0000.jpg"},
64
+ {"path": "/data/captures/frame_0001.png"}
65
+ ],
66
+ "session_settings": {"prediction_mode": "Predicted Pointmap"},
67
+ "client_request_id": "optional-idempotency-key"
68
+ }
69
+ ```
70
+
71
+ **Behavior**
72
+ - Validate payload (non-empty frames, supported job type, etc.).
73
+ - If `client_request_id` is provided, search Postgres for an existing job with the same key to ensure idempotency.
74
+ - Assign `job_id` (UUID) and enqueue the payload into the appropriate RQ queue (`pose_pointmap` or `model_build`).
75
+ - Insert `stream3r_jobs` row with `status=queued`.
76
+ - Return `202 Accepted` with job metadata:
77
+
78
+ ```json
79
+ {
80
+ "job_id": "uuid",
81
+ "status": "queued",
82
+ "job_type": "pose_pointmap",
83
+ "scene_id": "SCENE123"
84
+ }
85
+ ```
86
+
87
+ ### `GET /jobs/{job_id}`
88
+ Fetch job state and artifact references from Postgres.
89
+
90
+ **Response example** (`status=finished`)
91
+ ```json
92
+ {
93
+ "job_id": "uuid",
94
+ "job_type": "model_build",
95
+ "scene_id": "SCENE123",
96
+ "status": "finished",
97
+ "created_at": "...",
98
+ "started_at": "...",
99
+ "completed_at": "...",
100
+ "result": {
101
+ "result_url": "s3://bucket/scene/SCENE123/stream3r/models/summary.json",
102
+ "model_dir": "s3://bucket/scene/SCENE123/stream3r/models/",
103
+ "artifacts": {
104
+ "scene_glb_url": "...",
105
+ "poses_url": "...",
106
+ "pointmaps": [ {"frame_id": "frame_0000", "url": "..."} ]
107
+ }
108
+ },
109
+ "error": null
110
+ }
111
+ ```
112
+
113
+ ### `GET /jobs/{job_id}/events`
114
+ Server-Sent Events endpoint bridging Redis Streams.
115
+
116
+ - Uses `XREAD` on `stream3r:events` with `job_id` filter.
117
+ - Suitable for browser or gateway consumers needing near-real-time progress.
118
+ - Emits lines like:
119
+
120
+ ```
121
+ event: progress
122
+ data: {"progress": 60, "status": "progress"}
123
+
124
+ event: finished
125
+ data: {"result_url": "s3://..."}
126
+ ```
127
+
128
+ Optionally provide a WebSocket variant if SSE is insufficient.
129
+
130
+ ### `GET /jobs`
131
+ Paged listing/filtering (optional but useful for dashboards).
132
+
133
+ Parameters: `scene_id`, `job_type`, `status`, pagination cursors.
134
+
135
+ ---
136
+
137
+ ## 4. Data Model & Persistence
138
+
139
+ ### Postgres (`stream3r_jobs`)
140
+
141
+ The API is the authoritative owner of the job record. It should create and migrate the following schema during startup (extend with `client_request_id` if idempotency keys are required):
142
+
143
+ ```sql
144
+ CREATE TABLE IF NOT EXISTS stream3r_jobs (
145
+ job_id UUID PRIMARY KEY,
146
+ job_type TEXT NOT NULL, -- 'pose_pointmap' | 'model_build'
147
+ scene_id TEXT NOT NULL,
148
+ status TEXT NOT NULL, -- 'queued' | 'started' | 'finished' | 'failed'
149
+ created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
150
+ started_at TIMESTAMPTZ,
151
+ completed_at TIMESTAMPTZ,
152
+ payload JSONB, -- enqueue-time payload
153
+ result JSONB, -- worker-published result bundle
154
+ error TEXT,
155
+ client_request_id TEXT UNIQUE -- optional idempotency key
156
+ );
157
+
158
+ CREATE INDEX IF NOT EXISTS stream3r_jobs_scene_id_idx ON stream3r_jobs(scene_id);
159
+ CREATE INDEX IF NOT EXISTS stream3r_jobs_status_idx ON stream3r_jobs(status);
160
+ ```
161
+
162
+ ### Redis
163
+
164
+ - **Queues**: two RQ queues exist—`pose_pointmap` for latency-sensitive pose extraction and `model_build` for long reconstruction jobs. The API selects the queue based on `job_type`.
165
+ - **Event Stream**: the worker pushes lifecycle updates to the Redis Stream `stream3r:events`. Every entry is a flat map with the following fields:
166
+
167
+ ```
168
+ job_id, job_type, scene_id,
169
+ status, progress, result_url, model_dir,
170
+ error, ts
171
+ ```
172
+
173
+ `status` takes `started`, `progress`, `finished`, or `failed`. `progress` is an integer percentage (0–100). `result_url` and `model_dir` mirror the URLs stored in Postgres when a job completes.
174
+
175
+ ### Artifact Storage Layout
176
+
177
+ The worker persists artifacts to S3/Backblaze (or local storage) under a deterministic folder hierarchy. The API does not move files but should surface these URLs verbatim so consumers know where to fetch results:
178
+
179
+ ```
180
+ scene/{scene_id}/stream3r/
181
+ results/
182
+ {job_id}.json # per-job result JSON (pose_pointmap)
183
+ models/
184
+ kv_cache.pt # serialized KV cache
185
+ predictions.npz # packed model outputs
186
+ session_settings.json # runtime/config settings
187
+ selected_frames.json # frame subset indices (optional)
188
+ scene.glb # fused 3D scene
189
+ poses.jsonl # per-frame extrinsics
190
+ summary.json # canonical model_build result JSON
191
+ pointmaps/
192
+ {frame_token}.npz # per-frame world coords + confidence
193
+ ```
194
+
195
+ `pose_pointmap` jobs typically populate `results/{job_id}.json` plus `pointmaps/`; `model_build` jobs populate the `models/` subtree. All URLs returned by the worker use this structure.
196
+
197
+ ---
198
+
199
+ ## 5. Job Lifecycle
200
+
201
+ 1. **Submit** (`POST /jobs`):
202
+ - Validate input, persist `queued` row, enqueue payload.
203
+ - Return `job_id`.
204
+ 2. **Worker processing**:
205
+ - Worker acquires GPU lock, runs inference, streams events, writes artifacts, and updates DB.
206
+ 3. **Status checks**:
207
+ - Clients poll `GET /jobs/{id}` or subscribe to `/jobs/{id}/events`.
208
+ 4. **Completion**:
209
+ - Job row contains `status=finished`, `result` JSON with URLs.
210
+ - API response is the source of truth for artifact discovery.
211
+ 5. **Failure**:
212
+ - Worker updates DB with `status=failed`, `error` string.
213
+ - API surfaces the error in `GET /jobs/{id}` and via events.
214
+
215
+ ---
216
+
217
+ ## 6. Request Validation Contracts
218
+
219
+ `POST /jobs` validation rules:
220
+ - `job_type` ∈ {`pose_pointmap`, `model_build`}.
221
+ - `scene_id` non-empty string.
222
+ - `frames` list size ≥ 1 (unless `frames_dir` provided).
223
+ - Each frame entry must have exactly one of `url`, `path`, or `content` (base64 image string).
224
+ - `mode` default `causal`; forbid `full` for streaming jobs to match worker behavior.
225
+ - Optional numeric fields converted to `int/float` before enqueueing.
226
+ - Enforce max frames (configurable) to avoid resource exhaustion.
227
+
228
+ ---
229
+
230
+ ## 7. Security & Authentication
231
+
232
+ - Deploy behind an API gateway that injects `X-Client-Id` or similar metadata for auditing.
233
+ - Support bearer token / API key auth via middleware; store hashed keys in Postgres if needed.
234
+ - Restrict access to internal network when possible—as artifacts contain scene data.
235
+ - Sanitize inbound URLs to prevent SSRF; optionally proxy downloads through a whitelist.
236
+
237
+ ---
238
+
239
+ ## 8. Observability & Operations
240
+
241
+ - **Logging**: Structured logs capturing `job_id`, `scene_id`, `client_request_id`, and remote IP.
242
+ - **Metrics**: Track enqueue latency, job duration (from DB timestamps), queue depth, event lag.
243
+ - **Health checks**: `GET /healthz` verifying Redis and Postgres connectivity.
244
+ - **Backpressure**: Before accepting a new job, check queue length; if above threshold, return `429 Too Many Requests`.
245
+ - **Timeouts**: Configure HTTP request timeouts to avoid hanging on large payloads.
246
+
247
+ ---
248
+
249
+ ## 9. Deployment Considerations
250
+
251
+ - Package as a standalone FastAPI app (e.g., `stream3r_api.main:app`).
252
+ - Run under Uvicorn/Gunicorn with workers sized for I/O-bound traffic.
253
+ - Configure service with the same environment variables as worker (`STREAM3R_REDIS_URL`, `STREAM3R_DB_DSN`, etc.).
254
+ - Use infrastructure-as-code to provision Redis, Postgres, and S3 credentials shared with worker.
255
+
256
+ ---
257
+
258
+ ## 10. Future Enhancements
259
+
260
+ - **Job cancellation**: Add `DELETE /jobs/{id}` to flag jobs for cancellation (requires worker support).
261
+ - **Scene-level dashboards**: Aggregate artifacts from multiple jobs for a scene.
262
+ - **Signed download URLs**: API could issue pre-signed URLs for public sharing, decoupled from worker credentials.
263
+ - **Batch submissions**: Support uploading a tar/zip and asynchronously unpacking/validating frames.
264
+
265
+ ---
266
+
267
+ **References**
268
+ - `design_docs/worker.md` — Worker design and artifact contracts
269
+ - `stream3r/worker/tasks.py` — Concrete payload fields exchanged with the API
design_docs/worker.md ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # STream3R — Jobs, Events, and Storage Design
2
+
3
+ ## **Executive Summary**
4
+
5
+ The **STream3R Job System** provides an asynchronous GPU job orchestration layer for 3D scene reconstruction and perception tasks.
6
+ It standardizes how **pose and world-coordinate extraction** and **scene model building** are executed, stored, and tracked across services.
7
+
8
+ ### **Primary Goals**
9
+
10
+ - **Asynchronous GPU processing:**
11
+ All heavy inference runs on background RQ workers; FastAPI services only enqueue jobs and monitor progress.
12
+ - **Unified observability:**
13
+ - **Redis Streams** for job lifecycle events and progress (`stream3r:events`)
14
+ - **Postgres (`stream3r_jobs`)** as the canonical job record
15
+ - **S3/Backblaze** for durable artifacts and results
16
+ - **Two calling modes:**
17
+ - `get_pose_and_world_coords` → **Streams-based (Option A)** for near-real-time updates
18
+ - `create_model` → **Polling (Option B)** for long-running model generation
19
+ - **Consistent storage under** `/scene/{scene_id}/stream3r/`, containing:
20
+ - `kv_cache.pt` — serialized key/value cache state
21
+ - `predictions.npz` — packed outputs from the model build
22
+ - `session_settings.json` — runtime/config parameters
23
+ - `selected_frames.json` — frame subset selection
24
+ - `scene.glb` — final assembled scene model
25
+ - `poses.jsonl` — per-frame extrinsics (camera poses)
26
+ - `pointmaps/*.npz` — per-frame world coordinates + confidence maps
27
+
28
+ ### **Key Outcomes**
29
+ - Clean separation of API ↔ GPU worker responsibilities
30
+ - Event-driven feedback for quick jobs; reliable polling for long ones
31
+ - Durable, versioned scene data under a unified layout
32
+ - End-to-end traceability of all STream3R jobs via Redis + Postgres + S3
33
+
34
+ ---
35
+
36
+ ## 1. Queues, Streams, and Locks
37
+
38
+ | Component | Purpose | Notes |
39
+ |------------|----------|-------|
40
+ | `pose_pointmap` | RQ queue for latency-sensitive `pose_pointmap` jobs | |
41
+ | `model_build` | RQ queue for long `model_build` jobs | |
42
+ | `stream3r:events` | Redis Stream for all job events (`started`, `progress`, `finished`, `failed`) | trimmed periodically |
43
+ | `gpu:lock` | Redis lock ensuring single GPU job at a time per machine | |
44
+
45
+ Each Stream event is a flat map of strings:
46
+ ```
47
+
48
+ job_id, job_type, scene_id,
49
+ status, progress, result_url, model_dir,
50
+ error, ts
51
+
52
+ ```
53
+
54
+ ---
55
+
56
+ ## 2. S3 / Backblaze Storage Layout
57
+
58
+ All STream3R artifacts live under a **scene folder**:
59
+
60
+ ```
61
+
62
+ s3://<bucket>/scene/{scene_id}/stream3r/
63
+ results/
64
+ {job_id}.json # per-job result JSON (pose_pointmap)
65
+ models/
66
+ kv_cache.pt # serialized KV cache
67
+ predictions.npz # packed model outputs
68
+ session_settings.json # runtime/config settings
69
+ selected_frames.json # frame subset indices
70
+ scene.glb # fused 3D scene
71
+ poses.jsonl # per-frame extrinsics
72
+ summary.json # canonical model_build result JSON
73
+ pointmaps/
74
+ {frame_token}.npz # per-frame world_coords + confidence
75
+
76
+ ````
77
+
78
+ **Key Result URLs**
79
+ - Pose/pointmap job → `s3://.../scene/{scene_id}/stream3r/results/{job_id}.json`
80
+ - Model build job → `s3://.../scene/{scene_id}/stream3r/models/summary.json`
81
+
82
+ ---
83
+
84
+ ## 3. Database: `stream3r_jobs`
85
+
86
+ Canonical job table in Postgres.
87
+
88
+ ```sql
89
+ CREATE TABLE IF NOT EXISTS stream3r_jobs (
90
+ job_id UUID PRIMARY KEY,
91
+ job_type TEXT NOT NULL, -- 'pose_pointmap' | 'model_build'
92
+ scene_id TEXT NOT NULL,
93
+ status TEXT NOT NULL, -- 'queued' | 'started' | 'finished' | 'failed'
94
+ created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
95
+ started_at TIMESTAMPTZ,
96
+ completed_at TIMESTAMPTZ,
97
+ payload JSONB, -- enqueue-time payload
98
+ result JSONB, -- URLs / metrics
99
+ error TEXT
100
+ );
101
+
102
+ CREATE INDEX IF NOT EXISTS stream3r_jobs_scene_id_idx ON stream3r_jobs(scene_id);
103
+ CREATE INDEX IF NOT EXISTS stream3r_jobs_status_idx ON stream3r_jobs(status);
104
+ ````
105
+
106
+ **Upsert pattern:**
107
+
108
+ * Insert on enqueue (`queued`)
109
+ * Update on start → `started`
110
+ * Update on finish → `finished`, add `result`
111
+ * Update on failure → `failed`, add `error`
112
+
113
+ ---
114
+
115
+ ## 4. Result JSON Schemas
116
+
117
+ ### a. Pose + World Coords (per-frame)
118
+
119
+ `s3://…/scene/{scene_id}/stream3r/results/{job_id}.json`
120
+
121
+ ```json
122
+ {
123
+ "job_id": "uuid",
124
+ "job_type": "pose_pointmap",
125
+ "scene_id": "SCENE123",
126
+ "artifacts": {
127
+ "pointmap_url": "s3://.../scene/SCENE123/stream3r/pointmaps/frame_000010.npz"
128
+ },
129
+ "pose": { "R": [[...]], "t": [x, y, z] },
130
+ "intrinsics": { "fx":..., "fy":..., "cx":..., "cy":... },
131
+ "metrics": { "runtime_s": 1.23 },
132
+ "stream3r": { "cfg": "configs/stream3r_base.yaml", "commit": "<git_sha>" }
133
+ }
134
+ ```
135
+
136
+ ### b. Model Build (scene-level)
137
+
138
+ `s3://…/scene/{scene_id}/stream3r/models/summary.json`
139
+
140
+ ```json
141
+ {
142
+ "job_id": "uuid",
143
+ "job_type": "model_build",
144
+ "scene_id": "SCENE123",
145
+ "artifacts": {
146
+ "model_dir": "s3://.../scene/SCENE123/stream3r/models/",
147
+ "kv_cache": "s3://.../scene/SCENE123/stream3r/models/kv_cache.pt",
148
+ "predictions": "s3://.../scene/SCENE123/stream3r/models/predictions.npz",
149
+ "session_settings": "s3://.../scene/SCENE123/stream3r/models/session_settings.json",
150
+ "selected_frames": "s3://.../scene/SCENE123/stream3r/models/selected_frames.json",
151
+ "scene_glb": "s3://.../scene/SCENE123/stream3r/models/scene.glb",
152
+ "poses_jsonl": "s3://.../scene/SCENE123/stream3r/models/poses.jsonl"
153
+ },
154
+ "metrics": { "frames": 128, "runtime_s": 42.3 },
155
+ "stream3r": { "cfg": "configs/stream3r_base.yaml", "commit": "<git_sha>" }
156
+ }
157
+ ```
158
+
159
+ ---
160
+
161
+ ## 5. Caller API Responsibilities
162
+
163
+ ### `get_pose_and_world_coords` → **Option A (Streams)**
164
+
165
+ 1. Enqueue job → get `job_id`
166
+ 2. `XREAD BLOCK` on `stream3r:events` until `status=finished`
167
+ 3. On finish:
168
+
169
+ * Fetch `result_url`
170
+ * Load JSON → retrieve `pose`, `intrinsics`, and `pointmap_url`
171
+ * Download `.npz` to get `world_coords` + `confidence`
172
+
173
+ ### `create_model` → **Option B (Polling)**
174
+
175
+ 1. Enqueue job → return `job_id` immediately
176
+ 2. Periodically poll `GET /jobs/{job_id}`
177
+ 3. On `finished`:
178
+
179
+ * Read `result` with `result_url` + `model_dir`
180
+ * Download `summary.json` and listed model files
181
+
182
+ ---
183
+
184
+ ## 6. Worker Event & Persistence Flow
185
+
186
+ 1. **Acquire GPU lock**
187
+ 2. **Emit** `started`
188
+ 3. **Upsert** DB row (`stream3r_jobs`)
189
+ 4. **Run inference**, emitting `progress` events (every N frames)
190
+ 5. **Save** artifacts to S3:
191
+
192
+ * `pointmaps/*.npz` with `{xyz, conf}`
193
+ * `poses.jsonl`
194
+ * Model outputs listed above
195
+ 6. **Write** result JSON → emit `finished`
196
+ 7. **Update** DB row → `status=finished, result=…`
197
+ 8. On error → emit `failed`, update DB
198
+
199
+ ---
200
+
201
+ ## 7. Example Event Payloads (Redis Stream)
202
+
203
+ **Started**
204
+
205
+ ```
206
+ job_id=uuid
207
+ job_type=pose_pointmap
208
+ scene_id=SCENE123
209
+ status=started
210
+ progress=1
211
+ ts=1730312345.12
212
+ ```
213
+
214
+ **Progress**
215
+
216
+ ```
217
+ job_id=uuid
218
+ job_type=model_build
219
+ scene_id=SCENE123
220
+ status=progress
221
+ progress=40
222
+ ts=1730312456.22
223
+ ```
224
+
225
+ **Finished**
226
+
227
+ ```
228
+ job_id=uuid
229
+ job_type=model_build
230
+ scene_id=SCENE123
231
+ status=finished
232
+ progress=100
233
+ result_url=s3://bucket/scene/SCENE123/stream3r/models/summary.json
234
+ model_dir=s3://bucket/scene/SCENE123/stream3r/models/
235
+ ts=1730312567.33
236
+ ```
237
+
238
+ **Failed**
239
+
240
+ ```
241
+ job_id=uuid
242
+ job_type=pose_pointmap
243
+ scene_id=SCENE123
244
+ status=failed
245
+ error=RuntimeError: CUDA OOM
246
+ ts=1730312570.00
247
+ ```
248
+
249
+ ---
250
+
251
+ ## 8. Operational Guidelines
252
+
253
+ | Concern | Best Practice |
254
+ | -------------------------- | --------------------------------------------------------------- |
255
+ | **GPU Safety** | Use `gpu:lock` to serialize jobs per GPU |
256
+ | **Redis Stream retention** | `XTRIM stream3r:events MAXLEN ~50000` |
257
+ | **Durability** | All artifacts and summaries must persist to S3/Backblaze |
258
+ | **DB Reliability** | Upsert on each transition; retry writes if DB unavailable |
259
+ | **Idempotency** | Support caller-supplied `job_id` or `request_id` |
260
+ | **Security** | Keep Redis internal; use signed or private S3 URLs |
261
+ | **Backpressure** | Enqueueing API should reject (`429`) when queue depth too large |
262
+
263
+ ---
264
+
265
+ ## 9. End-to-End Flows
266
+
267
+ ### 🔹 Pose + World Coords (short job)
268
+
269
+ 1. API enqueues job → returns `job_id`
270
+ 2. Client subscribes via Redis Stream (blocking XREAD)
271
+ 3. Worker runs inference → writes `pointmap.npz` + `result.json`
272
+ 4. Worker emits `finished` → client downloads results
273
+
274
+ ### 🔹 Model Build (long job)
275
+
276
+ 1. API enqueues → returns `job_id`
277
+ 2. Client polls `GET /jobs/{id}` or DB row
278
+ 3. Worker fuses frames → writes full scene model files
279
+ 4. Worker updates DB + emits `finished`
280
+ 5. Client retrieves `summary.json` + artifacts under `/scene/{scene_id}/stream3r/models/`
281
+
282
+ ---
283
+
284
+ ## 10. Summary
285
+
286
+ | Component | Responsibility | Persistence |
287
+ | ------------------------------ | --------------------------------------- | ------------------------------- |
288
+ | **FastAPI API** | Enqueue jobs, expose `/jobs/{id}` | DB (via worker), Redis (events) |
289
+ | **GPU Worker** | Execute STream3R inference, emit events | S3/Backblaze, DB |
290
+ | **Redis Streams** | Event bus for progress + completion | ephemeral |
291
+ | **Postgres (`stream3r_jobs`)** | Canonical job record | durable |
292
+ | **S3/Backblaze /scene/** | Scene artifacts, model data | durable |
293
+
294
+ ---
295
+
296
+ **Outcome:**
297
+ This design provides an **asynchronous, event-driven, and durable** framework for managing STream3R GPU jobs, with standardized scene storage, traceable job metadata, and clear integration points for both real-time and long-running workflows.
298
+
299
+ ```
300
+ ```
requirements.txt CHANGED
@@ -43,10 +43,17 @@ pyglet<2
43
  huggingface-hub[torch]>=0.22
44
  spaces
45
 
 
 
 
 
 
 
 
46
  # --------- eval --------- #
47
  accelerate
48
  evo
49
 
50
  # --------- demo --------- #
51
  gradio==5.17.1
52
- onnxruntime
 
43
  huggingface-hub[torch]>=0.22
44
  spaces
45
 
46
+ # --------- worker --------- #
47
+ redis
48
+ rq
49
+ boto3
50
+ psycopg2-binary
51
+ requests
52
+
53
  # --------- eval --------- #
54
  accelerate
55
  evo
56
 
57
  # --------- demo --------- #
58
  gradio==5.17.1
59
+ onnxruntime
stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc CHANGED
Binary files a/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc and b/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc differ
 
stream3r/models/components/utils/geometry.py CHANGED
@@ -32,8 +32,14 @@ def unproject_depth_map_to_point_map(
32
 
33
  world_points_list = []
34
  for frame_idx in range(depth_map.shape[0]):
 
 
 
 
 
 
35
  cur_world_points, _, _ = depth_to_world_coords_points(
36
- depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
  )
38
  world_points_list.append(cur_world_points)
39
  world_points_array = np.stack(world_points_list, axis=0)
 
32
 
33
  world_points_list = []
34
  for frame_idx in range(depth_map.shape[0]):
35
+ intrinsic = intrinsics_cam[frame_idx]
36
+ if intrinsic.shape[-2:] != (3, 3):
37
+ intrinsic = intrinsic.reshape(-1, 3, 3)[0]
38
+ extrinsic = extrinsics_cam[frame_idx]
39
+ if extrinsic.shape[-2:] != (3, 4):
40
+ extrinsic = extrinsic.reshape(-1, 3, 4)[0]
41
  cur_world_points, _, _ = depth_to_world_coords_points(
42
+ depth_map[frame_idx].squeeze(-1), extrinsic, intrinsic
43
  )
44
  world_points_list.append(cur_world_points)
45
  world_points_array = np.stack(world_points_list, axis=0)
stream3r/worker/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ """Worker utilities for running STream3R jobs via RQ."""
2
+
3
+ from .tasks import model_build_job, pose_pointmap_job
4
+
5
+ __all__ = [
6
+ "pose_pointmap_job",
7
+ "model_build_job",
8
+ ]
stream3r/worker/config.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Configuration helpers for the STream3R RQ worker."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ import os
7
+ from pathlib import Path
8
+
9
+
10
+ def _env_bool(name: str, default: bool) -> bool:
11
+ value = os.getenv(name)
12
+ if value is None:
13
+ return default
14
+ value = value.strip().lower()
15
+ if value in {"1", "true", "yes", "y", "on"}:
16
+ return True
17
+ if value in {"0", "false", "no", "n", "off"}:
18
+ return False
19
+ return default
20
+
21
+
22
+ def _env_bool_any(default: bool, *names: str) -> bool:
23
+ for name in names:
24
+ if not name:
25
+ continue
26
+ value = os.getenv(name)
27
+ if value is None:
28
+ continue
29
+ value = value.strip().lower()
30
+ if value in {"1", "true", "yes", "y", "on"}:
31
+ return True
32
+ if value in {"0", "false", "no", "n", "off"}:
33
+ return False
34
+ return default
35
+
36
+
37
+ def _env_int(name: str, default: int) -> int:
38
+ value = os.getenv(name)
39
+ if value is None:
40
+ return default
41
+ try:
42
+ return int(value)
43
+ except ValueError:
44
+ return default
45
+
46
+
47
+ def _env_value(primary: str, *aliases: str, default: str | None = None) -> str | None:
48
+ """Best-effort environment lookup with fallbacks for legacy names."""
49
+
50
+ names = (primary, *aliases)
51
+ for name in names:
52
+ if not name:
53
+ continue
54
+ value = os.getenv(name)
55
+ if value:
56
+ return value
57
+ return default
58
+
59
+
60
+ @dataclass(slots=True)
61
+ class WorkerSettings:
62
+ """Runtime configuration derived from environment variables."""
63
+
64
+ redis_url: str = "redis://localhost:6379/0"
65
+ redis_events_stream: str = "stream3r:events"
66
+ redis_stream_maxlen: int = 50_000
67
+ redis_healthcheck_interval: int = 30
68
+
69
+ pose_queue: str = "pose_pointmap"
70
+ model_queue: str = "model_build"
71
+
72
+ gpu_lock_key: str = "gpu:lock"
73
+ gpu_lock_timeout: int = 3600
74
+ gpu_lock_blocking_timeout: int = 600
75
+
76
+ storage_prefix: str = "scene"
77
+ s3_bucket: str | None = None
78
+ s3_endpoint_url: str | None = None
79
+ s3_region: str | None = None
80
+ s3_profile: str | None = None
81
+ s3_force_path_style: bool = False
82
+ aws_access_key_id: str | None = None
83
+ aws_secret_access_key: str | None = None
84
+ aws_session_token: str | None = None
85
+ local_storage_root: Path = field(default_factory=lambda: Path("storage"))
86
+
87
+ db_dsn: str | None = None
88
+ job_table: str = "stream3r_jobs"
89
+
90
+ model_id: str = "yslan/STream3R"
91
+ model_revision: str | None = None
92
+ model_device_preference: str | None = None
93
+ model_dtype: str | None = None
94
+
95
+ default_mode: str = "window"
96
+ default_streaming: bool = True
97
+ download_workers: int = 4
98
+
99
+ worker_name: str | None = None
100
+
101
+ session_cache_filename: str = "kv_cache.pt"
102
+ predictions_filename: str = "predictions.npz"
103
+ poses_filename: str = "poses.jsonl"
104
+ result_filename: str = "summary.json"
105
+ selected_frames_filename: str = "selected_frames.json"
106
+ scene_glb_filename: str = "scene.glb"
107
+ session_settings_filename: str = "session_settings.json"
108
+ summary_results_filename: str = "summary.json"
109
+
110
+ pointmap_dir: str = "pointmaps"
111
+ models_dir: str = "models"
112
+ results_dir: str = "results"
113
+
114
+ scene_media_api_base_url: str | None = None
115
+ scene_media_api_token: str | None = None
116
+ scene_media_page_size: int = 200
117
+ stream_window_size: int = 10
118
+ max_frames_per_job: int = 10
119
+
120
+ @classmethod
121
+ def from_env(cls) -> "WorkerSettings":
122
+ base = cls()
123
+
124
+ kwargs: dict[str, object] = {
125
+ "redis_url": _env_value("STREAM3R_REDIS_URL", "REDIS_URL", default=base.redis_url),
126
+ "redis_events_stream": os.getenv("STREAM3R_EVENTS_STREAM", base.redis_events_stream),
127
+ "redis_stream_maxlen": _env_int("STREAM3R_EVENTS_MAXLEN", base.redis_stream_maxlen),
128
+ "redis_healthcheck_interval": _env_int(
129
+ "STREAM3R_REDIS_HEALTHCHECK", base.redis_healthcheck_interval
130
+ ),
131
+ "pose_queue": os.getenv("STREAM3R_QUEUE_POSE", base.pose_queue),
132
+ "model_queue": os.getenv("STREAM3R_QUEUE_MODEL", base.model_queue),
133
+ "gpu_lock_key": os.getenv("STREAM3R_GPU_LOCK_KEY", base.gpu_lock_key),
134
+ "gpu_lock_timeout": _env_int("STREAM3R_GPU_LOCK_TIMEOUT", base.gpu_lock_timeout),
135
+ "gpu_lock_blocking_timeout": _env_int(
136
+ "STREAM3R_GPU_LOCK_BLOCK", base.gpu_lock_blocking_timeout
137
+ ),
138
+ "storage_prefix": os.getenv("STREAM3R_STORAGE_PREFIX", base.storage_prefix),
139
+ "s3_bucket": _env_value("STREAM3R_STORAGE_BUCKET", "S3_BUCKET", default=base.s3_bucket) or None,
140
+ "s3_endpoint_url": _env_value("STREAM3R_S3_ENDPOINT", "S3_ENDPOINT", default=base.s3_endpoint_url) or None,
141
+ "s3_region": _env_value("STREAM3R_S3_REGION", "AWS_REGION", default=base.s3_region) or None,
142
+ "s3_profile": os.getenv("STREAM3R_S3_PROFILE", base.s3_profile or "") or None,
143
+ "s3_force_path_style": _env_bool_any(
144
+ base.s3_force_path_style,
145
+ "STREAM3R_S3_FORCE_PATH",
146
+ "S3_FORCE_PATH_STYLE",
147
+ ),
148
+ "aws_access_key_id": _env_value("AWS_ACCESS_KEY_ID", default=base.aws_access_key_id) or None,
149
+ "aws_secret_access_key": _env_value("AWS_SECRET_ACCESS_KEY", default=base.aws_secret_access_key) or None,
150
+ "aws_session_token": _env_value("AWS_SESSION_TOKEN", default=base.aws_session_token) or None,
151
+ "local_storage_root": Path(
152
+ os.getenv("STREAM3R_LOCAL_STORAGE", str(base.local_storage_root))
153
+ ).resolve(),
154
+ "db_dsn": _env_value("STREAM3R_DB_DSN", "DATABASE_URL", default=base.db_dsn) or None,
155
+ "job_table": os.getenv("STREAM3R_JOBS_TABLE", base.job_table),
156
+ "model_id": os.getenv("STREAM3R_MODEL_ID", base.model_id),
157
+ "model_revision": os.getenv("STREAM3R_MODEL_REVISION", base.model_revision or "") or None,
158
+ "model_device_preference": os.getenv(
159
+ "STREAM3R_MODEL_DEVICE", base.model_device_preference or ""
160
+ )
161
+ or None,
162
+ "model_dtype": os.getenv("STREAM3R_MODEL_DTYPE", base.model_dtype or "") or None,
163
+ "default_mode": os.getenv("STREAM3R_DEFAULT_MODE", base.default_mode),
164
+ "default_streaming": _env_bool(
165
+ "STREAM3R_DEFAULT_STREAMING", base.default_streaming
166
+ ),
167
+ "download_workers": _env_int("STREAM3R_DOWNLOAD_WORKERS", base.download_workers),
168
+ "worker_name": os.getenv("STREAM3R_WORKER_NAME", base.worker_name or "") or None,
169
+ "session_cache_filename": os.getenv(
170
+ "STREAM3R_SESSION_CACHE", base.session_cache_filename
171
+ ),
172
+ "predictions_filename": os.getenv(
173
+ "STREAM3R_PREDICTIONS_FILE", base.predictions_filename
174
+ ),
175
+ "poses_filename": os.getenv("STREAM3R_POSES_FILE", base.poses_filename),
176
+ "result_filename": os.getenv("STREAM3R_RESULT_FILE", base.result_filename),
177
+ "selected_frames_filename": os.getenv(
178
+ "STREAM3R_SELECTED_FRAMES_FILE", base.selected_frames_filename
179
+ ),
180
+ "scene_glb_filename": os.getenv("STREAM3R_SCENE_GLB_FILE", base.scene_glb_filename),
181
+ "session_settings_filename": os.getenv(
182
+ "STREAM3R_SESSION_SETTINGS_FILE", base.session_settings_filename
183
+ ),
184
+ "summary_results_filename": os.getenv(
185
+ "STREAM3R_SUMMARY_FILE", base.summary_results_filename
186
+ ),
187
+ "pointmap_dir": os.getenv("STREAM3R_POINTMAP_DIR", base.pointmap_dir),
188
+ "models_dir": os.getenv("STREAM3R_MODELS_DIR", base.models_dir),
189
+ "results_dir": os.getenv("STREAM3R_RESULTS_DIR", base.results_dir),
190
+ "scene_media_api_base_url": _env_value(
191
+ "STREAM3R_MEDIA_API_BASE_URL",
192
+ "API_BASE_URL",
193
+ default=base.scene_media_api_base_url,
194
+ )
195
+ or None,
196
+ "scene_media_api_token": _env_value(
197
+ "STREAM3R_MEDIA_API_TOKEN",
198
+ "MEDIA_API_TOKEN",
199
+ default=base.scene_media_api_token,
200
+ )
201
+ or None,
202
+ "scene_media_page_size": _env_int(
203
+ "STREAM3R_MEDIA_PAGE_SIZE", base.scene_media_page_size
204
+ ),
205
+ "stream_window_size": _env_int(
206
+ "STREAM3R_WINDOW_SIZE", base.stream_window_size
207
+ ),
208
+ "max_frames_per_job": _env_int(
209
+ "STREAM3R_MAX_FRAMES", base.max_frames_per_job
210
+ ),
211
+ }
212
+
213
+ return cls(**kwargs)
stream3r/worker/db.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Database helpers for persisting job metadata."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import logging
7
+ from contextlib import contextmanager
8
+ from typing import Iterator, Mapping
9
+
10
+ try: # Optional dependency
11
+ import psycopg2
12
+ from psycopg2.extensions import connection as PGConnection
13
+ from psycopg2.extras import Json
14
+ except ModuleNotFoundError: # pragma: no cover - exercised when psycopg2 missing
15
+ psycopg2 = None # type: ignore[assignment]
16
+ PGConnection = None # type: ignore[assignment]
17
+ Json = None # type: ignore[assignment]
18
+
19
+ from .config import WorkerSettings
20
+
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ SCHEMA_SQL = """
26
+ CREATE TABLE IF NOT EXISTS {table_name} (
27
+ job_id UUID PRIMARY KEY,
28
+ job_type TEXT NOT NULL,
29
+ scene_id TEXT NOT NULL,
30
+ status TEXT NOT NULL,
31
+ created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
32
+ started_at TIMESTAMPTZ,
33
+ completed_at TIMESTAMPTZ,
34
+ payload JSONB,
35
+ result JSONB,
36
+ error TEXT
37
+ );
38
+
39
+ CREATE INDEX IF NOT EXISTS {table_name}_scene_id_idx ON {table_name}(scene_id);
40
+ CREATE INDEX IF NOT EXISTS {table_name}_status_idx ON {table_name}(status);
41
+ """
42
+
43
+
44
+ class DatabaseError(RuntimeError):
45
+ """Raised when database operations fail."""
46
+
47
+
48
+ class BaseDatabaseClient:
49
+ """Small interface for database persistence."""
50
+
51
+ def ensure_schema(self) -> None: # pragma: no cover - noop implementation
52
+ raise NotImplementedError
53
+
54
+ def upsert_job(
55
+ self,
56
+ *,
57
+ job_id: str,
58
+ job_type: str,
59
+ scene_id: str,
60
+ status: str,
61
+ payload: Mapping[str, object] | None = None,
62
+ result: Mapping[str, object] | None = None,
63
+ error: str | None = None,
64
+ ) -> None:
65
+ raise NotImplementedError
66
+
67
+ def close(self) -> None: # pragma: no cover - noop implementation
68
+ pass
69
+
70
+
71
+ class NoopDatabaseClient(BaseDatabaseClient):
72
+ """Fallback when no database configuration is provided."""
73
+
74
+ def ensure_schema(self) -> None: # pragma: no cover - nothing to do
75
+ logger.debug("Database is disabled; skipping schema creation")
76
+
77
+ def upsert_job(
78
+ self,
79
+ *,
80
+ job_id: str,
81
+ job_type: str,
82
+ scene_id: str,
83
+ status: str,
84
+ payload: Mapping[str, object] | None = None,
85
+ result: Mapping[str, object] | None = None,
86
+ error: str | None = None,
87
+ ) -> None:
88
+ logger.debug(
89
+ "Noop DB: job_id=%s job_type=%s scene_id=%s status=%s", job_id, job_type, scene_id, status
90
+ )
91
+
92
+
93
+ class DatabaseClient(BaseDatabaseClient):
94
+ """Postgres implementation using psycopg2."""
95
+
96
+ def __init__(self, settings: WorkerSettings):
97
+ if psycopg2 is None: # pragma: no cover - optional dependency guard
98
+ raise DatabaseError("psycopg2-binary is required for database support")
99
+
100
+ self.settings = settings
101
+
102
+ @contextmanager
103
+ def _connect(self) -> Iterator[PGConnection]:
104
+ conn = psycopg2.connect(self.settings.db_dsn) # type: ignore[arg-type]
105
+ try:
106
+ conn.autocommit = True
107
+ yield conn
108
+ finally:
109
+ conn.close()
110
+
111
+ def ensure_schema(self) -> None:
112
+ table_name = self.settings.job_table
113
+ with self._connect() as conn:
114
+ with conn.cursor() as cur:
115
+ cur.execute(SCHEMA_SQL.format(table_name=table_name))
116
+
117
+ def upsert_job(
118
+ self,
119
+ *,
120
+ job_id: str,
121
+ job_type: str,
122
+ scene_id: str,
123
+ status: str,
124
+ payload: Mapping[str, object] | None = None,
125
+ result: Mapping[str, object] | None = None,
126
+ error: str | None = None,
127
+ ) -> None:
128
+ table = self.settings.job_table
129
+ payload_json = Json(payload) if payload is not None and Json is not None else None
130
+ result_json = Json(result) if result is not None and Json is not None else None
131
+
132
+ with self._connect() as conn:
133
+ with conn.cursor() as cur:
134
+ cur.execute(
135
+ f"""
136
+ INSERT INTO {table} (job_id, job_type, scene_id, status, payload, result, error, started_at, completed_at)
137
+ VALUES (%s, %s, %s, %s, %s, %s, %s,
138
+ CASE WHEN %s = 'started' THEN now() ELSE NULL END,
139
+ CASE WHEN %s IN ('finished', 'failed') THEN now() ELSE NULL END)
140
+ ON CONFLICT (job_id)
141
+ DO UPDATE SET
142
+ job_type = EXCLUDED.job_type,
143
+ scene_id = EXCLUDED.scene_id,
144
+ status = EXCLUDED.status,
145
+ payload = COALESCE(EXCLUDED.payload, {table}.payload),
146
+ result = COALESCE(EXCLUDED.result, {table}.result),
147
+ error = EXCLUDED.error,
148
+ started_at = COALESCE({table}.started_at, EXCLUDED.started_at),
149
+ completed_at = COALESCE({table}.completed_at, EXCLUDED.completed_at)
150
+ """,
151
+ (
152
+ job_id,
153
+ job_type,
154
+ scene_id,
155
+ status,
156
+ payload_json,
157
+ result_json,
158
+ error,
159
+ status,
160
+ status,
161
+ ),
162
+ )
163
+
164
+
165
+ def create_database_client(settings: WorkerSettings) -> BaseDatabaseClient:
166
+ """Factory that returns a database client or a noop fallback."""
167
+
168
+ if not settings.db_dsn:
169
+ return NoopDatabaseClient()
170
+ return DatabaseClient(settings)
stream3r/worker/main.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """CLI entrypoint to launch an RQ worker for STream3R jobs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import logging
7
+ from typing import Sequence
8
+
9
+ from rq import Queue, Worker
10
+
11
+ from .config import WorkerSettings
12
+ from .runtime import get_runtime
13
+
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def _parse_args(default_queues: Sequence[str]) -> argparse.Namespace:
19
+ parser = argparse.ArgumentParser(description="Run the STream3R RQ worker")
20
+ parser.add_argument(
21
+ "--queue",
22
+ "--queues",
23
+ dest="queues",
24
+ action="append",
25
+ help="Queue names to listen to (can be repeated)",
26
+ )
27
+ parser.add_argument(
28
+ "--burst",
29
+ action="store_true",
30
+ help="Run in burst mode (exit when queues are empty)",
31
+ )
32
+ parser.add_argument(
33
+ "--log-level",
34
+ default="INFO",
35
+ help="Logging level",
36
+ )
37
+ args = parser.parse_args()
38
+ if not args.queues:
39
+ args.queues = list(default_queues)
40
+ return args
41
+
42
+
43
+ def main() -> None:
44
+ settings = WorkerSettings.from_env()
45
+ args = _parse_args([settings.pose_queue, settings.model_queue])
46
+ logging.basicConfig(level=getattr(logging, str(args.log_level).upper(), logging.INFO))
47
+
48
+ runtime = get_runtime()
49
+
50
+ queues = [Queue(name, connection=runtime.redis) for name in args.queues]
51
+ for queue in queues:
52
+ logger.info("Listening on queue '%s'", queue.name)
53
+
54
+ worker = Worker(queues, name=settings.worker_name)
55
+ worker.work(burst=args.burst)
56
+
57
+
58
+ if __name__ == "__main__": # pragma: no cover
59
+ main()
stream3r/worker/pipeline.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Inference pipeline utilities reused by worker jobs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from contextlib import nullcontext
6
+ from dataclasses import dataclass
7
+ from pathlib import Path
8
+ from typing import Callable, Iterable, Mapping
9
+
10
+ import numpy as np
11
+ import torch
12
+
13
+ from stream3r.models.components.utils.load_fn import load_and_preprocess_images
14
+ from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri
15
+ from stream3r.stream_session import StreamSession
16
+
17
+ from .runtime import WorkerRuntime
18
+
19
+
20
+ ProgressCallback = Callable[[int, int], None]
21
+
22
+
23
+ @dataclass
24
+ class InferenceResult:
25
+ predictions: dict[str, np.ndarray]
26
+ total_frames: int
27
+ cache_path: Path | None
28
+
29
+
30
+ def _to_numpy(payload):
31
+ if isinstance(payload, torch.Tensor):
32
+ return payload.detach().cpu().numpy()
33
+ if isinstance(payload, dict):
34
+ return {k: _to_numpy(v) for k, v in payload.items()}
35
+ if isinstance(payload, (list, tuple)):
36
+ converted = [_to_numpy(item) for item in payload]
37
+ return type(payload)(converted)
38
+ return payload
39
+
40
+
41
+ def run_stream3r_inference(
42
+ *,
43
+ runtime: WorkerRuntime,
44
+ image_paths: Iterable[Path],
45
+ mode: str,
46
+ streaming: bool,
47
+ cache_output_path: Path | None,
48
+ progress_cb: ProgressCallback | None = None,
49
+ window_size: int | None = None,
50
+ ) -> InferenceResult:
51
+ """Execute STream3R inference for the provided frames."""
52
+
53
+ image_list = [Path(p) for p in image_paths]
54
+ if not image_list:
55
+ raise ValueError("No images provided to inference pipeline")
56
+
57
+ model = runtime.get_model()
58
+ device = runtime.model_device()
59
+
60
+ images = load_and_preprocess_images([str(path) for path in image_list])
61
+ total_frames = images.shape[0]
62
+
63
+ autocast_dtype = runtime.autocast_dtype()
64
+ autocast_ctx = (
65
+ torch.autocast(device_type=device.type, dtype=autocast_dtype)
66
+ if device.type == "cuda"
67
+ else nullcontext()
68
+ )
69
+
70
+ predictions: Mapping[str, torch.Tensor]
71
+ cache_path: Path | None = None
72
+
73
+ model.eval()
74
+
75
+ if window_size is not None and window_size <= 0:
76
+ window_size = None
77
+
78
+ with torch.no_grad():
79
+ if streaming:
80
+ session_kwargs = {"mode": mode}
81
+ if window_size is not None:
82
+ session_kwargs["window_size"] = window_size
83
+ session = StreamSession(model, **session_kwargs)
84
+ session.clear()
85
+ for idx in range(total_frames):
86
+ frame = images[idx : idx + 1].to(device)
87
+ with autocast_ctx:
88
+ session.forward_stream(frame)
89
+ if progress_cb is not None:
90
+ progress_cb(idx + 1, total_frames)
91
+
92
+ if cache_output_path is not None:
93
+ session.save_cache(str(cache_output_path))
94
+ cache_path = cache_output_path
95
+
96
+ predictions = session.get_all_predictions()
97
+ else:
98
+ inputs = images.to(device)
99
+ with autocast_ctx:
100
+ predictions = model(inputs, mode=mode)
101
+ if progress_cb is not None:
102
+ progress_cb(total_frames, total_frames)
103
+
104
+ predictions = dict(predictions)
105
+
106
+ # Augment predictions with pose matrices and world coordinates
107
+ height, width = images.shape[-2:]
108
+
109
+ pose_enc = predictions.get("pose_enc")
110
+ if pose_enc is None:
111
+ raise RuntimeError("Model predictions missing 'pose_enc'")
112
+
113
+ if not isinstance(pose_enc, torch.Tensor):
114
+ pose_enc = torch.as_tensor(pose_enc)
115
+
116
+ if pose_enc.dim() == 2: # streaming cache might drop batch dim
117
+ pose_enc = pose_enc.unsqueeze(0)
118
+
119
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, (height, width))
120
+ predictions["extrinsic"] = extrinsic
121
+ predictions["intrinsic"] = intrinsic
122
+
123
+ for key, value in list(predictions.items()):
124
+ if isinstance(value, torch.Tensor):
125
+ predictions[key] = value
126
+
127
+ predictions_np = {key: _to_numpy(value) for key, value in predictions.items()}
128
+
129
+ pose_enc_np = predictions_np.pop("pose_enc", None)
130
+ if pose_enc_np is not None and pose_enc_np.ndim >= 3:
131
+ predictions_np["pose_enc"] = pose_enc_np
132
+
133
+ # Remove batch dimension if present
134
+ for key, value in list(predictions_np.items()):
135
+ if isinstance(value, np.ndarray) and value.shape[0] == 1:
136
+ predictions_np[key] = np.squeeze(value, axis=0)
137
+
138
+ torch.cuda.empty_cache()
139
+
140
+ return InferenceResult(
141
+ predictions=predictions_np,
142
+ total_frames=total_frames,
143
+ cache_path=cache_path,
144
+ )
stream3r/worker/runtime.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Runtime registry for shared worker resources."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from contextlib import contextmanager
7
+ from threading import Lock
8
+ from typing import Any, Dict, Mapping
9
+
10
+ import redis
11
+ import torch
12
+
13
+ from stream3r.models.stream3r import STream3R
14
+
15
+ from .config import WorkerSettings
16
+ from .db import BaseDatabaseClient, create_database_client
17
+ from .storage import StorageClient, create_storage_client
18
+
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class WorkerRuntime:
24
+ """Holds shared state reused across RQ jobs."""
25
+
26
+ def __init__(self, settings: WorkerSettings):
27
+ self.settings = settings
28
+ self._redis = redis.Redis.from_url(
29
+ settings.redis_url,
30
+ decode_responses=False,
31
+ health_check_interval=settings.redis_healthcheck_interval,
32
+ )
33
+ self.storage: StorageClient = create_storage_client(settings)
34
+ self.db: BaseDatabaseClient = create_database_client(settings)
35
+
36
+ try:
37
+ self.db.ensure_schema()
38
+ except Exception as exc: # pragma: no cover - depends on external DB
39
+ logger.warning("Failed to ensure job schema: %s", exc)
40
+
41
+ self._model: STream3R | None = None
42
+ self._model_lock = Lock()
43
+ self._device: torch.device | None = None
44
+ self._autocast_dtype: torch.dtype | None = None
45
+
46
+ # -----------------------------------------------------------------
47
+ # Redis helpers
48
+ # -----------------------------------------------------------------
49
+ @property
50
+ def redis(self) -> redis.Redis:
51
+ return self._redis
52
+
53
+ def emit_event(self, payload: Mapping[str, Any]) -> None:
54
+ try:
55
+ stream = self.settings.redis_events_stream
56
+ data = {k: str(v) for k, v in payload.items() if v is not None}
57
+ maxlen = self.settings.redis_stream_maxlen
58
+ self._redis.xadd(stream, data, maxlen=maxlen, approximate=True)
59
+ except redis.RedisError as exc: # pragma: no cover - depends on Redis
60
+ logger.warning("Failed to emit event to Redis: %s", exc)
61
+
62
+ @contextmanager
63
+ def gpu_lock(self) -> Any:
64
+ lock = self._redis.lock(
65
+ self.settings.gpu_lock_key,
66
+ timeout=self.settings.gpu_lock_timeout,
67
+ blocking_timeout=self.settings.gpu_lock_blocking_timeout,
68
+ )
69
+ acquired = False
70
+ try:
71
+ acquired = lock.acquire(blocking=True)
72
+ if not acquired:
73
+ raise TimeoutError("Timed out waiting for GPU lock")
74
+ yield
75
+ finally:
76
+ if acquired:
77
+ try:
78
+ lock.release()
79
+ except redis.RedisError: # pragma: no cover - depends on Redis
80
+ logger.debug("GPU lock already released")
81
+
82
+ # -----------------------------------------------------------------
83
+ # Model helpers
84
+ # -----------------------------------------------------------------
85
+ def _resolve_device(self) -> torch.device:
86
+ if self._device is not None:
87
+ return self._device
88
+
89
+ preference = self.settings.model_device_preference
90
+ if preference:
91
+ try:
92
+ device = torch.device(preference)
93
+ except (ValueError, RuntimeError):
94
+ logger.warning("Unknown device preference '%s', falling back to auto", preference)
95
+ device = None
96
+ else:
97
+ device = None
98
+
99
+ if device is None:
100
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
101
+
102
+ self._device = device
103
+ return device
104
+
105
+ def _resolve_autocast_dtype(self) -> torch.dtype:
106
+ if self._autocast_dtype is not None:
107
+ return self._autocast_dtype
108
+
109
+ dtype_name = self.settings.model_dtype
110
+ if dtype_name:
111
+ try:
112
+ self._autocast_dtype = getattr(torch, dtype_name)
113
+ return self._autocast_dtype
114
+ except AttributeError:
115
+ logger.warning("Unsupported dtype '%s', using default", dtype_name)
116
+
117
+ if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
118
+ self._autocast_dtype = torch.bfloat16
119
+ else:
120
+ self._autocast_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
121
+ return self._autocast_dtype
122
+
123
+ def get_model(self) -> STream3R:
124
+ with self._model_lock:
125
+ if self._model is None:
126
+ logger.info("Loading STream3R model '%s'", self.settings.model_id)
127
+ model = STream3R.from_pretrained(
128
+ self.settings.model_id,
129
+ revision=self.settings.model_revision,
130
+ )
131
+ device = self._resolve_device()
132
+ model.to(device)
133
+ model.eval()
134
+ self._model = model
135
+ return self._model
136
+
137
+ def model_device(self) -> torch.device:
138
+ return self._resolve_device()
139
+
140
+ def autocast_dtype(self) -> torch.dtype:
141
+ return self._resolve_autocast_dtype()
142
+
143
+ # -----------------------------------------------------------------
144
+ def close(self) -> None:
145
+ try:
146
+ self.db.close()
147
+ except AttributeError:
148
+ pass
149
+
150
+
151
+ _RUNTIME: WorkerRuntime | None = None
152
+
153
+
154
+ def get_runtime() -> WorkerRuntime:
155
+ global _RUNTIME
156
+ if _RUNTIME is None:
157
+ settings = WorkerSettings.from_env()
158
+ _RUNTIME = WorkerRuntime(settings)
159
+ return _RUNTIME
stream3r/worker/storage.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Storage backends for persisting STream3R job artifacts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import json
6
+ import shutil
7
+ from pathlib import Path
8
+ from typing import Mapping
9
+
10
+ try: # Lazy import to keep optional dependency
11
+ import boto3
12
+ from botocore.client import Config as Boto3Config
13
+ from botocore.exceptions import BotoCoreError, ClientError
14
+ except ModuleNotFoundError: # pragma: no cover - fallback when boto3 missing
15
+ boto3 = None # type: ignore[assignment]
16
+ Boto3Config = None # type: ignore[assignment]
17
+ BotoCoreError = ClientError = Exception # type: ignore[assignment]
18
+
19
+ from .config import WorkerSettings
20
+
21
+
22
+ class StorageError(RuntimeError):
23
+ """Raised when artifact persistence fails."""
24
+
25
+
26
+ class StorageClient:
27
+ """Abstract base class providing a minimal upload interface."""
28
+
29
+ def __init__(self, settings: WorkerSettings):
30
+ self.settings = settings
31
+
32
+ # --- key builders -------------------------------------------------
33
+ def build_key(self, scene_id: str, *parts: str) -> str:
34
+ components = [str(self.settings.storage_prefix), str(scene_id), "stream3r"]
35
+ for part in parts:
36
+ if not part:
37
+ continue
38
+ components.append(str(part).strip("/"))
39
+ return "/".join(components)
40
+
41
+ def build_uri(self, key: str) -> str:
42
+ raise NotImplementedError
43
+
44
+ # --- upload primitives -------------------------------------------
45
+ def upload_file(self, local_path: Path, key: str, *, content_type: str | None = None) -> str:
46
+ raise NotImplementedError
47
+
48
+ def upload_bytes(self, data: bytes, key: str, *, content_type: str | None = None) -> str:
49
+ raise NotImplementedError
50
+
51
+ def upload_json(self, payload: Mapping[str, object], key: str) -> str:
52
+ data = json.dumps(payload, allow_nan=False).encode("utf-8")
53
+ return self.upload_bytes(data, key, content_type="application/json")
54
+
55
+ def download_to_path(self, key: str, destination: Path) -> Path:
56
+ """Download an object identified by *key* into *destination*."""
57
+
58
+ raise NotImplementedError
59
+
60
+ # --- helpers ------------------------------------------------------
61
+ def ensure_dir(self, scene_id: str, *parts: str) -> str:
62
+ """Create a logical directory path under the storage prefix."""
63
+ key = self.build_key(scene_id, *parts)
64
+ if not key.endswith("/"):
65
+ key = f"{key}/"
66
+ return key
67
+
68
+
69
+ class S3StorageClient(StorageClient):
70
+ """S3-compatible storage backend."""
71
+
72
+ def __init__(self, settings: WorkerSettings):
73
+ if boto3 is None: # pragma: no cover - guarded by optional dependency
74
+ raise StorageError("boto3 is required for S3 storage but is not installed")
75
+
76
+ super().__init__(settings)
77
+
78
+ session_kwargs: dict[str, object] = {}
79
+ if settings.s3_profile:
80
+ session_kwargs["profile_name"] = settings.s3_profile
81
+ if settings.s3_region:
82
+ session_kwargs["region_name"] = settings.s3_region
83
+ if settings.aws_access_key_id and settings.aws_secret_access_key:
84
+ session_kwargs["aws_access_key_id"] = settings.aws_access_key_id
85
+ session_kwargs["aws_secret_access_key"] = settings.aws_secret_access_key
86
+ if settings.aws_session_token:
87
+ session_kwargs["aws_session_token"] = settings.aws_session_token
88
+
89
+ session = boto3.session.Session(**session_kwargs)
90
+ config = None
91
+ if settings.s3_force_path_style and Boto3Config is not None:
92
+ config = Boto3Config(s3={"addressing_style": "path"})
93
+
94
+ self._client = session.client(
95
+ "s3",
96
+ endpoint_url=settings.s3_endpoint_url,
97
+ config=config,
98
+ )
99
+
100
+ if not settings.s3_bucket:
101
+ raise StorageError("STREAM3R_STORAGE_BUCKET is required for S3 storage")
102
+
103
+ def build_uri(self, key: str) -> str:
104
+ bucket = self.settings.s3_bucket
105
+ return f"s3://{bucket}/{key}"
106
+
107
+ def upload_file(self, local_path: Path, key: str, *, content_type: str | None = None) -> str:
108
+ extra_args = {"ContentType": content_type} if content_type else None
109
+ try:
110
+ self._client.upload_file(str(local_path), self.settings.s3_bucket, key, ExtraArgs=extra_args)
111
+ except (BotoCoreError, ClientError) as exc: # pragma: no cover - network side effects
112
+ raise StorageError(f"Failed to upload {local_path} to {key}: {exc}") from exc
113
+ return self.build_uri(key)
114
+
115
+ def upload_bytes(self, data: bytes, key: str, *, content_type: str | None = None) -> str:
116
+ extra_args = {"ContentType": content_type} if content_type else None
117
+ try:
118
+ self._client.put_object(
119
+ Bucket=self.settings.s3_bucket,
120
+ Key=key,
121
+ Body=data,
122
+ ContentType=extra_args.get("ContentType") if extra_args else None,
123
+ )
124
+ except (BotoCoreError, ClientError) as exc: # pragma: no cover - network side effects
125
+ raise StorageError(f"Failed to upload payload to {key}: {exc}") from exc
126
+ return self.build_uri(key)
127
+
128
+ def download_to_path(self, key: str, destination: Path) -> Path:
129
+ destination.parent.mkdir(parents=True, exist_ok=True)
130
+ try:
131
+ object_key = str(key).lstrip("/")
132
+ self._client.download_file(self.settings.s3_bucket, object_key, str(destination))
133
+ except (BotoCoreError, ClientError) as exc: # pragma: no cover - network side effects
134
+ raise StorageError(
135
+ f"Failed to download {object_key} from bucket {self.settings.s3_bucket} to {destination}: {exc}"
136
+ ) from exc
137
+ return destination
138
+
139
+
140
+ class LocalStorageClient(StorageClient):
141
+ """On-disk storage backend for development and testing."""
142
+
143
+ def __init__(self, settings: WorkerSettings):
144
+ super().__init__(settings)
145
+ self.root = settings.local_storage_root
146
+ self.root.mkdir(parents=True, exist_ok=True)
147
+
148
+ def _resolve(self, key: str) -> Path:
149
+ path = self.root.joinpath(*key.split("/"))
150
+ path.parent.mkdir(parents=True, exist_ok=True)
151
+ return path
152
+
153
+ def build_uri(self, key: str) -> str:
154
+ return str(self._resolve(key))
155
+
156
+ def upload_file(self, local_path: Path, key: str, *, content_type: str | None = None) -> str: # noqa: ARG002
157
+ destination = self._resolve(key)
158
+ shutil.copyfile(local_path, destination)
159
+ return str(destination)
160
+
161
+ def upload_bytes(self, data: bytes, key: str, *, content_type: str | None = None) -> str: # noqa: ARG002
162
+ destination = self._resolve(key)
163
+ destination.write_bytes(data)
164
+ return str(destination)
165
+
166
+ def download_to_path(self, key: str, destination: Path) -> Path:
167
+ source = self._resolve(key)
168
+ if not source.exists():
169
+ raise StorageError(f"Local object not found for key: {key}")
170
+ destination.parent.mkdir(parents=True, exist_ok=True)
171
+ shutil.copyfile(source, destination)
172
+ return destination
173
+
174
+
175
+ def create_storage_client(settings: WorkerSettings) -> StorageClient:
176
+ """Instantiate the appropriate storage backend."""
177
+
178
+ if settings.s3_bucket:
179
+ return S3StorageClient(settings)
180
+ return LocalStorageClient(settings)
stream3r/worker/tasks.py ADDED
@@ -0,0 +1,1036 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """RQ job entrypoints for STream3R worker."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import base64
6
+ import json
7
+ import logging
8
+ import re
9
+ import shutil
10
+ import tempfile
11
+ import traceback
12
+ import uuid
13
+ from dataclasses import dataclass, field
14
+ from datetime import datetime, timezone
15
+ from pathlib import Path
16
+ from typing import Any, Callable, Mapping
17
+
18
+ import numpy as np
19
+ import requests
20
+ from rq import get_current_job
21
+
22
+ from stream3r.utils.visual_utils import predictions_to_glb
23
+
24
+ from .pipeline import InferenceResult, run_stream3r_inference
25
+ from .runtime import WorkerRuntime, get_runtime
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+ IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
31
+ _SAFE_CHARS = re.compile(r"[^0-9A-Za-z_-]")
32
+
33
+
34
+ def _as_bool(value: Any, default: bool) -> bool:
35
+ if isinstance(value, bool):
36
+ return value
37
+ if isinstance(value, str):
38
+ lowered = value.strip().lower()
39
+ if lowered in {"1", "true", "yes", "y", "on"}:
40
+ return True
41
+ if lowered in {"0", "false", "no", "n", "off"}:
42
+ return False
43
+ return default
44
+
45
+
46
+ def _as_int(value: Any, default: int) -> int:
47
+ try:
48
+ return int(value)
49
+ except (TypeError, ValueError):
50
+ return default
51
+
52
+
53
+ @dataclass(slots=True)
54
+ class FrameRecord:
55
+ index: int
56
+ frame_id: str
57
+ path: Path
58
+ source: str | None = None
59
+ timestamp: str | None = None
60
+ metadata: dict[str, Any] = field(default_factory=dict)
61
+
62
+
63
+ class ProgressTracker:
64
+ """Aggregates frame progress to percentage updates."""
65
+
66
+ def __init__(self, runtime: WorkerRuntime, job_meta: Mapping[str, str | None]):
67
+ self.runtime = runtime
68
+ self.job_meta = job_meta
69
+ self.last_value = -1
70
+
71
+ def __call__(self, processed: int, total: int) -> None:
72
+ if total <= 0:
73
+ return
74
+ percent = int(round((processed / total) * 100))
75
+ percent = max(0, min(100, percent))
76
+ if percent == self.last_value:
77
+ return
78
+ self.last_value = percent
79
+ payload = {
80
+ **self.job_meta,
81
+ "status": "progress",
82
+ "progress": percent,
83
+ "ts": datetime.now(timezone.utc).timestamp(),
84
+ }
85
+ runtime_emit(self.runtime, payload)
86
+
87
+
88
+ def runtime_emit(runtime: WorkerRuntime, payload: Mapping[str, Any]) -> None:
89
+ runtime.emit_event(payload)
90
+
91
+
92
+ def _slugify(value: str, fallback: str) -> str:
93
+ candidate = _SAFE_CHARS.sub("_", value).strip("_")
94
+ if not candidate:
95
+ candidate = fallback
96
+ return candidate[:128]
97
+
98
+
99
+ def _is_url(value: str) -> bool:
100
+ return value.startswith("http://") or value.startswith("https://")
101
+
102
+
103
+ def _download_to_path(url: str, destination: Path) -> None:
104
+ response = requests.get(url, stream=True, timeout=60)
105
+ response.raise_for_status()
106
+ with destination.open("wb") as handle:
107
+ for chunk in response.iter_content(chunk_size=1 << 16):
108
+ if chunk:
109
+ handle.write(chunk)
110
+
111
+
112
+ def _write_base64(content: str, destination: Path) -> None:
113
+ data = base64.b64decode(content)
114
+ destination.write_bytes(data)
115
+
116
+
117
+ def _resolve_frame_entry(entry: Any, *, index: int, dest_dir: Path) -> FrameRecord:
118
+ metadata: dict[str, Any] = {}
119
+ timestamp = None
120
+ source = None
121
+ dest_dir.mkdir(parents=True, exist_ok=True)
122
+
123
+ if isinstance(entry, str):
124
+ if _is_url(entry):
125
+ source = entry
126
+ frame_id = _slugify(Path(entry).stem or f"frame_{index:06d}", f"frame_{index:06d}")
127
+ destination = dest_dir / f"{frame_id}.jpg"
128
+ _download_to_path(entry, destination)
129
+ else:
130
+ path = Path(entry)
131
+ if not path.exists():
132
+ raise FileNotFoundError(f"Frame path does not exist: {entry}")
133
+ frame_id = _slugify(path.stem, f"frame_{index:06d}")
134
+ destination = dest_dir / path.name
135
+ shutil.copy2(path, destination)
136
+ elif isinstance(entry, Mapping):
137
+ frame_id = _slugify(str(entry.get("frame_id") or entry.get("id") or f"frame_{index:06d}"), f"frame_{index:06d}")
138
+ timestamp = entry.get("timestamp")
139
+ metadata = {k: v for k, v in entry.items() if k not in {"path", "url", "content", "frame_id", "id", "timestamp"}}
140
+
141
+ if path := entry.get("path") or entry.get("local_path"):
142
+ path = Path(path)
143
+ if not path.exists():
144
+ raise FileNotFoundError(f"Frame path does not exist: {path}")
145
+ destination = dest_dir / (path.name if path.suffix else f"{frame_id}.jpg")
146
+ shutil.copy2(path, destination)
147
+ elif url := entry.get("url"):
148
+ source = url
149
+ suffix = Path(url).suffix or ".jpg"
150
+ destination = dest_dir / f"{frame_id}{suffix}"
151
+ _download_to_path(url, destination)
152
+ elif content := entry.get("content"):
153
+ destination = dest_dir / f"{frame_id}.jpg"
154
+ _write_base64(content, destination)
155
+ else:
156
+ raise ValueError("Frame entry must include 'path', 'url', or 'content'")
157
+ else:
158
+ raise TypeError("Unsupported frame specification")
159
+
160
+ if destination.suffix.lower() not in IMAGE_EXTENSIONS:
161
+ destination = destination.with_suffix(".png")
162
+
163
+ return FrameRecord(
164
+ index=index,
165
+ frame_id=_slugify(destination.stem, f"frame_{index:06d}"),
166
+ path=destination,
167
+ source=source,
168
+ timestamp=timestamp,
169
+ metadata=metadata,
170
+ )
171
+
172
+
173
+ def _collect_frames(
174
+ runtime: WorkerRuntime,
175
+ scene_id: str,
176
+ payload: Mapping[str, Any],
177
+ tmp_dir: Path,
178
+ ) -> list[FrameRecord]:
179
+ frames_dir = tmp_dir / "frames"
180
+ frames_payload = payload.get("frames") or []
181
+ frame_limit = runtime.settings.max_frames_per_job
182
+
183
+ records: list[FrameRecord] = []
184
+ if frames_payload:
185
+ for entry in frames_payload:
186
+ if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
187
+ break
188
+ records.append(_resolve_frame_entry(entry, index=len(records), dest_dir=frames_dir))
189
+ else:
190
+ directory = payload.get("frames_dir") or payload.get("images_dir")
191
+ if directory:
192
+ directory_path = Path(directory)
193
+ if not directory_path.is_dir():
194
+ raise ValueError(f"frames_dir does not exist: {directory}")
195
+ for idx, file in enumerate(sorted(directory_path.iterdir())):
196
+ if file.suffix.lower() not in IMAGE_EXTENSIONS:
197
+ continue
198
+ if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
199
+ break
200
+ destination = frames_dir / file.name
201
+ shutil.copy2(file, destination)
202
+ records.append(
203
+ FrameRecord(
204
+ index=len(records),
205
+ frame_id=_slugify(file.stem, f"frame_{idx:06d}"),
206
+ path=destination,
207
+ )
208
+ )
209
+
210
+ if not records:
211
+ records = _collect_frames_from_scene_media(runtime, scene_id, frames_dir)
212
+
213
+ if not records:
214
+ raise ValueError(f"No valid frames found for scene '{scene_id}'")
215
+
216
+ limit = runtime.settings.max_frames_per_job
217
+ if limit and limit > 0 and len(records) > limit:
218
+ records = records[:limit]
219
+
220
+ for new_idx, record in enumerate(records):
221
+ if record.index != new_idx:
222
+ record.index = new_idx
223
+
224
+ return records
225
+
226
+
227
+ def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
228
+ result = dict(payload)
229
+ frames = result.pop("frames", None)
230
+ if frames is not None:
231
+ result["frame_count"] = len(frames)
232
+ if "frames_dir" in result:
233
+ result["frames_dir"] = str(result["frames_dir"])
234
+ return result
235
+
236
+
237
+ def _prepare_session_settings(
238
+ payload: Mapping[str, Any],
239
+ *,
240
+ mode: str,
241
+ streaming: bool,
242
+ frame_records: list[FrameRecord],
243
+ window_size: int | None = None,
244
+ ) -> dict[str, Any]:
245
+ base_settings = payload.get("session_settings") or {}
246
+ session_settings = dict(base_settings)
247
+ session_settings.update(
248
+ {
249
+ "mode": mode,
250
+ "streaming": streaming,
251
+ "frame_count": len(frame_records),
252
+ }
253
+ )
254
+ window_setting = window_size if window_size is not None else payload.get("window_size")
255
+ if window_setting:
256
+ try:
257
+ session_settings["window_size"] = int(window_setting)
258
+ except (TypeError, ValueError):
259
+ pass
260
+ return session_settings
261
+
262
+
263
+ def _collect_frames_from_scene_media(
264
+ runtime: WorkerRuntime,
265
+ scene_id: str,
266
+ dest_dir: Path,
267
+ ) -> list[FrameRecord]:
268
+ base_url = runtime.settings.scene_media_api_base_url
269
+ if not base_url:
270
+ raise ValueError(
271
+ "Scene media API base URL is not configured. Set API_BASE_URL"
272
+ )
273
+
274
+ base_url = base_url.rstrip("/")
275
+ dest_dir.mkdir(parents=True, exist_ok=True)
276
+
277
+ per_page = runtime.settings.scene_media_page_size
278
+ if per_page <= 0:
279
+ per_page = 100
280
+ per_page = max(1, min(per_page, 1000))
281
+ frame_limit = runtime.settings.max_frames_per_job
282
+
283
+ headers = {}
284
+ token = runtime.settings.scene_media_api_token
285
+ if token:
286
+ headers["Authorization"] = f"Bearer {token}"
287
+
288
+ url = f"{base_url}/scenes/{scene_id}/media"
289
+ session = requests.Session()
290
+ records: list[FrameRecord] = []
291
+ offset = 0
292
+
293
+ while True:
294
+ if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
295
+ break
296
+
297
+ request_limit = per_page
298
+ if frame_limit and frame_limit > 0:
299
+ remaining = frame_limit - len(records)
300
+ if remaining <= 0:
301
+ break
302
+ request_limit = min(request_limit, remaining)
303
+
304
+ params = {
305
+ "limit": request_limit,
306
+ "offset": offset,
307
+ "media_type": "image",
308
+ }
309
+ try:
310
+ response = session.get(url, params=params, headers=headers, timeout=30)
311
+ response.raise_for_status()
312
+ except requests.RequestException as exc:
313
+ raise RuntimeError(f"Failed to fetch media for scene '{scene_id}': {exc}") from exc
314
+
315
+ data = response.json()
316
+ items = data.get("items") or []
317
+ if not items:
318
+ break
319
+
320
+ for item in items:
321
+ if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
322
+ break
323
+ file_key = item.get("file")
324
+ if not file_key:
325
+ continue
326
+
327
+ idx = len(records)
328
+ source_path = Path(str(file_key))
329
+ suffix = source_path.suffix if source_path.suffix else ".png"
330
+ frame_id = _slugify(source_path.stem or f"frame_{idx:06d}", f"frame_{idx:06d}")
331
+ destination = dest_dir / f"{frame_id}{suffix}"
332
+
333
+ try:
334
+ runtime.storage.download_to_path(str(file_key), destination)
335
+ except Exception as exc: # pragma: no cover - download depends on external storage
336
+ raise RuntimeError(f"Failed to download media '{file_key}' for scene '{scene_id}': {exc}") from exc
337
+
338
+ records.append(
339
+ FrameRecord(
340
+ index=idx,
341
+ frame_id=frame_id,
342
+ path=destination,
343
+ source=str(file_key),
344
+ timestamp=item.get("captured_at"),
345
+ metadata={
346
+ "media_id": item.get("id"),
347
+ "media_type": item.get("media_type"),
348
+ },
349
+ )
350
+ )
351
+
352
+ if len(items) < request_limit:
353
+ break
354
+ offset += request_limit
355
+
356
+ return records
357
+
358
+
359
+ def _pose_confidence(predictions: Mapping[str, np.ndarray]) -> np.ndarray | None:
360
+ if "world_points_conf" in predictions:
361
+ return np.asarray(predictions["world_points_conf"], dtype=np.float32)
362
+ if "depth_conf" in predictions:
363
+ return np.asarray(predictions["depth_conf"], dtype=np.float32)
364
+ return None
365
+
366
+
367
+ def _save_pointmaps(
368
+ *,
369
+ runtime: WorkerRuntime,
370
+ scene_id: str,
371
+ predictions: Mapping[str, np.ndarray],
372
+ frame_records: list[FrameRecord],
373
+ temp_dir: Path,
374
+ ) -> dict[str, Any]:
375
+ world_points = predictions.get("world_points")
376
+ if world_points is None:
377
+ world_points = predictions.get("world_points_from_depth")
378
+ if world_points is None:
379
+ raise RuntimeError("Predictions missing world points")
380
+
381
+ world_points = np.asarray(world_points)
382
+ confidence = _pose_confidence(predictions)
383
+ if confidence is None:
384
+ confidence = np.ones(world_points.shape[:-1], dtype=np.float32)
385
+
386
+ local_dir = temp_dir / "pointmaps"
387
+ local_dir.mkdir(parents=True, exist_ok=True)
388
+
389
+ entries: list[dict[str, Any]] = []
390
+ for record in frame_records:
391
+ idx = record.index
392
+ filename = f"{record.frame_id}.npz"
393
+ local_file = local_dir / filename
394
+ np.savez(
395
+ local_file,
396
+ xyz=np.asarray(world_points[idx], dtype=np.float32),
397
+ confidence=np.asarray(confidence[idx], dtype=np.float32),
398
+ )
399
+ key = runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir, filename)
400
+ uri = runtime.storage.upload_file(local_file, key, content_type="application/octet-stream")
401
+ entries.append(
402
+ {
403
+ "frame_id": record.frame_id,
404
+ "frame_index": record.index,
405
+ "url": uri,
406
+ "timestamp": record.timestamp,
407
+ }
408
+ )
409
+
410
+ directory_uri = runtime.storage.build_uri(
411
+ runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir)
412
+ )
413
+
414
+ return {
415
+ "pointmaps": entries,
416
+ "pointmap_dir": directory_uri,
417
+ }
418
+
419
+
420
+ def _write_poses_jsonl(
421
+ *,
422
+ runtime: WorkerRuntime,
423
+ scene_id: str,
424
+ job_id: str,
425
+ predictions: Mapping[str, np.ndarray],
426
+ frame_records: list[FrameRecord],
427
+ temp_dir: Path,
428
+ ) -> str:
429
+ extrinsic = np.asarray(predictions.get("extrinsic"))
430
+ intrinsic = predictions.get("intrinsic")
431
+ if intrinsic is not None:
432
+ intrinsic = np.asarray(intrinsic)
433
+
434
+ local_file = temp_dir / "poses.jsonl"
435
+ with local_file.open("w", encoding="utf-8") as handle:
436
+ for record in frame_records:
437
+ idx = record.index
438
+ payload = {
439
+ "job_id": job_id,
440
+ "scene_id": scene_id,
441
+ "frame_id": record.frame_id,
442
+ "frame_index": record.index,
443
+ "extrinsic": extrinsic[idx].tolist(),
444
+ }
445
+ if intrinsic is not None:
446
+ payload["intrinsic"] = intrinsic[idx].tolist()
447
+ if record.timestamp is not None:
448
+ payload["timestamp"] = record.timestamp
449
+ if record.source is not None:
450
+ payload["source"] = record.source
451
+ if record.metadata:
452
+ payload["metadata"] = record.metadata
453
+ handle.write(json.dumps(payload))
454
+ handle.write("\n")
455
+
456
+ key = runtime.storage.build_key(
457
+ scene_id,
458
+ runtime.settings.models_dir,
459
+ runtime.settings.poses_filename,
460
+ )
461
+ return runtime.storage.upload_file(local_file, key, content_type="application/json")
462
+
463
+
464
+ def _upload_cache(
465
+ *,
466
+ runtime: WorkerRuntime,
467
+ scene_id: str,
468
+ cache_path: Path | None,
469
+ ) -> str | None:
470
+ if cache_path is None or not cache_path.exists():
471
+ return None
472
+ key = runtime.storage.build_key(
473
+ scene_id,
474
+ runtime.settings.models_dir,
475
+ runtime.settings.session_cache_filename,
476
+ )
477
+ return runtime.storage.upload_file(cache_path, key, content_type="application/octet-stream")
478
+
479
+
480
+ def _write_predictions_npz(
481
+ *,
482
+ runtime: WorkerRuntime,
483
+ scene_id: str,
484
+ predictions: Mapping[str, np.ndarray],
485
+ temp_dir: Path,
486
+ ) -> str:
487
+ payload = {k: v for k, v in predictions.items() if isinstance(v, np.ndarray)}
488
+ local_file = temp_dir / runtime.settings.predictions_filename
489
+ np.savez(local_file, **payload)
490
+ key = runtime.storage.build_key(
491
+ scene_id,
492
+ runtime.settings.models_dir,
493
+ runtime.settings.predictions_filename,
494
+ )
495
+ return runtime.storage.upload_file(local_file, key, content_type="application/octet-stream")
496
+
497
+
498
+ def _write_session_settings(
499
+ *,
500
+ runtime: WorkerRuntime,
501
+ scene_id: str,
502
+ session_settings: Mapping[str, Any],
503
+ temp_dir: Path,
504
+ ) -> str:
505
+ local_file = temp_dir / runtime.settings.session_settings_filename
506
+ local_file.write_text(json.dumps(session_settings, indent=2), encoding="utf-8")
507
+ key = runtime.storage.build_key(
508
+ scene_id,
509
+ runtime.settings.models_dir,
510
+ runtime.settings.session_settings_filename,
511
+ )
512
+ return runtime.storage.upload_file(local_file, key, content_type="application/json")
513
+
514
+
515
+ def _write_selected_frames(
516
+ *,
517
+ runtime: WorkerRuntime,
518
+ scene_id: str,
519
+ selected_frames: list[dict[str, Any]],
520
+ top_k: int,
521
+ temp_dir: Path,
522
+ ) -> str | None:
523
+ if not selected_frames:
524
+ return None
525
+ local_file = temp_dir / runtime.settings.selected_frames_filename
526
+ payload = {"top_k": top_k, "frames": selected_frames}
527
+ local_file.write_text(json.dumps(payload, indent=2), encoding="utf-8")
528
+ key = runtime.storage.build_key(
529
+ scene_id,
530
+ runtime.settings.models_dir,
531
+ runtime.settings.selected_frames_filename,
532
+ )
533
+ return runtime.storage.upload_file(local_file, key, content_type="application/json")
534
+
535
+
536
+ def _compute_selected_frames(
537
+ predictions: Mapping[str, np.ndarray],
538
+ frame_records: list[FrameRecord],
539
+ top_k: int,
540
+ ) -> list[dict[str, Any]]:
541
+ if top_k <= 0:
542
+ return []
543
+ confidence = _pose_confidence(predictions)
544
+ if confidence is None:
545
+ return []
546
+ scores = confidence.reshape(confidence.shape[0], -1).mean(axis=1)
547
+ indices = np.argsort(scores)[::-1][:top_k]
548
+ result = []
549
+ for idx in indices:
550
+ record = frame_records[int(idx)]
551
+ result.append(
552
+ {
553
+ "frame_id": record.frame_id,
554
+ "frame_index": record.index,
555
+ "score": float(scores[idx]),
556
+ }
557
+ )
558
+ return result
559
+
560
+
561
+ def _save_scene_glb(
562
+ *,
563
+ runtime: WorkerRuntime,
564
+ scene_id: str,
565
+ predictions: Mapping[str, np.ndarray],
566
+ temp_dir: Path,
567
+ payload: Mapping[str, Any],
568
+ ) -> str:
569
+ local_file = temp_dir / runtime.settings.scene_glb_filename
570
+ scene = predictions_to_glb(
571
+ dict(predictions),
572
+ conf_thres=float(payload.get("conf_thres", 3.0)),
573
+ filter_by_frames=payload.get("frame_filter", "All"),
574
+ mask_black_bg=_as_bool(payload.get("mask_black_bg"), False),
575
+ mask_white_bg=_as_bool(payload.get("mask_white_bg"), False),
576
+ show_cam=_as_bool(payload.get("show_cam"), True),
577
+ mask_sky=_as_bool(payload.get("mask_sky"), False),
578
+ target_dir=str(temp_dir),
579
+ prediction_mode=payload.get("prediction_mode", "Predicted Pointmap"),
580
+ )
581
+ scene.export(file_obj=str(local_file))
582
+ key = runtime.storage.build_key(
583
+ scene_id,
584
+ runtime.settings.models_dir,
585
+ runtime.settings.scene_glb_filename,
586
+ )
587
+ return runtime.storage.upload_file(local_file, key, content_type="model/gltf-binary")
588
+
589
+
590
+ def _write_summary_json(
591
+ *,
592
+ runtime: WorkerRuntime,
593
+ scene_id: str,
594
+ summary: Mapping[str, Any],
595
+ temp_dir: Path,
596
+ ) -> str:
597
+ filename = runtime.settings.result_filename
598
+ local_file = temp_dir / filename
599
+ local_file.write_text(json.dumps(summary, indent=2), encoding="utf-8")
600
+ key = runtime.storage.build_key(
601
+ scene_id,
602
+ runtime.settings.models_dir,
603
+ filename,
604
+ )
605
+ return runtime.storage.upload_file(local_file, key, content_type="application/json")
606
+
607
+
608
+ def _upload_result_record(
609
+ *,
610
+ runtime: WorkerRuntime,
611
+ scene_id: str,
612
+ job_id: str,
613
+ payload: Mapping[str, Any],
614
+ ) -> str:
615
+ local = json.dumps(payload, indent=2).encode("utf-8")
616
+ key = runtime.storage.build_key(
617
+ scene_id,
618
+ runtime.settings.results_dir,
619
+ f"{job_id}.json",
620
+ )
621
+ return runtime.storage.upload_bytes(local, key, content_type="application/json")
622
+
623
+
624
+ def _model_dir_uri(runtime: WorkerRuntime, scene_id: str) -> str:
625
+ return runtime.storage.build_uri(
626
+ runtime.storage.build_key(scene_id, runtime.settings.models_dir)
627
+ )
628
+
629
+
630
+ def _generate_core_outputs(
631
+ *,
632
+ runtime: WorkerRuntime,
633
+ scene_id: str,
634
+ job_id: str,
635
+ predictions: Mapping[str, np.ndarray],
636
+ frame_records: list[FrameRecord],
637
+ inference: InferenceResult,
638
+ session_settings: Mapping[str, Any],
639
+ temp_dir: Path,
640
+ ) -> dict[str, Any]:
641
+ pointmap_info = _save_pointmaps(
642
+ runtime=runtime,
643
+ scene_id=scene_id,
644
+ predictions=predictions,
645
+ frame_records=frame_records,
646
+ temp_dir=temp_dir,
647
+ )
648
+
649
+ poses_url = _write_poses_jsonl(
650
+ runtime=runtime,
651
+ scene_id=scene_id,
652
+ job_id=job_id,
653
+ predictions=predictions,
654
+ frame_records=frame_records,
655
+ temp_dir=temp_dir,
656
+ )
657
+
658
+ cache_url = _upload_cache(
659
+ runtime=runtime,
660
+ scene_id=scene_id,
661
+ cache_path=inference.cache_path,
662
+ )
663
+
664
+ predictions_url = _write_predictions_npz(
665
+ runtime=runtime,
666
+ scene_id=scene_id,
667
+ predictions=predictions,
668
+ temp_dir=temp_dir,
669
+ )
670
+
671
+ session_settings_url = _write_session_settings(
672
+ runtime=runtime,
673
+ scene_id=scene_id,
674
+ session_settings=session_settings,
675
+ temp_dir=temp_dir,
676
+ )
677
+
678
+ extrinsic = np.asarray(predictions.get("extrinsic"))
679
+ intrinsic = predictions.get("intrinsic")
680
+ if intrinsic is not None:
681
+ intrinsic = np.asarray(intrinsic)
682
+
683
+ frames_payload: list[dict[str, Any]] = []
684
+ for entry in pointmap_info["pointmaps"]:
685
+ idx = entry["frame_index"]
686
+ frame = frame_records[idx]
687
+ frame_payload = {
688
+ "frame_id": frame.frame_id,
689
+ "frame_index": frame.index,
690
+ "pointmap_url": entry["url"],
691
+ "extrinsic": extrinsic[idx].tolist(),
692
+ }
693
+ if intrinsic is not None:
694
+ frame_payload["intrinsic"] = intrinsic[idx].tolist()
695
+ if frame.timestamp is not None:
696
+ frame_payload["timestamp"] = frame.timestamp
697
+ if frame.source is not None:
698
+ frame_payload["source"] = frame.source
699
+ frames_payload.append(frame_payload)
700
+
701
+ artifacts = {
702
+ "poses_url": poses_url,
703
+ "pointmap_dir": pointmap_info["pointmap_dir"],
704
+ "pointmaps": pointmap_info["pointmaps"],
705
+ "predictions_url": predictions_url,
706
+ "session_settings_url": session_settings_url,
707
+ }
708
+ if cache_url:
709
+ artifacts["kv_cache_url"] = cache_url
710
+
711
+ return {
712
+ "artifacts": artifacts,
713
+ "frames": frames_payload,
714
+ }
715
+
716
+
717
+ def _handle_pose_pointmap(
718
+ *,
719
+ runtime: WorkerRuntime,
720
+ payload: Mapping[str, Any],
721
+ mode: str,
722
+ streaming: bool,
723
+ job_id: str,
724
+ scene_id: str,
725
+ frame_records: list[FrameRecord],
726
+ inference: InferenceResult,
727
+ session_settings: Mapping[str, Any],
728
+ temp_dir: Path,
729
+ ) -> dict[str, Any]:
730
+ predictions = inference.predictions
731
+ core = _generate_core_outputs(
732
+ runtime=runtime,
733
+ scene_id=scene_id,
734
+ job_id=job_id,
735
+ predictions=predictions,
736
+ frame_records=frame_records,
737
+ inference=inference,
738
+ session_settings=session_settings,
739
+ temp_dir=temp_dir,
740
+ )
741
+
742
+ result_payload = {
743
+ "job_id": job_id,
744
+ "job_type": "pose_pointmap",
745
+ "scene_id": scene_id,
746
+ "mode": mode,
747
+ "streaming": streaming,
748
+ "frame_count": inference.total_frames,
749
+ "created_at": datetime.now(timezone.utc).isoformat(),
750
+ "artifacts": core["artifacts"],
751
+ "frames": core["frames"],
752
+ }
753
+
754
+ result_url = _upload_result_record(
755
+ runtime=runtime,
756
+ scene_id=scene_id,
757
+ job_id=job_id,
758
+ payload=result_payload,
759
+ )
760
+ result_payload["result_url"] = result_url
761
+ result_payload["model_dir"] = _model_dir_uri(runtime, scene_id)
762
+
763
+ return result_payload
764
+
765
+
766
+ JobHandler = Callable[..., dict[str, Any]]
767
+
768
+
769
+ def _execute_job(job_type: str, payload: Mapping[str, Any], handler: JobHandler) -> dict[str, Any]:
770
+ runtime = get_runtime()
771
+ job = get_current_job()
772
+ payload = dict(payload)
773
+
774
+ job_id = str(payload.get("job_id") or (job.id if job else uuid.uuid4()))
775
+ scene_id = payload.get("scene_id")
776
+ if not scene_id:
777
+ raise ValueError("Job payload is missing 'scene_id'")
778
+
779
+ payload.setdefault("job_type", job_type)
780
+ payload.setdefault("scene_id", scene_id)
781
+
782
+ mode = payload.get("mode") or runtime.settings.default_mode
783
+ streaming = _as_bool(payload.get("streaming"), runtime.settings.default_streaming)
784
+ window_size: int | None = None
785
+
786
+ if mode == "window":
787
+ streaming = True
788
+ payload["streaming"] = True
789
+ window_candidate = payload.get("window_size") or runtime.settings.stream_window_size
790
+ try:
791
+ window_size = int(window_candidate) if window_candidate else None
792
+ except (TypeError, ValueError):
793
+ window_size = runtime.settings.stream_window_size or None
794
+ if window_size and window_size > 0:
795
+ payload["window_size"] = window_size
796
+ else:
797
+ window_size = None
798
+
799
+ payload["mode"] = mode
800
+ timeout_override = payload.get("timeout")
801
+ if timeout_override is not None:
802
+ try:
803
+ job.timeout = int(timeout_override)
804
+ except (TypeError, ValueError):
805
+ pass
806
+
807
+ # Default to 15 minutes if no timeout already applied
808
+ if job.timeout is None:
809
+ job.timeout = 15 * 60
810
+
811
+ sanitized_payload = _sanitize_payload(payload)
812
+
813
+ job_meta = {
814
+ "job_id": job_id,
815
+ "job_type": job_type,
816
+ "scene_id": scene_id,
817
+ }
818
+
819
+ runtime.db.upsert_job(
820
+ job_id=job_id,
821
+ job_type=job_type,
822
+ scene_id=scene_id,
823
+ status="started",
824
+ payload=sanitized_payload,
825
+ )
826
+
827
+ runtime_emit(
828
+ runtime,
829
+ {
830
+ **job_meta,
831
+ "status": "started",
832
+ "progress": 0,
833
+ "ts": datetime.now(timezone.utc).timestamp(),
834
+ },
835
+ )
836
+
837
+ try:
838
+ with runtime.gpu_lock():
839
+ with tempfile.TemporaryDirectory(prefix=f"stream3r_{job_id}_") as tmp_dir:
840
+ temp_path = Path(tmp_dir)
841
+ frame_records = _collect_frames(runtime, scene_id, payload, temp_path)
842
+ cache_path = temp_path / runtime.settings.session_cache_filename if streaming else None
843
+
844
+ tracker = ProgressTracker(runtime, job_meta)
845
+ inference = run_stream3r_inference(
846
+ runtime=runtime,
847
+ image_paths=[record.path for record in frame_records],
848
+ mode=mode,
849
+ streaming=streaming,
850
+ cache_output_path=cache_path,
851
+ progress_cb=tracker,
852
+ window_size=window_size if streaming and mode == "window" else None,
853
+ )
854
+
855
+ session_settings = _prepare_session_settings(
856
+ payload,
857
+ mode=mode,
858
+ streaming=streaming,
859
+ frame_records=frame_records,
860
+ window_size=window_size,
861
+ )
862
+
863
+ result_payload = handler(
864
+ runtime=runtime,
865
+ payload=payload,
866
+ mode=mode,
867
+ streaming=streaming,
868
+ job_id=job_id,
869
+ scene_id=scene_id,
870
+ frame_records=frame_records,
871
+ inference=inference,
872
+ session_settings=session_settings,
873
+ temp_dir=temp_path,
874
+ )
875
+
876
+ except Exception as exc:
877
+ error_text = traceback.format_exc()
878
+ runtime.db.upsert_job(
879
+ job_id=job_id,
880
+ job_type=job_type,
881
+ scene_id=scene_id,
882
+ status="failed",
883
+ error=error_text,
884
+ )
885
+ runtime_emit(
886
+ runtime,
887
+ {
888
+ **job_meta,
889
+ "status": "failed",
890
+ "ts": datetime.now(timezone.utc).timestamp(),
891
+ "error": str(exc),
892
+ },
893
+ )
894
+ logger.exception("Job %s failed", job_id)
895
+ raise
896
+
897
+ runtime.db.upsert_job(
898
+ job_id=job_id,
899
+ job_type=job_type,
900
+ scene_id=scene_id,
901
+ status="finished",
902
+ result=result_payload,
903
+ )
904
+
905
+ runtime_emit(
906
+ runtime,
907
+ {
908
+ **job_meta,
909
+ "status": "finished",
910
+ "progress": 100,
911
+ "result_url": result_payload.get("result_url"),
912
+ "model_dir": result_payload.get("model_dir"),
913
+ "ts": datetime.now(timezone.utc).timestamp(),
914
+ },
915
+ )
916
+
917
+ return result_payload
918
+
919
+
920
+ def pose_pointmap_job(payload: Mapping[str, Any]) -> dict[str, Any]:
921
+ """Process a pose + pointmap job."""
922
+
923
+ return _execute_job("pose_pointmap", payload, _handle_pose_pointmap)
924
+
925
+
926
+ def model_build_job(payload: Mapping[str, Any]) -> dict[str, Any]:
927
+ """Process a full model build job."""
928
+
929
+ return _execute_job("model_build", payload, _handle_model_build)
930
+
931
+
932
+ def _handle_model_build(
933
+ *,
934
+ runtime: WorkerRuntime,
935
+ payload: Mapping[str, Any],
936
+ mode: str,
937
+ streaming: bool,
938
+ job_id: str,
939
+ scene_id: str,
940
+ frame_records: list[FrameRecord],
941
+ inference: InferenceResult,
942
+ session_settings: Mapping[str, Any],
943
+ temp_dir: Path,
944
+ ) -> dict[str, Any]:
945
+ predictions = inference.predictions
946
+
947
+ core = _generate_core_outputs(
948
+ runtime=runtime,
949
+ scene_id=scene_id,
950
+ job_id=job_id,
951
+ predictions=predictions,
952
+ frame_records=frame_records,
953
+ inference=inference,
954
+ session_settings=session_settings,
955
+ temp_dir=temp_dir,
956
+ )
957
+
958
+ artifacts = dict(core["artifacts"])
959
+
960
+ top_k = _as_int(payload.get("top_k_frames") or payload.get("top_k"), 0)
961
+ selected_frames = _compute_selected_frames(predictions, frame_records, top_k)
962
+ selected_frames_url = _write_selected_frames(
963
+ runtime=runtime,
964
+ scene_id=scene_id,
965
+ selected_frames=selected_frames,
966
+ top_k=top_k,
967
+ temp_dir=temp_dir,
968
+ )
969
+ if selected_frames_url:
970
+ artifacts["selected_frames_url"] = selected_frames_url
971
+
972
+ scene_glb_url = _save_scene_glb(
973
+ runtime=runtime,
974
+ scene_id=scene_id,
975
+ predictions=predictions,
976
+ temp_dir=temp_dir,
977
+ payload=payload,
978
+ )
979
+ artifacts["scene_glb_url"] = scene_glb_url
980
+
981
+ summary_payload = {
982
+ "job_id": job_id,
983
+ "job_type": "model_build",
984
+ "scene_id": scene_id,
985
+ "frame_count": inference.total_frames,
986
+ "created_at": datetime.now(timezone.utc).isoformat(),
987
+ "artifacts": artifacts,
988
+ "selected_frames": selected_frames,
989
+ "parameters": {
990
+ "mode": mode,
991
+ "streaming": streaming,
992
+ "conf_thres": float(payload.get("conf_thres", 3.0)),
993
+ "frame_filter": payload.get("frame_filter", "All"),
994
+ "mask_black_bg": _as_bool(payload.get("mask_black_bg"), False),
995
+ "mask_white_bg": _as_bool(payload.get("mask_white_bg"), False),
996
+ "show_cam": _as_bool(payload.get("show_cam"), True),
997
+ "mask_sky": _as_bool(payload.get("mask_sky"), False),
998
+ "prediction_mode": payload.get("prediction_mode", "Predicted Pointmap"),
999
+ },
1000
+ }
1001
+
1002
+ summary_url = _write_summary_json(
1003
+ runtime=runtime,
1004
+ scene_id=scene_id,
1005
+ summary=summary_payload,
1006
+ temp_dir=temp_dir,
1007
+ )
1008
+ artifacts["summary_url"] = summary_url
1009
+
1010
+ result_record = dict(summary_payload)
1011
+ result_record["result_url"] = summary_url
1012
+ result_record_url = _upload_result_record(
1013
+ runtime=runtime,
1014
+ scene_id=scene_id,
1015
+ job_id=job_id,
1016
+ payload=result_record,
1017
+ )
1018
+
1019
+ result_payload = {
1020
+ "job_id": job_id,
1021
+ "job_type": "model_build",
1022
+ "scene_id": scene_id,
1023
+ "mode": mode,
1024
+ "streaming": streaming,
1025
+ "frame_count": inference.total_frames,
1026
+ "created_at": summary_payload["created_at"],
1027
+ "artifacts": artifacts,
1028
+ "frames": core["frames"],
1029
+ "selected_frames": selected_frames,
1030
+ "summary_url": summary_url,
1031
+ "result_url": summary_url,
1032
+ "result_record_url": result_record_url,
1033
+ "model_dir": _model_dir_uri(runtime, scene_id),
1034
+ }
1035
+
1036
+ return result_payload
worker/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Compatibility shim package for legacy job import paths."""
worker/stream3r/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Compatibility namespace for legacy worker module paths."""
worker/stream3r/jobs.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Legacy job dispatch entrypoints for compatibility with existing queues."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Any, Callable, Mapping
6
+
7
+ from stream3r.worker.tasks import model_build_job, pose_pointmap_job
8
+
9
+
10
+ _HANDLERS: dict[str, Callable[[Mapping[str, Any]], Any]] = {
11
+ "pose_pointmap": pose_pointmap_job,
12
+ "model_build": model_build_job,
13
+ }
14
+
15
+
16
+ def handle_job(*args: Any, **kwargs: Any) -> Any:
17
+ """Dispatch jobs enqueued with the legacy `worker.stream3r.jobs.handle_job` path.
18
+
19
+ Supports the following invocation patterns:
20
+
21
+ - ``handle_job(payload)`` where ``payload`` is a mapping containing ``job_type``.
22
+ - ``handle_job(job_type, payload)`` matching older enqueue signatures.
23
+ - ``handle_job(job_type=..., payload=...)`` keyword usage.
24
+ """
25
+
26
+ job_type: str | None = None
27
+ payload: Mapping[str, Any] | None = None
28
+
29
+ if args:
30
+ if isinstance(args[0], Mapping) and "job_type" in args[0]:
31
+ payload = args[0]
32
+ job_type = str(payload.get("job_type"))
33
+ else:
34
+ job_type = str(args[0]) if args else None
35
+ if len(args) > 1 and isinstance(args[1], Mapping):
36
+ payload = args[1]
37
+
38
+ if "job_type" in kwargs and not job_type:
39
+ job_type = str(kwargs["job_type"])
40
+ if "payload" in kwargs and payload is None:
41
+ candidate = kwargs["payload"]
42
+ if isinstance(candidate, Mapping):
43
+ payload = candidate
44
+
45
+ if payload is None and isinstance(args[0], Mapping):
46
+ payload = args[0]
47
+ job_type = str(payload.get("job_type")) if payload else job_type
48
+
49
+ if payload is None:
50
+ raise ValueError("handle_job requires a payload mapping")
51
+
52
+ if not job_type:
53
+ job_type = str(payload.get("job_type", "")).strip()
54
+
55
+ if not job_type:
56
+ raise ValueError("handle_job payload is missing 'job_type'")
57
+
58
+ handler = _HANDLERS.get(job_type)
59
+ if handler is None:
60
+ raise ValueError(f"Unsupported job_type '{job_type}'")
61
+
62
+ return handler(payload)
63
+