ihtesham0345 commited on
Commit
60113d3
·
1 Parent(s): 50b0d2b

Add Super-Resolution API: MewZoom 2X/4X + comparison + metrics

Browse files
Files changed (3) hide show
  1. Dockerfile +17 -0
  2. app.py +195 -0
  3. requirements.txt +9 -0
Dockerfile ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ libgl1 libglib2.0-0 git curl \
5
+ && rm -rf /var/lib/apt/lists/*
6
+
7
+ RUN useradd -m -u 1000 user
8
+ USER user
9
+ ENV PATH="/home/user/.local/bin:$PATH"
10
+ WORKDIR /app
11
+
12
+ COPY --chown=user requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+
15
+ COPY --chown=user app.py .
16
+
17
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ import time
4
+ from io import BytesIO
5
+ from pathlib import Path
6
+ from contextlib import asynccontextmanager
7
+ from typing import Literal
8
+
9
+ import numpy as np
10
+ import torch
11
+ from PIL import Image, ImageDraw, ImageFont
12
+ from scipy import ndimage
13
+
14
+ from fastapi import FastAPI, File, UploadFile, Query, HTTPException
15
+ from fastapi.middleware.cors import CORSMiddleware
16
+ from fastapi.responses import StreamingResponse, JSONResponse
17
+
18
+ from torchvision.io import decode_image, ImageReadMode
19
+ from torchvision.transforms.v2 import ToDtype, ToPILImage
20
+
21
+ from mewzoom.model import MewZoom
22
+
23
+ MODELS_CONFIG = {"2x": "andrewdalpino/MewZoom-V1-2X-Unet", "4x": "andrewdalpino/MewZoom-V1-4X-Unet"}
24
+ MAX_DIM = {"2x": 2048, "4x": 1024}
25
+ CACHE_DIR = Path("models")
26
+
27
+ logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
28
+ logger = logging.getLogger(__name__)
29
+
30
+ _models: dict[str, MewZoom] = {}
31
+ _image_to_tensor = ToDtype(torch.float32, scale=True)
32
+ _tensor_to_pil = ToPILImage()
33
+ _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
34
+
35
+
36
+ def _load_model(scale: str) -> MewZoom:
37
+ if scale in _models:
38
+ return _models[scale]
39
+ model_id = MODELS_CONFIG[scale]
40
+ logger.info("Loading %s (%s) on %s ...", scale, model_id, _DEVICE)
41
+ CACHE_DIR.mkdir(exist_ok=True)
42
+ model = MewZoom.from_pretrained(model_id, cache_dir=str(CACHE_DIR))
43
+ model.to(_DEVICE).eval()
44
+ _models[scale] = model
45
+ logger.info("%s loaded (%s params)", scale, f"{sum(p.numel() for p in model.parameters()):,}")
46
+ return model
47
+
48
+
49
+ def _resize_if_needed(img: Image.Image, scale: str) -> tuple[Image.Image, bool]:
50
+ max_dim = MAX_DIM[scale]
51
+ w, h = img.size
52
+ if max(w, h) <= max_dim:
53
+ return img, False
54
+ ratio = max_dim / max(w, h)
55
+ return img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS), True
56
+
57
+
58
+ def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
59
+ arr = np.array(img, dtype=np.float32) / 255.0
60
+ return torch.from_numpy(arr).permute(2, 0, 1)
61
+
62
+
63
+ def upscale_image(image_bytes: bytes, scale: str) -> tuple[bytes, dict]:
64
+ model = _load_model(scale)
65
+ factor = int(scale[0])
66
+ try:
67
+ pil = Image.open(BytesIO(image_bytes)).convert("RGB")
68
+ except Exception as e:
69
+ raise HTTPException(400, f"Bad image: {e}")
70
+ orig = (pil.width, pil.height)
71
+ pil, resized = _resize_if_needed(pil, scale)
72
+ out_mp = pil.width * factor * pil.height * factor / 1e6
73
+ if out_mp > 64:
74
+ raise HTTPException(400, f"Output too large ({out_mp:.0f}MP). Use smaller image.")
75
+ x = _pil_to_tensor(pil).unsqueeze(0).to(_DEVICE)
76
+ with torch.inference_mode():
77
+ y = model.upscale(x)
78
+ result = _tensor_to_pil(y.squeeze(0).cpu())
79
+ buf = BytesIO()
80
+ result.save(buf, format="PNG")
81
+ buf.seek(0)
82
+ return buf.getvalue(), {"scale": scale, "input": f"{orig[0]}x{orig[1]}", "output": f"{result.width}x{result.height}", "resized": resized}
83
+
84
+
85
+ def _laplacian_variance(img: Image.Image) -> float:
86
+ lap = ndimage.laplace(np.array(img.convert("L"), dtype=np.float64))
87
+ return float(lap.var())
88
+
89
+
90
+ def _entropy(img: Image.Image) -> float:
91
+ hist = np.histogram(np.array(img.convert("L")), bins=256, range=(0, 256))[0]
92
+ hist = hist[hist > 0] / hist.sum()
93
+ return float(-np.sum(hist * np.log2(hist)))
94
+
95
+
96
+ def _edge_density(img: Image.Image) -> float:
97
+ arr = np.array(img.convert("L"), dtype=np.float64)
98
+ mag = np.hypot(ndimage.sobel(arr, axis=0), ndimage.sobel(arr, axis=1))
99
+ return float(np.mean(mag > mag.mean() + mag.std()))
100
+
101
+
102
+ def compute_metrics(img: Image.Image) -> dict:
103
+ return {"size": f"{img.width}x{img.height}", "sharpness": round(_laplacian_variance(img), 4), "entropy": round(_entropy(img), 4), "edge_density": round(_edge_density(img), 4), "contrast_std": round(float(np.array(img).std()), 2)}
104
+
105
+
106
+ def generate_comparison(image_bytes: bytes) -> tuple[bytes, dict]:
107
+ original = Image.open(BytesIO(image_bytes)).convert("RGB")
108
+ metrics = {"original": compute_metrics(original)}
109
+ upscaled = {}
110
+ for scale in MODELS_CONFIG:
111
+ t0 = time.perf_counter()
112
+ result_bytes, info = upscale_image(image_bytes, scale)
113
+ elapsed = time.perf_counter() - t0
114
+ img = Image.open(BytesIO(result_bytes)).convert("RGB")
115
+ upscaled[scale] = img
116
+ metrics[scale] = {**compute_metrics(img), "time_s": round(elapsed, 3), **info}
117
+ orig_r = original.resize(upscaled["2x"].size, Image.LANCZOS)
118
+ images = [orig_r, upscaled["2x"], upscaled["4x"]]
119
+ labels = ["Original", "MewZoom 2X", "MewZoom 4X"]
120
+ label_h, gap = 30, 8
121
+ max_h = max(i.height for i in images)
122
+ total_w = sum(i.width for i in images) + gap * (len(images) - 1)
123
+ canvas = Image.new("RGB", (total_w, max_h + label_h), (30, 30, 30))
124
+ draw = ImageDraw.Draw(canvas)
125
+ try:
126
+ font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
127
+ except Exception:
128
+ font = ImageFont.load_default()
129
+ x = 0
130
+ for img, lbl in zip(images, labels):
131
+ canvas.paste(img, (x, label_h))
132
+ bbox = draw.textbbox((0, 0), lbl, font=font)
133
+ tw = bbox[2] - bbox[0]
134
+ draw.text((x + (img.width - tw) // 2, 6), lbl, fill=(255, 255, 255), font=font)
135
+ x += img.width + gap
136
+ buf = BytesIO()
137
+ canvas.save(buf, format="PNG")
138
+ buf.seek(0)
139
+ return buf.getvalue(), metrics
140
+
141
+
142
+ @asynccontextmanager
143
+ async def lifespan(app: FastAPI):
144
+ logger.info("Starting on %s, loading models...", _DEVICE)
145
+ for scale in MODELS_CONFIG:
146
+ _load_model(scale)
147
+ yield
148
+
149
+
150
+ app = FastAPI(
151
+ title="Super-Resolution API",
152
+ description="MewZoom 2X/4X upscaling + comparison + quality metrics. InvSR requires GPU (not on free tier).",
153
+ version="1.0.0",
154
+ lifespan=lifespan,
155
+ )
156
+ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
157
+
158
+
159
+ @app.get("/")
160
+ @app.get("/health")
161
+ async def health():
162
+ return JSONResponse({"status": "healthy", "device": str(_DEVICE), "models": list(MODELS_CONFIG.keys()), "gpu": torch.cuda.is_available()})
163
+
164
+
165
+ @app.post("/upscale/2x")
166
+ async def route_2x(file: UploadFile = File(...)):
167
+ result, info = upscale_image(await file.read(), "2x")
168
+ return StreamingResponse(BytesIO(result), media_type="image/png", headers={"X-Info": json.dumps(info)})
169
+
170
+
171
+ @app.post("/upscale/4x")
172
+ async def route_4x(file: UploadFile = File(...)):
173
+ result, info = upscale_image(await file.read(), "4x")
174
+ return StreamingResponse(BytesIO(result), media_type="image/png", headers={"X-Info": json.dumps(info)})
175
+
176
+
177
+ @app.post("/upscale/compare")
178
+ async def route_compare(file: UploadFile = File(...), format: Literal["image", "json", "both"] = Query("both")):
179
+ img, metrics = generate_comparison(await file.read())
180
+ if format == "json":
181
+ return JSONResponse(metrics)
182
+ if format == "image":
183
+ return StreamingResponse(BytesIO(img), media_type="image/png")
184
+ return StreamingResponse(BytesIO(img), media_type="image/png", headers={"X-Metrics": json.dumps(metrics)})
185
+
186
+
187
+ @app.post("/upscale/metrics")
188
+ async def route_metrics(file: UploadFile = File(...)):
189
+ _, metrics = generate_comparison(await file.read())
190
+ return JSONResponse(metrics)
191
+
192
+
193
+ @app.post("/upscale/invsr")
194
+ async def route_invsr(file: UploadFile = File(...)):
195
+ raise HTTPException(400, detail="InvSR (diffusion 4X) needs GPU. This Space is CPU. Use /upscale/2x or /upscale/4x.")
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ fastapi>=0.110.0
2
+ uvicorn[standard]>=0.29.0
3
+ python-multipart>=0.0.9
4
+ mewzoom~=1.0.0
5
+ torch>=2.0.0
6
+ torchvision>=0.15.0
7
+ Pillow>=10.0.0
8
+ scipy>=1.10.0
9
+ numpy>=1.23.0