Spaces:
Sleeping
Sleeping
| """ | |
| FastAPI inference API for Image Denoising + 2x Super-Resolution. | |
| Endpoints: | |
| POST /predict — upload image, returns upscaled PNG | |
| GET /health — health check | |
| GET /metrics — Prometheus metrics (auto-instrumented) | |
| """ | |
| import io | |
| import os | |
| import sys | |
| import time | |
| import logging | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| from fastapi import FastAPI, File, Form, UploadFile, HTTPException | |
| from fastapi.responses import Response | |
| from prometheus_client import Counter, Histogram, generate_latest, CONTENT_TYPE_LATEST | |
| from starlette.responses import PlainTextResponse | |
| import config | |
| from app.inference import load_sr_model, upscale_image | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # --------------------------------------------------------------------------- | |
| # App | |
| # --------------------------------------------------------------------------- | |
| app = FastAPI( | |
| title="Image Denoiser + 2x Super-Resolution API", | |
| description="Upload a noisy image, get back a clean 2x upscaled version.", | |
| version="1.0.0", | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Prometheus custom metrics | |
| # --------------------------------------------------------------------------- | |
| REQUEST_COUNT = Counter( | |
| "inference_requests_total", | |
| "Total inference requests", | |
| ["noise_type", "status"], | |
| ) | |
| INFERENCE_LATENCY = Histogram( | |
| "inference_duration_seconds", | |
| "Time spent running inference", | |
| ["noise_type"], | |
| buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0], | |
| ) | |
| IMAGE_WIDTH = Histogram( | |
| "input_image_width_pixels", | |
| "Width of input images in pixels", | |
| buckets=[32, 64, 128, 256, 512, 1024, 2048], | |
| ) | |
| IMAGE_HEIGHT = Histogram( | |
| "input_image_height_pixels", | |
| "Height of input images in pixels", | |
| buckets=[32, 64, 128, 256, 512, 1024, 2048], | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Model cache — load once at startup, reuse per request | |
| # --------------------------------------------------------------------------- | |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| _models: dict = {} | |
| def load_all_models(): | |
| logger.info(f"Device: {DEVICE}") | |
| for noise_type in config.NOISE_TYPES: | |
| try: | |
| _models[noise_type] = load_sr_model(noise_type, DEVICE) | |
| logger.info(f"Loaded SR model: {noise_type}") | |
| except FileNotFoundError as e: | |
| logger.warning(f"Checkpoint missing for {noise_type}: {e}") | |
| # --------------------------------------------------------------------------- | |
| # Routes | |
| # --------------------------------------------------------------------------- | |
| def health(): | |
| loaded = [k for k in _models] | |
| return { | |
| "status": "ok", | |
| "device": str(DEVICE), | |
| "models_loaded": loaded, | |
| } | |
| def predict( | |
| file: UploadFile = File(..., description="Input image (JPEG, PNG, etc.)"), | |
| noise_type: str = Form("gaussian", description="gaussian | salt_pepper | speckle"), | |
| ): | |
| if noise_type not in config.NOISE_TYPES: | |
| raise HTTPException( | |
| status_code=400, | |
| detail=f"Invalid noise_type '{noise_type}'. Choose from: {config.NOISE_TYPES}", | |
| ) | |
| if noise_type not in _models: | |
| raise HTTPException( | |
| status_code=503, | |
| detail=f"Model for '{noise_type}' not loaded. Check checkpoint files.", | |
| ) | |
| try: | |
| image_bytes = file.file.read() | |
| image = Image.open(io.BytesIO(image_bytes)).convert("RGB") | |
| except Exception: | |
| REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc() | |
| raise HTTPException(status_code=400, detail="Could not read image file.") | |
| W, H = image.size | |
| IMAGE_WIDTH.observe(W) | |
| IMAGE_HEIGHT.observe(H) | |
| t0 = time.perf_counter() | |
| try: | |
| result = upscale_image(image, noise_type=noise_type, device=DEVICE, model=_models[noise_type]) | |
| except Exception as e: | |
| REQUEST_COUNT.labels(noise_type=noise_type, status="error").inc() | |
| logger.error(f"Inference error: {e}") | |
| raise HTTPException(status_code=500, detail="Inference failed.") | |
| elapsed = time.perf_counter() - t0 | |
| INFERENCE_LATENCY.labels(noise_type=noise_type).observe(elapsed) | |
| REQUEST_COUNT.labels(noise_type=noise_type, status="success").inc() | |
| logger.info(f"[{noise_type}] {W}x{H} -> {result.width}x{result.height} in {elapsed:.2f}s") | |
| buf = io.BytesIO() | |
| result.save(buf, format="PNG") | |
| buf.seek(0) | |
| return Response(content=buf.read(), media_type="image/png") | |
| def metrics(): | |
| return PlainTextResponse(generate_latest(), media_type=CONTENT_TYPE_LATEST) | |