Spaces:
Sleeping
Sleeping
File size: 4,864 Bytes
8b83582 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | """
FastAPI inference API for Image Denoising + 2x Super-Resolution.
Endpoints:
POST /predict — upload image, returns upscaled PNG
GET /health — health check
GET /metrics — Prometheus metrics (auto-instrumented)
"""
import io
import os
import sys
import time
import logging
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import Response
from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
from starlette.responses import PlainTextResponse
import config
from app.inference import load_sr_model, upscale_image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(
title="Image Denoiser + 2x Super-Resolution API",
description="Upload a noisy image, get back a clean 2x upscaled version.",
version="1.0.0",
)
# ---------------------------------------------------------------------------
# Prometheus custom metrics
# ---------------------------------------------------------------------------
REQUEST_COUNT = Counter(
"inference_requests_total",
"Total inference requests",
["noise_type", "status"],
)
INFERENCE_LATENCY = Histogram(
"inference_duration_seconds",
"Time spent running inference",
["noise_type"],
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0],
)
IMAGE_WIDTH = Histogram(
"input_image_width_pixels",
"Width of input images in pixels",
buckets=[32, 64, 128, 256, 512, 1024, 2048],
)
IMAGE_HEIGHT = Histogram(
"input_image_height_pixels",
"Height of input images in pixels",
buckets=[32, 64, 128, 256, 512, 1024, 2048],
)
# ---------------------------------------------------------------------------
# Model cache — load once at startup, reuse per request
# ---------------------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_models: dict = {}
@app.on_event("startup")
def load_all_models():
logger.info(f"Device: {DEVICE}")
for noise_type in config.NOISE_TYPES:
try:
_models[noise_type] = load_sr_model(noise_type, DEVICE)
logger.info(f"Loaded SR model: {noise_type}")
except FileNotFoundError as e:
logger.warning(f"Checkpoint missing for {noise_type}: {e}")
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
loaded = [k for k in _models]
return {
"status": "ok",
"device": str(DEVICE),
"models_loaded": loaded,
}
@app.post("/predict")
def predict(
file: UploadFile = File(..., description="Input image (JPEG, PNG, etc.)"),
noise_type: str = Form("gaussian", description="gaussian | salt_pepper | speckle"),
):
if noise_type not in config.NOISE_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid noise_type '{noise_type}'. Choose from: {config.NOISE_TYPES}",
)
if noise_type not in _models:
raise HTTPException(
status_code=503,
detail=f"Model for '{noise_type}' not loaded. Check checkpoint files.",
)
try:
image_bytes = file.file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception:
REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc()
raise HTTPException(status_code=400, detail="Could not read image file.")
W, H = image.size
IMAGE_WIDTH.observe(W)
IMAGE_HEIGHT.observe(H)
t0 = time.perf_counter()
try:
result = upscale_image(image, noise_type=noise_type, device=DEVICE, model=_models[noise_type])
except Exception as e:
REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc()
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail="Inference failed.")
elapsed = time.perf_counter() - t0
INFERENCE_LATENCY.labels(noise_type=noise_type).observe(elapsed)
REQUEST_COUNT.labels(noise_type=noise_type, status="success").inc()
logger.info(f"[{noise_type}] {W}x{H} -> {result.width}x{result.height} in {elapsed:.2f}s")
buf = io.BytesIO()
result.save(buf, format="PNG")
buf.seek(0)
return Response(content=buf.read(), media_type="image/png")
@app.get("/metrics", response_class=PlainTextResponse)
def metrics():
return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)
|