Patrick Kastner commited on
Commit ·
42b48e1
1
Parent(s): 9c8e922
Update api.py
Browse files
api.py
CHANGED
|
@@ -8,6 +8,7 @@ Accepts 512x512 input images or raw float arrays and returns predicted wind spee
|
|
| 8 |
import os
|
| 9 |
import io
|
| 10 |
import base64
|
|
|
|
| 11 |
import logging
|
| 12 |
import time
|
| 13 |
from contextlib import asynccontextmanager
|
|
@@ -15,9 +16,9 @@ from typing import Optional
|
|
| 15 |
|
| 16 |
import numpy as np
|
| 17 |
from PIL import Image
|
| 18 |
-
from fastapi import FastAPI,
|
| 19 |
from fastapi.middleware.cors import CORSMiddleware
|
| 20 |
-
from fastapi.responses import JSONResponse
|
| 21 |
from pydantic import BaseModel
|
| 22 |
import onnxruntime as rt
|
| 23 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
@@ -267,36 +268,7 @@ app.add_middleware(
|
|
| 267 |
)
|
| 268 |
|
| 269 |
|
| 270 |
-
|
| 271 |
-
# Helper: run inference
|
| 272 |
-
# ---------------------------------------------------------------------------
|
| 273 |
-
def _run_inference(image_bytes: bytes) -> tuple[np.ndarray, np.ndarray]:
|
| 274 |
-
"""Preprocess image, run model, return (raw_output, denormalized_output).
|
| 275 |
-
|
| 276 |
-
Returns:
|
| 277 |
-
raw_output: shape (3, 512, 512), values in [-1, 1]
|
| 278 |
-
denorm_output: shape (512, 512, 3), values in [0, 255], uint8
|
| 279 |
-
"""
|
| 280 |
-
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
| 281 |
-
input_array = np.array(image.resize((512, 512)), dtype=np.float32)
|
| 282 |
-
|
| 283 |
-
# Normalize [0, 255] -> [-1, 1]
|
| 284 |
-
input_array = (input_array / 127.5) - 1.0
|
| 285 |
-
input_array = np.expand_dims(input_array, axis=0) # (1, 512, 512, 3)
|
| 286 |
-
input_array = input_array.transpose((0, 3, 1, 2)) # (1, 3, 512, 512)
|
| 287 |
-
|
| 288 |
-
session = app.state.session
|
| 289 |
-
inputs = {app.state.input_name: input_array}
|
| 290 |
-
outputs = session.run(None, inputs)
|
| 291 |
-
|
| 292 |
-
raw = outputs[0][0] # (3, 512, 512)
|
| 293 |
-
|
| 294 |
-
# Denormalise for image output
|
| 295 |
-
denorm = (raw.transpose((1, 2, 0)) + 1.0) * 127.5 # (512, 512, 3)
|
| 296 |
-
denorm = np.clip(denorm, 0, 255).astype(np.uint8)
|
| 297 |
-
|
| 298 |
-
return raw, denorm
|
| 299 |
-
|
| 300 |
|
| 301 |
def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 302 |
"""Run model from a pre-normalised float32 array.
|
|
@@ -322,13 +294,16 @@ def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]
|
|
| 322 |
return raw, denorm
|
| 323 |
|
| 324 |
|
| 325 |
-
|
| 326 |
-
|
|
|
|
|
|
|
| 327 |
|
| 328 |
-
The
|
| 329 |
-
channel-first order (R, G, B),
|
|
|
|
| 330 |
"""
|
| 331 |
-
|
| 332 |
|
| 333 |
|
| 334 |
# ---------------------------------------------------------------------------
|
|
@@ -337,10 +312,7 @@ class ArrayPredictRequest(BaseModel):
|
|
| 337 |
@app.get("/")
|
| 338 |
async def root():
|
| 339 |
"""Basic service info for base URL checks."""
|
| 340 |
-
|
| 341 |
-
"service": "Eddy3D GAN Wind Prediction API",
|
| 342 |
-
"status": "ok",
|
| 343 |
-
"endpoints": ["/health", "/predict", "/predict_array", "/predict_image"],
|
| 344 |
}
|
| 345 |
|
| 346 |
|
|
@@ -353,120 +325,58 @@ async def health():
|
|
| 353 |
}
|
| 354 |
|
| 355 |
|
|
|
|
|
|
|
| 356 |
@app.post("/predict")
|
| 357 |
@limiter.limit(_rate_limit_str)
|
| 358 |
-
async def predict(request: Request,
|
| 359 |
-
"""Run GAN inference
|
| 360 |
|
| 361 |
Returns JSON:
|
| 362 |
-
|
| 363 |
image_base64: base64-encoded PNG of the output image
|
| 364 |
width: 512
|
| 365 |
height: 512
|
| 366 |
"""
|
| 367 |
-
# Validate content type
|
| 368 |
-
if file.content_type and not file.content_type.startswith("image/"):
|
| 369 |
-
raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
|
| 370 |
-
|
| 371 |
-
input_data = await file.read()
|
| 372 |
-
if len(input_data) > MAX_UPLOAD_BYTES:
|
| 373 |
-
raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
|
| 374 |
-
|
| 375 |
t0 = time.perf_counter()
|
| 376 |
try:
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
| 387 |
-
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
elapsed = time.perf_counter() - t0
|
| 392 |
-
logger.info("Inference completed in %.2f s", elapsed)
|
| 393 |
-
|
| 394 |
-
return JSONResponse({
|
| 395 |
-
"wind_speeds": wind_speeds,
|
| 396 |
-
"image_base64": image_b64,
|
| 397 |
-
"width": 512,
|
| 398 |
-
"height": 512,
|
| 399 |
-
})
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
@app.post("/predict_image")
|
| 403 |
-
@limiter.limit(_rate_limit_str)
|
| 404 |
-
async def predict_image(request: Request, file: UploadFile = File(...)):
|
| 405 |
-
"""Legacy endpoint: returns the output image as a PNG stream."""
|
| 406 |
-
if file.content_type and not file.content_type.startswith("image/"):
|
| 407 |
-
raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
|
| 408 |
-
|
| 409 |
-
input_data = await file.read()
|
| 410 |
-
if len(input_data) > MAX_UPLOAD_BYTES:
|
| 411 |
-
raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
|
| 412 |
-
|
| 413 |
-
try:
|
| 414 |
-
_, denorm = _run_inference(input_data)
|
| 415 |
-
except Exception as e:
|
| 416 |
-
logger.exception("Inference failed")
|
| 417 |
-
raise HTTPException(status_code=500, detail=f"Inference error: {e}")
|
| 418 |
-
|
| 419 |
-
output_image = Image.fromarray(denorm, "RGB")
|
| 420 |
-
buf = io.BytesIO()
|
| 421 |
-
output_image.save(buf, format="PNG")
|
| 422 |
-
buf.seek(0)
|
| 423 |
-
|
| 424 |
-
return StreamingResponse(buf, media_type="image/png")
|
| 425 |
|
|
|
|
| 426 |
|
| 427 |
-
|
| 428 |
-
|
| 429 |
-
|
| 430 |
-
|
| 431 |
|
| 432 |
-
|
| 433 |
-
|
|
|
|
|
|
|
| 434 |
|
| 435 |
-
Returns JSON:
|
| 436 |
-
wind_speeds: list of floats (length 262144), values 0-15 m/s
|
| 437 |
-
image_base64: base64-encoded PNG of the output image
|
| 438 |
-
width: 512
|
| 439 |
-
height: 512
|
| 440 |
-
"""
|
| 441 |
-
expected = 3 * 512 * 512
|
| 442 |
-
if len(body.data) != expected:
|
| 443 |
-
raise HTTPException(
|
| 444 |
-
status_code=400,
|
| 445 |
-
detail=f"Expected {expected} floats (3*512*512), got {len(body.data)}.",
|
| 446 |
-
)
|
| 447 |
-
|
| 448 |
-
t0 = time.perf_counter()
|
| 449 |
-
try:
|
| 450 |
-
arr = np.array(body.data, dtype=np.float32).reshape((3, 512, 512))
|
| 451 |
-
raw, denorm = _run_inference_from_array(arr)
|
| 452 |
except HTTPException:
|
| 453 |
raise
|
| 454 |
except Exception as e:
|
| 455 |
logger.exception("Inference failed")
|
| 456 |
raise HTTPException(status_code=500, detail=f"Inference error: {e}")
|
| 457 |
|
| 458 |
-
wind_speeds = _color_to_windspeed(raw)
|
| 459 |
-
|
| 460 |
-
output_image = Image.fromarray(denorm, "RGB")
|
| 461 |
-
buf = io.BytesIO()
|
| 462 |
-
output_image.save(buf, format="PNG")
|
| 463 |
-
image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 464 |
-
|
| 465 |
elapsed = time.perf_counter() - t0
|
| 466 |
-
logger.info("
|
| 467 |
|
| 468 |
return JSONResponse({
|
| 469 |
-
"
|
| 470 |
"image_base64": image_b64,
|
| 471 |
"width": 512,
|
| 472 |
"height": 512,
|
|
|
|
| 8 |
import os
|
| 9 |
import io
|
| 10 |
import base64
|
| 11 |
+
import gzip
|
| 12 |
import logging
|
| 13 |
import time
|
| 14 |
from contextlib import asynccontextmanager
|
|
|
|
| 16 |
|
| 17 |
import numpy as np
|
| 18 |
from PIL import Image
|
| 19 |
+
from fastapi import FastAPI, Request, HTTPException
|
| 20 |
from fastapi.middleware.cors import CORSMiddleware
|
| 21 |
+
from fastapi.responses import JSONResponse
|
| 22 |
from pydantic import BaseModel
|
| 23 |
import onnxruntime as rt
|
| 24 |
from slowapi import Limiter, _rate_limit_exceeded_handler
|
|
|
|
| 268 |
)
|
| 269 |
|
| 270 |
|
| 271 |
+
None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
|
| 273 |
def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 274 |
"""Run model from a pre-normalised float32 array.
|
|
|
|
| 294 |
return raw, denorm
|
| 295 |
|
| 296 |
|
| 297 |
+
None
|
| 298 |
+
|
| 299 |
+
class PredictRequest(BaseModel):
|
| 300 |
+
"""Request body for /predict.
|
| 301 |
|
| 302 |
+
The data_b64 field is a base64-encoded, gzip-compressed flat array of
|
| 303 |
+
786432 float32 values (3 * 512 * 512) in channel-first order (R, G, B),
|
| 304 |
+
with values already normalised to [-1, 1].
|
| 305 |
"""
|
| 306 |
+
data_b64: str
|
| 307 |
|
| 308 |
|
| 309 |
# ---------------------------------------------------------------------------
|
|
|
|
| 312 |
@app.get("/")
|
| 313 |
async def root():
|
| 314 |
"""Basic service info for base URL checks."""
|
| 315 |
+
"endpoints": ["/health", "/predict"],
|
|
|
|
|
|
|
|
|
|
| 316 |
}
|
| 317 |
|
| 318 |
|
|
|
|
| 325 |
}
|
| 326 |
|
| 327 |
|
| 328 |
+
None
|
| 329 |
+
|
| 330 |
@app.post("/predict")
|
| 331 |
@limiter.limit(_rate_limit_str)
|
| 332 |
+
async def predict(request: Request, body: PredictRequest):
|
| 333 |
+
"""Run GAN inference from a gzip-compressed, base64-encoded float32 array.
|
| 334 |
|
| 335 |
Returns JSON:
|
| 336 |
+
wind_speeds_b64: base64-encoded, gzip-compressed float32 array (length 262144)
|
| 337 |
image_base64: base64-encoded PNG of the output image
|
| 338 |
width: 512
|
| 339 |
height: 512
|
| 340 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 341 |
t0 = time.perf_counter()
|
| 342 |
try:
|
| 343 |
+
compressed_bytes = base64.b64decode(body.data_b64)
|
| 344 |
+
raw_bytes = gzip.decompress(compressed_bytes)
|
| 345 |
+
|
| 346 |
+
arr = np.frombuffer(raw_bytes, dtype=np.float32)
|
| 347 |
+
expected = 3 * 512 * 512
|
| 348 |
+
if arr.size != expected:
|
| 349 |
+
raise HTTPException(
|
| 350 |
+
status_code=400,
|
| 351 |
+
detail=f"Expected {expected} floats (3*512*512), got {arr.size}.",
|
| 352 |
+
)
|
| 353 |
+
|
| 354 |
+
arr = arr.reshape((3, 512, 512))
|
| 355 |
+
raw, denorm = _run_inference_from_array(arr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 356 |
|
| 357 |
+
wind_speeds_list = _color_to_windspeed(raw)
|
| 358 |
|
| 359 |
+
wind_speeds_arr = np.array(wind_speeds_list, dtype=np.float32)
|
| 360 |
+
wind_speeds_bytes = wind_speeds_arr.tobytes()
|
| 361 |
+
compressed_wind_speeds = gzip.compress(wind_speeds_bytes)
|
| 362 |
+
wind_speeds_b64 = base64.b64encode(compressed_wind_speeds).decode("ascii")
|
| 363 |
|
| 364 |
+
output_image = Image.fromarray(denorm, "RGB")
|
| 365 |
+
buf = io.BytesIO()
|
| 366 |
+
output_image.save(buf, format="PNG")
|
| 367 |
+
image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
|
| 368 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 369 |
except HTTPException:
|
| 370 |
raise
|
| 371 |
except Exception as e:
|
| 372 |
logger.exception("Inference failed")
|
| 373 |
raise HTTPException(status_code=500, detail=f"Inference error: {e}")
|
| 374 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 375 |
elapsed = time.perf_counter() - t0
|
| 376 |
+
logger.info("Binary inference completed in %.2f s", elapsed)
|
| 377 |
|
| 378 |
return JSONResponse({
|
| 379 |
+
"wind_speeds_b64": wind_speeds_b64,
|
| 380 |
"image_base64": image_b64,
|
| 381 |
"width": 512,
|
| 382 |
"height": 512,
|