""" FastAPI inference API for Image Denoising + 2x Super-Resolution. Endpoints: POST /predict — upload image, returns upscaled PNG GET /health — health check GET /metrics — Prometheus metrics (auto-instrumented) """ import io import os import sys import time import logging sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) import torch import numpy as np from PIL import Image from fastapi import FastAPI, File, Form, UploadFile, HTTPException from fastapi.responses import Response from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST from starlette.responses import PlainTextResponse import config from app.inference import load_sr_model, upscale_image logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # --------------------------------------------------------------------------- # App # --------------------------------------------------------------------------- app = FastAPI( title="Image Denoiser + 2x Super-Resolution API", description="Upload a noisy image, get back a clean 2x upscaled version.", version="1.0.0", ) # --------------------------------------------------------------------------- # Prometheus custom metrics # --------------------------------------------------------------------------- REQUEST_COUNT = Counter( "inference_requests_total", "Total inference requests", ["noise_type", "status"], ) INFERENCE_LATENCY = Histogram( "inference_duration_seconds", "Time spent running inference", ["noise_type"], buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0], ) IMAGE_WIDTH = Histogram( "input_image_width_pixels", "Width of input images in pixels", buckets=[32, 64, 128, 256, 512, 1024, 2048], ) IMAGE_HEIGHT = Histogram( "input_image_height_pixels", "Height of input images in pixels", buckets=[32, 64, 128, 256, 512, 1024, 2048], ) # --------------------------------------------------------------------------- # Model cache — load once at startup, reuse per request # --------------------------------------------------------------------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") _models: dict = {} @app.on_event("startup") def load_all_models(): logger.info(f"Device: {DEVICE}") for noise_type in config.NOISE_TYPES: try: _models[noise_type] = load_sr_model(noise_type, DEVICE) logger.info(f"Loaded SR model: {noise_type}") except FileNotFoundError as e: logger.warning(f"Checkpoint missing for {noise_type}: {e}") # --------------------------------------------------------------------------- # Routes # --------------------------------------------------------------------------- @app.get("/health") def health(): loaded = [k for k in _models] return { "status": "ok", "device": str(DEVICE), "models_loaded": loaded, } @app.post("/predict") def predict( file: UploadFile = File(..., description="Input image (JPEG, PNG, etc.)"), noise_type: str = Form("gaussian", description="gaussian | salt_pepper | speckle"), ): if noise_type not in config.NOISE_TYPES: raise HTTPException( status_code=400, detail=f"Invalid noise_type '{noise_type}'. Choose from: {config.NOISE_TYPES}", ) if noise_type not in _models: raise HTTPException( status_code=503, detail=f"Model for '{noise_type}' not loaded. Check checkpoint files.", ) try: image_bytes = file.file.read() image = Image.open(io.BytesIO(image_bytes)).convert("RGB") except Exception: REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc() raise HTTPException(status_code=400, detail="Could not read image file.") W, H = image.size IMAGE_WIDTH.observe(W) IMAGE_HEIGHT.observe(H) t0 = time.perf_counter() try: result = upscale_image(image, noise_type=noise_type, device=DEVICE, model=_models[noise_type]) except Exception as e: REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc() logger.error(f"Inference error: {e}") raise HTTPException(status_code=500, detail="Inference failed.") elapsed = time.perf_counter() - t0 INFERENCE_LATENCY.labels(noise_type=noise_type).observe(elapsed) REQUEST_COUNT.labels(noise_type=noise_type, status="success").inc() logger.info(f"[{noise_type}] {W}x{H} -> {result.width}x{result.height} in {elapsed:.2f}s") buf = io.BytesIO() result.save(buf, format="PNG") buf.seek(0) return Response(content=buf.read(), media_type="image/png") @app.get("/metrics", response_class=PlainTextResponse) def metrics(): return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)