File size: 4,864 Bytes
8b83582
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
FastAPI inference API for Image Denoising + 2x Super-Resolution.

Endpoints:
  POST /predict        — upload image, returns upscaled PNG
  GET  /health         — health check
  GET  /metrics        — Prometheus metrics (auto-instrumented)
"""
import io
import os
import sys
import time
import logging

sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))

import torch
import numpy as np
from PIL import Image
from fastapi import FastAPI, File, Form, UploadFile, HTTPException
from fastapi.responses import Response
from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST
from starlette.responses import PlainTextResponse

import config
from app.inference import load_sr_model, upscale_image

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# ---------------------------------------------------------------------------
# App
# ---------------------------------------------------------------------------
app = FastAPI(
    title="Image Denoiser + 2x Super-Resolution API",
    description="Upload a noisy image, get back a clean 2x upscaled version.",
    version="1.0.0",
)

# ---------------------------------------------------------------------------
# Prometheus custom metrics
# ---------------------------------------------------------------------------
REQUEST_COUNT = Counter(
    "inference_requests_total",
    "Total inference requests",
    ["noise_type", "status"],
)

INFERENCE_LATENCY = Histogram(
    "inference_duration_seconds",
    "Time spent running inference",
    ["noise_type"],
    buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0],
)

IMAGE_WIDTH = Histogram(
    "input_image_width_pixels",
    "Width of input images in pixels",
    buckets=[32, 64, 128, 256, 512, 1024, 2048],
)

IMAGE_HEIGHT = Histogram(
    "input_image_height_pixels",
    "Height of input images in pixels",
    buckets=[32, 64, 128, 256, 512, 1024, 2048],
)

# ---------------------------------------------------------------------------
# Model cache — load once at startup, reuse per request
# ---------------------------------------------------------------------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
_models: dict = {}


@app.on_event("startup")
def load_all_models():
    logger.info(f"Device: {DEVICE}")
    for noise_type in config.NOISE_TYPES:
        try:
            _models[noise_type] = load_sr_model(noise_type, DEVICE)
            logger.info(f"Loaded SR model: {noise_type}")
        except FileNotFoundError as e:
            logger.warning(f"Checkpoint missing for {noise_type}: {e}")


# ---------------------------------------------------------------------------
# Routes
# ---------------------------------------------------------------------------
@app.get("/health")
def health():
    loaded = [k for k in _models]
    return {
        "status": "ok",
        "device": str(DEVICE),
        "models_loaded": loaded,
    }


@app.post("/predict")
def predict(
    file: UploadFile = File(..., description="Input image (JPEG, PNG, etc.)"),
    noise_type: str = Form("gaussian", description="gaussian | salt_pepper | speckle"),
):
    if noise_type not in config.NOISE_TYPES:
        raise HTTPException(
            status_code=400,
            detail=f"Invalid noise_type '{noise_type}'. Choose from: {config.NOISE_TYPES}",
        )

    if noise_type not in _models:
        raise HTTPException(
            status_code=503,
            detail=f"Model for '{noise_type}' not loaded. Check checkpoint files.",
        )

    try:
        image_bytes = file.file.read()
        image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
    except Exception:
        REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc()
        raise HTTPException(status_code=400, detail="Could not read image file.")

    W, H = image.size
    IMAGE_WIDTH.observe(W)
    IMAGE_HEIGHT.observe(H)

    t0 = time.perf_counter()
    try:
        result = upscale_image(image, noise_type=noise_type, device=DEVICE, model=_models[noise_type])
    except Exception as e:
        REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc()
        logger.error(f"Inference error: {e}")
        raise HTTPException(status_code=500, detail="Inference failed.")

    elapsed = time.perf_counter() - t0
    INFERENCE_LATENCY.labels(noise_type=noise_type).observe(elapsed)
    REQUEST_COUNT.labels(noise_type=noise_type, status="success").inc()

    logger.info(f"[{noise_type}] {W}x{H} -> {result.width}x{result.height} in {elapsed:.2f}s")

    buf = io.BytesIO()
    result.save(buf, format="PNG")
    buf.seek(0)
    return Response(content=buf.read(), media_type="image/png")


@app.get("/metrics", response_class=PlainTextResponse)
def metrics():
    return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST)