Patrick Kastner commited on
Commit
bf2cbd5
·
1 Parent(s): 4be65d1

Add root endpoint for base URL health checks

Fallback to default model URL when MODEL_URL is unset

.

.

Update Dockerfile

Files changed (11) hide show
  1. .dockerignore +6 -0
  2. .gitignore +1 -0
  3. Dockerfile +17 -0
  4. README.md +94 -3
  5. api.py +456 -34
  6. download_model.py +58 -0
  7. pyproject.toml +23 -0
  8. render.yaml +25 -0
  9. requirements.txt +0 -8
  10. test.ipynb +0 -106
  11. uv.lock +0 -0
.dockerignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ .git
2
+ .venv
3
+ __pycache__/
4
+ *.pyc
5
+ *.pyo
6
+ .DS_Store
.gitignore CHANGED
@@ -154,3 +154,4 @@ cython_debug/
154
  .space
155
  .onnx
156
  model.onnx
 
 
154
  .space
155
  .onnx
156
  model.onnx
157
+ .DS_Store
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.11-slim
2
+
3
+ # Install uv
4
+ COPY --from=ghcr.io/astral-sh/uv:latest /uv /uvx /bin/
5
+
6
+ WORKDIR /app
7
+
8
+ # Install dependencies first for layer caching
9
+ COPY pyproject.toml uv.lock ./
10
+ RUN uv sync --frozen --no-dev --no-install-project
11
+
12
+ COPY . .
13
+
14
+ # Model is downloaded at startup via MODEL_URL env var
15
+ CMD ["sh", "-c", "uv run python download_model.py && uv run uvicorn api:app --host 0.0.0.0 --port ${PORT:-8000}"]
16
+
17
+ EXPOSE 8000
README.md CHANGED
@@ -1,4 +1,95 @@
1
- # GAN-API
2
 
3
- http://gan-230522.eddy3d.com/test
4
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Eddy3D GAN Wind Prediction API
2
 
3
+ FastAPI service that serves the GAN surrogate model for urban pedestrian-level wind flow prediction.
4
+
5
+ ## ONNX Model Hosting
6
+
7
+ The ONNX model file (`GAN-21-05-2023-23-Generative.onnx`, ~208 MB) is **not included** in this repository. The container downloads it at startup from a URL you provide.
8
+
9
+ The model is hosted on Hugging Face:
10
+
11
+ ```
12
+ https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx
13
+ ```
14
+
15
+ `MODEL_URL` defaults to this URL if not provided, but you can override it via environment variable.
16
+
17
+ ### Optional: SHA-256 Integrity Check
18
+
19
+ Generate a checksum and set `MODEL_SHA256` to verify downloads:
20
+
21
+ ```bash
22
+ shasum -a 256 GAN-21-05-2023-23-Generative.onnx
23
+ ```
24
+
25
+ ## Local Development
26
+
27
+ ```bash
28
+ # Install uv (if not already installed)
29
+ curl -LsSf https://astral.sh/uv/install.sh | sh
30
+
31
+ # Install dependencies
32
+ uv sync
33
+
34
+ # Place your model file
35
+ cp /path/to/GAN-21-05-2023-23-Generative.onnx model.onnx
36
+
37
+ # Run the server
38
+ uv run uvicorn api:app --reload --port 8000
39
+ ```
40
+
41
+ ## Deploying to Render
42
+
43
+ 1. Push this repo to GitHub
44
+ 2. Create a new **Web Service** on [Render](https://render.com)
45
+ 3. Connect the GitHub repo
46
+ 4. Render will auto-detect the `Dockerfile`
47
+ 5. Set the following environment variables in the Render dashboard:
48
+
49
+ | Variable | Required | Example |
50
+ |----------|----------|---------|
51
+ | `MODEL_URL` | No | `https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx` |
52
+ | `MODEL_SHA256` | No | `abc123...` (hex digest for integrity check) |
53
+ | `RATE_LIMIT_REQUESTS` | No | `30` (default: 30 requests per window) |
54
+ | `RATE_LIMIT_WINDOW` | No | `minute` (options: second, minute, hour, day) |
55
+ | `ALLOWED_ORIGINS` | No | `*` (comma-separated CORS origins) |
56
+ | `LOG_LEVEL` | No | `INFO` |
57
+
58
+ Alternatively, use `render.yaml` for infrastructure-as-code deployment (Blueprint).
59
+
60
+ ## API Endpoints
61
+
62
+ | Method | Path | Description |
63
+ |--------|------|-------------|
64
+ | GET | `/health` | Health check (not rate limited) |
65
+ | POST | `/predict` | Image upload → wind speeds JSON + output image (base64) |
66
+ | POST | `/predict_array` | Raw float array → wind speeds JSON + output image (base64) |
67
+ | POST | `/predict_image` | Image upload → output PNG stream (legacy) |
68
+
69
+ ### `/predict_array` (primary endpoint)
70
+
71
+ Accepts a JSON body with a flat float array of 786,432 values (3 x 512 x 512), channel-first order (R, G, B), normalised to [-1, 1].
72
+
73
+ ```json
74
+ {
75
+ "data": [0.1, -0.5, ...]
76
+ }
77
+ ```
78
+
79
+ Returns:
80
+
81
+ ```json
82
+ {
83
+ "wind_speeds": [0.5, 1.2, ...],
84
+ "image_base64": "iVBORw0KGgo...",
85
+ "width": 512,
86
+ "height": 512
87
+ }
88
+ ```
89
+
90
+ ## Docker
91
+
92
+ ```bash
93
+ docker build -t eddy3d-gan-api .
94
+ docker run -p 8000:8000 -e MODEL_URL="https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx" eddy3d-gan-api
95
+ ```
api.py CHANGED
@@ -1,51 +1,473 @@
1
- from fastapi import FastAPI, File, UploadFile, responses
2
- from PIL import Image
 
 
 
 
 
 
3
  import io
 
 
 
 
 
 
4
  import numpy as np
 
 
 
 
 
5
  import onnxruntime as rt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- app = FastAPI()
 
 
 
8
 
9
- sess = rt.InferenceSession("model.onnx")
 
 
10
 
11
- @app.get("/test")
12
- async def test():
13
- return {"message": "API is working"}
14
 
15
- @app.api_route("/dummy", methods=["POST"])
16
- async def dummy_function(file: UploadFile = File(...)):
17
- print("Dummy function called")
18
- # Perform some dummy processing or return a dummy response
19
- return {"message": "Dummy function called"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
- #@app.api_route("/process_image", methods=["POST"])
22
- @app.post("/process_image")
23
- async def process_image(file: UploadFile = File(...)):
24
  input_data = await file.read()
25
- input_image = Image.open(io.BytesIO(input_data)).convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
- # Resize and convert to numpy array
28
- input_array = np.array(input_image.resize((512, 512)))
 
 
29
 
30
- # Normalize to [-1, 1]
31
- input_array = (input_array.astype(np.float32) / 127.5) - 1
32
- input_array = np.expand_dims(input_array, axis=0)
33
 
34
- # Reorder dimensions to match model's expectations
35
- input_array = input_array.transpose((0, 3, 1, 2)) # Move the color dimension to the correct position
 
 
 
 
 
 
 
 
 
 
36
 
37
- inputs = {sess.get_inputs()[0].name: input_array}
38
- outputs = sess.run(None, inputs)
 
 
 
 
 
 
 
39
 
40
- # Denormalize from [-1, 1] to [0, 255]
41
- output_array = (outputs[0][0].transpose((1, 2, 0)) + 1) * 127.5
42
- output_array = np.clip(output_array, 0, 255) # Clip values outside of valid range
43
 
44
- # Convert back to image
45
- output_image = Image.fromarray(output_array.astype(np.uint8), "RGB")
 
 
46
 
47
- byte_arr = io.BytesIO()
48
- output_image.save(byte_arr, format='PNG')
49
- byte_arr.seek(0) # Move cursor to the beginning of the file
50
 
51
- return responses.StreamingResponse(byte_arr, media_type="image/png")
 
 
 
 
 
 
1
+ """
2
+ Eddy3D GAN Wind Prediction API
3
+
4
+ Serves the GAN surrogate model for urban wind flow prediction.
5
+ Accepts 512x512 input images or raw float arrays and returns predicted wind speeds.
6
+ """
7
+
8
+ import os
9
  import io
10
+ import base64
11
+ import logging
12
+ import time
13
+ from contextlib import asynccontextmanager
14
+ from typing import Optional
15
+
16
  import numpy as np
17
+ from PIL import Image
18
+ from fastapi import FastAPI, File, Request, UploadFile, HTTPException
19
+ from fastapi.middleware.cors import CORSMiddleware
20
+ from fastapi.responses import JSONResponse, StreamingResponse
21
+ from pydantic import BaseModel
22
  import onnxruntime as rt
23
+ from slowapi import Limiter, _rate_limit_exceeded_handler
24
+ from slowapi.util import get_remote_address
25
+ from slowapi.errors import RateLimitExceeded
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Configuration
29
+ # ---------------------------------------------------------------------------
30
+ MODEL_PATH = os.environ.get("MODEL_PATH", "model.onnx")
31
+ PORT = int(os.environ.get("PORT", "8000"))
32
+ LOG_LEVEL = os.environ.get("LOG_LEVEL", "INFO").upper()
33
+ ALLOWED_ORIGINS = os.environ.get("ALLOWED_ORIGINS", "*").split(",")
34
+ RATE_LIMIT_REQUESTS = os.environ.get("RATE_LIMIT_REQUESTS", "30")
35
+ RATE_LIMIT_WINDOW = os.environ.get("RATE_LIMIT_WINDOW", "minute")
36
+ MAX_UPLOAD_BYTES = 10 * 1024 * 1024 # 10 MB
37
+
38
+ logging.basicConfig(level=LOG_LEVEL, format="%(asctime)s %(levelname)s %(message)s")
39
+ logger = logging.getLogger(__name__)
40
+
41
+ # ---------------------------------------------------------------------------
42
+ # Turbo colormap (256 entries, RGB in [0, 1])
43
+ # Used for reverse-mapping predicted pixel colours to wind speed values.
44
+ # ---------------------------------------------------------------------------
45
+ TURBO_COLORMAP = np.array([
46
+ [0.18995, 0.07176, 0.23217], [0.19483, 0.08339, 0.26149],
47
+ [0.19956, 0.09498, 0.29024], [0.20415, 0.10652, 0.31844],
48
+ [0.20860, 0.11802, 0.34607], [0.21291, 0.12947, 0.37314],
49
+ [0.21708, 0.14087, 0.39964], [0.22111, 0.15223, 0.42558],
50
+ [0.22500, 0.16354, 0.45096], [0.22875, 0.17481, 0.47578],
51
+ [0.23236, 0.18603, 0.50004], [0.23582, 0.19720, 0.52373],
52
+ [0.23915, 0.20833, 0.54686], [0.24234, 0.21941, 0.56942],
53
+ [0.24539, 0.23044, 0.59142], [0.24830, 0.24143, 0.61286],
54
+ [0.25107, 0.25237, 0.63374], [0.25369, 0.26327, 0.65406],
55
+ [0.25618, 0.27412, 0.67381], [0.25853, 0.28492, 0.69300],
56
+ [0.26074, 0.29568, 0.71162], [0.26280, 0.30639, 0.72968],
57
+ [0.26473, 0.31706, 0.74718], [0.26652, 0.32768, 0.76412],
58
+ [0.26816, 0.33825, 0.78050], [0.26967, 0.34878, 0.79631],
59
+ [0.27103, 0.35926, 0.81156], [0.27226, 0.36970, 0.82624],
60
+ [0.27334, 0.38008, 0.84037], [0.27429, 0.39043, 0.85393],
61
+ [0.27509, 0.40072, 0.86692], [0.27576, 0.41097, 0.87936],
62
+ [0.27628, 0.42118, 0.89123], [0.27667, 0.43134, 0.90254],
63
+ [0.27691, 0.44145, 0.91328], [0.27701, 0.45152, 0.92347],
64
+ [0.27698, 0.46153, 0.93309], [0.27680, 0.47151, 0.94214],
65
+ [0.27648, 0.48144, 0.95064], [0.27603, 0.49132, 0.95857],
66
+ [0.27543, 0.50115, 0.96594], [0.27469, 0.51094, 0.97275],
67
+ [0.27381, 0.52069, 0.97899], [0.27273, 0.53040, 0.98461],
68
+ [0.27106, 0.54015, 0.98930], [0.26878, 0.54995, 0.99303],
69
+ [0.26592, 0.55979, 0.99583], [0.26252, 0.56967, 0.99773],
70
+ [0.25862, 0.57958, 0.99876], [0.25425, 0.58950, 0.99896],
71
+ [0.24946, 0.59943, 0.99835], [0.24427, 0.60937, 0.99697],
72
+ [0.23874, 0.61931, 0.99485], [0.23288, 0.62923, 0.99202],
73
+ [0.22676, 0.63913, 0.98851], [0.22039, 0.64901, 0.98436],
74
+ [0.21382, 0.65886, 0.97959], [0.20708, 0.66866, 0.97423],
75
+ [0.20021, 0.67842, 0.96833], [0.19326, 0.68812, 0.96190],
76
+ [0.18625, 0.69775, 0.95498], [0.17923, 0.70732, 0.94761],
77
+ [0.17223, 0.71680, 0.93981], [0.16529, 0.72620, 0.93161],
78
+ [0.15844, 0.73551, 0.92305], [0.15173, 0.74472, 0.91416],
79
+ [0.14519, 0.75381, 0.90496], [0.13886, 0.76279, 0.89550],
80
+ [0.13278, 0.77165, 0.88580], [0.12698, 0.78037, 0.87590],
81
+ [0.12151, 0.78896, 0.86581], [0.11639, 0.79740, 0.85559],
82
+ [0.11167, 0.80569, 0.84525], [0.10738, 0.81381, 0.83484],
83
+ [0.10357, 0.82177, 0.82437], [0.10026, 0.82955, 0.81389],
84
+ [0.09750, 0.83714, 0.80342], [0.09532, 0.84455, 0.79299],
85
+ [0.09377, 0.85175, 0.78264], [0.09287, 0.85875, 0.77240],
86
+ [0.09267, 0.86554, 0.76230], [0.09320, 0.87211, 0.75237],
87
+ [0.09451, 0.87844, 0.74265], [0.09662, 0.88454, 0.73316],
88
+ [0.09958, 0.89040, 0.72393], [0.10342, 0.89600, 0.71500],
89
+ [0.10815, 0.90142, 0.70599], [0.11374, 0.90673, 0.69651],
90
+ [0.12014, 0.91193, 0.68660], [0.12733, 0.91701, 0.67627],
91
+ [0.13526, 0.92197, 0.66556], [0.14391, 0.92680, 0.65448],
92
+ [0.15323, 0.93151, 0.64308], [0.16319, 0.93609, 0.63137],
93
+ [0.17377, 0.94053, 0.61938], [0.18491, 0.94484, 0.60713],
94
+ [0.19659, 0.94901, 0.59466], [0.20877, 0.95304, 0.58199],
95
+ [0.22142, 0.95692, 0.56914], [0.23449, 0.96065, 0.55614],
96
+ [0.24797, 0.96423, 0.54303], [0.26180, 0.96765, 0.52981],
97
+ [0.27597, 0.97092, 0.51653], [0.29042, 0.97403, 0.50321],
98
+ [0.30513, 0.97697, 0.48987], [0.32006, 0.97974, 0.47654],
99
+ [0.33517, 0.98234, 0.46325], [0.35043, 0.98477, 0.45002],
100
+ [0.36581, 0.98702, 0.43688], [0.38127, 0.98909, 0.42386],
101
+ [0.39678, 0.99098, 0.41098], [0.41229, 0.99268, 0.39826],
102
+ [0.42778, 0.99419, 0.38575], [0.44321, 0.99551, 0.37345],
103
+ [0.45854, 0.99663, 0.36140], [0.47375, 0.99755, 0.34963],
104
+ [0.48879, 0.99828, 0.33816], [0.50362, 0.99879, 0.32701],
105
+ [0.51822, 0.99910, 0.31622], [0.53255, 0.99919, 0.30581],
106
+ [0.54658, 0.99907, 0.29581], [0.56026, 0.99873, 0.28623],
107
+ [0.57357, 0.99817, 0.27712], [0.58646, 0.99739, 0.26849],
108
+ [0.59891, 0.99638, 0.26038], [0.61088, 0.99514, 0.25280],
109
+ [0.62233, 0.99366, 0.24579], [0.63323, 0.99195, 0.23937],
110
+ [0.64362, 0.98999, 0.23356], [0.65394, 0.98775, 0.22835],
111
+ [0.66428, 0.98524, 0.22370], [0.67462, 0.98246, 0.21960],
112
+ [0.68494, 0.97941, 0.21602], [0.69525, 0.97610, 0.21294],
113
+ [0.70553, 0.97255, 0.21032], [0.71577, 0.96875, 0.20815],
114
+ [0.72596, 0.96470, 0.20640], [0.73610, 0.96043, 0.20504],
115
+ [0.74617, 0.95593, 0.20406], [0.75617, 0.95121, 0.20343],
116
+ [0.76608, 0.94627, 0.20311], [0.77591, 0.94113, 0.20310],
117
+ [0.78563, 0.93579, 0.20336], [0.79524, 0.93025, 0.20386],
118
+ [0.80473, 0.92452, 0.20459], [0.81410, 0.91861, 0.20552],
119
+ [0.82333, 0.91253, 0.20663], [0.83241, 0.90627, 0.20788],
120
+ [0.84133, 0.89986, 0.20926], [0.85010, 0.89328, 0.21074],
121
+ [0.85868, 0.88655, 0.21230], [0.86709, 0.87968, 0.21391],
122
+ [0.87530, 0.87267, 0.21555], [0.88331, 0.86553, 0.21719],
123
+ [0.89112, 0.85826, 0.21880], [0.89870, 0.85087, 0.22038],
124
+ [0.90605, 0.84337, 0.22188], [0.91317, 0.83576, 0.22328],
125
+ [0.92004, 0.82806, 0.22456], [0.92666, 0.82025, 0.22570],
126
+ [0.93301, 0.81236, 0.22667], [0.93909, 0.80439, 0.22744],
127
+ [0.94489, 0.79634, 0.22800], [0.95039, 0.78823, 0.22831],
128
+ [0.95560, 0.78005, 0.22836], [0.96049, 0.77181, 0.22811],
129
+ [0.96507, 0.76352, 0.22754], [0.96931, 0.75519, 0.22663],
130
+ [0.97323, 0.74682, 0.22536], [0.97679, 0.73842, 0.22369],
131
+ [0.98000, 0.73000, 0.22161], [0.98289, 0.72140, 0.21918],
132
+ [0.98549, 0.71250, 0.21650], [0.98781, 0.70330, 0.21358],
133
+ [0.98986, 0.69382, 0.21043], [0.99163, 0.68408, 0.20706],
134
+ [0.99314, 0.67408, 0.20348], [0.99438, 0.66386, 0.19971],
135
+ [0.99535, 0.65341, 0.19577], [0.99607, 0.64277, 0.19165],
136
+ [0.99654, 0.63193, 0.18738], [0.99675, 0.62093, 0.18297],
137
+ [0.99672, 0.60977, 0.17842], [0.99644, 0.59846, 0.17376],
138
+ [0.99593, 0.58703, 0.16899], [0.99517, 0.57549, 0.16412],
139
+ [0.99419, 0.56386, 0.15918], [0.99297, 0.55214, 0.15417],
140
+ [0.99153, 0.54036, 0.14910], [0.98987, 0.52854, 0.14398],
141
+ [0.98799, 0.51667, 0.13883], [0.98590, 0.50479, 0.13367],
142
+ [0.98360, 0.49291, 0.12849], [0.98108, 0.48104, 0.12332],
143
+ [0.97837, 0.46920, 0.11817], [0.97545, 0.45740, 0.11305],
144
+ [0.97234, 0.44565, 0.10797], [0.96904, 0.43399, 0.10294],
145
+ [0.96555, 0.42241, 0.09798], [0.96187, 0.41093, 0.09310],
146
+ [0.95801, 0.39958, 0.08831], [0.95398, 0.38836, 0.08362],
147
+ [0.94977, 0.37729, 0.07905], [0.94538, 0.36638, 0.07461],
148
+ [0.94084, 0.35566, 0.07031], [0.93612, 0.34513, 0.06616],
149
+ [0.93125, 0.33482, 0.06218], [0.92623, 0.32473, 0.05837],
150
+ [0.92105, 0.31489, 0.05475], [0.91572, 0.30530, 0.05134],
151
+ [0.91024, 0.29599, 0.04814], [0.90463, 0.28696, 0.04516],
152
+ [0.89888, 0.27824, 0.04243], [0.89298, 0.26981, 0.03993],
153
+ [0.88691, 0.26152, 0.03753], [0.88066, 0.25334, 0.03521],
154
+ [0.87422, 0.24526, 0.03297], [0.86760, 0.23730, 0.03082],
155
+ [0.86079, 0.22945, 0.02875], [0.85380, 0.22170, 0.02677],
156
+ [0.84662, 0.21407, 0.02487], [0.83926, 0.20654, 0.02305],
157
+ [0.83172, 0.19912, 0.02131], [0.82399, 0.19182, 0.01966],
158
+ [0.81608, 0.18462, 0.01809], [0.80799, 0.17753, 0.01660],
159
+ [0.79971, 0.17055, 0.01520], [0.79125, 0.16368, 0.01387],
160
+ [0.78260, 0.15693, 0.01264], [0.77377, 0.15028, 0.01148],
161
+ [0.76476, 0.14374, 0.01041], [0.75556, 0.13731, 0.00942],
162
+ [0.74617, 0.13098, 0.00851], [0.73661, 0.12477, 0.00769],
163
+ [0.72686, 0.11867, 0.00695], [0.71692, 0.11268, 0.00629],
164
+ [0.70680, 0.10680, 0.00571], [0.69650, 0.10102, 0.00522],
165
+ [0.68602, 0.09536, 0.00481], [0.67535, 0.08980, 0.00449],
166
+ [0.66449, 0.08436, 0.00424], [0.65345, 0.07902, 0.00408],
167
+ [0.64223, 0.07380, 0.00401], [0.63082, 0.06868, 0.00401],
168
+ [0.61923, 0.06367, 0.00410], [0.60746, 0.05878, 0.00427],
169
+ [0.59550, 0.05399, 0.00453], [0.58336, 0.04931, 0.00486],
170
+ [0.57103, 0.04474, 0.00529], [0.55852, 0.04028, 0.00579],
171
+ [0.54583, 0.03593, 0.00638], [0.53295, 0.03169, 0.00705],
172
+ [0.51989, 0.02756, 0.00780], [0.50664, 0.02354, 0.00863],
173
+ [0.49321, 0.01963, 0.00955], [0.47960, 0.01583, 0.01055],
174
+ ], dtype=np.float64) # shape (256, 3)
175
+
176
+
177
+ def _color_to_windspeed(raw_output: np.ndarray) -> list[float]:
178
+ """Map raw model output to wind speed values using Turbo colormap reverse-lookup.
179
+
180
+ Ported from Prediction.ColorToNumber in the decompiled CFDComponent.
181
+
182
+ Args:
183
+ raw_output: Model output array of shape (3, H, W) with values in [-1, 1].
184
+
185
+ Returns:
186
+ Flat list of wind speed values (0-15 m/s), length H*W.
187
+ """
188
+ _, h, w = raw_output.shape
189
+ n_colors = TURBO_COLORMAP.shape[0]
190
+
191
+ # Convert from [-1, 1] to [0, 1]
192
+ r = (raw_output[0] + 1.0) / 2.0 # (H, W)
193
+ g = (raw_output[1] + 1.0) / 2.0
194
+ b = (raw_output[2] + 1.0) / 2.0
195
+
196
+ # Stack into (H*W, 3)
197
+ pixels = np.stack([r.ravel(), g.ravel(), b.ravel()], axis=1) # (H*W, 3)
198
+
199
+ # Find closest colormap entry for each pixel via Euclidean distance
200
+ # Using broadcasting: pixels (H*W, 1, 3) - colormap (1, 256, 3) -> (H*W, 256, 3)
201
+ diff = pixels[:, np.newaxis, :] - TURBO_COLORMAP[np.newaxis, :, :]
202
+ dists = np.sum(diff * diff, axis=2) # (H*W, 256)
203
+ indices = np.argmin(dists, axis=1) # (H*W,)
204
+
205
+ # Map index to wind speed: index / n_colors * 15.0
206
+ wind_speeds = indices.astype(np.float64) / n_colors * 15.0
207
+
208
+ return wind_speeds.tolist()
209
+
210
+
211
+ # ---------------------------------------------------------------------------
212
+ # Rate limiting
213
+ # ---------------------------------------------------------------------------
214
+ _rate_limit_str = f"{RATE_LIMIT_REQUESTS}/{RATE_LIMIT_WINDOW}"
215
+ _rate_limiting_enabled = int(RATE_LIMIT_REQUESTS) > 0
216
+
217
+ limiter = Limiter(
218
+ key_func=get_remote_address,
219
+ default_limits=[_rate_limit_str] if _rate_limiting_enabled else [],
220
+ enabled=_rate_limiting_enabled,
221
+ )
222
+
223
+ # ---------------------------------------------------------------------------
224
+ # Application lifecycle
225
+ # ---------------------------------------------------------------------------
226
+ @asynccontextmanager
227
+ async def lifespan(app: FastAPI):
228
+ """Load ONNX model at startup, release at shutdown."""
229
+ logger.info("Loading ONNX model from %s ...", MODEL_PATH)
230
+ if not os.path.exists(MODEL_PATH):
231
+ logger.error("Model file not found at %s", MODEL_PATH)
232
+ raise FileNotFoundError(f"Model file not found at {MODEL_PATH}")
233
+
234
+ session = rt.InferenceSession(MODEL_PATH)
235
+ input_name = session.get_inputs()[0].name
236
+ input_shape = session.get_inputs()[0].shape
237
+ output_name = session.get_outputs()[0].name
238
+ output_shape = session.get_outputs()[0].shape
239
+ logger.info("Model loaded. Input: %s %s Output: %s %s",
240
+ input_name, input_shape, output_name, output_shape)
241
+
242
+ app.state.session = session
243
+ app.state.input_name = input_name
244
+ app.state.model_loaded = True
245
+
246
+ yield
247
+
248
+ app.state.model_loaded = False
249
+ logger.info("Shutting down.")
250
+
251
+
252
+ app = FastAPI(
253
+ title="Eddy3D GAN Wind Prediction API",
254
+ version="2.0.0",
255
+ lifespan=lifespan,
256
+ )
257
+
258
+ app.state.limiter = limiter
259
+ app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
260
+
261
+ app.add_middleware(
262
+ CORSMiddleware,
263
+ allow_origins=ALLOWED_ORIGINS,
264
+ allow_credentials=True,
265
+ allow_methods=["*"],
266
+ allow_headers=["*"],
267
+ )
268
+
269
+
270
+ # ---------------------------------------------------------------------------
271
+ # Helper: run inference
272
+ # ---------------------------------------------------------------------------
273
+ def _run_inference(image_bytes: bytes) -> tuple[np.ndarray, np.ndarray]:
274
+ """Preprocess image, run model, return (raw_output, denormalized_output).
275
+
276
+ Returns:
277
+ raw_output: shape (3, 512, 512), values in [-1, 1]
278
+ denorm_output: shape (512, 512, 3), values in [0, 255], uint8
279
+ """
280
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
281
+ input_array = np.array(image.resize((512, 512)), dtype=np.float32)
282
 
283
+ # Normalize [0, 255] -> [-1, 1]
284
+ input_array = (input_array / 127.5) - 1.0
285
+ input_array = np.expand_dims(input_array, axis=0) # (1, 512, 512, 3)
286
+ input_array = input_array.transpose((0, 3, 1, 2)) # (1, 3, 512, 512)
287
 
288
+ session = app.state.session
289
+ inputs = {app.state.input_name: input_array}
290
+ outputs = session.run(None, inputs)
291
 
292
+ raw = outputs[0][0] # (3, 512, 512)
 
 
293
 
294
+ # Denormalise for image output
295
+ denorm = (raw.transpose((1, 2, 0)) + 1.0) * 127.5 # (512, 512, 3)
296
+ denorm = np.clip(denorm, 0, 255).astype(np.uint8)
297
+
298
+ return raw, denorm
299
+
300
+
301
+ def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
302
+ """Run model from a pre-normalised float32 array.
303
+
304
+ Args:
305
+ data: shape (3, 512, 512), float32, values already in [-1, 1].
306
+
307
+ Returns:
308
+ raw_output: shape (3, 512, 512), values in [-1, 1]
309
+ denorm_output: shape (512, 512, 3), values in [0, 255], uint8
310
+ """
311
+ input_array = np.expand_dims(data, axis=0).astype(np.float32) # (1, 3, 512, 512)
312
+
313
+ session = app.state.session
314
+ inputs = {app.state.input_name: input_array}
315
+ outputs = session.run(None, inputs)
316
+
317
+ raw = outputs[0][0] # (3, 512, 512)
318
+
319
+ denorm = (raw.transpose((1, 2, 0)) + 1.0) * 127.5
320
+ denorm = np.clip(denorm, 0, 255).astype(np.uint8)
321
+
322
+ return raw, denorm
323
+
324
+
325
+ class ArrayPredictRequest(BaseModel):
326
+ """Request body for /predict_array.
327
+
328
+ The data field is a flat list of 786432 floats (3 * 512 * 512) in
329
+ channel-first order (R, G, B), with values already normalised to [-1, 1].
330
+ """
331
+ data: list[float]
332
+
333
+
334
+ # ---------------------------------------------------------------------------
335
+ # Endpoints
336
+ # ---------------------------------------------------------------------------
337
+ @app.get("/")
338
+ async def root():
339
+ """Basic service info for base URL checks."""
340
+ return {
341
+ "service": "Eddy3D GAN Wind Prediction API",
342
+ "status": "ok",
343
+ "endpoints": ["/health", "/predict", "/predict_array", "/predict_image"],
344
+ }
345
+
346
+
347
+ @app.get("/health")
348
+ async def health():
349
+ """Health check (not rate limited)."""
350
+ return {
351
+ "status": "healthy",
352
+ "model_loaded": getattr(app.state, "model_loaded", False),
353
+ }
354
+
355
+
356
+ @app.post("/predict")
357
+ @limiter.limit(_rate_limit_str)
358
+ async def predict(request: Request, file: UploadFile = File(...)):
359
+ """Run GAN inference and return wind speeds + output image.
360
+
361
+ Returns JSON:
362
+ wind_speeds: list of floats (length 262144), values 0–15 m/s
363
+ image_base64: base64-encoded PNG of the output image
364
+ width: 512
365
+ height: 512
366
+ """
367
+ # Validate content type
368
+ if file.content_type and not file.content_type.startswith("image/"):
369
+ raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
370
+
371
+ input_data = await file.read()
372
+ if len(input_data) > MAX_UPLOAD_BYTES:
373
+ raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
374
+
375
+ t0 = time.perf_counter()
376
+ try:
377
+ raw, denorm = _run_inference(input_data)
378
+ except Exception as e:
379
+ logger.exception("Inference failed")
380
+ raise HTTPException(status_code=500, detail=f"Inference error: {e}")
381
+
382
+ # Wind speeds from raw output
383
+ wind_speeds = _color_to_windspeed(raw)
384
+
385
+ # Encode output image as base64 PNG
386
+ output_image = Image.fromarray(denorm, "RGB")
387
+ buf = io.BytesIO()
388
+ output_image.save(buf, format="PNG")
389
+ image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
390
+
391
+ elapsed = time.perf_counter() - t0
392
+ logger.info("Inference completed in %.2f s", elapsed)
393
+
394
+ return JSONResponse({
395
+ "wind_speeds": wind_speeds,
396
+ "image_base64": image_b64,
397
+ "width": 512,
398
+ "height": 512,
399
+ })
400
+
401
+
402
+ @app.post("/predict_image")
403
+ @limiter.limit(_rate_limit_str)
404
+ async def predict_image(request: Request, file: UploadFile = File(...)):
405
+ """Legacy endpoint: returns the output image as a PNG stream."""
406
+ if file.content_type and not file.content_type.startswith("image/"):
407
+ raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
408
 
 
 
 
409
  input_data = await file.read()
410
+ if len(input_data) > MAX_UPLOAD_BYTES:
411
+ raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
412
+
413
+ try:
414
+ _, denorm = _run_inference(input_data)
415
+ except Exception as e:
416
+ logger.exception("Inference failed")
417
+ raise HTTPException(status_code=500, detail=f"Inference error: {e}")
418
+
419
+ output_image = Image.fromarray(denorm, "RGB")
420
+ buf = io.BytesIO()
421
+ output_image.save(buf, format="PNG")
422
+ buf.seek(0)
423
+
424
+ return StreamingResponse(buf, media_type="image/png")
425
+
426
 
427
+ @app.post("/predict_array")
428
+ @limiter.limit(_rate_limit_str)
429
+ async def predict_array(request: Request, body: ArrayPredictRequest):
430
+ """Run GAN inference from a raw float array (no image encoding needed).
431
 
432
+ Accepts a flat list of 786432 floats (3 * 512 * 512) in channel-first
433
+ order, already normalised to [-1, 1].
 
434
 
435
+ Returns JSON:
436
+ wind_speeds: list of floats (length 262144), values 0-15 m/s
437
+ image_base64: base64-encoded PNG of the output image
438
+ width: 512
439
+ height: 512
440
+ """
441
+ expected = 3 * 512 * 512
442
+ if len(body.data) != expected:
443
+ raise HTTPException(
444
+ status_code=400,
445
+ detail=f"Expected {expected} floats (3*512*512), got {len(body.data)}.",
446
+ )
447
 
448
+ t0 = time.perf_counter()
449
+ try:
450
+ arr = np.array(body.data, dtype=np.float32).reshape((3, 512, 512))
451
+ raw, denorm = _run_inference_from_array(arr)
452
+ except HTTPException:
453
+ raise
454
+ except Exception as e:
455
+ logger.exception("Inference failed")
456
+ raise HTTPException(status_code=500, detail=f"Inference error: {e}")
457
 
458
+ wind_speeds = _color_to_windspeed(raw)
 
 
459
 
460
+ output_image = Image.fromarray(denorm, "RGB")
461
+ buf = io.BytesIO()
462
+ output_image.save(buf, format="PNG")
463
+ image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
464
 
465
+ elapsed = time.perf_counter() - t0
466
+ logger.info("Array inference completed in %.2f s", elapsed)
 
467
 
468
+ return JSONResponse({
469
+ "wind_speeds": wind_speeds,
470
+ "image_base64": image_b64,
471
+ "width": 512,
472
+ "height": 512,
473
+ })
download_model.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Download ONNX model from cloud storage if not already present."""
2
+
3
+ import hashlib
4
+ import os
5
+ import sys
6
+
7
+ import requests
8
+
9
+ DEFAULT_MODEL_URL = (
10
+ "https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/"
11
+ "GAN-21-05-2023-23-Generative.onnx"
12
+ )
13
+
14
+ # Allow empty env var values to fall back to the default hosted model.
15
+ MODEL_URL = os.environ.get("MODEL_URL") or DEFAULT_MODEL_URL
16
+ MODEL_PATH = os.environ.get("MODEL_PATH", "model.onnx")
17
+ MODEL_SHA256 = os.environ.get("MODEL_SHA256", "")
18
+
19
+
20
+ def download() -> None:
21
+ if os.path.exists(MODEL_PATH):
22
+ print(f"Model already exists at {MODEL_PATH}")
23
+ return
24
+
25
+ print(f"Downloading model from {MODEL_URL} ...")
26
+ response = requests.get(MODEL_URL, stream=True, timeout=600)
27
+ response.raise_for_status()
28
+
29
+ os.makedirs(os.path.dirname(MODEL_PATH) or ".", exist_ok=True)
30
+
31
+ total = int(response.headers.get("content-length", 0))
32
+ downloaded = 0
33
+ with open(MODEL_PATH, "wb") as f:
34
+ for chunk in response.iter_content(chunk_size=8192):
35
+ f.write(chunk)
36
+ downloaded += len(chunk)
37
+ if total:
38
+ pct = downloaded / total * 100
39
+ print(f"\r {downloaded / 1e6:.1f} / {total / 1e6:.1f} MB ({pct:.0f}%)", end="", flush=True)
40
+ print()
41
+
42
+ # Optional SHA-256 integrity check
43
+ if MODEL_SHA256:
44
+ h = hashlib.sha256()
45
+ with open(MODEL_PATH, "rb") as f:
46
+ for chunk in iter(lambda: f.read(8192), b""):
47
+ h.update(chunk)
48
+ if h.hexdigest() != MODEL_SHA256:
49
+ os.remove(MODEL_PATH)
50
+ print(f"ERROR: SHA-256 mismatch! Expected {MODEL_SHA256}, got {h.hexdigest()}", file=sys.stderr)
51
+ sys.exit(1)
52
+ print("SHA-256 verified.")
53
+
54
+ print(f"Model downloaded to {MODEL_PATH}")
55
+
56
+
57
+ if __name__ == "__main__":
58
+ download()
pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "eddy3d-gan-api"
3
+ version = "2.0.0"
4
+ description = "GAN surrogate model API for urban wind flow prediction"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "fastapi==0.115.0",
9
+ "uvicorn[standard]==0.30.0",
10
+ "python-multipart==0.0.9",
11
+ "onnxruntime==1.18.0",
12
+ "pillow==10.4.0",
13
+ "numpy==1.26.4",
14
+ "requests>=2.32.3",
15
+ "slowapi==0.1.9",
16
+ ]
17
+
18
+ [build-system]
19
+ requires = ["hatchling"]
20
+ build-backend = "hatchling.build"
21
+
22
+ [tool.uv]
23
+ package = false
render.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ services:
2
+ - type: web
3
+ name: eddy3d-gan-api
4
+ runtime: docker
5
+ dockerfilePath: ./Dockerfile
6
+ region: oregon
7
+ plan: starter
8
+ healthCheckPath: /health
9
+ envVars:
10
+ - key: MODEL_URL
11
+ value: "https://huggingface.co/SustainableUrbanSystemsLab/UrbanWind-GAN/resolve/main/GAN-21-05-2023-23-Generative.onnx"
12
+ - key: MODEL_PATH
13
+ value: /app/model.onnx
14
+ - key: MODEL_SHA256
15
+ value: ""
16
+ - key: PORT
17
+ value: "8000"
18
+ - key: LOG_LEVEL
19
+ value: INFO
20
+ - key: ALLOWED_ORIGINS
21
+ value: "*"
22
+ - key: RATE_LIMIT_REQUESTS
23
+ value: "30"
24
+ - key: RATE_LIMIT_WINDOW
25
+ value: minute
requirements.txt DELETED
@@ -1,8 +0,0 @@
1
- fastapi
2
- uvicorn[standard]
3
- python-multipart
4
- uvicorn
5
- onnxruntime
6
- pillow
7
- numpy
8
- requests
 
 
 
 
 
 
 
 
 
test.ipynb DELETED
@@ -1,106 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": 13,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "import requests\n",
10
- "from PIL import Image\n",
11
- "import io\n"
12
- ]
13
- },
14
- {
15
- "cell_type": "code",
16
- "execution_count": 22,
17
- "metadata": {},
18
- "outputs": [
19
- {
20
- "name": "stdout",
21
- "output_type": "stream",
22
- "text": [
23
- "b'{\"message\":\"API is working\"}'\n"
24
- ]
25
- }
26
- ],
27
- "source": [
28
- "response = requests.get(\"http://gan-230522.eddy3d.com/test\")\n",
29
- "\n",
30
- "print(response.content)\n"
31
- ]
32
- },
33
- {
34
- "cell_type": "code",
35
- "execution_count": 35,
36
- "metadata": {},
37
- "outputs": [],
38
- "source": [
39
- "# Open your image file in binary mode\n",
40
- "with open(\"test_input.png\", \"rb\") as f:\n",
41
- " img_binary = f.read()\n",
42
- "\n",
43
- "\n",
44
- "# Create a file-like object from the image binary data\n",
45
- "img_file = io.BytesIO(img_binary)\n",
46
- "\n",
47
- "# Open the image using PIL\n",
48
- "img_in = Image.open(img_file)\n",
49
- "\n",
50
- "# Display the image\n",
51
- "img_in.show()\n",
52
- "\n",
53
- "\n",
54
- "# Send the POST request\n",
55
- "# response = requests.post('http://127.0.0.1:8000/process_image', files={'file': img_binary})\n",
56
- "response = requests.post(\n",
57
- " \"https://gan-230522.eddy3d.com/process_image\", files={\"file\": img_binary}\n",
58
- ")\n",
59
- "\n",
60
- "\n",
61
- "# The response is an image file, so we need to convert it back to an image\n",
62
- "image = Image.open(io.BytesIO(response.content))\n",
63
- "\n",
64
- "# Now you can display the image or save it to a file\n",
65
- "image.show()\n"
66
- ]
67
- },
68
- {
69
- "cell_type": "code",
70
- "execution_count": 18,
71
- "metadata": {},
72
- "outputs": [
73
- {
74
- "name": "stdout",
75
- "output_type": "stream",
76
- "text": [
77
- "b'{\"message\":\"API is working\"}'\n"
78
- ]
79
- }
80
- ],
81
- "source": []
82
- }
83
- ],
84
- "metadata": {
85
- "kernelspec": {
86
- "display_name": "ganapi",
87
- "language": "python",
88
- "name": "python3"
89
- },
90
- "language_info": {
91
- "codemirror_mode": {
92
- "name": "ipython",
93
- "version": 3
94
- },
95
- "file_extension": ".py",
96
- "mimetype": "text/x-python",
97
- "name": "python",
98
- "nbconvert_exporter": "python",
99
- "pygments_lexer": "ipython3",
100
- "version": "3.9.16"
101
- },
102
- "orig_nbformat": 4
103
- },
104
- "nbformat": 4,
105
- "nbformat_minor": 2
106
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
uv.lock ADDED
The diff for this file is too large to render. See raw diff