Patrick Kastner commited on
Commit
42b48e1
·
1 Parent(s): 9c8e922

Update api.py

Browse files
Files changed (1) hide show
  1. api.py +42 -132
api.py CHANGED
@@ -8,6 +8,7 @@ Accepts 512x512 input images or raw float arrays and returns predicted wind spee
8
  import os
9
  import io
10
  import base64
 
11
  import logging
12
  import time
13
  from contextlib import asynccontextmanager
@@ -15,9 +16,9 @@ from typing import Optional
15
 
16
  import numpy as np
17
  from PIL import Image
18
- from fastapi import FastAPI, File, Request, UploadFile, HTTPException
19
  from fastapi.middleware.cors import CORSMiddleware
20
- from fastapi.responses import JSONResponse, StreamingResponse
21
  from pydantic import BaseModel
22
  import onnxruntime as rt
23
  from slowapi import Limiter, _rate_limit_exceeded_handler
@@ -267,36 +268,7 @@ app.add_middleware(
267
  )
268
 
269
 
270
- # ---------------------------------------------------------------------------
271
- # Helper: run inference
272
- # ---------------------------------------------------------------------------
273
- def _run_inference(image_bytes: bytes) -> tuple[np.ndarray, np.ndarray]:
274
- """Preprocess image, run model, return (raw_output, denormalized_output).
275
-
276
- Returns:
277
- raw_output: shape (3, 512, 512), values in [-1, 1]
278
- denorm_output: shape (512, 512, 3), values in [0, 255], uint8
279
- """
280
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
281
- input_array = np.array(image.resize((512, 512)), dtype=np.float32)
282
-
283
- # Normalize [0, 255] -> [-1, 1]
284
- input_array = (input_array / 127.5) - 1.0
285
- input_array = np.expand_dims(input_array, axis=0) # (1, 512, 512, 3)
286
- input_array = input_array.transpose((0, 3, 1, 2)) # (1, 3, 512, 512)
287
-
288
- session = app.state.session
289
- inputs = {app.state.input_name: input_array}
290
- outputs = session.run(None, inputs)
291
-
292
- raw = outputs[0][0] # (3, 512, 512)
293
-
294
- # Denormalise for image output
295
- denorm = (raw.transpose((1, 2, 0)) + 1.0) * 127.5 # (512, 512, 3)
296
- denorm = np.clip(denorm, 0, 255).astype(np.uint8)
297
-
298
- return raw, denorm
299
-
300
 
301
  def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
302
  """Run model from a pre-normalised float32 array.
@@ -322,13 +294,16 @@ def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]
322
  return raw, denorm
323
 
324
 
325
- class ArrayPredictRequest(BaseModel):
326
- """Request body for /predict_array.
 
 
327
 
328
- The data field is a flat list of 786432 floats (3 * 512 * 512) in
329
- channel-first order (R, G, B), with values already normalised to [-1, 1].
 
330
  """
331
- data: list[float]
332
 
333
 
334
  # ---------------------------------------------------------------------------
@@ -337,10 +312,7 @@ class ArrayPredictRequest(BaseModel):
337
  @app.get("/")
338
  async def root():
339
  """Basic service info for base URL checks."""
340
- return {
341
- "service": "Eddy3D GAN Wind Prediction API",
342
- "status": "ok",
343
- "endpoints": ["/health", "/predict", "/predict_array", "/predict_image"],
344
  }
345
 
346
 
@@ -353,120 +325,58 @@ async def health():
353
  }
354
 
355
 
 
 
356
  @app.post("/predict")
357
  @limiter.limit(_rate_limit_str)
358
- async def predict(request: Request, file: UploadFile = File(...)):
359
- """Run GAN inference and return wind speeds + output image.
360
 
361
  Returns JSON:
362
- wind_speeds: list of floats (length 262144), values 0–15 m/s
363
  image_base64: base64-encoded PNG of the output image
364
  width: 512
365
  height: 512
366
  """
367
- # Validate content type
368
- if file.content_type and not file.content_type.startswith("image/"):
369
- raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
370
-
371
- input_data = await file.read()
372
- if len(input_data) > MAX_UPLOAD_BYTES:
373
- raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
374
-
375
  t0 = time.perf_counter()
376
  try:
377
- raw, denorm = _run_inference(input_data)
378
- except Exception as e:
379
- logger.exception("Inference failed")
380
- raise HTTPException(status_code=500, detail=f"Inference error: {e}")
381
-
382
- # Wind speeds from raw output
383
- wind_speeds = _color_to_windspeed(raw)
384
-
385
- # Encode output image as base64 PNG
386
- output_image = Image.fromarray(denorm, "RGB")
387
- buf = io.BytesIO()
388
- output_image.save(buf, format="PNG")
389
- image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
390
-
391
- elapsed = time.perf_counter() - t0
392
- logger.info("Inference completed in %.2f s", elapsed)
393
-
394
- return JSONResponse({
395
- "wind_speeds": wind_speeds,
396
- "image_base64": image_b64,
397
- "width": 512,
398
- "height": 512,
399
- })
400
-
401
-
402
- @app.post("/predict_image")
403
- @limiter.limit(_rate_limit_str)
404
- async def predict_image(request: Request, file: UploadFile = File(...)):
405
- """Legacy endpoint: returns the output image as a PNG stream."""
406
- if file.content_type and not file.content_type.startswith("image/"):
407
- raise HTTPException(status_code=400, detail="Uploaded file must be an image.")
408
-
409
- input_data = await file.read()
410
- if len(input_data) > MAX_UPLOAD_BYTES:
411
- raise HTTPException(status_code=413, detail="File exceeds 10 MB limit.")
412
-
413
- try:
414
- _, denorm = _run_inference(input_data)
415
- except Exception as e:
416
- logger.exception("Inference failed")
417
- raise HTTPException(status_code=500, detail=f"Inference error: {e}")
418
-
419
- output_image = Image.fromarray(denorm, "RGB")
420
- buf = io.BytesIO()
421
- output_image.save(buf, format="PNG")
422
- buf.seek(0)
423
-
424
- return StreamingResponse(buf, media_type="image/png")
425
 
 
426
 
427
- @app.post("/predict_array")
428
- @limiter.limit(_rate_limit_str)
429
- async def predict_array(request: Request, body: ArrayPredictRequest):
430
- """Run GAN inference from a raw float array (no image encoding needed).
431
 
432
- Accepts a flat list of 786432 floats (3 * 512 * 512) in channel-first
433
- order, already normalised to [-1, 1].
 
 
434
 
435
- Returns JSON:
436
- wind_speeds: list of floats (length 262144), values 0-15 m/s
437
- image_base64: base64-encoded PNG of the output image
438
- width: 512
439
- height: 512
440
- """
441
- expected = 3 * 512 * 512
442
- if len(body.data) != expected:
443
- raise HTTPException(
444
- status_code=400,
445
- detail=f"Expected {expected} floats (3*512*512), got {len(body.data)}.",
446
- )
447
-
448
- t0 = time.perf_counter()
449
- try:
450
- arr = np.array(body.data, dtype=np.float32).reshape((3, 512, 512))
451
- raw, denorm = _run_inference_from_array(arr)
452
  except HTTPException:
453
  raise
454
  except Exception as e:
455
  logger.exception("Inference failed")
456
  raise HTTPException(status_code=500, detail=f"Inference error: {e}")
457
 
458
- wind_speeds = _color_to_windspeed(raw)
459
-
460
- output_image = Image.fromarray(denorm, "RGB")
461
- buf = io.BytesIO()
462
- output_image.save(buf, format="PNG")
463
- image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
464
-
465
  elapsed = time.perf_counter() - t0
466
- logger.info("Array inference completed in %.2f s", elapsed)
467
 
468
  return JSONResponse({
469
- "wind_speeds": wind_speeds,
470
  "image_base64": image_b64,
471
  "width": 512,
472
  "height": 512,
 
8
  import os
9
  import io
10
  import base64
11
+ import gzip
12
  import logging
13
  import time
14
  from contextlib import asynccontextmanager
 
16
 
17
  import numpy as np
18
  from PIL import Image
19
+ from fastapi import FastAPI, Request, HTTPException
20
  from fastapi.middleware.cors import CORSMiddleware
21
+ from fastapi.responses import JSONResponse
22
  from pydantic import BaseModel
23
  import onnxruntime as rt
24
  from slowapi import Limiter, _rate_limit_exceeded_handler
 
268
  )
269
 
270
 
271
+ None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
272
 
273
  def _run_inference_from_array(data: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
274
  """Run model from a pre-normalised float32 array.
 
294
  return raw, denorm
295
 
296
 
297
+ None
298
+
299
+ class PredictRequest(BaseModel):
300
+ """Request body for /predict.
301
 
302
+ The data_b64 field is a base64-encoded, gzip-compressed flat array of
303
+ 786432 float32 values (3 * 512 * 512) in channel-first order (R, G, B),
304
+ with values already normalised to [-1, 1].
305
  """
306
+ data_b64: str
307
 
308
 
309
  # ---------------------------------------------------------------------------
 
312
  @app.get("/")
313
  async def root():
314
  """Basic service info for base URL checks."""
315
+ "endpoints": ["/health", "/predict"],
 
 
 
316
  }
317
 
318
 
 
325
  }
326
 
327
 
328
+ None
329
+
330
  @app.post("/predict")
331
  @limiter.limit(_rate_limit_str)
332
+ async def predict(request: Request, body: PredictRequest):
333
+ """Run GAN inference from a gzip-compressed, base64-encoded float32 array.
334
 
335
  Returns JSON:
336
+ wind_speeds_b64: base64-encoded, gzip-compressed float32 array (length 262144)
337
  image_base64: base64-encoded PNG of the output image
338
  width: 512
339
  height: 512
340
  """
 
 
 
 
 
 
 
 
341
  t0 = time.perf_counter()
342
  try:
343
+ compressed_bytes = base64.b64decode(body.data_b64)
344
+ raw_bytes = gzip.decompress(compressed_bytes)
345
+
346
+ arr = np.frombuffer(raw_bytes, dtype=np.float32)
347
+ expected = 3 * 512 * 512
348
+ if arr.size != expected:
349
+ raise HTTPException(
350
+ status_code=400,
351
+ detail=f"Expected {expected} floats (3*512*512), got {arr.size}.",
352
+ )
353
+
354
+ arr = arr.reshape((3, 512, 512))
355
+ raw, denorm = _run_inference_from_array(arr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
+ wind_speeds_list = _color_to_windspeed(raw)
358
 
359
+ wind_speeds_arr = np.array(wind_speeds_list, dtype=np.float32)
360
+ wind_speeds_bytes = wind_speeds_arr.tobytes()
361
+ compressed_wind_speeds = gzip.compress(wind_speeds_bytes)
362
+ wind_speeds_b64 = base64.b64encode(compressed_wind_speeds).decode("ascii")
363
 
364
+ output_image = Image.fromarray(denorm, "RGB")
365
+ buf = io.BytesIO()
366
+ output_image.save(buf, format="PNG")
367
+ image_b64 = base64.b64encode(buf.getvalue()).decode("ascii")
368
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
369
  except HTTPException:
370
  raise
371
  except Exception as e:
372
  logger.exception("Inference failed")
373
  raise HTTPException(status_code=500, detail=f"Inference error: {e}")
374
 
 
 
 
 
 
 
 
375
  elapsed = time.perf_counter() - t0
376
+ logger.info("Binary inference completed in %.2f s", elapsed)
377
 
378
  return JSONResponse({
379
+ "wind_speeds_b64": wind_speeds_b64,
380
  "image_base64": image_b64,
381
  "width": 512,
382
  "height": 512,