Spaces:
Configuration error
Configuration error
Commit
·
1c5aca1
1
Parent(s):
08c2845
initi worker working
Browse files- .vscode/launch.json +12 -0
- design_docs/stream3r_api.md +269 -0
- design_docs/worker.md +300 -0
- requirements.txt +8 -1
- stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc +0 -0
- stream3r/models/components/utils/geometry.py +7 -1
- stream3r/worker/__init__.py +8 -0
- stream3r/worker/config.py +213 -0
- stream3r/worker/db.py +170 -0
- stream3r/worker/main.py +59 -0
- stream3r/worker/pipeline.py +144 -0
- stream3r/worker/runtime.py +159 -0
- stream3r/worker/storage.py +180 -0
- stream3r/worker/tasks.py +1036 -0
- worker/__init__.py +1 -0
- worker/stream3r/__init__.py +1 -0
- worker/stream3r/jobs.py +63 -0
.vscode/launch.json
CHANGED
|
@@ -13,5 +13,17 @@
|
|
| 13 |
"python": "/home/robot_op/miniconda3/envs/stream3r/bin/python",
|
| 14 |
"justMyCode": true
|
| 15 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
]
|
| 17 |
}
|
|
|
|
| 13 |
"python": "/home/robot_op/miniconda3/envs/stream3r/bin/python",
|
| 14 |
"justMyCode": true
|
| 15 |
},
|
| 16 |
+
{
|
| 17 |
+
"name": "Python: STream3R Worker",
|
| 18 |
+
"type": "debugpy",
|
| 19 |
+
"request": "launch",
|
| 20 |
+
"module": "stream3r.worker.main",
|
| 21 |
+
"args": ["--log-level", "INFO"],
|
| 22 |
+
"console": "integratedTerminal",
|
| 23 |
+
"cwd": "${workspaceFolder}",
|
| 24 |
+
"envFile": "${workspaceFolder}/.env",
|
| 25 |
+
"python": "/home/robot_op/miniconda3/envs/stream3r/bin/python",
|
| 26 |
+
"justMyCode": true
|
| 27 |
+
}
|
| 28 |
]
|
| 29 |
}
|
design_docs/stream3r_api.md
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# STream3R API — Job Orchestration and Integration Plan
|
| 2 |
+
|
| 3 |
+
## Executive Summary
|
| 4 |
+
|
| 5 |
+
This document proposes a lightweight STream3R API service that wraps the async job system implemented by the RQ worker. The API exposes RESTful endpoints and Redis Stream subscriptions that allow upstream applications to submit reconstruction jobs, track progress, and retrieve completed artifacts. Responsibilities are split cleanly: the API handles request validation, persistence, and orchestration, while the GPU worker (documented in `design_docs/worker.md`) performs heavy inference and storage of artifacts.
|
| 6 |
+
|
| 7 |
+
**Key outcomes**
|
| 8 |
+
- Unified interface for both short `pose_pointmap` jobs and long-running `model_build` jobs.
|
| 9 |
+
- Consistent job lifecycle backed by Postgres and Redis, mirroring the worker contract.
|
| 10 |
+
- Environment-agnostic integration so other services can enqueue jobs or consume progress events without GPU access.
|
| 11 |
+
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## 1. Scope & Goals
|
| 15 |
+
|
| 16 |
+
### Goals
|
| 17 |
+
- Provide a REST API for submitting jobs and querying job status or results.
|
| 18 |
+
- Offer optional server-sent events (SSE) or WebSocket feeds for near-real-time updates on `pose_pointmap` jobs.
|
| 19 |
+
- Enforce validation and idempotency for job submissions.
|
| 20 |
+
- Expose artifacts written by the worker (S3 URLs, local paths) without re-hosting the files.
|
| 21 |
+
|
| 22 |
+
### Non-Goals
|
| 23 |
+
- Implement GPU inference or artifact generation (handled by RQ worker).
|
| 24 |
+
- Manage long-term artifact retention or CDN delivery.
|
| 25 |
+
- Provide fine-grained authorization beyond bearer token / API key patterns (left to integration).
|
| 26 |
+
|
| 27 |
+
---
|
| 28 |
+
|
| 29 |
+
## 2. Architecture Overview
|
| 30 |
+
|
| 31 |
+
```
|
| 32 |
+
Client ──HTTP──► STream3R API ──RQ Enqueue──► Redis Queue ─► Worker
|
| 33 |
+
│ │ │
|
| 34 |
+
│ └──Postgres──────────────►│
|
| 35 |
+
│ └──Redis Stream (events)◄─┘
|
| 36 |
+
│ └──S3 (artifact URLs)◄───── Worker writes
|
| 37 |
+
│
|
| 38 |
+
└── Poll `/jobs/{id}` or Subscribe SSE/WebSocket for progress updates
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
Components:
|
| 42 |
+
- **FastAPI** service (recommended) running behind an ASGI server.
|
| 43 |
+
- **Redis**: shared with worker for queues and events.
|
| 44 |
+
- **Postgres**: `stream3r_jobs` table is the canonical job record.
|
| 45 |
+
- **S3/Backblaze** (or local storage): artifact URLs returned by worker.
|
| 46 |
+
- **RQ worker** (implemented separately) executing jobs and updating state.
|
| 47 |
+
|
| 48 |
+
---
|
| 49 |
+
|
| 50 |
+
## 3. Endpoints
|
| 51 |
+
|
| 52 |
+
### `POST /jobs`
|
| 53 |
+
Submit a job for either `pose_pointmap` or `model_build`.
|
| 54 |
+
|
| 55 |
+
**Request body**
|
| 56 |
+
```json
|
| 57 |
+
{
|
| 58 |
+
"job_type": "pose_pointmap",
|
| 59 |
+
"scene_id": "SCENE123",
|
| 60 |
+
"mode": "causal",
|
| 61 |
+
"streaming": true,
|
| 62 |
+
"frames": [
|
| 63 |
+
{"url": "https://.../frame_0000.jpg"},
|
| 64 |
+
{"path": "/data/captures/frame_0001.png"}
|
| 65 |
+
],
|
| 66 |
+
"session_settings": {"prediction_mode": "Predicted Pointmap"},
|
| 67 |
+
"client_request_id": "optional-idempotency-key"
|
| 68 |
+
}
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
**Behavior**
|
| 72 |
+
- Validate payload (non-empty frames, supported job type, etc.).
|
| 73 |
+
- If `client_request_id` is provided, search Postgres for an existing job with the same key to ensure idempotency.
|
| 74 |
+
- Assign `job_id` (UUID) and enqueue the payload into the appropriate RQ queue (`pose_pointmap` or `model_build`).
|
| 75 |
+
- Insert `stream3r_jobs` row with `status=queued`.
|
| 76 |
+
- Return `202 Accepted` with job metadata:
|
| 77 |
+
|
| 78 |
+
```json
|
| 79 |
+
{
|
| 80 |
+
"job_id": "uuid",
|
| 81 |
+
"status": "queued",
|
| 82 |
+
"job_type": "pose_pointmap",
|
| 83 |
+
"scene_id": "SCENE123"
|
| 84 |
+
}
|
| 85 |
+
```
|
| 86 |
+
|
| 87 |
+
### `GET /jobs/{job_id}`
|
| 88 |
+
Fetch job state and artifact references from Postgres.
|
| 89 |
+
|
| 90 |
+
**Response example** (`status=finished`)
|
| 91 |
+
```json
|
| 92 |
+
{
|
| 93 |
+
"job_id": "uuid",
|
| 94 |
+
"job_type": "model_build",
|
| 95 |
+
"scene_id": "SCENE123",
|
| 96 |
+
"status": "finished",
|
| 97 |
+
"created_at": "...",
|
| 98 |
+
"started_at": "...",
|
| 99 |
+
"completed_at": "...",
|
| 100 |
+
"result": {
|
| 101 |
+
"result_url": "s3://bucket/scene/SCENE123/stream3r/models/summary.json",
|
| 102 |
+
"model_dir": "s3://bucket/scene/SCENE123/stream3r/models/",
|
| 103 |
+
"artifacts": {
|
| 104 |
+
"scene_glb_url": "...",
|
| 105 |
+
"poses_url": "...",
|
| 106 |
+
"pointmaps": [ {"frame_id": "frame_0000", "url": "..."} ]
|
| 107 |
+
}
|
| 108 |
+
},
|
| 109 |
+
"error": null
|
| 110 |
+
}
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
### `GET /jobs/{job_id}/events`
|
| 114 |
+
Server-Sent Events endpoint bridging Redis Streams.
|
| 115 |
+
|
| 116 |
+
- Uses `XREAD` on `stream3r:events` with `job_id` filter.
|
| 117 |
+
- Suitable for browser or gateway consumers needing near-real-time progress.
|
| 118 |
+
- Emits lines like:
|
| 119 |
+
|
| 120 |
+
```
|
| 121 |
+
event: progress
|
| 122 |
+
data: {"progress": 60, "status": "progress"}
|
| 123 |
+
|
| 124 |
+
event: finished
|
| 125 |
+
data: {"result_url": "s3://..."}
|
| 126 |
+
```
|
| 127 |
+
|
| 128 |
+
Optionally provide a WebSocket variant if SSE is insufficient.
|
| 129 |
+
|
| 130 |
+
### `GET /jobs`
|
| 131 |
+
Paged listing/filtering (optional but useful for dashboards).
|
| 132 |
+
|
| 133 |
+
Parameters: `scene_id`, `job_type`, `status`, pagination cursors.
|
| 134 |
+
|
| 135 |
+
---
|
| 136 |
+
|
| 137 |
+
## 4. Data Model & Persistence
|
| 138 |
+
|
| 139 |
+
### Postgres (`stream3r_jobs`)
|
| 140 |
+
|
| 141 |
+
The API is the authoritative owner of the job record. It should create and migrate the following schema during startup (extend with `client_request_id` if idempotency keys are required):
|
| 142 |
+
|
| 143 |
+
```sql
|
| 144 |
+
CREATE TABLE IF NOT EXISTS stream3r_jobs (
|
| 145 |
+
job_id UUID PRIMARY KEY,
|
| 146 |
+
job_type TEXT NOT NULL, -- 'pose_pointmap' | 'model_build'
|
| 147 |
+
scene_id TEXT NOT NULL,
|
| 148 |
+
status TEXT NOT NULL, -- 'queued' | 'started' | 'finished' | 'failed'
|
| 149 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
| 150 |
+
started_at TIMESTAMPTZ,
|
| 151 |
+
completed_at TIMESTAMPTZ,
|
| 152 |
+
payload JSONB, -- enqueue-time payload
|
| 153 |
+
result JSONB, -- worker-published result bundle
|
| 154 |
+
error TEXT,
|
| 155 |
+
client_request_id TEXT UNIQUE -- optional idempotency key
|
| 156 |
+
);
|
| 157 |
+
|
| 158 |
+
CREATE INDEX IF NOT EXISTS stream3r_jobs_scene_id_idx ON stream3r_jobs(scene_id);
|
| 159 |
+
CREATE INDEX IF NOT EXISTS stream3r_jobs_status_idx ON stream3r_jobs(status);
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
### Redis
|
| 163 |
+
|
| 164 |
+
- **Queues**: two RQ queues exist—`pose_pointmap` for latency-sensitive pose extraction and `model_build` for long reconstruction jobs. The API selects the queue based on `job_type`.
|
| 165 |
+
- **Event Stream**: the worker pushes lifecycle updates to the Redis Stream `stream3r:events`. Every entry is a flat map with the following fields:
|
| 166 |
+
|
| 167 |
+
```
|
| 168 |
+
job_id, job_type, scene_id,
|
| 169 |
+
status, progress, result_url, model_dir,
|
| 170 |
+
error, ts
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
`status` takes `started`, `progress`, `finished`, or `failed`. `progress` is an integer percentage (0–100). `result_url` and `model_dir` mirror the URLs stored in Postgres when a job completes.
|
| 174 |
+
|
| 175 |
+
### Artifact Storage Layout
|
| 176 |
+
|
| 177 |
+
The worker persists artifacts to S3/Backblaze (or local storage) under a deterministic folder hierarchy. The API does not move files but should surface these URLs verbatim so consumers know where to fetch results:
|
| 178 |
+
|
| 179 |
+
```
|
| 180 |
+
scene/{scene_id}/stream3r/
|
| 181 |
+
results/
|
| 182 |
+
{job_id}.json # per-job result JSON (pose_pointmap)
|
| 183 |
+
models/
|
| 184 |
+
kv_cache.pt # serialized KV cache
|
| 185 |
+
predictions.npz # packed model outputs
|
| 186 |
+
session_settings.json # runtime/config settings
|
| 187 |
+
selected_frames.json # frame subset indices (optional)
|
| 188 |
+
scene.glb # fused 3D scene
|
| 189 |
+
poses.jsonl # per-frame extrinsics
|
| 190 |
+
summary.json # canonical model_build result JSON
|
| 191 |
+
pointmaps/
|
| 192 |
+
{frame_token}.npz # per-frame world coords + confidence
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
`pose_pointmap` jobs typically populate `results/{job_id}.json` plus `pointmaps/`; `model_build` jobs populate the `models/` subtree. All URLs returned by the worker use this structure.
|
| 196 |
+
|
| 197 |
+
---
|
| 198 |
+
|
| 199 |
+
## 5. Job Lifecycle
|
| 200 |
+
|
| 201 |
+
1. **Submit** (`POST /jobs`):
|
| 202 |
+
- Validate input, persist `queued` row, enqueue payload.
|
| 203 |
+
- Return `job_id`.
|
| 204 |
+
2. **Worker processing**:
|
| 205 |
+
- Worker acquires GPU lock, runs inference, streams events, writes artifacts, and updates DB.
|
| 206 |
+
3. **Status checks**:
|
| 207 |
+
- Clients poll `GET /jobs/{id}` or subscribe to `/jobs/{id}/events`.
|
| 208 |
+
4. **Completion**:
|
| 209 |
+
- Job row contains `status=finished`, `result` JSON with URLs.
|
| 210 |
+
- API response is the source of truth for artifact discovery.
|
| 211 |
+
5. **Failure**:
|
| 212 |
+
- Worker updates DB with `status=failed`, `error` string.
|
| 213 |
+
- API surfaces the error in `GET /jobs/{id}` and via events.
|
| 214 |
+
|
| 215 |
+
---
|
| 216 |
+
|
| 217 |
+
## 6. Request Validation Contracts
|
| 218 |
+
|
| 219 |
+
`POST /jobs` validation rules:
|
| 220 |
+
- `job_type` ∈ {`pose_pointmap`, `model_build`}.
|
| 221 |
+
- `scene_id` non-empty string.
|
| 222 |
+
- `frames` list size ≥ 1 (unless `frames_dir` provided).
|
| 223 |
+
- Each frame entry must have exactly one of `url`, `path`, or `content` (base64 image string).
|
| 224 |
+
- `mode` default `causal`; forbid `full` for streaming jobs to match worker behavior.
|
| 225 |
+
- Optional numeric fields converted to `int/float` before enqueueing.
|
| 226 |
+
- Enforce max frames (configurable) to avoid resource exhaustion.
|
| 227 |
+
|
| 228 |
+
---
|
| 229 |
+
|
| 230 |
+
## 7. Security & Authentication
|
| 231 |
+
|
| 232 |
+
- Deploy behind an API gateway that injects `X-Client-Id` or similar metadata for auditing.
|
| 233 |
+
- Support bearer token / API key auth via middleware; store hashed keys in Postgres if needed.
|
| 234 |
+
- Restrict access to internal network when possible—as artifacts contain scene data.
|
| 235 |
+
- Sanitize inbound URLs to prevent SSRF; optionally proxy downloads through a whitelist.
|
| 236 |
+
|
| 237 |
+
---
|
| 238 |
+
|
| 239 |
+
## 8. Observability & Operations
|
| 240 |
+
|
| 241 |
+
- **Logging**: Structured logs capturing `job_id`, `scene_id`, `client_request_id`, and remote IP.
|
| 242 |
+
- **Metrics**: Track enqueue latency, job duration (from DB timestamps), queue depth, event lag.
|
| 243 |
+
- **Health checks**: `GET /healthz` verifying Redis and Postgres connectivity.
|
| 244 |
+
- **Backpressure**: Before accepting a new job, check queue length; if above threshold, return `429 Too Many Requests`.
|
| 245 |
+
- **Timeouts**: Configure HTTP request timeouts to avoid hanging on large payloads.
|
| 246 |
+
|
| 247 |
+
---
|
| 248 |
+
|
| 249 |
+
## 9. Deployment Considerations
|
| 250 |
+
|
| 251 |
+
- Package as a standalone FastAPI app (e.g., `stream3r_api.main:app`).
|
| 252 |
+
- Run under Uvicorn/Gunicorn with workers sized for I/O-bound traffic.
|
| 253 |
+
- Configure service with the same environment variables as worker (`STREAM3R_REDIS_URL`, `STREAM3R_DB_DSN`, etc.).
|
| 254 |
+
- Use infrastructure-as-code to provision Redis, Postgres, and S3 credentials shared with worker.
|
| 255 |
+
|
| 256 |
+
---
|
| 257 |
+
|
| 258 |
+
## 10. Future Enhancements
|
| 259 |
+
|
| 260 |
+
- **Job cancellation**: Add `DELETE /jobs/{id}` to flag jobs for cancellation (requires worker support).
|
| 261 |
+
- **Scene-level dashboards**: Aggregate artifacts from multiple jobs for a scene.
|
| 262 |
+
- **Signed download URLs**: API could issue pre-signed URLs for public sharing, decoupled from worker credentials.
|
| 263 |
+
- **Batch submissions**: Support uploading a tar/zip and asynchronously unpacking/validating frames.
|
| 264 |
+
|
| 265 |
+
---
|
| 266 |
+
|
| 267 |
+
**References**
|
| 268 |
+
- `design_docs/worker.md` — Worker design and artifact contracts
|
| 269 |
+
- `stream3r/worker/tasks.py` — Concrete payload fields exchanged with the API
|
design_docs/worker.md
ADDED
|
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# STream3R — Jobs, Events, and Storage Design
|
| 2 |
+
|
| 3 |
+
## **Executive Summary**
|
| 4 |
+
|
| 5 |
+
The **STream3R Job System** provides an asynchronous GPU job orchestration layer for 3D scene reconstruction and perception tasks.
|
| 6 |
+
It standardizes how **pose and world-coordinate extraction** and **scene model building** are executed, stored, and tracked across services.
|
| 7 |
+
|
| 8 |
+
### **Primary Goals**
|
| 9 |
+
|
| 10 |
+
- **Asynchronous GPU processing:**
|
| 11 |
+
All heavy inference runs on background RQ workers; FastAPI services only enqueue jobs and monitor progress.
|
| 12 |
+
- **Unified observability:**
|
| 13 |
+
- **Redis Streams** for job lifecycle events and progress (`stream3r:events`)
|
| 14 |
+
- **Postgres (`stream3r_jobs`)** as the canonical job record
|
| 15 |
+
- **S3/Backblaze** for durable artifacts and results
|
| 16 |
+
- **Two calling modes:**
|
| 17 |
+
- `get_pose_and_world_coords` → **Streams-based (Option A)** for near-real-time updates
|
| 18 |
+
- `create_model` → **Polling (Option B)** for long-running model generation
|
| 19 |
+
- **Consistent storage under** `/scene/{scene_id}/stream3r/`, containing:
|
| 20 |
+
- `kv_cache.pt` — serialized key/value cache state
|
| 21 |
+
- `predictions.npz` — packed outputs from the model build
|
| 22 |
+
- `session_settings.json` — runtime/config parameters
|
| 23 |
+
- `selected_frames.json` — frame subset selection
|
| 24 |
+
- `scene.glb` — final assembled scene model
|
| 25 |
+
- `poses.jsonl` — per-frame extrinsics (camera poses)
|
| 26 |
+
- `pointmaps/*.npz` — per-frame world coordinates + confidence maps
|
| 27 |
+
|
| 28 |
+
### **Key Outcomes**
|
| 29 |
+
- Clean separation of API ↔ GPU worker responsibilities
|
| 30 |
+
- Event-driven feedback for quick jobs; reliable polling for long ones
|
| 31 |
+
- Durable, versioned scene data under a unified layout
|
| 32 |
+
- End-to-end traceability of all STream3R jobs via Redis + Postgres + S3
|
| 33 |
+
|
| 34 |
+
---
|
| 35 |
+
|
| 36 |
+
## 1. Queues, Streams, and Locks
|
| 37 |
+
|
| 38 |
+
| Component | Purpose | Notes |
|
| 39 |
+
|------------|----------|-------|
|
| 40 |
+
| `pose_pointmap` | RQ queue for latency-sensitive `pose_pointmap` jobs | |
|
| 41 |
+
| `model_build` | RQ queue for long `model_build` jobs | |
|
| 42 |
+
| `stream3r:events` | Redis Stream for all job events (`started`, `progress`, `finished`, `failed`) | trimmed periodically |
|
| 43 |
+
| `gpu:lock` | Redis lock ensuring single GPU job at a time per machine | |
|
| 44 |
+
|
| 45 |
+
Each Stream event is a flat map of strings:
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
job_id, job_type, scene_id,
|
| 49 |
+
status, progress, result_url, model_dir,
|
| 50 |
+
error, ts
|
| 51 |
+
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
---
|
| 55 |
+
|
| 56 |
+
## 2. S3 / Backblaze Storage Layout
|
| 57 |
+
|
| 58 |
+
All STream3R artifacts live under a **scene folder**:
|
| 59 |
+
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
s3://<bucket>/scene/{scene_id}/stream3r/
|
| 63 |
+
results/
|
| 64 |
+
{job_id}.json # per-job result JSON (pose_pointmap)
|
| 65 |
+
models/
|
| 66 |
+
kv_cache.pt # serialized KV cache
|
| 67 |
+
predictions.npz # packed model outputs
|
| 68 |
+
session_settings.json # runtime/config settings
|
| 69 |
+
selected_frames.json # frame subset indices
|
| 70 |
+
scene.glb # fused 3D scene
|
| 71 |
+
poses.jsonl # per-frame extrinsics
|
| 72 |
+
summary.json # canonical model_build result JSON
|
| 73 |
+
pointmaps/
|
| 74 |
+
{frame_token}.npz # per-frame world_coords + confidence
|
| 75 |
+
|
| 76 |
+
````
|
| 77 |
+
|
| 78 |
+
**Key Result URLs**
|
| 79 |
+
- Pose/pointmap job → `s3://.../scene/{scene_id}/stream3r/results/{job_id}.json`
|
| 80 |
+
- Model build job → `s3://.../scene/{scene_id}/stream3r/models/summary.json`
|
| 81 |
+
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
## 3. Database: `stream3r_jobs`
|
| 85 |
+
|
| 86 |
+
Canonical job table in Postgres.
|
| 87 |
+
|
| 88 |
+
```sql
|
| 89 |
+
CREATE TABLE IF NOT EXISTS stream3r_jobs (
|
| 90 |
+
job_id UUID PRIMARY KEY,
|
| 91 |
+
job_type TEXT NOT NULL, -- 'pose_pointmap' | 'model_build'
|
| 92 |
+
scene_id TEXT NOT NULL,
|
| 93 |
+
status TEXT NOT NULL, -- 'queued' | 'started' | 'finished' | 'failed'
|
| 94 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
| 95 |
+
started_at TIMESTAMPTZ,
|
| 96 |
+
completed_at TIMESTAMPTZ,
|
| 97 |
+
payload JSONB, -- enqueue-time payload
|
| 98 |
+
result JSONB, -- URLs / metrics
|
| 99 |
+
error TEXT
|
| 100 |
+
);
|
| 101 |
+
|
| 102 |
+
CREATE INDEX IF NOT EXISTS stream3r_jobs_scene_id_idx ON stream3r_jobs(scene_id);
|
| 103 |
+
CREATE INDEX IF NOT EXISTS stream3r_jobs_status_idx ON stream3r_jobs(status);
|
| 104 |
+
````
|
| 105 |
+
|
| 106 |
+
**Upsert pattern:**
|
| 107 |
+
|
| 108 |
+
* Insert on enqueue (`queued`)
|
| 109 |
+
* Update on start → `started`
|
| 110 |
+
* Update on finish → `finished`, add `result`
|
| 111 |
+
* Update on failure → `failed`, add `error`
|
| 112 |
+
|
| 113 |
+
---
|
| 114 |
+
|
| 115 |
+
## 4. Result JSON Schemas
|
| 116 |
+
|
| 117 |
+
### a. Pose + World Coords (per-frame)
|
| 118 |
+
|
| 119 |
+
`s3://…/scene/{scene_id}/stream3r/results/{job_id}.json`
|
| 120 |
+
|
| 121 |
+
```json
|
| 122 |
+
{
|
| 123 |
+
"job_id": "uuid",
|
| 124 |
+
"job_type": "pose_pointmap",
|
| 125 |
+
"scene_id": "SCENE123",
|
| 126 |
+
"artifacts": {
|
| 127 |
+
"pointmap_url": "s3://.../scene/SCENE123/stream3r/pointmaps/frame_000010.npz"
|
| 128 |
+
},
|
| 129 |
+
"pose": { "R": [[...]], "t": [x, y, z] },
|
| 130 |
+
"intrinsics": { "fx":..., "fy":..., "cx":..., "cy":... },
|
| 131 |
+
"metrics": { "runtime_s": 1.23 },
|
| 132 |
+
"stream3r": { "cfg": "configs/stream3r_base.yaml", "commit": "<git_sha>" }
|
| 133 |
+
}
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### b. Model Build (scene-level)
|
| 137 |
+
|
| 138 |
+
`s3://…/scene/{scene_id}/stream3r/models/summary.json`
|
| 139 |
+
|
| 140 |
+
```json
|
| 141 |
+
{
|
| 142 |
+
"job_id": "uuid",
|
| 143 |
+
"job_type": "model_build",
|
| 144 |
+
"scene_id": "SCENE123",
|
| 145 |
+
"artifacts": {
|
| 146 |
+
"model_dir": "s3://.../scene/SCENE123/stream3r/models/",
|
| 147 |
+
"kv_cache": "s3://.../scene/SCENE123/stream3r/models/kv_cache.pt",
|
| 148 |
+
"predictions": "s3://.../scene/SCENE123/stream3r/models/predictions.npz",
|
| 149 |
+
"session_settings": "s3://.../scene/SCENE123/stream3r/models/session_settings.json",
|
| 150 |
+
"selected_frames": "s3://.../scene/SCENE123/stream3r/models/selected_frames.json",
|
| 151 |
+
"scene_glb": "s3://.../scene/SCENE123/stream3r/models/scene.glb",
|
| 152 |
+
"poses_jsonl": "s3://.../scene/SCENE123/stream3r/models/poses.jsonl"
|
| 153 |
+
},
|
| 154 |
+
"metrics": { "frames": 128, "runtime_s": 42.3 },
|
| 155 |
+
"stream3r": { "cfg": "configs/stream3r_base.yaml", "commit": "<git_sha>" }
|
| 156 |
+
}
|
| 157 |
+
```
|
| 158 |
+
|
| 159 |
+
---
|
| 160 |
+
|
| 161 |
+
## 5. Caller API Responsibilities
|
| 162 |
+
|
| 163 |
+
### `get_pose_and_world_coords` → **Option A (Streams)**
|
| 164 |
+
|
| 165 |
+
1. Enqueue job → get `job_id`
|
| 166 |
+
2. `XREAD BLOCK` on `stream3r:events` until `status=finished`
|
| 167 |
+
3. On finish:
|
| 168 |
+
|
| 169 |
+
* Fetch `result_url`
|
| 170 |
+
* Load JSON → retrieve `pose`, `intrinsics`, and `pointmap_url`
|
| 171 |
+
* Download `.npz` to get `world_coords` + `confidence`
|
| 172 |
+
|
| 173 |
+
### `create_model` → **Option B (Polling)**
|
| 174 |
+
|
| 175 |
+
1. Enqueue job → return `job_id` immediately
|
| 176 |
+
2. Periodically poll `GET /jobs/{job_id}`
|
| 177 |
+
3. On `finished`:
|
| 178 |
+
|
| 179 |
+
* Read `result` with `result_url` + `model_dir`
|
| 180 |
+
* Download `summary.json` and listed model files
|
| 181 |
+
|
| 182 |
+
---
|
| 183 |
+
|
| 184 |
+
## 6. Worker Event & Persistence Flow
|
| 185 |
+
|
| 186 |
+
1. **Acquire GPU lock**
|
| 187 |
+
2. **Emit** `started`
|
| 188 |
+
3. **Upsert** DB row (`stream3r_jobs`)
|
| 189 |
+
4. **Run inference**, emitting `progress` events (every N frames)
|
| 190 |
+
5. **Save** artifacts to S3:
|
| 191 |
+
|
| 192 |
+
* `pointmaps/*.npz` with `{xyz, conf}`
|
| 193 |
+
* `poses.jsonl`
|
| 194 |
+
* Model outputs listed above
|
| 195 |
+
6. **Write** result JSON → emit `finished`
|
| 196 |
+
7. **Update** DB row → `status=finished, result=…`
|
| 197 |
+
8. On error → emit `failed`, update DB
|
| 198 |
+
|
| 199 |
+
---
|
| 200 |
+
|
| 201 |
+
## 7. Example Event Payloads (Redis Stream)
|
| 202 |
+
|
| 203 |
+
**Started**
|
| 204 |
+
|
| 205 |
+
```
|
| 206 |
+
job_id=uuid
|
| 207 |
+
job_type=pose_pointmap
|
| 208 |
+
scene_id=SCENE123
|
| 209 |
+
status=started
|
| 210 |
+
progress=1
|
| 211 |
+
ts=1730312345.12
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
**Progress**
|
| 215 |
+
|
| 216 |
+
```
|
| 217 |
+
job_id=uuid
|
| 218 |
+
job_type=model_build
|
| 219 |
+
scene_id=SCENE123
|
| 220 |
+
status=progress
|
| 221 |
+
progress=40
|
| 222 |
+
ts=1730312456.22
|
| 223 |
+
```
|
| 224 |
+
|
| 225 |
+
**Finished**
|
| 226 |
+
|
| 227 |
+
```
|
| 228 |
+
job_id=uuid
|
| 229 |
+
job_type=model_build
|
| 230 |
+
scene_id=SCENE123
|
| 231 |
+
status=finished
|
| 232 |
+
progress=100
|
| 233 |
+
result_url=s3://bucket/scene/SCENE123/stream3r/models/summary.json
|
| 234 |
+
model_dir=s3://bucket/scene/SCENE123/stream3r/models/
|
| 235 |
+
ts=1730312567.33
|
| 236 |
+
```
|
| 237 |
+
|
| 238 |
+
**Failed**
|
| 239 |
+
|
| 240 |
+
```
|
| 241 |
+
job_id=uuid
|
| 242 |
+
job_type=pose_pointmap
|
| 243 |
+
scene_id=SCENE123
|
| 244 |
+
status=failed
|
| 245 |
+
error=RuntimeError: CUDA OOM
|
| 246 |
+
ts=1730312570.00
|
| 247 |
+
```
|
| 248 |
+
|
| 249 |
+
---
|
| 250 |
+
|
| 251 |
+
## 8. Operational Guidelines
|
| 252 |
+
|
| 253 |
+
| Concern | Best Practice |
|
| 254 |
+
| -------------------------- | --------------------------------------------------------------- |
|
| 255 |
+
| **GPU Safety** | Use `gpu:lock` to serialize jobs per GPU |
|
| 256 |
+
| **Redis Stream retention** | `XTRIM stream3r:events MAXLEN ~50000` |
|
| 257 |
+
| **Durability** | All artifacts and summaries must persist to S3/Backblaze |
|
| 258 |
+
| **DB Reliability** | Upsert on each transition; retry writes if DB unavailable |
|
| 259 |
+
| **Idempotency** | Support caller-supplied `job_id` or `request_id` |
|
| 260 |
+
| **Security** | Keep Redis internal; use signed or private S3 URLs |
|
| 261 |
+
| **Backpressure** | Enqueueing API should reject (`429`) when queue depth too large |
|
| 262 |
+
|
| 263 |
+
---
|
| 264 |
+
|
| 265 |
+
## 9. End-to-End Flows
|
| 266 |
+
|
| 267 |
+
### 🔹 Pose + World Coords (short job)
|
| 268 |
+
|
| 269 |
+
1. API enqueues job → returns `job_id`
|
| 270 |
+
2. Client subscribes via Redis Stream (blocking XREAD)
|
| 271 |
+
3. Worker runs inference → writes `pointmap.npz` + `result.json`
|
| 272 |
+
4. Worker emits `finished` → client downloads results
|
| 273 |
+
|
| 274 |
+
### 🔹 Model Build (long job)
|
| 275 |
+
|
| 276 |
+
1. API enqueues → returns `job_id`
|
| 277 |
+
2. Client polls `GET /jobs/{id}` or DB row
|
| 278 |
+
3. Worker fuses frames → writes full scene model files
|
| 279 |
+
4. Worker updates DB + emits `finished`
|
| 280 |
+
5. Client retrieves `summary.json` + artifacts under `/scene/{scene_id}/stream3r/models/`
|
| 281 |
+
|
| 282 |
+
---
|
| 283 |
+
|
| 284 |
+
## 10. Summary
|
| 285 |
+
|
| 286 |
+
| Component | Responsibility | Persistence |
|
| 287 |
+
| ------------------------------ | --------------------------------------- | ------------------------------- |
|
| 288 |
+
| **FastAPI API** | Enqueue jobs, expose `/jobs/{id}` | DB (via worker), Redis (events) |
|
| 289 |
+
| **GPU Worker** | Execute STream3R inference, emit events | S3/Backblaze, DB |
|
| 290 |
+
| **Redis Streams** | Event bus for progress + completion | ephemeral |
|
| 291 |
+
| **Postgres (`stream3r_jobs`)** | Canonical job record | durable |
|
| 292 |
+
| **S3/Backblaze /scene/** | Scene artifacts, model data | durable |
|
| 293 |
+
|
| 294 |
+
---
|
| 295 |
+
|
| 296 |
+
**Outcome:**
|
| 297 |
+
This design provides an **asynchronous, event-driven, and durable** framework for managing STream3R GPU jobs, with standardized scene storage, traceable job metadata, and clear integration points for both real-time and long-running workflows.
|
| 298 |
+
|
| 299 |
+
```
|
| 300 |
+
```
|
requirements.txt
CHANGED
|
@@ -43,10 +43,17 @@ pyglet<2
|
|
| 43 |
huggingface-hub[torch]>=0.22
|
| 44 |
spaces
|
| 45 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
# --------- eval --------- #
|
| 47 |
accelerate
|
| 48 |
evo
|
| 49 |
|
| 50 |
# --------- demo --------- #
|
| 51 |
gradio==5.17.1
|
| 52 |
-
onnxruntime
|
|
|
|
| 43 |
huggingface-hub[torch]>=0.22
|
| 44 |
spaces
|
| 45 |
|
| 46 |
+
# --------- worker --------- #
|
| 47 |
+
redis
|
| 48 |
+
rq
|
| 49 |
+
boto3
|
| 50 |
+
psycopg2-binary
|
| 51 |
+
requests
|
| 52 |
+
|
| 53 |
# --------- eval --------- #
|
| 54 |
accelerate
|
| 55 |
evo
|
| 56 |
|
| 57 |
# --------- demo --------- #
|
| 58 |
gradio==5.17.1
|
| 59 |
+
onnxruntime
|
stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc
CHANGED
|
Binary files a/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc and b/stream3r/models/components/utils/__pycache__/geometry.cpython-311.pyc differ
|
|
|
stream3r/models/components/utils/geometry.py
CHANGED
|
@@ -32,8 +32,14 @@ def unproject_depth_map_to_point_map(
|
|
| 32 |
|
| 33 |
world_points_list = []
|
| 34 |
for frame_idx in range(depth_map.shape[0]):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
cur_world_points, _, _ = depth_to_world_coords_points(
|
| 36 |
-
depth_map[frame_idx].squeeze(-1),
|
| 37 |
)
|
| 38 |
world_points_list.append(cur_world_points)
|
| 39 |
world_points_array = np.stack(world_points_list, axis=0)
|
|
|
|
| 32 |
|
| 33 |
world_points_list = []
|
| 34 |
for frame_idx in range(depth_map.shape[0]):
|
| 35 |
+
intrinsic = intrinsics_cam[frame_idx]
|
| 36 |
+
if intrinsic.shape[-2:] != (3, 3):
|
| 37 |
+
intrinsic = intrinsic.reshape(-1, 3, 3)[0]
|
| 38 |
+
extrinsic = extrinsics_cam[frame_idx]
|
| 39 |
+
if extrinsic.shape[-2:] != (3, 4):
|
| 40 |
+
extrinsic = extrinsic.reshape(-1, 3, 4)[0]
|
| 41 |
cur_world_points, _, _ = depth_to_world_coords_points(
|
| 42 |
+
depth_map[frame_idx].squeeze(-1), extrinsic, intrinsic
|
| 43 |
)
|
| 44 |
world_points_list.append(cur_world_points)
|
| 45 |
world_points_array = np.stack(world_points_list, axis=0)
|
stream3r/worker/__init__.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Worker utilities for running STream3R jobs via RQ."""
|
| 2 |
+
|
| 3 |
+
from .tasks import model_build_job, pose_pointmap_job
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"pose_pointmap_job",
|
| 7 |
+
"model_build_job",
|
| 8 |
+
]
|
stream3r/worker/config.py
ADDED
|
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Configuration helpers for the STream3R RQ worker."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from dataclasses import dataclass, field
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _env_bool(name: str, default: bool) -> bool:
|
| 11 |
+
value = os.getenv(name)
|
| 12 |
+
if value is None:
|
| 13 |
+
return default
|
| 14 |
+
value = value.strip().lower()
|
| 15 |
+
if value in {"1", "true", "yes", "y", "on"}:
|
| 16 |
+
return True
|
| 17 |
+
if value in {"0", "false", "no", "n", "off"}:
|
| 18 |
+
return False
|
| 19 |
+
return default
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _env_bool_any(default: bool, *names: str) -> bool:
|
| 23 |
+
for name in names:
|
| 24 |
+
if not name:
|
| 25 |
+
continue
|
| 26 |
+
value = os.getenv(name)
|
| 27 |
+
if value is None:
|
| 28 |
+
continue
|
| 29 |
+
value = value.strip().lower()
|
| 30 |
+
if value in {"1", "true", "yes", "y", "on"}:
|
| 31 |
+
return True
|
| 32 |
+
if value in {"0", "false", "no", "n", "off"}:
|
| 33 |
+
return False
|
| 34 |
+
return default
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _env_int(name: str, default: int) -> int:
|
| 38 |
+
value = os.getenv(name)
|
| 39 |
+
if value is None:
|
| 40 |
+
return default
|
| 41 |
+
try:
|
| 42 |
+
return int(value)
|
| 43 |
+
except ValueError:
|
| 44 |
+
return default
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _env_value(primary: str, *aliases: str, default: str | None = None) -> str | None:
|
| 48 |
+
"""Best-effort environment lookup with fallbacks for legacy names."""
|
| 49 |
+
|
| 50 |
+
names = (primary, *aliases)
|
| 51 |
+
for name in names:
|
| 52 |
+
if not name:
|
| 53 |
+
continue
|
| 54 |
+
value = os.getenv(name)
|
| 55 |
+
if value:
|
| 56 |
+
return value
|
| 57 |
+
return default
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
@dataclass(slots=True)
|
| 61 |
+
class WorkerSettings:
|
| 62 |
+
"""Runtime configuration derived from environment variables."""
|
| 63 |
+
|
| 64 |
+
redis_url: str = "redis://localhost:6379/0"
|
| 65 |
+
redis_events_stream: str = "stream3r:events"
|
| 66 |
+
redis_stream_maxlen: int = 50_000
|
| 67 |
+
redis_healthcheck_interval: int = 30
|
| 68 |
+
|
| 69 |
+
pose_queue: str = "pose_pointmap"
|
| 70 |
+
model_queue: str = "model_build"
|
| 71 |
+
|
| 72 |
+
gpu_lock_key: str = "gpu:lock"
|
| 73 |
+
gpu_lock_timeout: int = 3600
|
| 74 |
+
gpu_lock_blocking_timeout: int = 600
|
| 75 |
+
|
| 76 |
+
storage_prefix: str = "scene"
|
| 77 |
+
s3_bucket: str | None = None
|
| 78 |
+
s3_endpoint_url: str | None = None
|
| 79 |
+
s3_region: str | None = None
|
| 80 |
+
s3_profile: str | None = None
|
| 81 |
+
s3_force_path_style: bool = False
|
| 82 |
+
aws_access_key_id: str | None = None
|
| 83 |
+
aws_secret_access_key: str | None = None
|
| 84 |
+
aws_session_token: str | None = None
|
| 85 |
+
local_storage_root: Path = field(default_factory=lambda: Path("storage"))
|
| 86 |
+
|
| 87 |
+
db_dsn: str | None = None
|
| 88 |
+
job_table: str = "stream3r_jobs"
|
| 89 |
+
|
| 90 |
+
model_id: str = "yslan/STream3R"
|
| 91 |
+
model_revision: str | None = None
|
| 92 |
+
model_device_preference: str | None = None
|
| 93 |
+
model_dtype: str | None = None
|
| 94 |
+
|
| 95 |
+
default_mode: str = "window"
|
| 96 |
+
default_streaming: bool = True
|
| 97 |
+
download_workers: int = 4
|
| 98 |
+
|
| 99 |
+
worker_name: str | None = None
|
| 100 |
+
|
| 101 |
+
session_cache_filename: str = "kv_cache.pt"
|
| 102 |
+
predictions_filename: str = "predictions.npz"
|
| 103 |
+
poses_filename: str = "poses.jsonl"
|
| 104 |
+
result_filename: str = "summary.json"
|
| 105 |
+
selected_frames_filename: str = "selected_frames.json"
|
| 106 |
+
scene_glb_filename: str = "scene.glb"
|
| 107 |
+
session_settings_filename: str = "session_settings.json"
|
| 108 |
+
summary_results_filename: str = "summary.json"
|
| 109 |
+
|
| 110 |
+
pointmap_dir: str = "pointmaps"
|
| 111 |
+
models_dir: str = "models"
|
| 112 |
+
results_dir: str = "results"
|
| 113 |
+
|
| 114 |
+
scene_media_api_base_url: str | None = None
|
| 115 |
+
scene_media_api_token: str | None = None
|
| 116 |
+
scene_media_page_size: int = 200
|
| 117 |
+
stream_window_size: int = 10
|
| 118 |
+
max_frames_per_job: int = 10
|
| 119 |
+
|
| 120 |
+
@classmethod
|
| 121 |
+
def from_env(cls) -> "WorkerSettings":
|
| 122 |
+
base = cls()
|
| 123 |
+
|
| 124 |
+
kwargs: dict[str, object] = {
|
| 125 |
+
"redis_url": _env_value("STREAM3R_REDIS_URL", "REDIS_URL", default=base.redis_url),
|
| 126 |
+
"redis_events_stream": os.getenv("STREAM3R_EVENTS_STREAM", base.redis_events_stream),
|
| 127 |
+
"redis_stream_maxlen": _env_int("STREAM3R_EVENTS_MAXLEN", base.redis_stream_maxlen),
|
| 128 |
+
"redis_healthcheck_interval": _env_int(
|
| 129 |
+
"STREAM3R_REDIS_HEALTHCHECK", base.redis_healthcheck_interval
|
| 130 |
+
),
|
| 131 |
+
"pose_queue": os.getenv("STREAM3R_QUEUE_POSE", base.pose_queue),
|
| 132 |
+
"model_queue": os.getenv("STREAM3R_QUEUE_MODEL", base.model_queue),
|
| 133 |
+
"gpu_lock_key": os.getenv("STREAM3R_GPU_LOCK_KEY", base.gpu_lock_key),
|
| 134 |
+
"gpu_lock_timeout": _env_int("STREAM3R_GPU_LOCK_TIMEOUT", base.gpu_lock_timeout),
|
| 135 |
+
"gpu_lock_blocking_timeout": _env_int(
|
| 136 |
+
"STREAM3R_GPU_LOCK_BLOCK", base.gpu_lock_blocking_timeout
|
| 137 |
+
),
|
| 138 |
+
"storage_prefix": os.getenv("STREAM3R_STORAGE_PREFIX", base.storage_prefix),
|
| 139 |
+
"s3_bucket": _env_value("STREAM3R_STORAGE_BUCKET", "S3_BUCKET", default=base.s3_bucket) or None,
|
| 140 |
+
"s3_endpoint_url": _env_value("STREAM3R_S3_ENDPOINT", "S3_ENDPOINT", default=base.s3_endpoint_url) or None,
|
| 141 |
+
"s3_region": _env_value("STREAM3R_S3_REGION", "AWS_REGION", default=base.s3_region) or None,
|
| 142 |
+
"s3_profile": os.getenv("STREAM3R_S3_PROFILE", base.s3_profile or "") or None,
|
| 143 |
+
"s3_force_path_style": _env_bool_any(
|
| 144 |
+
base.s3_force_path_style,
|
| 145 |
+
"STREAM3R_S3_FORCE_PATH",
|
| 146 |
+
"S3_FORCE_PATH_STYLE",
|
| 147 |
+
),
|
| 148 |
+
"aws_access_key_id": _env_value("AWS_ACCESS_KEY_ID", default=base.aws_access_key_id) or None,
|
| 149 |
+
"aws_secret_access_key": _env_value("AWS_SECRET_ACCESS_KEY", default=base.aws_secret_access_key) or None,
|
| 150 |
+
"aws_session_token": _env_value("AWS_SESSION_TOKEN", default=base.aws_session_token) or None,
|
| 151 |
+
"local_storage_root": Path(
|
| 152 |
+
os.getenv("STREAM3R_LOCAL_STORAGE", str(base.local_storage_root))
|
| 153 |
+
).resolve(),
|
| 154 |
+
"db_dsn": _env_value("STREAM3R_DB_DSN", "DATABASE_URL", default=base.db_dsn) or None,
|
| 155 |
+
"job_table": os.getenv("STREAM3R_JOBS_TABLE", base.job_table),
|
| 156 |
+
"model_id": os.getenv("STREAM3R_MODEL_ID", base.model_id),
|
| 157 |
+
"model_revision": os.getenv("STREAM3R_MODEL_REVISION", base.model_revision or "") or None,
|
| 158 |
+
"model_device_preference": os.getenv(
|
| 159 |
+
"STREAM3R_MODEL_DEVICE", base.model_device_preference or ""
|
| 160 |
+
)
|
| 161 |
+
or None,
|
| 162 |
+
"model_dtype": os.getenv("STREAM3R_MODEL_DTYPE", base.model_dtype or "") or None,
|
| 163 |
+
"default_mode": os.getenv("STREAM3R_DEFAULT_MODE", base.default_mode),
|
| 164 |
+
"default_streaming": _env_bool(
|
| 165 |
+
"STREAM3R_DEFAULT_STREAMING", base.default_streaming
|
| 166 |
+
),
|
| 167 |
+
"download_workers": _env_int("STREAM3R_DOWNLOAD_WORKERS", base.download_workers),
|
| 168 |
+
"worker_name": os.getenv("STREAM3R_WORKER_NAME", base.worker_name or "") or None,
|
| 169 |
+
"session_cache_filename": os.getenv(
|
| 170 |
+
"STREAM3R_SESSION_CACHE", base.session_cache_filename
|
| 171 |
+
),
|
| 172 |
+
"predictions_filename": os.getenv(
|
| 173 |
+
"STREAM3R_PREDICTIONS_FILE", base.predictions_filename
|
| 174 |
+
),
|
| 175 |
+
"poses_filename": os.getenv("STREAM3R_POSES_FILE", base.poses_filename),
|
| 176 |
+
"result_filename": os.getenv("STREAM3R_RESULT_FILE", base.result_filename),
|
| 177 |
+
"selected_frames_filename": os.getenv(
|
| 178 |
+
"STREAM3R_SELECTED_FRAMES_FILE", base.selected_frames_filename
|
| 179 |
+
),
|
| 180 |
+
"scene_glb_filename": os.getenv("STREAM3R_SCENE_GLB_FILE", base.scene_glb_filename),
|
| 181 |
+
"session_settings_filename": os.getenv(
|
| 182 |
+
"STREAM3R_SESSION_SETTINGS_FILE", base.session_settings_filename
|
| 183 |
+
),
|
| 184 |
+
"summary_results_filename": os.getenv(
|
| 185 |
+
"STREAM3R_SUMMARY_FILE", base.summary_results_filename
|
| 186 |
+
),
|
| 187 |
+
"pointmap_dir": os.getenv("STREAM3R_POINTMAP_DIR", base.pointmap_dir),
|
| 188 |
+
"models_dir": os.getenv("STREAM3R_MODELS_DIR", base.models_dir),
|
| 189 |
+
"results_dir": os.getenv("STREAM3R_RESULTS_DIR", base.results_dir),
|
| 190 |
+
"scene_media_api_base_url": _env_value(
|
| 191 |
+
"STREAM3R_MEDIA_API_BASE_URL",
|
| 192 |
+
"API_BASE_URL",
|
| 193 |
+
default=base.scene_media_api_base_url,
|
| 194 |
+
)
|
| 195 |
+
or None,
|
| 196 |
+
"scene_media_api_token": _env_value(
|
| 197 |
+
"STREAM3R_MEDIA_API_TOKEN",
|
| 198 |
+
"MEDIA_API_TOKEN",
|
| 199 |
+
default=base.scene_media_api_token,
|
| 200 |
+
)
|
| 201 |
+
or None,
|
| 202 |
+
"scene_media_page_size": _env_int(
|
| 203 |
+
"STREAM3R_MEDIA_PAGE_SIZE", base.scene_media_page_size
|
| 204 |
+
),
|
| 205 |
+
"stream_window_size": _env_int(
|
| 206 |
+
"STREAM3R_WINDOW_SIZE", base.stream_window_size
|
| 207 |
+
),
|
| 208 |
+
"max_frames_per_job": _env_int(
|
| 209 |
+
"STREAM3R_MAX_FRAMES", base.max_frames_per_job
|
| 210 |
+
),
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
return cls(**kwargs)
|
stream3r/worker/db.py
ADDED
|
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Database helpers for persisting job metadata."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import logging
|
| 7 |
+
from contextlib import contextmanager
|
| 8 |
+
from typing import Iterator, Mapping
|
| 9 |
+
|
| 10 |
+
try: # Optional dependency
|
| 11 |
+
import psycopg2
|
| 12 |
+
from psycopg2.extensions import connection as PGConnection
|
| 13 |
+
from psycopg2.extras import Json
|
| 14 |
+
except ModuleNotFoundError: # pragma: no cover - exercised when psycopg2 missing
|
| 15 |
+
psycopg2 = None # type: ignore[assignment]
|
| 16 |
+
PGConnection = None # type: ignore[assignment]
|
| 17 |
+
Json = None # type: ignore[assignment]
|
| 18 |
+
|
| 19 |
+
from .config import WorkerSettings
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger(__name__)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
SCHEMA_SQL = """
|
| 26 |
+
CREATE TABLE IF NOT EXISTS {table_name} (
|
| 27 |
+
job_id UUID PRIMARY KEY,
|
| 28 |
+
job_type TEXT NOT NULL,
|
| 29 |
+
scene_id TEXT NOT NULL,
|
| 30 |
+
status TEXT NOT NULL,
|
| 31 |
+
created_at TIMESTAMPTZ NOT NULL DEFAULT now(),
|
| 32 |
+
started_at TIMESTAMPTZ,
|
| 33 |
+
completed_at TIMESTAMPTZ,
|
| 34 |
+
payload JSONB,
|
| 35 |
+
result JSONB,
|
| 36 |
+
error TEXT
|
| 37 |
+
);
|
| 38 |
+
|
| 39 |
+
CREATE INDEX IF NOT EXISTS {table_name}_scene_id_idx ON {table_name}(scene_id);
|
| 40 |
+
CREATE INDEX IF NOT EXISTS {table_name}_status_idx ON {table_name}(status);
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DatabaseError(RuntimeError):
|
| 45 |
+
"""Raised when database operations fail."""
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class BaseDatabaseClient:
|
| 49 |
+
"""Small interface for database persistence."""
|
| 50 |
+
|
| 51 |
+
def ensure_schema(self) -> None: # pragma: no cover - noop implementation
|
| 52 |
+
raise NotImplementedError
|
| 53 |
+
|
| 54 |
+
def upsert_job(
|
| 55 |
+
self,
|
| 56 |
+
*,
|
| 57 |
+
job_id: str,
|
| 58 |
+
job_type: str,
|
| 59 |
+
scene_id: str,
|
| 60 |
+
status: str,
|
| 61 |
+
payload: Mapping[str, object] | None = None,
|
| 62 |
+
result: Mapping[str, object] | None = None,
|
| 63 |
+
error: str | None = None,
|
| 64 |
+
) -> None:
|
| 65 |
+
raise NotImplementedError
|
| 66 |
+
|
| 67 |
+
def close(self) -> None: # pragma: no cover - noop implementation
|
| 68 |
+
pass
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class NoopDatabaseClient(BaseDatabaseClient):
|
| 72 |
+
"""Fallback when no database configuration is provided."""
|
| 73 |
+
|
| 74 |
+
def ensure_schema(self) -> None: # pragma: no cover - nothing to do
|
| 75 |
+
logger.debug("Database is disabled; skipping schema creation")
|
| 76 |
+
|
| 77 |
+
def upsert_job(
|
| 78 |
+
self,
|
| 79 |
+
*,
|
| 80 |
+
job_id: str,
|
| 81 |
+
job_type: str,
|
| 82 |
+
scene_id: str,
|
| 83 |
+
status: str,
|
| 84 |
+
payload: Mapping[str, object] | None = None,
|
| 85 |
+
result: Mapping[str, object] | None = None,
|
| 86 |
+
error: str | None = None,
|
| 87 |
+
) -> None:
|
| 88 |
+
logger.debug(
|
| 89 |
+
"Noop DB: job_id=%s job_type=%s scene_id=%s status=%s", job_id, job_type, scene_id, status
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class DatabaseClient(BaseDatabaseClient):
|
| 94 |
+
"""Postgres implementation using psycopg2."""
|
| 95 |
+
|
| 96 |
+
def __init__(self, settings: WorkerSettings):
|
| 97 |
+
if psycopg2 is None: # pragma: no cover - optional dependency guard
|
| 98 |
+
raise DatabaseError("psycopg2-binary is required for database support")
|
| 99 |
+
|
| 100 |
+
self.settings = settings
|
| 101 |
+
|
| 102 |
+
@contextmanager
|
| 103 |
+
def _connect(self) -> Iterator[PGConnection]:
|
| 104 |
+
conn = psycopg2.connect(self.settings.db_dsn) # type: ignore[arg-type]
|
| 105 |
+
try:
|
| 106 |
+
conn.autocommit = True
|
| 107 |
+
yield conn
|
| 108 |
+
finally:
|
| 109 |
+
conn.close()
|
| 110 |
+
|
| 111 |
+
def ensure_schema(self) -> None:
|
| 112 |
+
table_name = self.settings.job_table
|
| 113 |
+
with self._connect() as conn:
|
| 114 |
+
with conn.cursor() as cur:
|
| 115 |
+
cur.execute(SCHEMA_SQL.format(table_name=table_name))
|
| 116 |
+
|
| 117 |
+
def upsert_job(
|
| 118 |
+
self,
|
| 119 |
+
*,
|
| 120 |
+
job_id: str,
|
| 121 |
+
job_type: str,
|
| 122 |
+
scene_id: str,
|
| 123 |
+
status: str,
|
| 124 |
+
payload: Mapping[str, object] | None = None,
|
| 125 |
+
result: Mapping[str, object] | None = None,
|
| 126 |
+
error: str | None = None,
|
| 127 |
+
) -> None:
|
| 128 |
+
table = self.settings.job_table
|
| 129 |
+
payload_json = Json(payload) if payload is not None and Json is not None else None
|
| 130 |
+
result_json = Json(result) if result is not None and Json is not None else None
|
| 131 |
+
|
| 132 |
+
with self._connect() as conn:
|
| 133 |
+
with conn.cursor() as cur:
|
| 134 |
+
cur.execute(
|
| 135 |
+
f"""
|
| 136 |
+
INSERT INTO {table} (job_id, job_type, scene_id, status, payload, result, error, started_at, completed_at)
|
| 137 |
+
VALUES (%s, %s, %s, %s, %s, %s, %s,
|
| 138 |
+
CASE WHEN %s = 'started' THEN now() ELSE NULL END,
|
| 139 |
+
CASE WHEN %s IN ('finished', 'failed') THEN now() ELSE NULL END)
|
| 140 |
+
ON CONFLICT (job_id)
|
| 141 |
+
DO UPDATE SET
|
| 142 |
+
job_type = EXCLUDED.job_type,
|
| 143 |
+
scene_id = EXCLUDED.scene_id,
|
| 144 |
+
status = EXCLUDED.status,
|
| 145 |
+
payload = COALESCE(EXCLUDED.payload, {table}.payload),
|
| 146 |
+
result = COALESCE(EXCLUDED.result, {table}.result),
|
| 147 |
+
error = EXCLUDED.error,
|
| 148 |
+
started_at = COALESCE({table}.started_at, EXCLUDED.started_at),
|
| 149 |
+
completed_at = COALESCE({table}.completed_at, EXCLUDED.completed_at)
|
| 150 |
+
""",
|
| 151 |
+
(
|
| 152 |
+
job_id,
|
| 153 |
+
job_type,
|
| 154 |
+
scene_id,
|
| 155 |
+
status,
|
| 156 |
+
payload_json,
|
| 157 |
+
result_json,
|
| 158 |
+
error,
|
| 159 |
+
status,
|
| 160 |
+
status,
|
| 161 |
+
),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def create_database_client(settings: WorkerSettings) -> BaseDatabaseClient:
|
| 166 |
+
"""Factory that returns a database client or a noop fallback."""
|
| 167 |
+
|
| 168 |
+
if not settings.db_dsn:
|
| 169 |
+
return NoopDatabaseClient()
|
| 170 |
+
return DatabaseClient(settings)
|
stream3r/worker/main.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""CLI entrypoint to launch an RQ worker for STream3R jobs."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
import logging
|
| 7 |
+
from typing import Sequence
|
| 8 |
+
|
| 9 |
+
from rq import Queue, Worker
|
| 10 |
+
|
| 11 |
+
from .config import WorkerSettings
|
| 12 |
+
from .runtime import get_runtime
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger(__name__)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _parse_args(default_queues: Sequence[str]) -> argparse.Namespace:
|
| 19 |
+
parser = argparse.ArgumentParser(description="Run the STream3R RQ worker")
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--queue",
|
| 22 |
+
"--queues",
|
| 23 |
+
dest="queues",
|
| 24 |
+
action="append",
|
| 25 |
+
help="Queue names to listen to (can be repeated)",
|
| 26 |
+
)
|
| 27 |
+
parser.add_argument(
|
| 28 |
+
"--burst",
|
| 29 |
+
action="store_true",
|
| 30 |
+
help="Run in burst mode (exit when queues are empty)",
|
| 31 |
+
)
|
| 32 |
+
parser.add_argument(
|
| 33 |
+
"--log-level",
|
| 34 |
+
default="INFO",
|
| 35 |
+
help="Logging level",
|
| 36 |
+
)
|
| 37 |
+
args = parser.parse_args()
|
| 38 |
+
if not args.queues:
|
| 39 |
+
args.queues = list(default_queues)
|
| 40 |
+
return args
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def main() -> None:
|
| 44 |
+
settings = WorkerSettings.from_env()
|
| 45 |
+
args = _parse_args([settings.pose_queue, settings.model_queue])
|
| 46 |
+
logging.basicConfig(level=getattr(logging, str(args.log_level).upper(), logging.INFO))
|
| 47 |
+
|
| 48 |
+
runtime = get_runtime()
|
| 49 |
+
|
| 50 |
+
queues = [Queue(name, connection=runtime.redis) for name in args.queues]
|
| 51 |
+
for queue in queues:
|
| 52 |
+
logger.info("Listening on queue '%s'", queue.name)
|
| 53 |
+
|
| 54 |
+
worker = Worker(queues, name=settings.worker_name)
|
| 55 |
+
worker.work(burst=args.burst)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
if __name__ == "__main__": # pragma: no cover
|
| 59 |
+
main()
|
stream3r/worker/pipeline.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Inference pipeline utilities reused by worker jobs."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from contextlib import nullcontext
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Callable, Iterable, Mapping
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from stream3r.models.components.utils.load_fn import load_and_preprocess_images
|
| 14 |
+
from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri
|
| 15 |
+
from stream3r.stream_session import StreamSession
|
| 16 |
+
|
| 17 |
+
from .runtime import WorkerRuntime
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
ProgressCallback = Callable[[int, int], None]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class InferenceResult:
|
| 25 |
+
predictions: dict[str, np.ndarray]
|
| 26 |
+
total_frames: int
|
| 27 |
+
cache_path: Path | None
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _to_numpy(payload):
|
| 31 |
+
if isinstance(payload, torch.Tensor):
|
| 32 |
+
return payload.detach().cpu().numpy()
|
| 33 |
+
if isinstance(payload, dict):
|
| 34 |
+
return {k: _to_numpy(v) for k, v in payload.items()}
|
| 35 |
+
if isinstance(payload, (list, tuple)):
|
| 36 |
+
converted = [_to_numpy(item) for item in payload]
|
| 37 |
+
return type(payload)(converted)
|
| 38 |
+
return payload
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def run_stream3r_inference(
|
| 42 |
+
*,
|
| 43 |
+
runtime: WorkerRuntime,
|
| 44 |
+
image_paths: Iterable[Path],
|
| 45 |
+
mode: str,
|
| 46 |
+
streaming: bool,
|
| 47 |
+
cache_output_path: Path | None,
|
| 48 |
+
progress_cb: ProgressCallback | None = None,
|
| 49 |
+
window_size: int | None = None,
|
| 50 |
+
) -> InferenceResult:
|
| 51 |
+
"""Execute STream3R inference for the provided frames."""
|
| 52 |
+
|
| 53 |
+
image_list = [Path(p) for p in image_paths]
|
| 54 |
+
if not image_list:
|
| 55 |
+
raise ValueError("No images provided to inference pipeline")
|
| 56 |
+
|
| 57 |
+
model = runtime.get_model()
|
| 58 |
+
device = runtime.model_device()
|
| 59 |
+
|
| 60 |
+
images = load_and_preprocess_images([str(path) for path in image_list])
|
| 61 |
+
total_frames = images.shape[0]
|
| 62 |
+
|
| 63 |
+
autocast_dtype = runtime.autocast_dtype()
|
| 64 |
+
autocast_ctx = (
|
| 65 |
+
torch.autocast(device_type=device.type, dtype=autocast_dtype)
|
| 66 |
+
if device.type == "cuda"
|
| 67 |
+
else nullcontext()
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
predictions: Mapping[str, torch.Tensor]
|
| 71 |
+
cache_path: Path | None = None
|
| 72 |
+
|
| 73 |
+
model.eval()
|
| 74 |
+
|
| 75 |
+
if window_size is not None and window_size <= 0:
|
| 76 |
+
window_size = None
|
| 77 |
+
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
if streaming:
|
| 80 |
+
session_kwargs = {"mode": mode}
|
| 81 |
+
if window_size is not None:
|
| 82 |
+
session_kwargs["window_size"] = window_size
|
| 83 |
+
session = StreamSession(model, **session_kwargs)
|
| 84 |
+
session.clear()
|
| 85 |
+
for idx in range(total_frames):
|
| 86 |
+
frame = images[idx : idx + 1].to(device)
|
| 87 |
+
with autocast_ctx:
|
| 88 |
+
session.forward_stream(frame)
|
| 89 |
+
if progress_cb is not None:
|
| 90 |
+
progress_cb(idx + 1, total_frames)
|
| 91 |
+
|
| 92 |
+
if cache_output_path is not None:
|
| 93 |
+
session.save_cache(str(cache_output_path))
|
| 94 |
+
cache_path = cache_output_path
|
| 95 |
+
|
| 96 |
+
predictions = session.get_all_predictions()
|
| 97 |
+
else:
|
| 98 |
+
inputs = images.to(device)
|
| 99 |
+
with autocast_ctx:
|
| 100 |
+
predictions = model(inputs, mode=mode)
|
| 101 |
+
if progress_cb is not None:
|
| 102 |
+
progress_cb(total_frames, total_frames)
|
| 103 |
+
|
| 104 |
+
predictions = dict(predictions)
|
| 105 |
+
|
| 106 |
+
# Augment predictions with pose matrices and world coordinates
|
| 107 |
+
height, width = images.shape[-2:]
|
| 108 |
+
|
| 109 |
+
pose_enc = predictions.get("pose_enc")
|
| 110 |
+
if pose_enc is None:
|
| 111 |
+
raise RuntimeError("Model predictions missing 'pose_enc'")
|
| 112 |
+
|
| 113 |
+
if not isinstance(pose_enc, torch.Tensor):
|
| 114 |
+
pose_enc = torch.as_tensor(pose_enc)
|
| 115 |
+
|
| 116 |
+
if pose_enc.dim() == 2: # streaming cache might drop batch dim
|
| 117 |
+
pose_enc = pose_enc.unsqueeze(0)
|
| 118 |
+
|
| 119 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, (height, width))
|
| 120 |
+
predictions["extrinsic"] = extrinsic
|
| 121 |
+
predictions["intrinsic"] = intrinsic
|
| 122 |
+
|
| 123 |
+
for key, value in list(predictions.items()):
|
| 124 |
+
if isinstance(value, torch.Tensor):
|
| 125 |
+
predictions[key] = value
|
| 126 |
+
|
| 127 |
+
predictions_np = {key: _to_numpy(value) for key, value in predictions.items()}
|
| 128 |
+
|
| 129 |
+
pose_enc_np = predictions_np.pop("pose_enc", None)
|
| 130 |
+
if pose_enc_np is not None and pose_enc_np.ndim >= 3:
|
| 131 |
+
predictions_np["pose_enc"] = pose_enc_np
|
| 132 |
+
|
| 133 |
+
# Remove batch dimension if present
|
| 134 |
+
for key, value in list(predictions_np.items()):
|
| 135 |
+
if isinstance(value, np.ndarray) and value.shape[0] == 1:
|
| 136 |
+
predictions_np[key] = np.squeeze(value, axis=0)
|
| 137 |
+
|
| 138 |
+
torch.cuda.empty_cache()
|
| 139 |
+
|
| 140 |
+
return InferenceResult(
|
| 141 |
+
predictions=predictions_np,
|
| 142 |
+
total_frames=total_frames,
|
| 143 |
+
cache_path=cache_path,
|
| 144 |
+
)
|
stream3r/worker/runtime.py
ADDED
|
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Runtime registry for shared worker resources."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from threading import Lock
|
| 8 |
+
from typing import Any, Dict, Mapping
|
| 9 |
+
|
| 10 |
+
import redis
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from stream3r.models.stream3r import STream3R
|
| 14 |
+
|
| 15 |
+
from .config import WorkerSettings
|
| 16 |
+
from .db import BaseDatabaseClient, create_database_client
|
| 17 |
+
from .storage import StorageClient, create_storage_client
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class WorkerRuntime:
|
| 24 |
+
"""Holds shared state reused across RQ jobs."""
|
| 25 |
+
|
| 26 |
+
def __init__(self, settings: WorkerSettings):
|
| 27 |
+
self.settings = settings
|
| 28 |
+
self._redis = redis.Redis.from_url(
|
| 29 |
+
settings.redis_url,
|
| 30 |
+
decode_responses=False,
|
| 31 |
+
health_check_interval=settings.redis_healthcheck_interval,
|
| 32 |
+
)
|
| 33 |
+
self.storage: StorageClient = create_storage_client(settings)
|
| 34 |
+
self.db: BaseDatabaseClient = create_database_client(settings)
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
self.db.ensure_schema()
|
| 38 |
+
except Exception as exc: # pragma: no cover - depends on external DB
|
| 39 |
+
logger.warning("Failed to ensure job schema: %s", exc)
|
| 40 |
+
|
| 41 |
+
self._model: STream3R | None = None
|
| 42 |
+
self._model_lock = Lock()
|
| 43 |
+
self._device: torch.device | None = None
|
| 44 |
+
self._autocast_dtype: torch.dtype | None = None
|
| 45 |
+
|
| 46 |
+
# -----------------------------------------------------------------
|
| 47 |
+
# Redis helpers
|
| 48 |
+
# -----------------------------------------------------------------
|
| 49 |
+
@property
|
| 50 |
+
def redis(self) -> redis.Redis:
|
| 51 |
+
return self._redis
|
| 52 |
+
|
| 53 |
+
def emit_event(self, payload: Mapping[str, Any]) -> None:
|
| 54 |
+
try:
|
| 55 |
+
stream = self.settings.redis_events_stream
|
| 56 |
+
data = {k: str(v) for k, v in payload.items() if v is not None}
|
| 57 |
+
maxlen = self.settings.redis_stream_maxlen
|
| 58 |
+
self._redis.xadd(stream, data, maxlen=maxlen, approximate=True)
|
| 59 |
+
except redis.RedisError as exc: # pragma: no cover - depends on Redis
|
| 60 |
+
logger.warning("Failed to emit event to Redis: %s", exc)
|
| 61 |
+
|
| 62 |
+
@contextmanager
|
| 63 |
+
def gpu_lock(self) -> Any:
|
| 64 |
+
lock = self._redis.lock(
|
| 65 |
+
self.settings.gpu_lock_key,
|
| 66 |
+
timeout=self.settings.gpu_lock_timeout,
|
| 67 |
+
blocking_timeout=self.settings.gpu_lock_blocking_timeout,
|
| 68 |
+
)
|
| 69 |
+
acquired = False
|
| 70 |
+
try:
|
| 71 |
+
acquired = lock.acquire(blocking=True)
|
| 72 |
+
if not acquired:
|
| 73 |
+
raise TimeoutError("Timed out waiting for GPU lock")
|
| 74 |
+
yield
|
| 75 |
+
finally:
|
| 76 |
+
if acquired:
|
| 77 |
+
try:
|
| 78 |
+
lock.release()
|
| 79 |
+
except redis.RedisError: # pragma: no cover - depends on Redis
|
| 80 |
+
logger.debug("GPU lock already released")
|
| 81 |
+
|
| 82 |
+
# -----------------------------------------------------------------
|
| 83 |
+
# Model helpers
|
| 84 |
+
# -----------------------------------------------------------------
|
| 85 |
+
def _resolve_device(self) -> torch.device:
|
| 86 |
+
if self._device is not None:
|
| 87 |
+
return self._device
|
| 88 |
+
|
| 89 |
+
preference = self.settings.model_device_preference
|
| 90 |
+
if preference:
|
| 91 |
+
try:
|
| 92 |
+
device = torch.device(preference)
|
| 93 |
+
except (ValueError, RuntimeError):
|
| 94 |
+
logger.warning("Unknown device preference '%s', falling back to auto", preference)
|
| 95 |
+
device = None
|
| 96 |
+
else:
|
| 97 |
+
device = None
|
| 98 |
+
|
| 99 |
+
if device is None:
|
| 100 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 101 |
+
|
| 102 |
+
self._device = device
|
| 103 |
+
return device
|
| 104 |
+
|
| 105 |
+
def _resolve_autocast_dtype(self) -> torch.dtype:
|
| 106 |
+
if self._autocast_dtype is not None:
|
| 107 |
+
return self._autocast_dtype
|
| 108 |
+
|
| 109 |
+
dtype_name = self.settings.model_dtype
|
| 110 |
+
if dtype_name:
|
| 111 |
+
try:
|
| 112 |
+
self._autocast_dtype = getattr(torch, dtype_name)
|
| 113 |
+
return self._autocast_dtype
|
| 114 |
+
except AttributeError:
|
| 115 |
+
logger.warning("Unsupported dtype '%s', using default", dtype_name)
|
| 116 |
+
|
| 117 |
+
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
| 118 |
+
self._autocast_dtype = torch.bfloat16
|
| 119 |
+
else:
|
| 120 |
+
self._autocast_dtype = torch.float16 if torch.cuda.is_available() else torch.float32
|
| 121 |
+
return self._autocast_dtype
|
| 122 |
+
|
| 123 |
+
def get_model(self) -> STream3R:
|
| 124 |
+
with self._model_lock:
|
| 125 |
+
if self._model is None:
|
| 126 |
+
logger.info("Loading STream3R model '%s'", self.settings.model_id)
|
| 127 |
+
model = STream3R.from_pretrained(
|
| 128 |
+
self.settings.model_id,
|
| 129 |
+
revision=self.settings.model_revision,
|
| 130 |
+
)
|
| 131 |
+
device = self._resolve_device()
|
| 132 |
+
model.to(device)
|
| 133 |
+
model.eval()
|
| 134 |
+
self._model = model
|
| 135 |
+
return self._model
|
| 136 |
+
|
| 137 |
+
def model_device(self) -> torch.device:
|
| 138 |
+
return self._resolve_device()
|
| 139 |
+
|
| 140 |
+
def autocast_dtype(self) -> torch.dtype:
|
| 141 |
+
return self._resolve_autocast_dtype()
|
| 142 |
+
|
| 143 |
+
# -----------------------------------------------------------------
|
| 144 |
+
def close(self) -> None:
|
| 145 |
+
try:
|
| 146 |
+
self.db.close()
|
| 147 |
+
except AttributeError:
|
| 148 |
+
pass
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
_RUNTIME: WorkerRuntime | None = None
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_runtime() -> WorkerRuntime:
|
| 155 |
+
global _RUNTIME
|
| 156 |
+
if _RUNTIME is None:
|
| 157 |
+
settings = WorkerSettings.from_env()
|
| 158 |
+
_RUNTIME = WorkerRuntime(settings)
|
| 159 |
+
return _RUNTIME
|
stream3r/worker/storage.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Storage backends for persisting STream3R job artifacts."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import shutil
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Mapping
|
| 9 |
+
|
| 10 |
+
try: # Lazy import to keep optional dependency
|
| 11 |
+
import boto3
|
| 12 |
+
from botocore.client import Config as Boto3Config
|
| 13 |
+
from botocore.exceptions import BotoCoreError, ClientError
|
| 14 |
+
except ModuleNotFoundError: # pragma: no cover - fallback when boto3 missing
|
| 15 |
+
boto3 = None # type: ignore[assignment]
|
| 16 |
+
Boto3Config = None # type: ignore[assignment]
|
| 17 |
+
BotoCoreError = ClientError = Exception # type: ignore[assignment]
|
| 18 |
+
|
| 19 |
+
from .config import WorkerSettings
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class StorageError(RuntimeError):
|
| 23 |
+
"""Raised when artifact persistence fails."""
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class StorageClient:
|
| 27 |
+
"""Abstract base class providing a minimal upload interface."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, settings: WorkerSettings):
|
| 30 |
+
self.settings = settings
|
| 31 |
+
|
| 32 |
+
# --- key builders -------------------------------------------------
|
| 33 |
+
def build_key(self, scene_id: str, *parts: str) -> str:
|
| 34 |
+
components = [str(self.settings.storage_prefix), str(scene_id), "stream3r"]
|
| 35 |
+
for part in parts:
|
| 36 |
+
if not part:
|
| 37 |
+
continue
|
| 38 |
+
components.append(str(part).strip("/"))
|
| 39 |
+
return "/".join(components)
|
| 40 |
+
|
| 41 |
+
def build_uri(self, key: str) -> str:
|
| 42 |
+
raise NotImplementedError
|
| 43 |
+
|
| 44 |
+
# --- upload primitives -------------------------------------------
|
| 45 |
+
def upload_file(self, local_path: Path, key: str, *, content_type: str | None = None) -> str:
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
+
def upload_bytes(self, data: bytes, key: str, *, content_type: str | None = None) -> str:
|
| 49 |
+
raise NotImplementedError
|
| 50 |
+
|
| 51 |
+
def upload_json(self, payload: Mapping[str, object], key: str) -> str:
|
| 52 |
+
data = json.dumps(payload, allow_nan=False).encode("utf-8")
|
| 53 |
+
return self.upload_bytes(data, key, content_type="application/json")
|
| 54 |
+
|
| 55 |
+
def download_to_path(self, key: str, destination: Path) -> Path:
|
| 56 |
+
"""Download an object identified by *key* into *destination*."""
|
| 57 |
+
|
| 58 |
+
raise NotImplementedError
|
| 59 |
+
|
| 60 |
+
# --- helpers ------------------------------------------------------
|
| 61 |
+
def ensure_dir(self, scene_id: str, *parts: str) -> str:
|
| 62 |
+
"""Create a logical directory path under the storage prefix."""
|
| 63 |
+
key = self.build_key(scene_id, *parts)
|
| 64 |
+
if not key.endswith("/"):
|
| 65 |
+
key = f"{key}/"
|
| 66 |
+
return key
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class S3StorageClient(StorageClient):
|
| 70 |
+
"""S3-compatible storage backend."""
|
| 71 |
+
|
| 72 |
+
def __init__(self, settings: WorkerSettings):
|
| 73 |
+
if boto3 is None: # pragma: no cover - guarded by optional dependency
|
| 74 |
+
raise StorageError("boto3 is required for S3 storage but is not installed")
|
| 75 |
+
|
| 76 |
+
super().__init__(settings)
|
| 77 |
+
|
| 78 |
+
session_kwargs: dict[str, object] = {}
|
| 79 |
+
if settings.s3_profile:
|
| 80 |
+
session_kwargs["profile_name"] = settings.s3_profile
|
| 81 |
+
if settings.s3_region:
|
| 82 |
+
session_kwargs["region_name"] = settings.s3_region
|
| 83 |
+
if settings.aws_access_key_id and settings.aws_secret_access_key:
|
| 84 |
+
session_kwargs["aws_access_key_id"] = settings.aws_access_key_id
|
| 85 |
+
session_kwargs["aws_secret_access_key"] = settings.aws_secret_access_key
|
| 86 |
+
if settings.aws_session_token:
|
| 87 |
+
session_kwargs["aws_session_token"] = settings.aws_session_token
|
| 88 |
+
|
| 89 |
+
session = boto3.session.Session(**session_kwargs)
|
| 90 |
+
config = None
|
| 91 |
+
if settings.s3_force_path_style and Boto3Config is not None:
|
| 92 |
+
config = Boto3Config(s3={"addressing_style": "path"})
|
| 93 |
+
|
| 94 |
+
self._client = session.client(
|
| 95 |
+
"s3",
|
| 96 |
+
endpoint_url=settings.s3_endpoint_url,
|
| 97 |
+
config=config,
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
if not settings.s3_bucket:
|
| 101 |
+
raise StorageError("STREAM3R_STORAGE_BUCKET is required for S3 storage")
|
| 102 |
+
|
| 103 |
+
def build_uri(self, key: str) -> str:
|
| 104 |
+
bucket = self.settings.s3_bucket
|
| 105 |
+
return f"s3://{bucket}/{key}"
|
| 106 |
+
|
| 107 |
+
def upload_file(self, local_path: Path, key: str, *, content_type: str | None = None) -> str:
|
| 108 |
+
extra_args = {"ContentType": content_type} if content_type else None
|
| 109 |
+
try:
|
| 110 |
+
self._client.upload_file(str(local_path), self.settings.s3_bucket, key, ExtraArgs=extra_args)
|
| 111 |
+
except (BotoCoreError, ClientError) as exc: # pragma: no cover - network side effects
|
| 112 |
+
raise StorageError(f"Failed to upload {local_path} to {key}: {exc}") from exc
|
| 113 |
+
return self.build_uri(key)
|
| 114 |
+
|
| 115 |
+
def upload_bytes(self, data: bytes, key: str, *, content_type: str | None = None) -> str:
|
| 116 |
+
extra_args = {"ContentType": content_type} if content_type else None
|
| 117 |
+
try:
|
| 118 |
+
self._client.put_object(
|
| 119 |
+
Bucket=self.settings.s3_bucket,
|
| 120 |
+
Key=key,
|
| 121 |
+
Body=data,
|
| 122 |
+
ContentType=extra_args.get("ContentType") if extra_args else None,
|
| 123 |
+
)
|
| 124 |
+
except (BotoCoreError, ClientError) as exc: # pragma: no cover - network side effects
|
| 125 |
+
raise StorageError(f"Failed to upload payload to {key}: {exc}") from exc
|
| 126 |
+
return self.build_uri(key)
|
| 127 |
+
|
| 128 |
+
def download_to_path(self, key: str, destination: Path) -> Path:
|
| 129 |
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 130 |
+
try:
|
| 131 |
+
object_key = str(key).lstrip("/")
|
| 132 |
+
self._client.download_file(self.settings.s3_bucket, object_key, str(destination))
|
| 133 |
+
except (BotoCoreError, ClientError) as exc: # pragma: no cover - network side effects
|
| 134 |
+
raise StorageError(
|
| 135 |
+
f"Failed to download {object_key} from bucket {self.settings.s3_bucket} to {destination}: {exc}"
|
| 136 |
+
) from exc
|
| 137 |
+
return destination
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class LocalStorageClient(StorageClient):
|
| 141 |
+
"""On-disk storage backend for development and testing."""
|
| 142 |
+
|
| 143 |
+
def __init__(self, settings: WorkerSettings):
|
| 144 |
+
super().__init__(settings)
|
| 145 |
+
self.root = settings.local_storage_root
|
| 146 |
+
self.root.mkdir(parents=True, exist_ok=True)
|
| 147 |
+
|
| 148 |
+
def _resolve(self, key: str) -> Path:
|
| 149 |
+
path = self.root.joinpath(*key.split("/"))
|
| 150 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 151 |
+
return path
|
| 152 |
+
|
| 153 |
+
def build_uri(self, key: str) -> str:
|
| 154 |
+
return str(self._resolve(key))
|
| 155 |
+
|
| 156 |
+
def upload_file(self, local_path: Path, key: str, *, content_type: str | None = None) -> str: # noqa: ARG002
|
| 157 |
+
destination = self._resolve(key)
|
| 158 |
+
shutil.copyfile(local_path, destination)
|
| 159 |
+
return str(destination)
|
| 160 |
+
|
| 161 |
+
def upload_bytes(self, data: bytes, key: str, *, content_type: str | None = None) -> str: # noqa: ARG002
|
| 162 |
+
destination = self._resolve(key)
|
| 163 |
+
destination.write_bytes(data)
|
| 164 |
+
return str(destination)
|
| 165 |
+
|
| 166 |
+
def download_to_path(self, key: str, destination: Path) -> Path:
|
| 167 |
+
source = self._resolve(key)
|
| 168 |
+
if not source.exists():
|
| 169 |
+
raise StorageError(f"Local object not found for key: {key}")
|
| 170 |
+
destination.parent.mkdir(parents=True, exist_ok=True)
|
| 171 |
+
shutil.copyfile(source, destination)
|
| 172 |
+
return destination
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def create_storage_client(settings: WorkerSettings) -> StorageClient:
|
| 176 |
+
"""Instantiate the appropriate storage backend."""
|
| 177 |
+
|
| 178 |
+
if settings.s3_bucket:
|
| 179 |
+
return S3StorageClient(settings)
|
| 180 |
+
return LocalStorageClient(settings)
|
stream3r/worker/tasks.py
ADDED
|
@@ -0,0 +1,1036 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RQ job entrypoints for STream3R worker."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import base64
|
| 6 |
+
import json
|
| 7 |
+
import logging
|
| 8 |
+
import re
|
| 9 |
+
import shutil
|
| 10 |
+
import tempfile
|
| 11 |
+
import traceback
|
| 12 |
+
import uuid
|
| 13 |
+
from dataclasses import dataclass, field
|
| 14 |
+
from datetime import datetime, timezone
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from typing import Any, Callable, Mapping
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import requests
|
| 20 |
+
from rq import get_current_job
|
| 21 |
+
|
| 22 |
+
from stream3r.utils.visual_utils import predictions_to_glb
|
| 23 |
+
|
| 24 |
+
from .pipeline import InferenceResult, run_stream3r_inference
|
| 25 |
+
from .runtime import WorkerRuntime, get_runtime
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
logger = logging.getLogger(__name__)
|
| 29 |
+
|
| 30 |
+
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".bmp", ".webp"}
|
| 31 |
+
_SAFE_CHARS = re.compile(r"[^0-9A-Za-z_-]")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def _as_bool(value: Any, default: bool) -> bool:
|
| 35 |
+
if isinstance(value, bool):
|
| 36 |
+
return value
|
| 37 |
+
if isinstance(value, str):
|
| 38 |
+
lowered = value.strip().lower()
|
| 39 |
+
if lowered in {"1", "true", "yes", "y", "on"}:
|
| 40 |
+
return True
|
| 41 |
+
if lowered in {"0", "false", "no", "n", "off"}:
|
| 42 |
+
return False
|
| 43 |
+
return default
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def _as_int(value: Any, default: int) -> int:
|
| 47 |
+
try:
|
| 48 |
+
return int(value)
|
| 49 |
+
except (TypeError, ValueError):
|
| 50 |
+
return default
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass(slots=True)
|
| 54 |
+
class FrameRecord:
|
| 55 |
+
index: int
|
| 56 |
+
frame_id: str
|
| 57 |
+
path: Path
|
| 58 |
+
source: str | None = None
|
| 59 |
+
timestamp: str | None = None
|
| 60 |
+
metadata: dict[str, Any] = field(default_factory=dict)
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class ProgressTracker:
|
| 64 |
+
"""Aggregates frame progress to percentage updates."""
|
| 65 |
+
|
| 66 |
+
def __init__(self, runtime: WorkerRuntime, job_meta: Mapping[str, str | None]):
|
| 67 |
+
self.runtime = runtime
|
| 68 |
+
self.job_meta = job_meta
|
| 69 |
+
self.last_value = -1
|
| 70 |
+
|
| 71 |
+
def __call__(self, processed: int, total: int) -> None:
|
| 72 |
+
if total <= 0:
|
| 73 |
+
return
|
| 74 |
+
percent = int(round((processed / total) * 100))
|
| 75 |
+
percent = max(0, min(100, percent))
|
| 76 |
+
if percent == self.last_value:
|
| 77 |
+
return
|
| 78 |
+
self.last_value = percent
|
| 79 |
+
payload = {
|
| 80 |
+
**self.job_meta,
|
| 81 |
+
"status": "progress",
|
| 82 |
+
"progress": percent,
|
| 83 |
+
"ts": datetime.now(timezone.utc).timestamp(),
|
| 84 |
+
}
|
| 85 |
+
runtime_emit(self.runtime, payload)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def runtime_emit(runtime: WorkerRuntime, payload: Mapping[str, Any]) -> None:
|
| 89 |
+
runtime.emit_event(payload)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def _slugify(value: str, fallback: str) -> str:
|
| 93 |
+
candidate = _SAFE_CHARS.sub("_", value).strip("_")
|
| 94 |
+
if not candidate:
|
| 95 |
+
candidate = fallback
|
| 96 |
+
return candidate[:128]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _is_url(value: str) -> bool:
|
| 100 |
+
return value.startswith("http://") or value.startswith("https://")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _download_to_path(url: str, destination: Path) -> None:
|
| 104 |
+
response = requests.get(url, stream=True, timeout=60)
|
| 105 |
+
response.raise_for_status()
|
| 106 |
+
with destination.open("wb") as handle:
|
| 107 |
+
for chunk in response.iter_content(chunk_size=1 << 16):
|
| 108 |
+
if chunk:
|
| 109 |
+
handle.write(chunk)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _write_base64(content: str, destination: Path) -> None:
|
| 113 |
+
data = base64.b64decode(content)
|
| 114 |
+
destination.write_bytes(data)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _resolve_frame_entry(entry: Any, *, index: int, dest_dir: Path) -> FrameRecord:
|
| 118 |
+
metadata: dict[str, Any] = {}
|
| 119 |
+
timestamp = None
|
| 120 |
+
source = None
|
| 121 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 122 |
+
|
| 123 |
+
if isinstance(entry, str):
|
| 124 |
+
if _is_url(entry):
|
| 125 |
+
source = entry
|
| 126 |
+
frame_id = _slugify(Path(entry).stem or f"frame_{index:06d}", f"frame_{index:06d}")
|
| 127 |
+
destination = dest_dir / f"{frame_id}.jpg"
|
| 128 |
+
_download_to_path(entry, destination)
|
| 129 |
+
else:
|
| 130 |
+
path = Path(entry)
|
| 131 |
+
if not path.exists():
|
| 132 |
+
raise FileNotFoundError(f"Frame path does not exist: {entry}")
|
| 133 |
+
frame_id = _slugify(path.stem, f"frame_{index:06d}")
|
| 134 |
+
destination = dest_dir / path.name
|
| 135 |
+
shutil.copy2(path, destination)
|
| 136 |
+
elif isinstance(entry, Mapping):
|
| 137 |
+
frame_id = _slugify(str(entry.get("frame_id") or entry.get("id") or f"frame_{index:06d}"), f"frame_{index:06d}")
|
| 138 |
+
timestamp = entry.get("timestamp")
|
| 139 |
+
metadata = {k: v for k, v in entry.items() if k not in {"path", "url", "content", "frame_id", "id", "timestamp"}}
|
| 140 |
+
|
| 141 |
+
if path := entry.get("path") or entry.get("local_path"):
|
| 142 |
+
path = Path(path)
|
| 143 |
+
if not path.exists():
|
| 144 |
+
raise FileNotFoundError(f"Frame path does not exist: {path}")
|
| 145 |
+
destination = dest_dir / (path.name if path.suffix else f"{frame_id}.jpg")
|
| 146 |
+
shutil.copy2(path, destination)
|
| 147 |
+
elif url := entry.get("url"):
|
| 148 |
+
source = url
|
| 149 |
+
suffix = Path(url).suffix or ".jpg"
|
| 150 |
+
destination = dest_dir / f"{frame_id}{suffix}"
|
| 151 |
+
_download_to_path(url, destination)
|
| 152 |
+
elif content := entry.get("content"):
|
| 153 |
+
destination = dest_dir / f"{frame_id}.jpg"
|
| 154 |
+
_write_base64(content, destination)
|
| 155 |
+
else:
|
| 156 |
+
raise ValueError("Frame entry must include 'path', 'url', or 'content'")
|
| 157 |
+
else:
|
| 158 |
+
raise TypeError("Unsupported frame specification")
|
| 159 |
+
|
| 160 |
+
if destination.suffix.lower() not in IMAGE_EXTENSIONS:
|
| 161 |
+
destination = destination.with_suffix(".png")
|
| 162 |
+
|
| 163 |
+
return FrameRecord(
|
| 164 |
+
index=index,
|
| 165 |
+
frame_id=_slugify(destination.stem, f"frame_{index:06d}"),
|
| 166 |
+
path=destination,
|
| 167 |
+
source=source,
|
| 168 |
+
timestamp=timestamp,
|
| 169 |
+
metadata=metadata,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
def _collect_frames(
|
| 174 |
+
runtime: WorkerRuntime,
|
| 175 |
+
scene_id: str,
|
| 176 |
+
payload: Mapping[str, Any],
|
| 177 |
+
tmp_dir: Path,
|
| 178 |
+
) -> list[FrameRecord]:
|
| 179 |
+
frames_dir = tmp_dir / "frames"
|
| 180 |
+
frames_payload = payload.get("frames") or []
|
| 181 |
+
frame_limit = runtime.settings.max_frames_per_job
|
| 182 |
+
|
| 183 |
+
records: list[FrameRecord] = []
|
| 184 |
+
if frames_payload:
|
| 185 |
+
for entry in frames_payload:
|
| 186 |
+
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
|
| 187 |
+
break
|
| 188 |
+
records.append(_resolve_frame_entry(entry, index=len(records), dest_dir=frames_dir))
|
| 189 |
+
else:
|
| 190 |
+
directory = payload.get("frames_dir") or payload.get("images_dir")
|
| 191 |
+
if directory:
|
| 192 |
+
directory_path = Path(directory)
|
| 193 |
+
if not directory_path.is_dir():
|
| 194 |
+
raise ValueError(f"frames_dir does not exist: {directory}")
|
| 195 |
+
for idx, file in enumerate(sorted(directory_path.iterdir())):
|
| 196 |
+
if file.suffix.lower() not in IMAGE_EXTENSIONS:
|
| 197 |
+
continue
|
| 198 |
+
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
|
| 199 |
+
break
|
| 200 |
+
destination = frames_dir / file.name
|
| 201 |
+
shutil.copy2(file, destination)
|
| 202 |
+
records.append(
|
| 203 |
+
FrameRecord(
|
| 204 |
+
index=len(records),
|
| 205 |
+
frame_id=_slugify(file.stem, f"frame_{idx:06d}"),
|
| 206 |
+
path=destination,
|
| 207 |
+
)
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
if not records:
|
| 211 |
+
records = _collect_frames_from_scene_media(runtime, scene_id, frames_dir)
|
| 212 |
+
|
| 213 |
+
if not records:
|
| 214 |
+
raise ValueError(f"No valid frames found for scene '{scene_id}'")
|
| 215 |
+
|
| 216 |
+
limit = runtime.settings.max_frames_per_job
|
| 217 |
+
if limit and limit > 0 and len(records) > limit:
|
| 218 |
+
records = records[:limit]
|
| 219 |
+
|
| 220 |
+
for new_idx, record in enumerate(records):
|
| 221 |
+
if record.index != new_idx:
|
| 222 |
+
record.index = new_idx
|
| 223 |
+
|
| 224 |
+
return records
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _sanitize_payload(payload: Mapping[str, Any]) -> dict[str, Any]:
|
| 228 |
+
result = dict(payload)
|
| 229 |
+
frames = result.pop("frames", None)
|
| 230 |
+
if frames is not None:
|
| 231 |
+
result["frame_count"] = len(frames)
|
| 232 |
+
if "frames_dir" in result:
|
| 233 |
+
result["frames_dir"] = str(result["frames_dir"])
|
| 234 |
+
return result
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def _prepare_session_settings(
|
| 238 |
+
payload: Mapping[str, Any],
|
| 239 |
+
*,
|
| 240 |
+
mode: str,
|
| 241 |
+
streaming: bool,
|
| 242 |
+
frame_records: list[FrameRecord],
|
| 243 |
+
window_size: int | None = None,
|
| 244 |
+
) -> dict[str, Any]:
|
| 245 |
+
base_settings = payload.get("session_settings") or {}
|
| 246 |
+
session_settings = dict(base_settings)
|
| 247 |
+
session_settings.update(
|
| 248 |
+
{
|
| 249 |
+
"mode": mode,
|
| 250 |
+
"streaming": streaming,
|
| 251 |
+
"frame_count": len(frame_records),
|
| 252 |
+
}
|
| 253 |
+
)
|
| 254 |
+
window_setting = window_size if window_size is not None else payload.get("window_size")
|
| 255 |
+
if window_setting:
|
| 256 |
+
try:
|
| 257 |
+
session_settings["window_size"] = int(window_setting)
|
| 258 |
+
except (TypeError, ValueError):
|
| 259 |
+
pass
|
| 260 |
+
return session_settings
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def _collect_frames_from_scene_media(
|
| 264 |
+
runtime: WorkerRuntime,
|
| 265 |
+
scene_id: str,
|
| 266 |
+
dest_dir: Path,
|
| 267 |
+
) -> list[FrameRecord]:
|
| 268 |
+
base_url = runtime.settings.scene_media_api_base_url
|
| 269 |
+
if not base_url:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
"Scene media API base URL is not configured. Set API_BASE_URL"
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
base_url = base_url.rstrip("/")
|
| 275 |
+
dest_dir.mkdir(parents=True, exist_ok=True)
|
| 276 |
+
|
| 277 |
+
per_page = runtime.settings.scene_media_page_size
|
| 278 |
+
if per_page <= 0:
|
| 279 |
+
per_page = 100
|
| 280 |
+
per_page = max(1, min(per_page, 1000))
|
| 281 |
+
frame_limit = runtime.settings.max_frames_per_job
|
| 282 |
+
|
| 283 |
+
headers = {}
|
| 284 |
+
token = runtime.settings.scene_media_api_token
|
| 285 |
+
if token:
|
| 286 |
+
headers["Authorization"] = f"Bearer {token}"
|
| 287 |
+
|
| 288 |
+
url = f"{base_url}/scenes/{scene_id}/media"
|
| 289 |
+
session = requests.Session()
|
| 290 |
+
records: list[FrameRecord] = []
|
| 291 |
+
offset = 0
|
| 292 |
+
|
| 293 |
+
while True:
|
| 294 |
+
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
|
| 295 |
+
break
|
| 296 |
+
|
| 297 |
+
request_limit = per_page
|
| 298 |
+
if frame_limit and frame_limit > 0:
|
| 299 |
+
remaining = frame_limit - len(records)
|
| 300 |
+
if remaining <= 0:
|
| 301 |
+
break
|
| 302 |
+
request_limit = min(request_limit, remaining)
|
| 303 |
+
|
| 304 |
+
params = {
|
| 305 |
+
"limit": request_limit,
|
| 306 |
+
"offset": offset,
|
| 307 |
+
"media_type": "image",
|
| 308 |
+
}
|
| 309 |
+
try:
|
| 310 |
+
response = session.get(url, params=params, headers=headers, timeout=30)
|
| 311 |
+
response.raise_for_status()
|
| 312 |
+
except requests.RequestException as exc:
|
| 313 |
+
raise RuntimeError(f"Failed to fetch media for scene '{scene_id}': {exc}") from exc
|
| 314 |
+
|
| 315 |
+
data = response.json()
|
| 316 |
+
items = data.get("items") or []
|
| 317 |
+
if not items:
|
| 318 |
+
break
|
| 319 |
+
|
| 320 |
+
for item in items:
|
| 321 |
+
if frame_limit and frame_limit > 0 and len(records) >= frame_limit:
|
| 322 |
+
break
|
| 323 |
+
file_key = item.get("file")
|
| 324 |
+
if not file_key:
|
| 325 |
+
continue
|
| 326 |
+
|
| 327 |
+
idx = len(records)
|
| 328 |
+
source_path = Path(str(file_key))
|
| 329 |
+
suffix = source_path.suffix if source_path.suffix else ".png"
|
| 330 |
+
frame_id = _slugify(source_path.stem or f"frame_{idx:06d}", f"frame_{idx:06d}")
|
| 331 |
+
destination = dest_dir / f"{frame_id}{suffix}"
|
| 332 |
+
|
| 333 |
+
try:
|
| 334 |
+
runtime.storage.download_to_path(str(file_key), destination)
|
| 335 |
+
except Exception as exc: # pragma: no cover - download depends on external storage
|
| 336 |
+
raise RuntimeError(f"Failed to download media '{file_key}' for scene '{scene_id}': {exc}") from exc
|
| 337 |
+
|
| 338 |
+
records.append(
|
| 339 |
+
FrameRecord(
|
| 340 |
+
index=idx,
|
| 341 |
+
frame_id=frame_id,
|
| 342 |
+
path=destination,
|
| 343 |
+
source=str(file_key),
|
| 344 |
+
timestamp=item.get("captured_at"),
|
| 345 |
+
metadata={
|
| 346 |
+
"media_id": item.get("id"),
|
| 347 |
+
"media_type": item.get("media_type"),
|
| 348 |
+
},
|
| 349 |
+
)
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
if len(items) < request_limit:
|
| 353 |
+
break
|
| 354 |
+
offset += request_limit
|
| 355 |
+
|
| 356 |
+
return records
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def _pose_confidence(predictions: Mapping[str, np.ndarray]) -> np.ndarray | None:
|
| 360 |
+
if "world_points_conf" in predictions:
|
| 361 |
+
return np.asarray(predictions["world_points_conf"], dtype=np.float32)
|
| 362 |
+
if "depth_conf" in predictions:
|
| 363 |
+
return np.asarray(predictions["depth_conf"], dtype=np.float32)
|
| 364 |
+
return None
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _save_pointmaps(
|
| 368 |
+
*,
|
| 369 |
+
runtime: WorkerRuntime,
|
| 370 |
+
scene_id: str,
|
| 371 |
+
predictions: Mapping[str, np.ndarray],
|
| 372 |
+
frame_records: list[FrameRecord],
|
| 373 |
+
temp_dir: Path,
|
| 374 |
+
) -> dict[str, Any]:
|
| 375 |
+
world_points = predictions.get("world_points")
|
| 376 |
+
if world_points is None:
|
| 377 |
+
world_points = predictions.get("world_points_from_depth")
|
| 378 |
+
if world_points is None:
|
| 379 |
+
raise RuntimeError("Predictions missing world points")
|
| 380 |
+
|
| 381 |
+
world_points = np.asarray(world_points)
|
| 382 |
+
confidence = _pose_confidence(predictions)
|
| 383 |
+
if confidence is None:
|
| 384 |
+
confidence = np.ones(world_points.shape[:-1], dtype=np.float32)
|
| 385 |
+
|
| 386 |
+
local_dir = temp_dir / "pointmaps"
|
| 387 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
| 388 |
+
|
| 389 |
+
entries: list[dict[str, Any]] = []
|
| 390 |
+
for record in frame_records:
|
| 391 |
+
idx = record.index
|
| 392 |
+
filename = f"{record.frame_id}.npz"
|
| 393 |
+
local_file = local_dir / filename
|
| 394 |
+
np.savez(
|
| 395 |
+
local_file,
|
| 396 |
+
xyz=np.asarray(world_points[idx], dtype=np.float32),
|
| 397 |
+
confidence=np.asarray(confidence[idx], dtype=np.float32),
|
| 398 |
+
)
|
| 399 |
+
key = runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir, filename)
|
| 400 |
+
uri = runtime.storage.upload_file(local_file, key, content_type="application/octet-stream")
|
| 401 |
+
entries.append(
|
| 402 |
+
{
|
| 403 |
+
"frame_id": record.frame_id,
|
| 404 |
+
"frame_index": record.index,
|
| 405 |
+
"url": uri,
|
| 406 |
+
"timestamp": record.timestamp,
|
| 407 |
+
}
|
| 408 |
+
)
|
| 409 |
+
|
| 410 |
+
directory_uri = runtime.storage.build_uri(
|
| 411 |
+
runtime.storage.build_key(scene_id, runtime.settings.pointmap_dir)
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
return {
|
| 415 |
+
"pointmaps": entries,
|
| 416 |
+
"pointmap_dir": directory_uri,
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def _write_poses_jsonl(
|
| 421 |
+
*,
|
| 422 |
+
runtime: WorkerRuntime,
|
| 423 |
+
scene_id: str,
|
| 424 |
+
job_id: str,
|
| 425 |
+
predictions: Mapping[str, np.ndarray],
|
| 426 |
+
frame_records: list[FrameRecord],
|
| 427 |
+
temp_dir: Path,
|
| 428 |
+
) -> str:
|
| 429 |
+
extrinsic = np.asarray(predictions.get("extrinsic"))
|
| 430 |
+
intrinsic = predictions.get("intrinsic")
|
| 431 |
+
if intrinsic is not None:
|
| 432 |
+
intrinsic = np.asarray(intrinsic)
|
| 433 |
+
|
| 434 |
+
local_file = temp_dir / "poses.jsonl"
|
| 435 |
+
with local_file.open("w", encoding="utf-8") as handle:
|
| 436 |
+
for record in frame_records:
|
| 437 |
+
idx = record.index
|
| 438 |
+
payload = {
|
| 439 |
+
"job_id": job_id,
|
| 440 |
+
"scene_id": scene_id,
|
| 441 |
+
"frame_id": record.frame_id,
|
| 442 |
+
"frame_index": record.index,
|
| 443 |
+
"extrinsic": extrinsic[idx].tolist(),
|
| 444 |
+
}
|
| 445 |
+
if intrinsic is not None:
|
| 446 |
+
payload["intrinsic"] = intrinsic[idx].tolist()
|
| 447 |
+
if record.timestamp is not None:
|
| 448 |
+
payload["timestamp"] = record.timestamp
|
| 449 |
+
if record.source is not None:
|
| 450 |
+
payload["source"] = record.source
|
| 451 |
+
if record.metadata:
|
| 452 |
+
payload["metadata"] = record.metadata
|
| 453 |
+
handle.write(json.dumps(payload))
|
| 454 |
+
handle.write("\n")
|
| 455 |
+
|
| 456 |
+
key = runtime.storage.build_key(
|
| 457 |
+
scene_id,
|
| 458 |
+
runtime.settings.models_dir,
|
| 459 |
+
runtime.settings.poses_filename,
|
| 460 |
+
)
|
| 461 |
+
return runtime.storage.upload_file(local_file, key, content_type="application/json")
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
def _upload_cache(
|
| 465 |
+
*,
|
| 466 |
+
runtime: WorkerRuntime,
|
| 467 |
+
scene_id: str,
|
| 468 |
+
cache_path: Path | None,
|
| 469 |
+
) -> str | None:
|
| 470 |
+
if cache_path is None or not cache_path.exists():
|
| 471 |
+
return None
|
| 472 |
+
key = runtime.storage.build_key(
|
| 473 |
+
scene_id,
|
| 474 |
+
runtime.settings.models_dir,
|
| 475 |
+
runtime.settings.session_cache_filename,
|
| 476 |
+
)
|
| 477 |
+
return runtime.storage.upload_file(cache_path, key, content_type="application/octet-stream")
|
| 478 |
+
|
| 479 |
+
|
| 480 |
+
def _write_predictions_npz(
|
| 481 |
+
*,
|
| 482 |
+
runtime: WorkerRuntime,
|
| 483 |
+
scene_id: str,
|
| 484 |
+
predictions: Mapping[str, np.ndarray],
|
| 485 |
+
temp_dir: Path,
|
| 486 |
+
) -> str:
|
| 487 |
+
payload = {k: v for k, v in predictions.items() if isinstance(v, np.ndarray)}
|
| 488 |
+
local_file = temp_dir / runtime.settings.predictions_filename
|
| 489 |
+
np.savez(local_file, **payload)
|
| 490 |
+
key = runtime.storage.build_key(
|
| 491 |
+
scene_id,
|
| 492 |
+
runtime.settings.models_dir,
|
| 493 |
+
runtime.settings.predictions_filename,
|
| 494 |
+
)
|
| 495 |
+
return runtime.storage.upload_file(local_file, key, content_type="application/octet-stream")
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
def _write_session_settings(
|
| 499 |
+
*,
|
| 500 |
+
runtime: WorkerRuntime,
|
| 501 |
+
scene_id: str,
|
| 502 |
+
session_settings: Mapping[str, Any],
|
| 503 |
+
temp_dir: Path,
|
| 504 |
+
) -> str:
|
| 505 |
+
local_file = temp_dir / runtime.settings.session_settings_filename
|
| 506 |
+
local_file.write_text(json.dumps(session_settings, indent=2), encoding="utf-8")
|
| 507 |
+
key = runtime.storage.build_key(
|
| 508 |
+
scene_id,
|
| 509 |
+
runtime.settings.models_dir,
|
| 510 |
+
runtime.settings.session_settings_filename,
|
| 511 |
+
)
|
| 512 |
+
return runtime.storage.upload_file(local_file, key, content_type="application/json")
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
def _write_selected_frames(
|
| 516 |
+
*,
|
| 517 |
+
runtime: WorkerRuntime,
|
| 518 |
+
scene_id: str,
|
| 519 |
+
selected_frames: list[dict[str, Any]],
|
| 520 |
+
top_k: int,
|
| 521 |
+
temp_dir: Path,
|
| 522 |
+
) -> str | None:
|
| 523 |
+
if not selected_frames:
|
| 524 |
+
return None
|
| 525 |
+
local_file = temp_dir / runtime.settings.selected_frames_filename
|
| 526 |
+
payload = {"top_k": top_k, "frames": selected_frames}
|
| 527 |
+
local_file.write_text(json.dumps(payload, indent=2), encoding="utf-8")
|
| 528 |
+
key = runtime.storage.build_key(
|
| 529 |
+
scene_id,
|
| 530 |
+
runtime.settings.models_dir,
|
| 531 |
+
runtime.settings.selected_frames_filename,
|
| 532 |
+
)
|
| 533 |
+
return runtime.storage.upload_file(local_file, key, content_type="application/json")
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def _compute_selected_frames(
|
| 537 |
+
predictions: Mapping[str, np.ndarray],
|
| 538 |
+
frame_records: list[FrameRecord],
|
| 539 |
+
top_k: int,
|
| 540 |
+
) -> list[dict[str, Any]]:
|
| 541 |
+
if top_k <= 0:
|
| 542 |
+
return []
|
| 543 |
+
confidence = _pose_confidence(predictions)
|
| 544 |
+
if confidence is None:
|
| 545 |
+
return []
|
| 546 |
+
scores = confidence.reshape(confidence.shape[0], -1).mean(axis=1)
|
| 547 |
+
indices = np.argsort(scores)[::-1][:top_k]
|
| 548 |
+
result = []
|
| 549 |
+
for idx in indices:
|
| 550 |
+
record = frame_records[int(idx)]
|
| 551 |
+
result.append(
|
| 552 |
+
{
|
| 553 |
+
"frame_id": record.frame_id,
|
| 554 |
+
"frame_index": record.index,
|
| 555 |
+
"score": float(scores[idx]),
|
| 556 |
+
}
|
| 557 |
+
)
|
| 558 |
+
return result
|
| 559 |
+
|
| 560 |
+
|
| 561 |
+
def _save_scene_glb(
|
| 562 |
+
*,
|
| 563 |
+
runtime: WorkerRuntime,
|
| 564 |
+
scene_id: str,
|
| 565 |
+
predictions: Mapping[str, np.ndarray],
|
| 566 |
+
temp_dir: Path,
|
| 567 |
+
payload: Mapping[str, Any],
|
| 568 |
+
) -> str:
|
| 569 |
+
local_file = temp_dir / runtime.settings.scene_glb_filename
|
| 570 |
+
scene = predictions_to_glb(
|
| 571 |
+
dict(predictions),
|
| 572 |
+
conf_thres=float(payload.get("conf_thres", 3.0)),
|
| 573 |
+
filter_by_frames=payload.get("frame_filter", "All"),
|
| 574 |
+
mask_black_bg=_as_bool(payload.get("mask_black_bg"), False),
|
| 575 |
+
mask_white_bg=_as_bool(payload.get("mask_white_bg"), False),
|
| 576 |
+
show_cam=_as_bool(payload.get("show_cam"), True),
|
| 577 |
+
mask_sky=_as_bool(payload.get("mask_sky"), False),
|
| 578 |
+
target_dir=str(temp_dir),
|
| 579 |
+
prediction_mode=payload.get("prediction_mode", "Predicted Pointmap"),
|
| 580 |
+
)
|
| 581 |
+
scene.export(file_obj=str(local_file))
|
| 582 |
+
key = runtime.storage.build_key(
|
| 583 |
+
scene_id,
|
| 584 |
+
runtime.settings.models_dir,
|
| 585 |
+
runtime.settings.scene_glb_filename,
|
| 586 |
+
)
|
| 587 |
+
return runtime.storage.upload_file(local_file, key, content_type="model/gltf-binary")
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
def _write_summary_json(
|
| 591 |
+
*,
|
| 592 |
+
runtime: WorkerRuntime,
|
| 593 |
+
scene_id: str,
|
| 594 |
+
summary: Mapping[str, Any],
|
| 595 |
+
temp_dir: Path,
|
| 596 |
+
) -> str:
|
| 597 |
+
filename = runtime.settings.result_filename
|
| 598 |
+
local_file = temp_dir / filename
|
| 599 |
+
local_file.write_text(json.dumps(summary, indent=2), encoding="utf-8")
|
| 600 |
+
key = runtime.storage.build_key(
|
| 601 |
+
scene_id,
|
| 602 |
+
runtime.settings.models_dir,
|
| 603 |
+
filename,
|
| 604 |
+
)
|
| 605 |
+
return runtime.storage.upload_file(local_file, key, content_type="application/json")
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def _upload_result_record(
|
| 609 |
+
*,
|
| 610 |
+
runtime: WorkerRuntime,
|
| 611 |
+
scene_id: str,
|
| 612 |
+
job_id: str,
|
| 613 |
+
payload: Mapping[str, Any],
|
| 614 |
+
) -> str:
|
| 615 |
+
local = json.dumps(payload, indent=2).encode("utf-8")
|
| 616 |
+
key = runtime.storage.build_key(
|
| 617 |
+
scene_id,
|
| 618 |
+
runtime.settings.results_dir,
|
| 619 |
+
f"{job_id}.json",
|
| 620 |
+
)
|
| 621 |
+
return runtime.storage.upload_bytes(local, key, content_type="application/json")
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def _model_dir_uri(runtime: WorkerRuntime, scene_id: str) -> str:
|
| 625 |
+
return runtime.storage.build_uri(
|
| 626 |
+
runtime.storage.build_key(scene_id, runtime.settings.models_dir)
|
| 627 |
+
)
|
| 628 |
+
|
| 629 |
+
|
| 630 |
+
def _generate_core_outputs(
|
| 631 |
+
*,
|
| 632 |
+
runtime: WorkerRuntime,
|
| 633 |
+
scene_id: str,
|
| 634 |
+
job_id: str,
|
| 635 |
+
predictions: Mapping[str, np.ndarray],
|
| 636 |
+
frame_records: list[FrameRecord],
|
| 637 |
+
inference: InferenceResult,
|
| 638 |
+
session_settings: Mapping[str, Any],
|
| 639 |
+
temp_dir: Path,
|
| 640 |
+
) -> dict[str, Any]:
|
| 641 |
+
pointmap_info = _save_pointmaps(
|
| 642 |
+
runtime=runtime,
|
| 643 |
+
scene_id=scene_id,
|
| 644 |
+
predictions=predictions,
|
| 645 |
+
frame_records=frame_records,
|
| 646 |
+
temp_dir=temp_dir,
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
poses_url = _write_poses_jsonl(
|
| 650 |
+
runtime=runtime,
|
| 651 |
+
scene_id=scene_id,
|
| 652 |
+
job_id=job_id,
|
| 653 |
+
predictions=predictions,
|
| 654 |
+
frame_records=frame_records,
|
| 655 |
+
temp_dir=temp_dir,
|
| 656 |
+
)
|
| 657 |
+
|
| 658 |
+
cache_url = _upload_cache(
|
| 659 |
+
runtime=runtime,
|
| 660 |
+
scene_id=scene_id,
|
| 661 |
+
cache_path=inference.cache_path,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
predictions_url = _write_predictions_npz(
|
| 665 |
+
runtime=runtime,
|
| 666 |
+
scene_id=scene_id,
|
| 667 |
+
predictions=predictions,
|
| 668 |
+
temp_dir=temp_dir,
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
session_settings_url = _write_session_settings(
|
| 672 |
+
runtime=runtime,
|
| 673 |
+
scene_id=scene_id,
|
| 674 |
+
session_settings=session_settings,
|
| 675 |
+
temp_dir=temp_dir,
|
| 676 |
+
)
|
| 677 |
+
|
| 678 |
+
extrinsic = np.asarray(predictions.get("extrinsic"))
|
| 679 |
+
intrinsic = predictions.get("intrinsic")
|
| 680 |
+
if intrinsic is not None:
|
| 681 |
+
intrinsic = np.asarray(intrinsic)
|
| 682 |
+
|
| 683 |
+
frames_payload: list[dict[str, Any]] = []
|
| 684 |
+
for entry in pointmap_info["pointmaps"]:
|
| 685 |
+
idx = entry["frame_index"]
|
| 686 |
+
frame = frame_records[idx]
|
| 687 |
+
frame_payload = {
|
| 688 |
+
"frame_id": frame.frame_id,
|
| 689 |
+
"frame_index": frame.index,
|
| 690 |
+
"pointmap_url": entry["url"],
|
| 691 |
+
"extrinsic": extrinsic[idx].tolist(),
|
| 692 |
+
}
|
| 693 |
+
if intrinsic is not None:
|
| 694 |
+
frame_payload["intrinsic"] = intrinsic[idx].tolist()
|
| 695 |
+
if frame.timestamp is not None:
|
| 696 |
+
frame_payload["timestamp"] = frame.timestamp
|
| 697 |
+
if frame.source is not None:
|
| 698 |
+
frame_payload["source"] = frame.source
|
| 699 |
+
frames_payload.append(frame_payload)
|
| 700 |
+
|
| 701 |
+
artifacts = {
|
| 702 |
+
"poses_url": poses_url,
|
| 703 |
+
"pointmap_dir": pointmap_info["pointmap_dir"],
|
| 704 |
+
"pointmaps": pointmap_info["pointmaps"],
|
| 705 |
+
"predictions_url": predictions_url,
|
| 706 |
+
"session_settings_url": session_settings_url,
|
| 707 |
+
}
|
| 708 |
+
if cache_url:
|
| 709 |
+
artifacts["kv_cache_url"] = cache_url
|
| 710 |
+
|
| 711 |
+
return {
|
| 712 |
+
"artifacts": artifacts,
|
| 713 |
+
"frames": frames_payload,
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def _handle_pose_pointmap(
|
| 718 |
+
*,
|
| 719 |
+
runtime: WorkerRuntime,
|
| 720 |
+
payload: Mapping[str, Any],
|
| 721 |
+
mode: str,
|
| 722 |
+
streaming: bool,
|
| 723 |
+
job_id: str,
|
| 724 |
+
scene_id: str,
|
| 725 |
+
frame_records: list[FrameRecord],
|
| 726 |
+
inference: InferenceResult,
|
| 727 |
+
session_settings: Mapping[str, Any],
|
| 728 |
+
temp_dir: Path,
|
| 729 |
+
) -> dict[str, Any]:
|
| 730 |
+
predictions = inference.predictions
|
| 731 |
+
core = _generate_core_outputs(
|
| 732 |
+
runtime=runtime,
|
| 733 |
+
scene_id=scene_id,
|
| 734 |
+
job_id=job_id,
|
| 735 |
+
predictions=predictions,
|
| 736 |
+
frame_records=frame_records,
|
| 737 |
+
inference=inference,
|
| 738 |
+
session_settings=session_settings,
|
| 739 |
+
temp_dir=temp_dir,
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
result_payload = {
|
| 743 |
+
"job_id": job_id,
|
| 744 |
+
"job_type": "pose_pointmap",
|
| 745 |
+
"scene_id": scene_id,
|
| 746 |
+
"mode": mode,
|
| 747 |
+
"streaming": streaming,
|
| 748 |
+
"frame_count": inference.total_frames,
|
| 749 |
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
| 750 |
+
"artifacts": core["artifacts"],
|
| 751 |
+
"frames": core["frames"],
|
| 752 |
+
}
|
| 753 |
+
|
| 754 |
+
result_url = _upload_result_record(
|
| 755 |
+
runtime=runtime,
|
| 756 |
+
scene_id=scene_id,
|
| 757 |
+
job_id=job_id,
|
| 758 |
+
payload=result_payload,
|
| 759 |
+
)
|
| 760 |
+
result_payload["result_url"] = result_url
|
| 761 |
+
result_payload["model_dir"] = _model_dir_uri(runtime, scene_id)
|
| 762 |
+
|
| 763 |
+
return result_payload
|
| 764 |
+
|
| 765 |
+
|
| 766 |
+
JobHandler = Callable[..., dict[str, Any]]
|
| 767 |
+
|
| 768 |
+
|
| 769 |
+
def _execute_job(job_type: str, payload: Mapping[str, Any], handler: JobHandler) -> dict[str, Any]:
|
| 770 |
+
runtime = get_runtime()
|
| 771 |
+
job = get_current_job()
|
| 772 |
+
payload = dict(payload)
|
| 773 |
+
|
| 774 |
+
job_id = str(payload.get("job_id") or (job.id if job else uuid.uuid4()))
|
| 775 |
+
scene_id = payload.get("scene_id")
|
| 776 |
+
if not scene_id:
|
| 777 |
+
raise ValueError("Job payload is missing 'scene_id'")
|
| 778 |
+
|
| 779 |
+
payload.setdefault("job_type", job_type)
|
| 780 |
+
payload.setdefault("scene_id", scene_id)
|
| 781 |
+
|
| 782 |
+
mode = payload.get("mode") or runtime.settings.default_mode
|
| 783 |
+
streaming = _as_bool(payload.get("streaming"), runtime.settings.default_streaming)
|
| 784 |
+
window_size: int | None = None
|
| 785 |
+
|
| 786 |
+
if mode == "window":
|
| 787 |
+
streaming = True
|
| 788 |
+
payload["streaming"] = True
|
| 789 |
+
window_candidate = payload.get("window_size") or runtime.settings.stream_window_size
|
| 790 |
+
try:
|
| 791 |
+
window_size = int(window_candidate) if window_candidate else None
|
| 792 |
+
except (TypeError, ValueError):
|
| 793 |
+
window_size = runtime.settings.stream_window_size or None
|
| 794 |
+
if window_size and window_size > 0:
|
| 795 |
+
payload["window_size"] = window_size
|
| 796 |
+
else:
|
| 797 |
+
window_size = None
|
| 798 |
+
|
| 799 |
+
payload["mode"] = mode
|
| 800 |
+
timeout_override = payload.get("timeout")
|
| 801 |
+
if timeout_override is not None:
|
| 802 |
+
try:
|
| 803 |
+
job.timeout = int(timeout_override)
|
| 804 |
+
except (TypeError, ValueError):
|
| 805 |
+
pass
|
| 806 |
+
|
| 807 |
+
# Default to 15 minutes if no timeout already applied
|
| 808 |
+
if job.timeout is None:
|
| 809 |
+
job.timeout = 15 * 60
|
| 810 |
+
|
| 811 |
+
sanitized_payload = _sanitize_payload(payload)
|
| 812 |
+
|
| 813 |
+
job_meta = {
|
| 814 |
+
"job_id": job_id,
|
| 815 |
+
"job_type": job_type,
|
| 816 |
+
"scene_id": scene_id,
|
| 817 |
+
}
|
| 818 |
+
|
| 819 |
+
runtime.db.upsert_job(
|
| 820 |
+
job_id=job_id,
|
| 821 |
+
job_type=job_type,
|
| 822 |
+
scene_id=scene_id,
|
| 823 |
+
status="started",
|
| 824 |
+
payload=sanitized_payload,
|
| 825 |
+
)
|
| 826 |
+
|
| 827 |
+
runtime_emit(
|
| 828 |
+
runtime,
|
| 829 |
+
{
|
| 830 |
+
**job_meta,
|
| 831 |
+
"status": "started",
|
| 832 |
+
"progress": 0,
|
| 833 |
+
"ts": datetime.now(timezone.utc).timestamp(),
|
| 834 |
+
},
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
try:
|
| 838 |
+
with runtime.gpu_lock():
|
| 839 |
+
with tempfile.TemporaryDirectory(prefix=f"stream3r_{job_id}_") as tmp_dir:
|
| 840 |
+
temp_path = Path(tmp_dir)
|
| 841 |
+
frame_records = _collect_frames(runtime, scene_id, payload, temp_path)
|
| 842 |
+
cache_path = temp_path / runtime.settings.session_cache_filename if streaming else None
|
| 843 |
+
|
| 844 |
+
tracker = ProgressTracker(runtime, job_meta)
|
| 845 |
+
inference = run_stream3r_inference(
|
| 846 |
+
runtime=runtime,
|
| 847 |
+
image_paths=[record.path for record in frame_records],
|
| 848 |
+
mode=mode,
|
| 849 |
+
streaming=streaming,
|
| 850 |
+
cache_output_path=cache_path,
|
| 851 |
+
progress_cb=tracker,
|
| 852 |
+
window_size=window_size if streaming and mode == "window" else None,
|
| 853 |
+
)
|
| 854 |
+
|
| 855 |
+
session_settings = _prepare_session_settings(
|
| 856 |
+
payload,
|
| 857 |
+
mode=mode,
|
| 858 |
+
streaming=streaming,
|
| 859 |
+
frame_records=frame_records,
|
| 860 |
+
window_size=window_size,
|
| 861 |
+
)
|
| 862 |
+
|
| 863 |
+
result_payload = handler(
|
| 864 |
+
runtime=runtime,
|
| 865 |
+
payload=payload,
|
| 866 |
+
mode=mode,
|
| 867 |
+
streaming=streaming,
|
| 868 |
+
job_id=job_id,
|
| 869 |
+
scene_id=scene_id,
|
| 870 |
+
frame_records=frame_records,
|
| 871 |
+
inference=inference,
|
| 872 |
+
session_settings=session_settings,
|
| 873 |
+
temp_dir=temp_path,
|
| 874 |
+
)
|
| 875 |
+
|
| 876 |
+
except Exception as exc:
|
| 877 |
+
error_text = traceback.format_exc()
|
| 878 |
+
runtime.db.upsert_job(
|
| 879 |
+
job_id=job_id,
|
| 880 |
+
job_type=job_type,
|
| 881 |
+
scene_id=scene_id,
|
| 882 |
+
status="failed",
|
| 883 |
+
error=error_text,
|
| 884 |
+
)
|
| 885 |
+
runtime_emit(
|
| 886 |
+
runtime,
|
| 887 |
+
{
|
| 888 |
+
**job_meta,
|
| 889 |
+
"status": "failed",
|
| 890 |
+
"ts": datetime.now(timezone.utc).timestamp(),
|
| 891 |
+
"error": str(exc),
|
| 892 |
+
},
|
| 893 |
+
)
|
| 894 |
+
logger.exception("Job %s failed", job_id)
|
| 895 |
+
raise
|
| 896 |
+
|
| 897 |
+
runtime.db.upsert_job(
|
| 898 |
+
job_id=job_id,
|
| 899 |
+
job_type=job_type,
|
| 900 |
+
scene_id=scene_id,
|
| 901 |
+
status="finished",
|
| 902 |
+
result=result_payload,
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
runtime_emit(
|
| 906 |
+
runtime,
|
| 907 |
+
{
|
| 908 |
+
**job_meta,
|
| 909 |
+
"status": "finished",
|
| 910 |
+
"progress": 100,
|
| 911 |
+
"result_url": result_payload.get("result_url"),
|
| 912 |
+
"model_dir": result_payload.get("model_dir"),
|
| 913 |
+
"ts": datetime.now(timezone.utc).timestamp(),
|
| 914 |
+
},
|
| 915 |
+
)
|
| 916 |
+
|
| 917 |
+
return result_payload
|
| 918 |
+
|
| 919 |
+
|
| 920 |
+
def pose_pointmap_job(payload: Mapping[str, Any]) -> dict[str, Any]:
|
| 921 |
+
"""Process a pose + pointmap job."""
|
| 922 |
+
|
| 923 |
+
return _execute_job("pose_pointmap", payload, _handle_pose_pointmap)
|
| 924 |
+
|
| 925 |
+
|
| 926 |
+
def model_build_job(payload: Mapping[str, Any]) -> dict[str, Any]:
|
| 927 |
+
"""Process a full model build job."""
|
| 928 |
+
|
| 929 |
+
return _execute_job("model_build", payload, _handle_model_build)
|
| 930 |
+
|
| 931 |
+
|
| 932 |
+
def _handle_model_build(
|
| 933 |
+
*,
|
| 934 |
+
runtime: WorkerRuntime,
|
| 935 |
+
payload: Mapping[str, Any],
|
| 936 |
+
mode: str,
|
| 937 |
+
streaming: bool,
|
| 938 |
+
job_id: str,
|
| 939 |
+
scene_id: str,
|
| 940 |
+
frame_records: list[FrameRecord],
|
| 941 |
+
inference: InferenceResult,
|
| 942 |
+
session_settings: Mapping[str, Any],
|
| 943 |
+
temp_dir: Path,
|
| 944 |
+
) -> dict[str, Any]:
|
| 945 |
+
predictions = inference.predictions
|
| 946 |
+
|
| 947 |
+
core = _generate_core_outputs(
|
| 948 |
+
runtime=runtime,
|
| 949 |
+
scene_id=scene_id,
|
| 950 |
+
job_id=job_id,
|
| 951 |
+
predictions=predictions,
|
| 952 |
+
frame_records=frame_records,
|
| 953 |
+
inference=inference,
|
| 954 |
+
session_settings=session_settings,
|
| 955 |
+
temp_dir=temp_dir,
|
| 956 |
+
)
|
| 957 |
+
|
| 958 |
+
artifacts = dict(core["artifacts"])
|
| 959 |
+
|
| 960 |
+
top_k = _as_int(payload.get("top_k_frames") or payload.get("top_k"), 0)
|
| 961 |
+
selected_frames = _compute_selected_frames(predictions, frame_records, top_k)
|
| 962 |
+
selected_frames_url = _write_selected_frames(
|
| 963 |
+
runtime=runtime,
|
| 964 |
+
scene_id=scene_id,
|
| 965 |
+
selected_frames=selected_frames,
|
| 966 |
+
top_k=top_k,
|
| 967 |
+
temp_dir=temp_dir,
|
| 968 |
+
)
|
| 969 |
+
if selected_frames_url:
|
| 970 |
+
artifacts["selected_frames_url"] = selected_frames_url
|
| 971 |
+
|
| 972 |
+
scene_glb_url = _save_scene_glb(
|
| 973 |
+
runtime=runtime,
|
| 974 |
+
scene_id=scene_id,
|
| 975 |
+
predictions=predictions,
|
| 976 |
+
temp_dir=temp_dir,
|
| 977 |
+
payload=payload,
|
| 978 |
+
)
|
| 979 |
+
artifacts["scene_glb_url"] = scene_glb_url
|
| 980 |
+
|
| 981 |
+
summary_payload = {
|
| 982 |
+
"job_id": job_id,
|
| 983 |
+
"job_type": "model_build",
|
| 984 |
+
"scene_id": scene_id,
|
| 985 |
+
"frame_count": inference.total_frames,
|
| 986 |
+
"created_at": datetime.now(timezone.utc).isoformat(),
|
| 987 |
+
"artifacts": artifacts,
|
| 988 |
+
"selected_frames": selected_frames,
|
| 989 |
+
"parameters": {
|
| 990 |
+
"mode": mode,
|
| 991 |
+
"streaming": streaming,
|
| 992 |
+
"conf_thres": float(payload.get("conf_thres", 3.0)),
|
| 993 |
+
"frame_filter": payload.get("frame_filter", "All"),
|
| 994 |
+
"mask_black_bg": _as_bool(payload.get("mask_black_bg"), False),
|
| 995 |
+
"mask_white_bg": _as_bool(payload.get("mask_white_bg"), False),
|
| 996 |
+
"show_cam": _as_bool(payload.get("show_cam"), True),
|
| 997 |
+
"mask_sky": _as_bool(payload.get("mask_sky"), False),
|
| 998 |
+
"prediction_mode": payload.get("prediction_mode", "Predicted Pointmap"),
|
| 999 |
+
},
|
| 1000 |
+
}
|
| 1001 |
+
|
| 1002 |
+
summary_url = _write_summary_json(
|
| 1003 |
+
runtime=runtime,
|
| 1004 |
+
scene_id=scene_id,
|
| 1005 |
+
summary=summary_payload,
|
| 1006 |
+
temp_dir=temp_dir,
|
| 1007 |
+
)
|
| 1008 |
+
artifacts["summary_url"] = summary_url
|
| 1009 |
+
|
| 1010 |
+
result_record = dict(summary_payload)
|
| 1011 |
+
result_record["result_url"] = summary_url
|
| 1012 |
+
result_record_url = _upload_result_record(
|
| 1013 |
+
runtime=runtime,
|
| 1014 |
+
scene_id=scene_id,
|
| 1015 |
+
job_id=job_id,
|
| 1016 |
+
payload=result_record,
|
| 1017 |
+
)
|
| 1018 |
+
|
| 1019 |
+
result_payload = {
|
| 1020 |
+
"job_id": job_id,
|
| 1021 |
+
"job_type": "model_build",
|
| 1022 |
+
"scene_id": scene_id,
|
| 1023 |
+
"mode": mode,
|
| 1024 |
+
"streaming": streaming,
|
| 1025 |
+
"frame_count": inference.total_frames,
|
| 1026 |
+
"created_at": summary_payload["created_at"],
|
| 1027 |
+
"artifacts": artifacts,
|
| 1028 |
+
"frames": core["frames"],
|
| 1029 |
+
"selected_frames": selected_frames,
|
| 1030 |
+
"summary_url": summary_url,
|
| 1031 |
+
"result_url": summary_url,
|
| 1032 |
+
"result_record_url": result_record_url,
|
| 1033 |
+
"model_dir": _model_dir_uri(runtime, scene_id),
|
| 1034 |
+
}
|
| 1035 |
+
|
| 1036 |
+
return result_payload
|
worker/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility shim package for legacy job import paths."""
|
worker/stream3r/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
"""Compatibility namespace for legacy worker module paths."""
|
worker/stream3r/jobs.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Legacy job dispatch entrypoints for compatibility with existing queues."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
from typing import Any, Callable, Mapping
|
| 6 |
+
|
| 7 |
+
from stream3r.worker.tasks import model_build_job, pose_pointmap_job
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
_HANDLERS: dict[str, Callable[[Mapping[str, Any]], Any]] = {
|
| 11 |
+
"pose_pointmap": pose_pointmap_job,
|
| 12 |
+
"model_build": model_build_job,
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def handle_job(*args: Any, **kwargs: Any) -> Any:
|
| 17 |
+
"""Dispatch jobs enqueued with the legacy `worker.stream3r.jobs.handle_job` path.
|
| 18 |
+
|
| 19 |
+
Supports the following invocation patterns:
|
| 20 |
+
|
| 21 |
+
- ``handle_job(payload)`` where ``payload`` is a mapping containing ``job_type``.
|
| 22 |
+
- ``handle_job(job_type, payload)`` matching older enqueue signatures.
|
| 23 |
+
- ``handle_job(job_type=..., payload=...)`` keyword usage.
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
job_type: str | None = None
|
| 27 |
+
payload: Mapping[str, Any] | None = None
|
| 28 |
+
|
| 29 |
+
if args:
|
| 30 |
+
if isinstance(args[0], Mapping) and "job_type" in args[0]:
|
| 31 |
+
payload = args[0]
|
| 32 |
+
job_type = str(payload.get("job_type"))
|
| 33 |
+
else:
|
| 34 |
+
job_type = str(args[0]) if args else None
|
| 35 |
+
if len(args) > 1 and isinstance(args[1], Mapping):
|
| 36 |
+
payload = args[1]
|
| 37 |
+
|
| 38 |
+
if "job_type" in kwargs and not job_type:
|
| 39 |
+
job_type = str(kwargs["job_type"])
|
| 40 |
+
if "payload" in kwargs and payload is None:
|
| 41 |
+
candidate = kwargs["payload"]
|
| 42 |
+
if isinstance(candidate, Mapping):
|
| 43 |
+
payload = candidate
|
| 44 |
+
|
| 45 |
+
if payload is None and isinstance(args[0], Mapping):
|
| 46 |
+
payload = args[0]
|
| 47 |
+
job_type = str(payload.get("job_type")) if payload else job_type
|
| 48 |
+
|
| 49 |
+
if payload is None:
|
| 50 |
+
raise ValueError("handle_job requires a payload mapping")
|
| 51 |
+
|
| 52 |
+
if not job_type:
|
| 53 |
+
job_type = str(payload.get("job_type", "")).strip()
|
| 54 |
+
|
| 55 |
+
if not job_type:
|
| 56 |
+
raise ValueError("handle_job payload is missing 'job_type'")
|
| 57 |
+
|
| 58 |
+
handler = _HANDLERS.get(job_type)
|
| 59 |
+
if handler is None:
|
| 60 |
+
raise ValueError(f"Unsupported job_type '{job_type}'")
|
| 61 |
+
|
| 62 |
+
return handler(payload)
|
| 63 |
+
|