Patrick Kastner commited on
Commit ·
bf2cbd5
1
Parent(s): 4be65d1
Overhaul
Browse filesAdd root endpoint for base URL health checks
Fallback to default model URL when MODEL_URL is unset
.
.
Update Dockerfile
- .dockerignore +6 -0
- .gitignore +1 -0
- Dockerfile +17 -0
- README.md +94 -3
- api.py +456 -34
- download_model.py +58 -0
- pyproject.toml +23 -0
- render.yaml +25 -0
- requirements.txt +0 -8
- test.ipynb +0 -106
- uv.lock +0 -0
.dockerignore
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.venv
|
| 3 |
+
__pycache__/
|
| 4 |
+
*.pyc
|
| 5 |
+
*.pyo
|
| 6 |
+
.DS_Store
|
.gitignore
CHANGED
|
@@ -154,3 +154,4 @@ cython_debug/
|
|
| 154 |
.space
|
| 155 |
.onnx
|
| 156 |
model.onnx
|
|
|
|
|
|
| 154 |
.space
|
| 155 |
.onnx
|
| 156 |
model.onnx
|
| 157 |
+
.DS_Store
|
Dockerfile
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.11-slim
|
| 2 |
+
|
| 3 |
+
# Install uv
|
| 4 |
+
COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
|
| 5 |
+
|
| 6 |
+
WORKDIR /app
|
| 7 |
+
|
| 8 |
+
# Install dependencies first for layer caching
|
| 9 |
+
COPY pyproject.toml uv.lock ./
|
| 10 |
+
RUN uv sync --frozen --no-dev --no-install-project
|
| 11 |
+
|
| 12 |
+
COPY . .
|
| 13 |
+
|
| 14 |
+
# Model is downloaded at startup via MODEL_URL env var
|
| 15 |
+
CMD ["sh", "-c", "uv run python download_model.py && uv run uvicorn api:app --host 0.0.0.0 --port ${PORT:-8000}"]
|
| 16 |
+
|
| 17 |
+
EXPOSE 8000
|
README.md
CHANGED
|
@@ -1,4 +1,95 @@
|
|
| 1 |
-
# GAN
|
| 2 |
|
| 3 |
-
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Eddy3D GAN Wind Prediction API
|
| 2 |
|
| 3 |
+
FastAPI service that serves the GAN surrogate model for urban pedestrian-level wind flow prediction.
|
| 4 |
+
|
| 5 |
+
## ONNX Model Hosting
|
| 6 |
+
|
| 7 |
+
The ONNX model file (`GAN-21-05-2023-23-Generative.onnx`, ~208 MB) is **not included** in this repository. The container downloads it at startup from a URL you provide.
|
| 8 |
+
|
| 9 |
+
The model is hosted on Hugging Face:
|
| 10 |
+
|
| 11 |
+
```
|
| 12 |
+
https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
`MODEL_URL` defaults to this URL if not provided, but you can override it via environment variable.
|
| 16 |
+
|
| 17 |
+
### Optional: SHA-256 Integrity Check
|
| 18 |
+
|
| 19 |
+
Generate a checksum and set `MODEL_SHA256` to verify downloads:
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
shasum -a 256 GAN-21-05-2023-23-Generative.onnx
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
## Local Development
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
# Install uv (if not already installed)
|
| 29 |
+
curl -LsSf https://astral.sh/uv/install.sh | sh
|
| 30 |
+
|
| 31 |
+
# Install dependencies
|
| 32 |
+
uv sync
|
| 33 |
+
|
| 34 |
+
# Place your model file
|
| 35 |
+
cp /path/to/GAN-21-05-2023-23-Generative.onnx model.onnx
|
| 36 |
+
|
| 37 |
+
# Run the server
|
| 38 |
+
uv run uvicorn api:app --reload --port 8000
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Deploying to Render
|
| 42 |
+
|
| 43 |
+
1. Push this repo to GitHub
|
| 44 |
+
2. Create a new **Web Service** on [Render](https://render.com)
|
| 45 |
+
3. Connect the GitHub repo
|
| 46 |
+
4. Render will auto-detect the `Dockerfile`
|
| 47 |
+
5. Set the following environment variables in the Render dashboard:
|
| 48 |
+
|
| 49 |
+
| Variable | Required | Example |
|
| 50 |
+
|----------|----------|---------|
|
| 51 |
+
| `MODEL_URL` | No | `https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx` |
|
| 52 |
+
| `MODEL_SHA256` | No | `abc123...` (hex digest for integrity check) |
|
| 53 |
+
| `RATE_LIMIT_REQUESTS` | No | `30` (default: 30 requests per window) |
|
| 54 |
+
| `RATE_LIMIT_WINDOW` | No | `minute` (options: second, minute, hour, day) |
|
| 55 |
+
| `ALLOWED_ORIGINS` | No | `*` (comma-separated CORS origins) |
|
| 56 |
+
| `LOG_LEVEL` | No | `INFO` |
|
| 57 |
+
|
| 58 |
+
Alternatively, use `render.yaml` for infrastructure-as-code deployment (Blueprint).
|
| 59 |
+
|
| 60 |
+
## API Endpoints
|
| 61 |
+
|
| 62 |
+
| Method | Path | Description |
|
| 63 |
+
|--------|------|-------------|
|
| 64 |
+
| GET | `/health` | Health check (not rate limited) |
|
| 65 |
+
| POST | `/predict` | Image upload → wind speeds JSON + output image (base64) |
|
| 66 |
+
| POST | `/predict_array` | Raw float array → wind speeds JSON + output image (base64) |
|
| 67 |
+
| POST | `/predict_image` | Image upload → output PNG stream (legacy) |
|
| 68 |
+
|
| 69 |
+
### `/predict_array` (primary endpoint)
|
| 70 |
+
|
| 71 |
+
Accepts a JSON body with a flat float array of 786,432 values (3 x 512 x 512), channel-first order (R, G, B), normalised to [-1, 1].
|
| 72 |
+
|
| 73 |
+
```json
|
| 74 |
+
{
|
| 75 |
+
"data": [0.1, -0.5, ...]
|
| 76 |
+
}
|
| 77 |
+
```
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
|
| 81 |
+
```json
|
| 82 |
+
{
|
| 83 |
+
"wind_speeds": [0.5, 1.2, ...],
|
| 84 |
+
"image_base64": "iVBORw0KGgo...",
|
| 85 |
+
"width": 512,
|
| 86 |
+
"height": 512
|
| 87 |
+
}
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
## Docker
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
docker build -t eddy3d-gan-api .
|
| 94 |
+
docker run -p 8000:8000 -e MODEL_URL="https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx" eddy3d-gan-api
|
| 95 |
+
```
|
api.py
CHANGED
|
@@ -1,51 +1,473 @@
|
|
| 1 |
-
|
| 2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import io
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import onnxruntime as rt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
-
|
|
|
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
async def test():
|
| 13 |
-
return {"message": "API is working"}
|
| 14 |
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
-
#@app.api_route("/process_image", methods=["POST"])
|
| 22 |
-
@app.post("/process_image")
|
| 23 |
-
async def process_image(file: UploadFile = File(...)):
|
| 24 |
input_data = await file.read()
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
| 29 |
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
input_array = np.expand_dims(input_array, axis=0)
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
-
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
|
| 40 |
-
|
| 41 |
-
output_array = (outputs[0][0].transpose((1, 2, 0)) + 1) * 127.5
|
| 42 |
-
output_array = np.clip(output_array, 0, 255) # Clip values outside of valid range
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
|
|
|
|
|
|
| 46 |
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
byte_arr.seek(0) # Move cursor to the beginning of the file
|
| 50 |
|
| 51 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Eddy3D GAN Wind Prediction API
|
| 3 |
+
|
| 4 |
+
Serves the GAN surrogate model for urban wind flow prediction.
|
| 5 |
+
Accepts 512x512 input images or raw float arrays and returns predicted wind speeds.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
import io
|
| 10 |
+
import base64
|
| 11 |
+
import logging
|
| 12 |
+
import time
|
| 13 |
+
from contextlib import asynccontextmanager
|
| 14 |
+
from typing import Optional
|
| 15 |
+
|
| 16 |
import numpy as np
|
| 17 |
+
from PIL import Image
|
| 18 |
+
from fastapi import FastAPI, File, Request, UploadFile, HTTPException
|
| 19 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
+
from fastapi.responses import JSONResponse, StreamingResponse
|
| 21 |
+
from pydantic import BaseModel
|
| 22 |
import onnxruntime as rt
|
| 23 |
+
from slowapi import Limiter, _rate_limit_exceeded_handler
|
| 24 |
+
from slowapi.util import get_remote_address
|
| 25 |
+
from slowapi.errors import RateLimitExceeded
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Configuration
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "model.onnx")
|
| 31 |
+
PORT = int(os.environ.get("PORT", "8000"))
|
| 32 |
+
LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
|
| 33 |
+
ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
|
| 34 |
+
RATE_LIMIT_REQUESTS = os.environ.get("RATE_LIMIT_REQUESTS", "30")
|
| 35 |
+
RATE_LIMIT_WINDOW = os.environ.get("RATE_LIMIT_WINDOW", "minute")
|
| 36 |
+
MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
# ---------------------------------------------------------------------------
|
| 42 |
+
# Turbo colormap (256 entries, RGB in [0, 1])
|
| 43 |
+
# Used for reverse-mapping predicted pixel colours to wind speed values.
|
| 44 |
+
# ---------------------------------------------------------------------------
|
| 45 |
+
TURBO_COLORMAP = np.array([
|
| 46 |
+
[0.18995, 0.07176, 0.23217], [0.19483, 0.08339, 0.26149],
|
| 47 |
+
[0.19956, 0.09498, 0.29024], [0.20415, 0.10652, 0.31844],
|
| 48 |
+
[0.20860, 0.11802, 0.34607], [0.21291, 0.12947, 0.37314],
|
| 49 |
+
[0.21708, 0.14087, 0.39964], [0.22111, 0.15223, 0.42558],
|
| 50 |
+
[0.22500, 0.16354, 0.45096], [0.22875, 0.17481, 0.47578],
|
| 51 |
+
[0.23236, 0.18603, 0.50004], [0.23582, 0.19720, 0.52373],
|
| 52 |
+
[0.23915, 0.20833, 0.54686], [0.24234, 0.21941, 0.56942],
|
| 53 |
+
[0.24539, 0.23044, 0.59142], [0.24830, 0.24143, 0.61286],
|
| 54 |
+
[0.25107, 0.25237, 0.63374], [0.25369, 0.26327, 0.65406],
|
| 55 |
+
[0.25618, 0.27412, 0.67381], [0.25853, 0.28492, 0.69300],
|
| 56 |
+
[0.26074, 0.29568, 0.71162], [0.26280, 0.30639, 0.72968],
|
| 57 |
+
[0.26473, 0.31706, 0.74718], [0.26652, 0.32768, 0.76412],
|
| 58 |
+
[0.26816, 0.33825, 0.78050], [0.26967, 0.34878, 0.79631],
|
| 59 |
+
[0.27103, 0.35926, 0.81156], [0.27226, 0.36970, 0.82624],
|
| 60 |
+
[0.27334, 0.38008, 0.84037], [0.27429, 0.39043, 0.85393],
|
| 61 |
+
[0.27509, 0.40072, 0.86692], [0.27576, 0.41097, 0.87936],
|
| 62 |
+
[0.27628, 0.42118, 0.89123], [0.27667, 0.43134, 0.90254],
|
| 63 |
+
[0.27691, 0.44145, 0.91328], [0.27701, 0.45152, 0.92347],
|
| 64 |
+
[0.27698, 0.46153, 0.93309], [0.27680, 0.47151, 0.94214],
|
| 65 |
+
[0.27648, 0.48144, 0.95064], [0.27603, 0.49132, 0.95857],
|
| 66 |
+
[0.27543, 0.50115, 0.96594], [0.27469, 0.51094, 0.97275],
|
| 67 |
+
[0.27381, 0.52069, 0.97899], [0.27273, 0.53040, 0.98461],
|
| 68 |
+
[0.27106, 0.54015, 0.98930], [0.26878, 0.54995, 0.99303],
|
| 69 |
+
[0.26592, 0.55979, 0.99583], [0.26252, 0.56967, 0.99773],
|
| 70 |
+
[0.25862, 0.57958, 0.99876], [0.25425, 0.58950, 0.99896],
|
| 71 |
+
[0.24946, 0.59943, 0.99835], [0.24427, 0.60937, 0.99697],
|
| 72 |
+
[0.23874, 0.61931, 0.99485], [0.23288, 0.62923, 0.99202],
|
| 73 |
+
[0.22676, 0.63913, 0.98851], [0.22039, 0.64901, 0.98436],
|
| 74 |
+
[0.21382, 0.65886, 0.97959], [0.20708, 0.66866, 0.97423],
|
| 75 |
+
[0.20021, 0.67842, 0.96833], [0.19326, 0.68812, 0.96190],
|
| 76 |
+
[0.18625, 0.69775, 0.95498], [0.17923, 0.70732, 0.94761],
|
| 77 |
+
[0.17223, 0.71680, 0.93981], [0.16529, 0.72620, 0.93161],
|
| 78 |
+
[0.15844, 0.73551, 0.92305], [0.15173, 0.74472, 0.91416],
|
| 79 |
+
[0.14519, 0.75381, 0.90496], [0.13886, 0.76279, 0.89550],
|
| 80 |
+
[0.13278, 0.77165, 0.88580], [0.12698, 0.78037, 0.87590],
|
| 81 |
+
[0.12151, 0.78896, 0.86581], [0.11639, 0.79740, 0.85559],
|
| 82 |
+
[0.11167, 0.80569, 0.84525], [0.10738, 0.81381, 0.83484],
|
| 83 |
+
[0.10357, 0.82177, 0.82437], [0.10026, 0.82955, 0.81389],
|
| 84 |
+
[0.09750, 0.83714, 0.80342], [0.09532, 0.84455, 0.79299],
|
| 85 |
+
[0.09377, 0.85175, 0.78264], [0.09287, 0.85875, 0.77240],
|
| 86 |
+
[0.09267, 0.86554, 0.76230], [0.09320, 0.87211, 0.75237],
|
| 87 |
+
[0.09451, 0.87844, 0.74265], [0.09662, 0.88454, 0.73316],
|
| 88 |
+
[0.09958, 0.89040, 0.72393], [0.10342, 0.89600, 0.71500],
|
| 89 |
+
[0.10815, 0.90142, 0.70599], [0.11374, 0.90673, 0.69651],
|
| 90 |
+
[0.12014, 0.91193, 0.68660], [0.12733, 0.91701, 0.67627],
|
| 91 |
+
[0.13526, 0.92197, 0.66556], [0.14391, 0.92680, 0.65448],
|
| 92 |
+
[0.15323, 0.93151, 0.64308], [0.16319, 0.93609, 0.63137],
|
| 93 |
+
[0.17377, 0.94053, 0.61938], [0.18491, 0.94484, 0.60713],
|
| 94 |
+
[0.19659, 0.94901, 0.59466], [0.20877, 0.95304, 0.58199],
|
| 95 |
+
[0.22142, 0.95692, 0.56914], [0.23449, 0.96065, 0.55614],
|
| 96 |
+
[0.24797, 0.96423, 0.54303], [0.26180, 0.96765, 0.52981],
|
| 97 |
+
[0.27597, 0.97092, 0.51653], [0.29042, 0.97403, 0.50321],
|
| 98 |
+
[0.30513, 0.97697, 0.48987], [0.32006, 0.97974, 0.47654],
|
| 99 |
+
[0.33517, 0.98234, 0.46325], [0.35043, 0.98477, 0.45002],
|
| 100 |
+
[0.36581, 0.98702, 0.43688], [0.38127, 0.98909, 0.42386],
|
| 101 |
+
[0.39678, 0.99098, 0.41098], [0.41229, 0.99268, 0.39826],
|
| 102 |
+
[0.42778, 0.99419, 0.38575], [0.44321, 0.99551, 0.37345],
|
| 103 |
+
[0.45854, 0.99663, 0.36140], [0.47375, 0.99755, 0.34963],
|
| 104 |
+
[0.48879, 0.99828, 0.33816], [0.50362, 0.99879, 0.32701],
|
| 105 |
+
[0.51822, 0.99910, 0.31622], [0.53255, 0.99919, 0.30581],
|
| 106 |
+
[0.54658, 0.99907, 0.29581], [0.56026, 0.99873, 0.28623],
|
| 107 |
+
[0.57357, 0.99817, 0.27712], [0.58646, 0.99739, 0.26849],
|
| 108 |
+
[0.59891, 0.99638, 0.26038], [0.61088, 0.99514, 0.25280],
|
| 109 |
+
[0.62233, 0.99366, 0.24579], [0.63323, 0.99195, 0.23937],
|
| 110 |
+
[0.64362, 0.98999, 0.23356], [0.65394, 0.98775, 0.22835],
|
| 111 |
+
[0.66428, 0.98524, 0.22370], [0.67462, 0.98246, 0.21960],
|
| 112 |
+
[0.68494, 0.97941, 0.21602], [0.69525, 0.97610, 0.21294],
|
| 113 |
+
[0.70553, 0.97255, 0.21032], [0.71577, 0.96875, 0.20815],
|
| 114 |
+
[0.72596, 0.96470, 0.20640], [0.73610, 0.96043, 0.20504],
|
| 115 |
+
[0.74617, 0.95593, 0.20406], [0.75617, 0.95121, 0.20343],
|
| 116 |
+
[0.76608, 0.94627, 0.20311], [0.77591, 0.94113, 0.20310],
|
| 117 |
+
[0.78563, 0.93579, 0.20336], [0.79524, 0.93025, 0.20386],
|
| 118 |
+
[0.80473, 0.92452, 0.20459], [0.81410, 0.91861, 0.20552],
|
| 119 |
+
[0.82333, 0.91253, 0.20663], [0.83241, 0.90627, 0.20788],
|
| 120 |
+
[0.84133, 0.89986, 0.20926], [0.85010, 0.89328, 0.21074],
|
| 121 |
+
[0.85868, 0.88655, 0.21230], [0.86709, 0.87968, 0.21391],
|
| 122 |
+
[0.87530, 0.87267, 0.21555], [0.88331, 0.86553, 0.21719],
|
| 123 |
+
[0.89112, 0.85826, 0.21880], [0.89870, 0.85087, 0.22038],
|
| 124 |
+
[0.90605, 0.84337, 0.22188], [0.91317, 0.83576, 0.22328],
|
| 125 |
+
[0.92004, 0.82806, 0.22456], [0.92666, 0.82025, 0.22570],
|
| 126 |
+
[0.93301, 0.81236, 0.22667], [0.93909, 0.80439, 0.22744],
|
| 127 |
+
[0.94489, 0.79634, 0.22800], [0.95039, 0.78823, 0.22831],
|
| 128 |
+
[0.95560, 0.78005, 0.22836], [0.96049, 0.77181, 0.22811],
|
| 129 |
+
[0.96507, 0.76352, 0.22754], [0.96931, 0.75519, 0.22663],
|
| 130 |
+
[0.97323, 0.74682, 0.22536], [0.97679, 0.73842, 0.22369],
|
| 131 |
+
[0.98000, 0.73000, 0.22161], [0.98289, 0.72140, 0.21918],
|
| 132 |
+
[0.98549, 0.71250, 0.21650], [0.98781, 0.70330, 0.21358],
|
| 133 |
+
[0.98986, 0.69382, 0.21043], [0.99163, 0.68408, 0.20706],
|
| 134 |
+
[0.99314, 0.67408, 0.20348], [0.99438, 0.66386, 0.19971],
|
| 135 |
+
[0.99535, 0.65341, 0.19577], [0.99607, 0.64277, 0.19165],
|
| 136 |
+
[0.99654, 0.63193, 0.18738], [0.99675, 0.62093, 0.18297],
|
| 137 |
+
[0.99672, 0.60977, 0.17842], [0.99644, 0.59846, 0.17376],
|
| 138 |
+
[0.99593, 0.58703, 0.16899], [0.99517, 0.57549, 0.16412],
|
| 139 |
+
[0.99419, 0.56386, 0.15918], [0.99297, 0.55214, 0.15417],
|
| 140 |
+
[0.99153, 0.54036, 0.14910], [0.98987, 0.52854, 0.14398],
|
| 141 |
+
[0.98799, 0.51667, 0.13883], [0.98590, 0.50479, 0.13367],
|
| 142 |
+
[0.98360, 0.49291, 0.12849], [0.98108, 0.48104, 0.12332],
|
| 143 |
+
[0.97837, 0.46920, 0.11817], [0.97545, 0.45740, 0.11305],
|
| 144 |
+
[0.97234, 0.44565, 0.10797], [0.96904, 0.43399, 0.10294],
|
| 145 |
+
[0.96555, 0.42241, 0.09798], [0.96187, 0.41093, 0.09310],
|
| 146 |
+
[0.95801, 0.39958, 0.08831], [0.95398, 0.38836, 0.08362],
|
| 147 |
+
[0.94977, 0.37729, 0.07905], [0.94538, 0.36638, 0.07461],
|
| 148 |
+
[0.94084, 0.35566, 0.07031], [0.93612, 0.34513, 0.06616],
|
| 149 |
+
[0.93125, 0.33482, 0.06218], [0.92623, 0.32473, 0.05837],
|
| 150 |
+
[0.92105, 0.31489, 0.05475], [0.91572, 0.30530, 0.05134],
|
| 151 |
+
[0.91024, 0.29599, 0.04814], [0.90463, 0.28696, 0.04516],
|
| 152 |
+
[0.89888, 0.27824, 0.04243], [0.89298, 0.26981, 0.03993],
|
| 153 |
+
[0.88691, 0.26152, 0.03753], [0.88066, 0.25334, 0.03521],
|
| 154 |
+
[0.87422, 0.24526, 0.03297], [0.86760, 0.23730, 0.03082],
|
| 155 |
+
[0.86079, 0.22945, 0.02875], [0.85380, 0.22170, 0.02677],
|
| 156 |
+
[0.84662, 0.21407, 0.02487], [0.83926, 0.20654, 0.02305],
|
| 157 |
+
[0.83172, 0.19912, 0.02131], [0.82399, 0.19182, 0.01966],
|
| 158 |
+
[0.81608, 0.18462, 0.01809], [0.80799, 0.17753, 0.01660],
|
| 159 |
+
[0.79971, 0.17055, 0.01520], [0.79125, 0.16368, 0.01387],
|
| 160 |
+
[0.78260, 0.15693, 0.01264], [0.77377, 0.15028, 0.01148],
|
| 161 |
+
[0.76476, 0.14374, 0.01041], [0.75556, 0.13731, 0.00942],
|
| 162 |
+
[0.74617, 0.13098, 0.00851], [0.73661, 0.12477, 0.00769],
|
| 163 |
+
[0.72686, 0.11867, 0.00695], [0.71692, 0.11268, 0.00629],
|
| 164 |
+
[0.70680, 0.10680, 0.00571], [0.69650, 0.10102, 0.00522],
|
| 165 |
+
[0.68602, 0.09536, 0.00481], [0.67535, 0.08980, 0.00449],
|
| 166 |
+
[0.66449, 0.08436, 0.00424], [0.65345, 0.07902, 0.00408],
|
| 167 |
+
[0.64223, 0.07380, 0.00401], [0.63082, 0.06868, 0.00401],
|
| 168 |
+
[0.61923, 0.06367, 0.00410], [0.60746, 0.05878, 0.00427],
|
| 169 |
+
[0.59550, 0.05399, 0.00453], [0.58336, 0.04931, 0.00486],
|
| 170 |
+
[0.57103, 0.04474, 0.00529], [0.55852, 0.04028, 0.00579],
|
| 171 |
+
[0.54583, 0.03593, 0.00638], [0.53295, 0.03169, 0.00705],
|
| 172 |
+
[0.51989, 0.02756, 0.00780], [0.50664, 0.02354, 0.00863],
|
| 173 |
+
[0.49321, 0.01963, 0.00955], [0.47960, 0.01583, 0.01055],
|
| 174 |
+
], dtype=np.float64) # shape (256, 3)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _color_to_windspeed(raw_output: np.ndarray) -> list[float]:
|
| 178 |
+
"""Map raw model output to wind speed values using Turbo colormap reverse-lookup.
|
| 179 |
+
|
| 180 |
+
Ported from Prediction.ColorToNumber in the decompiled CFDComponent.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
raw_output: Model output array of shape (3, H, W) with values in [-1, 1].
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Flat list of wind speed values (0-15 m/s), length H*W.
|
| 187 |
+
"""
|
| 188 |
+
_, h, w = raw_output.shape
|
| 189 |
+
n_colors = TURBO_COLORMAP.shape[0]
|
| 190 |
+
|
| 191 |
+
# Convert from [-1, 1] to [0, 1]
|
| 192 |
+
r = (raw_output[0] + 1.0) / 2.0 # (H, W)
|
| 193 |
+
g = (raw_output[1] + 1.0) / 2.0
|
| 194 |
+
b = (raw_output[2] + 1.0) / 2.0
|
| 195 |
+
|
| 196 |
+
# Stack into (H*W, 3)
|
| 197 |
+
pixels = np.stack([r.ravel(), g.ravel(), b.ravel()], axis=1) # (H*W, 3)
|
| 198 |
+
|
| 199 |
+
# Find closest colormap entry for each pixel via Euclidean distance
|
| 200 |
+
# Using broadcasting: pixels (H*W, 1, 3) - colormap (1, 256, 3) -> (H*W, 256, 3)
|
| 201 |
+
diff = pixels[:, np.newaxis, :] - TURBO_COLORMAP[np.newaxis, :, :]
|
| 202 |
+
dists = np.sum(diff * diff, axis=2) # (H*W, 256)
|
| 203 |
+
indices = np.argmin(dists, axis=1) # (H*W,)
|
| 204 |
+
|
| 205 |
+
# Map index to wind speed: index / n_colors * 15.0
|
| 206 |
+
wind_speeds = indices.astype(np.float64) / n_colors * 15.0
|
| 207 |
+
|
| 208 |
+
return wind_speeds.tolist()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
# ---------------------------------------------------------------------------
|
| 212 |
+
# Rate limiting
|
| 213 |
+
# ---------------------------------------------------------------------------
|
| 214 |
+
_rate_limit_str = f"{RATE_LIMIT_REQUESTS}/{RATE_LIMIT_WINDOW}"
|
| 215 |
+
_rate_limiting_enabled = int(RATE_LIMIT_REQUESTS) > 0
|
| 216 |
+
|
| 217 |
+
limiter = Limiter(
|
| 218 |
+
key_func=get_remote_address,
|
| 219 |
+
default_limits=[_rate_limit_str] if _rate_limiting_enabled else [],
|
| 220 |
+
enabled=_rate_limiting_enabled,
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
# ---------------------------------------------------------------------------
|
| 224 |
+
# Application lifecycle
|
| 225 |
+
# ---------------------------------------------------------------------------
|
| 226 |
+
@asynccontextmanager
|
| 227 |
+
async def lifespan(app: FastAPI):
|
| 228 |
+
"""Load ONNX model at startup, release at shutdown."""
|
| 229 |
+
logger.info("Loading ONNX model from %s ...", MODEL_PATH)
|
| 230 |
+
if not os.path.exists(MODEL_PATH):
|
| 231 |
+
logger.error("Model file not found at %s", MODEL_PATH)
|
| 232 |
+
raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
|
| 233 |
+
|
| 234 |
+
session = rt.InferenceSession(MODEL_PATH)
|
| 235 |
+
input_name = session.get_inputs()[0].name
|
| 236 |
+
input_shape = session.get_inputs()[0].shape
|
| 237 |
+
output_name = session.get_outputs()[0].name
|
| 238 |
+
output_shape = session.get_outputs()[0].shape
|
| 239 |
+
logger.info("Model loaded. Input: %s %s Output: %s %s",
|
| 240 |
+
input_name, input_shape, output_name, output_shape)
|
| 241 |
+
|
| 242 |
+
app.state.session = session
|
| 243 |
+
app.state.input_name = input_name
|
| 244 |
+
app.state.model_loaded = True
|
| 245 |
+
|
| 246 |
+
yield
|
| 247 |
+
|
| 248 |
+
app.state.model_loaded = False
|
| 249 |
+
logger.info("Shutting down.")
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
app = FastAPI(
|
| 253 |
+
title="Eddy3D GAN Wind Prediction API",
|
| 254 |
+
version="2.0.0",
|
| 255 |
+
lifespan=lifespan,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
app.state.limiter = limiter
|
| 259 |
+
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
| 260 |
+
|
| 261 |
+
app.add_middleware(
|
| 262 |
+
CORSMiddleware,
|
| 263 |
+
allow_origins=ALLOWED_ORIGINS,
|
| 264 |
+
allow_credentials=True,
|
| 265 |
+
allow_methods=["*"],
|
| 266 |
+
allow_headers=["*"],
|
| 267 |
+
)
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ---------------------------------------------------------------------------
|
| 271 |
+
# Helper: run inference
|
| 272 |
+
# ---------------------------------------------------------------------------
|
| 273 |
+
def _run_inference(image_bytes: bytes) -> tuple[np.ndarray, np.ndarray]:
|
| 274 |
+
"""Preprocess image, run model, return (raw_output, denormalized_output).
|
| 275 |
+
|
| 276 |
+
Returns:
|
| 277 |
+
raw_output: shape (3, 512, 512), values in [-1, 1]
|
| 278 |
+
denorm_output: shape (512, 512, 3), values in [0, 255], uint8
|
| 279 |
+
"""
|
| 280 |
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 281 |
+
input_array = np.array(image.resize((512, 512)), dtype=np.float32)
|
| 282 |
|
| 283 |
+
# Normalize [0, 255] -> [-1, 1]
|
| 284 |
+
input_array = (input_array / 127.5) - 1.0
|
| 285 |
+
input_array = np.expand_dims(input_array, axis=0) # (1, 512, 512, 3)
|
| 286 |
+
input_array = input_array.transpose((0, 3, 1, 2)) # (1, 3, 512, 512)
|
| 287 |
|
| 288 |
+
session = app.state.session
|
| 289 |
+
inputs = {app.state.input_name: input_array}
|
| 290 |
+
outputs = session.run(None, inputs)
|
| 291 |
|
| 292 |
+
raw = outputs[0][0] # (3, 512, 512)
|
|
|
|
|
|
|
| 293 |
|
| 294 |
+
# Denormalise for image output
|
| 295 |
+
denorm = (raw.transpose((1, 2, 0)) + 1.0) * 127.5 # (512, 512, 3)
|
| 296 |
+
denorm = np.clip(denorm, 0, 255).astype(np.uint8)
|
| 297 |
+
|
| 298 |
+
return raw, denorm
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 302 |
+
"""Run model from a pre-normalised float32 array.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
data: shape (3, 512, 512), float32, values already in [-1, 1].
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
raw_output: shape (3, 512, 512), values in [-1, 1]
|
| 309 |
+
denorm_output: shape (512, 512, 3), values in [0, 255], uint8
|
| 310 |
+
"""
|
| 311 |
+
input_array = np.expand_dims(data, axis=0).astype(np.float32) # (1, 3, 512, 512)
|
| 312 |
+
|
| 313 |
+
session = app.state.session
|
| 314 |
+
inputs = {app.state.input_name: input_array}
|
| 315 |
+
outputs = session.run(None, inputs)
|
| 316 |
+
|
| 317 |
+
raw = outputs[0][0] # (3, 512, 512)
|
| 318 |
+
|
| 319 |
+
denorm = (raw.transpose((1, 2, 0)) + 1.0) * 127.5
|
| 320 |
+
denorm = np.clip(denorm, 0, 255).astype(np.uint8)
|
| 321 |
+
|
| 322 |
+
return raw, denorm
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
class ArrayPredictRequest(BaseModel):
|
| 326 |
+
"""Request body for /predict_array.
|
| 327 |
+
|
| 328 |
+
The data field is a flat list of 786432 floats (3 * 512 * 512) in
|
| 329 |
+
channel-first order (R, G, B), with values already normalised to [-1, 1].
|
| 330 |
+
"""
|
| 331 |
+
data: list[float]
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
# ---------------------------------------------------------------------------
|
| 335 |
+
# Endpoints
|
| 336 |
+
# ---------------------------------------------------------------------------
|
| 337 |
+
@app.get("/")
|
| 338 |
+
async def root():
|
| 339 |
+
"""Basic service info for base URL checks."""
|
| 340 |
+
return {
|
| 341 |
+
"service": "Eddy3D GAN Wind Prediction API",
|
| 342 |
+
"status": "ok",
|
| 343 |
+
"endpoints": ["/health", "/predict", "/predict_array", "/predict_image"],
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
@app.get("/health")
|
| 348 |
+
async def health():
|
| 349 |
+
"""Health check (not rate limited)."""
|
| 350 |
+
return {
|
| 351 |
+
"status": "healthy",
|
| 352 |
+
"model_loaded": getattr(app.state, "model_loaded", False),
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
@app.post("/predict")
|
| 357 |
+
@limiter.limit(_rate_limit_str)
|
| 358 |
+
async def predict(request: Request, file: UploadFile = File(...)):
|
| 359 |
+
"""Run GAN inference and return wind speeds + output image.
|
| 360 |
+
|
| 361 |
+
Returns JSON:
|
| 362 |
+
wind_speeds: list of floats (length 262144), values 0–15 m/s
|
| 363 |
+
image_base64: base64-encoded PNG of the output image
|
| 364 |
+
width: 512
|
| 365 |
+
height: 512
|
| 366 |
+
"""
|
| 367 |
+
# Validate content type
|
| 368 |
+
if file.content_type and not file.content_type.startswith("image/"):
|
| 369 |
+
raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
|
| 370 |
+
|
| 371 |
+
input_data = await file.read()
|
| 372 |
+
if len(input_data) > MAX_UPLOAD_BYTES:
|
| 373 |
+
raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
|
| 374 |
+
|
| 375 |
+
t0 = time.perf_counter()
|
| 376 |
+
try:
|
| 377 |
+
raw, denorm = _run_inference(input_data)
|
| 378 |
+
except Exception as e:
|
| 379 |
+
logger.exception("Inference failed")
|
| 380 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {e}")
|
| 381 |
+
|
| 382 |
+
# Wind speeds from raw output
|
| 383 |
+
wind_speeds = _color_to_windspeed(raw)
|
| 384 |
+
|
| 385 |
+
# Encode output image as base64 PNG
|
| 386 |
+
output_image = Image.fromarray(denorm, "RGB")
|
| 387 |
+
buf = io.BytesIO()
|
| 388 |
+
output_image.save(buf, format="PNG")
|
| 389 |
+
image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 390 |
+
|
| 391 |
+
elapsed = time.perf_counter() - t0
|
| 392 |
+
logger.info("Inference completed in %.2f s", elapsed)
|
| 393 |
+
|
| 394 |
+
return JSONResponse({
|
| 395 |
+
"wind_speeds": wind_speeds,
|
| 396 |
+
"image_base64": image_b64,
|
| 397 |
+
"width": 512,
|
| 398 |
+
"height": 512,
|
| 399 |
+
})
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
@app.post("/predict_image")
|
| 403 |
+
@limiter.limit(_rate_limit_str)
|
| 404 |
+
async def predict_image(request: Request, file: UploadFile = File(...)):
|
| 405 |
+
"""Legacy endpoint: returns the output image as a PNG stream."""
|
| 406 |
+
if file.content_type and not file.content_type.startswith("image/"):
|
| 407 |
+
raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
|
| 408 |
|
|
|
|
|
|
|
|
|
|
| 409 |
input_data = await file.read()
|
| 410 |
+
if len(input_data) > MAX_UPLOAD_BYTES:
|
| 411 |
+
raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
|
| 412 |
+
|
| 413 |
+
try:
|
| 414 |
+
_, denorm = _run_inference(input_data)
|
| 415 |
+
except Exception as e:
|
| 416 |
+
logger.exception("Inference failed")
|
| 417 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {e}")
|
| 418 |
+
|
| 419 |
+
output_image = Image.fromarray(denorm, "RGB")
|
| 420 |
+
buf = io.BytesIO()
|
| 421 |
+
output_image.save(buf, format="PNG")
|
| 422 |
+
buf.seek(0)
|
| 423 |
+
|
| 424 |
+
return StreamingResponse(buf, media_type="image/png")
|
| 425 |
+
|
| 426 |
|
| 427 |
+
@app.post("/predict_array")
|
| 428 |
+
@limiter.limit(_rate_limit_str)
|
| 429 |
+
async def predict_array(request: Request, body: ArrayPredictRequest):
|
| 430 |
+
"""Run GAN inference from a raw float array (no image encoding needed).
|
| 431 |
|
| 432 |
+
Accepts a flat list of 786432 floats (3 * 512 * 512) in channel-first
|
| 433 |
+
order, already normalised to [-1, 1].
|
|
|
|
| 434 |
|
| 435 |
+
Returns JSON:
|
| 436 |
+
wind_speeds: list of floats (length 262144), values 0-15 m/s
|
| 437 |
+
image_base64: base64-encoded PNG of the output image
|
| 438 |
+
width: 512
|
| 439 |
+
height: 512
|
| 440 |
+
"""
|
| 441 |
+
expected = 3 * 512 * 512
|
| 442 |
+
if len(body.data) != expected:
|
| 443 |
+
raise HTTPException(
|
| 444 |
+
status_code=400,
|
| 445 |
+
detail=f"Expected {expected} floats (3*512*512), got {len(body.data)}.",
|
| 446 |
+
)
|
| 447 |
|
| 448 |
+
t0 = time.perf_counter()
|
| 449 |
+
try:
|
| 450 |
+
arr = np.array(body.data, dtype=np.float32).reshape((3, 512, 512))
|
| 451 |
+
raw, denorm = _run_inference_from_array(arr)
|
| 452 |
+
except HTTPException:
|
| 453 |
+
raise
|
| 454 |
+
except Exception as e:
|
| 455 |
+
logger.exception("Inference failed")
|
| 456 |
+
raise HTTPException(status_code=500, detail=f"Inference error: {e}")
|
| 457 |
|
| 458 |
+
wind_speeds = _color_to_windspeed(raw)
|
|
|
|
|
|
|
| 459 |
|
| 460 |
+
output_image = Image.fromarray(denorm, "RGB")
|
| 461 |
+
buf = io.BytesIO()
|
| 462 |
+
output_image.save(buf, format="PNG")
|
| 463 |
+
image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 464 |
|
| 465 |
+
elapsed = time.perf_counter() - t0
|
| 466 |
+
logger.info("Array inference completed in %.2f s", elapsed)
|
|
|
|
| 467 |
|
| 468 |
+
return JSONResponse({
|
| 469 |
+
"wind_speeds": wind_speeds,
|
| 470 |
+
"image_base64": image_b64,
|
| 471 |
+
"width": 512,
|
| 472 |
+
"height": 512,
|
| 473 |
+
})
|
download_model.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Download ONNX model from cloud storage if not already present."""
|
| 2 |
+
|
| 3 |
+
import hashlib
|
| 4 |
+
import os
|
| 5 |
+
import sys
|
| 6 |
+
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
DEFAULT_MODEL_URL = (
|
| 10 |
+
"https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/"
|
| 11 |
+
"GAN-21-05-2023-23-Generative.onnx"
|
| 12 |
+
)
|
| 13 |
+
|
| 14 |
+
# Allow empty env var values to fall back to the default hosted model.
|
| 15 |
+
MODEL_URL = os.environ.get("MODEL_URL") or DEFAULT_MODEL_URL
|
| 16 |
+
MODEL_PATH = os.environ.get("MODEL_PATH", "model.onnx")
|
| 17 |
+
MODEL_SHA256 = os.environ.get("MODEL_SHA256", "")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def download() -> None:
|
| 21 |
+
if os.path.exists(MODEL_PATH):
|
| 22 |
+
print(f"Model already exists at {MODEL_PATH}")
|
| 23 |
+
return
|
| 24 |
+
|
| 25 |
+
print(f"Downloading model from {MODEL_URL} ...")
|
| 26 |
+
response = requests.get(MODEL_URL, stream=True, timeout=600)
|
| 27 |
+
response.raise_for_status()
|
| 28 |
+
|
| 29 |
+
os.makedirs(os.path.dirname(MODEL_PATH) or ".", exist_ok=True)
|
| 30 |
+
|
| 31 |
+
total = int(response.headers.get("content-length", 0))
|
| 32 |
+
downloaded = 0
|
| 33 |
+
with open(MODEL_PATH, "wb") as f:
|
| 34 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 35 |
+
f.write(chunk)
|
| 36 |
+
downloaded += len(chunk)
|
| 37 |
+
if total:
|
| 38 |
+
pct = downloaded / total * 100
|
| 39 |
+
print(f"\r {downloaded / 1e6:.1f} / {total / 1e6:.1f} MB ({pct:.0f}%)", end="", flush=True)
|
| 40 |
+
print()
|
| 41 |
+
|
| 42 |
+
# Optional SHA-256 integrity check
|
| 43 |
+
if MODEL_SHA256:
|
| 44 |
+
h = hashlib.sha256()
|
| 45 |
+
with open(MODEL_PATH, "rb") as f:
|
| 46 |
+
for chunk in iter(lambda: f.read(8192), b""):
|
| 47 |
+
h.update(chunk)
|
| 48 |
+
if h.hexdigest() != MODEL_SHA256:
|
| 49 |
+
os.remove(MODEL_PATH)
|
| 50 |
+
print(f"ERROR: SHA-256 mismatch! Expected {MODEL_SHA256}, got {h.hexdigest()}", file=sys.stderr)
|
| 51 |
+
sys.exit(1)
|
| 52 |
+
print("SHA-256 verified.")
|
| 53 |
+
|
| 54 |
+
print(f"Model downloaded to {MODEL_PATH}")
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == "__main__":
|
| 58 |
+
download()
|
pyproject.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "eddy3d-gan-api"
|
| 3 |
+
version = "2.0.0"
|
| 4 |
+
description = "GAN surrogate model API for urban wind flow prediction"
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.11"
|
| 7 |
+
dependencies = [
|
| 8 |
+
"fastapi==0.115.0",
|
| 9 |
+
"uvicorn[standard]==0.30.0",
|
| 10 |
+
"python-multipart==0.0.9",
|
| 11 |
+
"onnxruntime==1.18.0",
|
| 12 |
+
"pillow==10.4.0",
|
| 13 |
+
"numpy==1.26.4",
|
| 14 |
+
"requests>=2.32.3",
|
| 15 |
+
"slowapi==0.1.9",
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
[build-system]
|
| 19 |
+
requires = ["hatchling"]
|
| 20 |
+
build-backend = "hatchling.build"
|
| 21 |
+
|
| 22 |
+
[tool.uv]
|
| 23 |
+
package = false
|
render.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
services:
|
| 2 |
+
- type: web
|
| 3 |
+
name: eddy3d-gan-api
|
| 4 |
+
runtime: docker
|
| 5 |
+
dockerfilePath: ./Dockerfile
|
| 6 |
+
region: oregon
|
| 7 |
+
plan: starter
|
| 8 |
+
healthCheckPath: /health
|
| 9 |
+
envVars:
|
| 10 |
+
- key: MODEL_URL
|
| 11 |
+
value: "https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx"
|
| 12 |
+
- key: MODEL_PATH
|
| 13 |
+
value: /app/model.onnx
|
| 14 |
+
- key: MODEL_SHA256
|
| 15 |
+
value: ""
|
| 16 |
+
- key: PORT
|
| 17 |
+
value: "8000"
|
| 18 |
+
- key: LOG_LEVEL
|
| 19 |
+
value: INFO
|
| 20 |
+
- key: ALLOWED_ORIGINS
|
| 21 |
+
value: "*"
|
| 22 |
+
- key: RATE_LIMIT_REQUESTS
|
| 23 |
+
value: "30"
|
| 24 |
+
- key: RATE_LIMIT_WINDOW
|
| 25 |
+
value: minute
|
requirements.txt
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
fastapi
|
| 2 |
-
uvicorn[standard]
|
| 3 |
-
python-multipart
|
| 4 |
-
uvicorn
|
| 5 |
-
onnxruntime
|
| 6 |
-
pillow
|
| 7 |
-
numpy
|
| 8 |
-
requests
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
test.ipynb
DELETED
|
@@ -1,106 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"cells": [
|
| 3 |
-
{
|
| 4 |
-
"cell_type": "code",
|
| 5 |
-
"execution_count": 13,
|
| 6 |
-
"metadata": {},
|
| 7 |
-
"outputs": [],
|
| 8 |
-
"source": [
|
| 9 |
-
"import requests\n",
|
| 10 |
-
"from PIL import Image\n",
|
| 11 |
-
"import io\n"
|
| 12 |
-
]
|
| 13 |
-
},
|
| 14 |
-
{
|
| 15 |
-
"cell_type": "code",
|
| 16 |
-
"execution_count": 22,
|
| 17 |
-
"metadata": {},
|
| 18 |
-
"outputs": [
|
| 19 |
-
{
|
| 20 |
-
"name": "stdout",
|
| 21 |
-
"output_type": "stream",
|
| 22 |
-
"text": [
|
| 23 |
-
"b'{\"message\":\"API is working\"}'\n"
|
| 24 |
-
]
|
| 25 |
-
}
|
| 26 |
-
],
|
| 27 |
-
"source": [
|
| 28 |
-
"response = requests.get(\"http://gan-230522.eddy3d.com/test\")\n",
|
| 29 |
-
"\n",
|
| 30 |
-
"print(response.content)\n"
|
| 31 |
-
]
|
| 32 |
-
},
|
| 33 |
-
{
|
| 34 |
-
"cell_type": "code",
|
| 35 |
-
"execution_count": 35,
|
| 36 |
-
"metadata": {},
|
| 37 |
-
"outputs": [],
|
| 38 |
-
"source": [
|
| 39 |
-
"# Open your image file in binary mode\n",
|
| 40 |
-
"with open(\"test_input.png\", \"rb\") as f:\n",
|
| 41 |
-
" img_binary = f.read()\n",
|
| 42 |
-
"\n",
|
| 43 |
-
"\n",
|
| 44 |
-
"# Create a file-like object from the image binary data\n",
|
| 45 |
-
"img_file = io.BytesIO(img_binary)\n",
|
| 46 |
-
"\n",
|
| 47 |
-
"# Open the image using PIL\n",
|
| 48 |
-
"img_in = Image.open(img_file)\n",
|
| 49 |
-
"\n",
|
| 50 |
-
"# Display the image\n",
|
| 51 |
-
"img_in.show()\n",
|
| 52 |
-
"\n",
|
| 53 |
-
"\n",
|
| 54 |
-
"# Send the POST request\n",
|
| 55 |
-
"# response = requests.post('http://127.0.0.1:8000/process_image', files={'file': img_binary})\n",
|
| 56 |
-
"response = requests.post(\n",
|
| 57 |
-
" \"https://gan-230522.eddy3d.com/process_image\", files={\"file\": img_binary}\n",
|
| 58 |
-
")\n",
|
| 59 |
-
"\n",
|
| 60 |
-
"\n",
|
| 61 |
-
"# The response is an image file, so we need to convert it back to an image\n",
|
| 62 |
-
"image = Image.open(io.BytesIO(response.content))\n",
|
| 63 |
-
"\n",
|
| 64 |
-
"# Now you can display the image or save it to a file\n",
|
| 65 |
-
"image.show()\n"
|
| 66 |
-
]
|
| 67 |
-
},
|
| 68 |
-
{
|
| 69 |
-
"cell_type": "code",
|
| 70 |
-
"execution_count": 18,
|
| 71 |
-
"metadata": {},
|
| 72 |
-
"outputs": [
|
| 73 |
-
{
|
| 74 |
-
"name": "stdout",
|
| 75 |
-
"output_type": "stream",
|
| 76 |
-
"text": [
|
| 77 |
-
"b'{\"message\":\"API is working\"}'\n"
|
| 78 |
-
]
|
| 79 |
-
}
|
| 80 |
-
],
|
| 81 |
-
"source": []
|
| 82 |
-
}
|
| 83 |
-
],
|
| 84 |
-
"metadata": {
|
| 85 |
-
"kernelspec": {
|
| 86 |
-
"display_name": "ganapi",
|
| 87 |
-
"language": "python",
|
| 88 |
-
"name": "python3"
|
| 89 |
-
},
|
| 90 |
-
"language_info": {
|
| 91 |
-
"codemirror_mode": {
|
| 92 |
-
"name": "ipython",
|
| 93 |
-
"version": 3
|
| 94 |
-
},
|
| 95 |
-
"file_extension": ".py",
|
| 96 |
-
"mimetype": "text/x-python",
|
| 97 |
-
"name": "python",
|
| 98 |
-
"nbconvert_exporter": "python",
|
| 99 |
-
"pygments_lexer": "ipython3",
|
| 100 |
-
"version": "3.9.16"
|
| 101 |
-
},
|
| 102 |
-
"orig_nbformat": 4
|
| 103 |
-
},
|
| 104 |
-
"nbformat": 4,
|
| 105 |
-
"nbformat_minor": 2
|
| 106 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
uv.lock
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|