image-denoiser / api /main.py
Kajuto's picture
Initial commit - image denoiser + SR + MLOps stack
8b83582
"""
FastAPI inference API for Image Denoising + 2x Super-Resolution.
Endpoints:
POST /predict — upload image, returns upscaled PNG
GET /health — health check
GET /metrics — Prometheus metrics (auto-instrumented)
"""
import io
import os
import sys
import time
import logging
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import Response
from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
from starlette.responses import PlainTextResponse
import config
from app.inference import load_sr_model, upscale_image
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(
title="Image Denoiser + 2x Super-Resolution API",
description="Upload a noisy image, get back a clean 2x upscaled version.",
version="1.0.0",
)
# ---------------------------------------------------------------------------
# Prometheus custom metrics
# ---------------------------------------------------------------------------
REQUEST_COUNT = Counter(
"inference_requests_total",
"Total inference requests",
["noise_type", "status"],
)
INFERENCE_LATENCY = Histogram(
"inference_duration_seconds",
"Time spent running inference",
["noise_type"],
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0],
)
IMAGE_WIDTH = Histogram(
"input_image_width_pixels",
"Width of input images in pixels",
buckets=[32, 64, 128, 256, 512, 1024, 2048],
)
IMAGE_HEIGHT = Histogram(
"input_image_height_pixels",
"Height of input images in pixels",
buckets=[32, 64, 128, 256, 512, 1024, 2048],
)
# ---------------------------------------------------------------------------
# Model cache — load once at startup, reuse per request
# ---------------------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_models: dict = {}
@app.on_event("startup")
def load_all_models():
logger.info(f"Device: {DEVICE}")
for noise_type in config.NOISE_TYPES:
try:
_models[noise_type] = load_sr_model(noise_type, DEVICE)
logger.info(f"Loaded SR model: {noise_type}")
except FileNotFoundError as e:
logger.warning(f"Checkpoint missing for {noise_type}: {e}")
# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
loaded = [k for k in _models]
return {
"status": "ok",
"device": str(DEVICE),
"models_loaded": loaded,
}
@app.post("/predict")
def predict(
file: UploadFile = File(..., description="Input image (JPEG, PNG, etc.)"),
noise_type: str = Form("gaussian", description="gaussian | salt_pepper | speckle"),
):
if noise_type not in config.NOISE_TYPES:
raise HTTPException(
status_code=400,
detail=f"Invalid noise_type '{noise_type}'. Choose from: {config.NOISE_TYPES}",
)
if noise_type not in _models:
raise HTTPException(
status_code=503,
detail=f"Model for '{noise_type}' not loaded. Check checkpoint files.",
)
try:
image_bytes = file.file.read()
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
except Exception:
REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc()
raise HTTPException(status_code=400, detail="Could not read image file.")
W, H = image.size
IMAGE_WIDTH.observe(W)
IMAGE_HEIGHT.observe(H)
t0 = time.perf_counter()
try:
result = upscale_image(image, noise_type=noise_type, device=DEVICE, model=_models[noise_type])
except Exception as e:
REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc()
logger.error(f"Inference error: {e}")
raise HTTPException(status_code=500, detail="Inference failed.")
elapsed = time.perf_counter() - t0
INFERENCE_LATENCY.labels(noise_type=noise_type).observe(elapsed)
REQUEST_COUNT.labels(noise_type=noise_type, status="success").inc()
logger.info(f"[{noise_type}] {W}x{H} -> {result.width}x{result.height} in {elapsed:.2f}s")
buf = io.BytesIO()
result.save(buf, format="PNG")
buf.seek(0)
return Response(content=buf.read(), media_type="image/png")
@app.get("/metrics", response_class=PlainTextResponse)
def metrics():
return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)