ihtesham0345 commited on
Commit
1ea9514
Β·
1 Parent(s): 3cab236

Add real InvSR model with CPU/float32 support (SD-Turbo + noise predictor)

Browse files
Files changed (3) hide show
  1. Dockerfile +5 -2
  2. app.py +219 -99
  3. requirements.txt +7 -0
Dockerfile CHANGED
@@ -1,7 +1,7 @@
1
  FROM python:3.12-slim
2
 
3
  RUN apt-get update && apt-get install -y --no-install-recommends \
4
- libgl1 libglib2.0-0 git curl \
5
  && rm -rf /var/lib/apt/lists/*
6
 
7
  RUN useradd -m -u 1000 user
@@ -9,9 +9,12 @@ USER user
9
  ENV PATH="/home/user/.local/bin:$PATH"
10
  WORKDIR /app
11
 
 
 
 
12
  COPY --chown=user requirements.txt .
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
  COPY --chown=user app.py .
16
 
17
- CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.12-slim
2
 
3
  RUN apt-get update && apt-get install -y --no-install-recommends \
4
+ libgl1 libglib2.0-0 libsm6 libxext6 libxrender1 git curl \
5
  && rm -rf /var/lib/apt/lists/*
6
 
7
  RUN useradd -m -u 1000 user
 
9
  ENV PATH="/home/user/.local/bin:$PATH"
10
  WORKDIR /app
11
 
12
+ # Clone InvSR source (custom diffusers pipeline + noise predictor support)
13
+ RUN git clone --depth 1 https://github.com/zsyOAOA/InvSR.git /app/InvSR
14
+
15
  COPY --chown=user requirements.txt .
16
  RUN pip install --no-cache-dir -r requirements.txt
17
 
18
  COPY --chown=user app.py .
19
 
20
+ CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860", "--timeout-keep-alive", "600"]
app.py CHANGED
@@ -1,6 +1,4 @@
1
- import json
2
- import logging
3
- import time
4
  from io import BytesIO
5
  from pathlib import Path
6
  from contextlib import asynccontextmanager
@@ -15,112 +13,243 @@ from fastapi import FastAPI, File, UploadFile, Query, HTTPException
15
  from fastapi.middleware.cors import CORSMiddleware
16
  from fastapi.responses import StreamingResponse, JSONResponse
17
 
18
- from torchvision.io import decode_image, ImageReadMode
19
- from torchvision.transforms.v2 import ToDtype, ToPILImage
20
-
21
  from mewzoom.model import MewZoom
22
 
23
- MODELS_CONFIG = {"2x": "andrewdalpino/MewZoom-V1-2X-Unet", "4x": "andrewdalpino/MewZoom-V1-4X-Unet"}
24
- MAX_DIM = {"2x": 2048, "4x": 1024}
 
25
  CACHE_DIR = Path("models")
26
 
27
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
28
  logger = logging.getLogger(__name__)
29
 
30
- _models: dict[str, MewZoom] = {}
31
- _image_to_tensor = ToDtype(torch.float32, scale=True)
32
- _tensor_to_pil = ToPILImage()
33
- _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
34
 
35
 
36
- def _load_model(scale: str) -> MewZoom:
37
- if scale in _models:
38
- return _models[scale]
39
- model_id = MODELS_CONFIG[scale]
40
- logger.info("Loading %s (%s) on %s ...", scale, model_id, _DEVICE)
41
  CACHE_DIR.mkdir(exist_ok=True)
42
- model = MewZoom.from_pretrained(model_id, cache_dir=str(CACHE_DIR))
43
- model.to(_DEVICE).eval()
44
- _models[scale] = model
45
- logger.info("%s loaded (%s params)", scale, f"{sum(p.numel() for p in model.parameters()):,}")
46
- return model
 
 
 
 
 
47
 
48
 
49
  def _resize_if_needed(img: Image.Image, scale: str) -> tuple[Image.Image, bool]:
50
- max_dim = MAX_DIM[scale]
51
  w, h = img.size
52
  if max(w, h) <= max_dim:
53
  return img, False
54
- ratio = max_dim / max(w, h)
55
- return img.resize((int(w * ratio), int(h * ratio)), Image.LANCZOS), True
56
-
57
-
58
- def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
59
- arr = np.array(img, dtype=np.float32) / 255.0
60
- return torch.from_numpy(arr).permute(2, 0, 1)
61
 
62
 
63
- def upscale_image(image_bytes: bytes, scale: str) -> tuple[bytes, dict]:
64
- model = _load_model(scale)
65
  factor = int(scale[0])
66
- try:
67
- pil = Image.open(BytesIO(image_bytes)).convert("RGB")
68
- except Exception as e:
69
- raise HTTPException(400, f"Bad image: {e}")
70
  orig = (pil.width, pil.height)
71
  pil, resized = _resize_if_needed(pil, scale)
72
  out_mp = pil.width * factor * pil.height * factor / 1e6
73
  if out_mp > 64:
74
- raise HTTPException(400, f"Output too large ({out_mp:.0f}MP). Use smaller image.")
75
  x = _pil_to_tensor(pil).unsqueeze(0).to(_DEVICE)
76
  with torch.inference_mode():
77
  y = model.upscale(x)
78
- result = _tensor_to_pil(y.squeeze(0).cpu())
79
- buf = BytesIO()
80
- result.save(buf, format="PNG")
81
- buf.seek(0)
82
  return buf.getvalue(), {"scale": scale, "input": f"{orig[0]}x{orig[1]}", "output": f"{result.width}x{result.height}", "resized": resized}
83
 
84
 
85
- def _laplacian_variance(img: Image.Image) -> float:
86
- lap = ndimage.laplace(np.array(img.convert("L"), dtype=np.float64))
87
- return float(lap.var())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
 
 
89
 
90
- def _entropy(img: Image.Image) -> float:
91
- hist = np.histogram(np.array(img.convert("L")), bins=256, range=(0, 256))[0]
92
- hist = hist[hist > 0] / hist.sum()
93
- return float(-np.sum(hist * np.log2(hist)))
94
 
 
 
 
 
95
 
96
- def _edge_density(img: Image.Image) -> float:
97
- arr = np.array(img.convert("L"), dtype=np.float64)
98
- mag = np.hypot(ndimage.sobel(arr, axis=0), ndimage.sobel(arr, axis=1))
99
- return float(np.mean(mag > mag.mean() + mag.std()))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
 
 
 
 
101
 
 
 
 
 
 
 
 
 
102
  def compute_metrics(img: Image.Image) -> dict:
103
- return {"size": f"{img.width}x{img.height}", "sharpness": round(_laplacian_variance(img), 4), "entropy": round(_entropy(img), 4), "edge_density": round(_edge_density(img), 4), "contrast_std": round(float(np.array(img).std()), 2)}
 
 
 
 
 
 
 
 
 
 
 
104
 
105
 
106
  def generate_comparison(image_bytes: bytes) -> tuple[bytes, dict]:
107
  original = Image.open(BytesIO(image_bytes)).convert("RGB")
108
  metrics = {"original": compute_metrics(original)}
109
  upscaled = {}
110
- for scale in MODELS_CONFIG:
111
  t0 = time.perf_counter()
112
- result_bytes, info = upscale_image(image_bytes, scale)
113
- elapsed = time.perf_counter() - t0
114
- img = Image.open(BytesIO(result_bytes)).convert("RGB")
115
  upscaled[scale] = img
116
- metrics[scale] = {**compute_metrics(img), "time_s": round(elapsed, 3), **info}
117
  orig_r = original.resize(upscaled["2x"].size, Image.LANCZOS)
118
  images = [orig_r, upscaled["2x"], upscaled["4x"]]
119
  labels = ["Original", "MewZoom 2X", "MewZoom 4X"]
120
- label_h, gap = 30, 8
121
- max_h = max(i.height for i in images)
122
- total_w = sum(i.width for i in images) + gap * (len(images) - 1)
123
- canvas = Image.new("RGB", (total_w, max_h + label_h), (30, 30, 30))
124
  draw = ImageDraw.Draw(canvas)
125
  try:
126
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
@@ -128,29 +257,28 @@ def generate_comparison(image_bytes: bytes) -> tuple[bytes, dict]:
128
  font = ImageFont.load_default()
129
  x = 0
130
  for img, lbl in zip(images, labels):
131
- canvas.paste(img, (x, label_h))
132
- bbox = draw.textbbox((0, 0), lbl, font=font)
133
- tw = bbox[2] - bbox[0]
134
- draw.text((x + (img.width - tw) // 2, 6), lbl, fill=(255, 255, 255), font=font)
135
  x += img.width + gap
136
- buf = BytesIO()
137
- canvas.save(buf, format="PNG")
138
- buf.seek(0)
139
  return buf.getvalue(), metrics
140
 
141
 
 
142
  @asynccontextmanager
143
  async def lifespan(app: FastAPI):
144
- logger.info("Starting on %s, loading models...", _DEVICE)
145
- for scale in MODELS_CONFIG:
146
- _load_model(scale)
147
  yield
148
 
149
 
150
  app = FastAPI(
151
  title="Super-Resolution API",
152
- description="MewZoom 2X/4X upscaling + comparison + quality metrics. InvSR requires GPU (not on free tier).",
153
- version="1.0.0",
154
  lifespan=lifespan,
155
  )
156
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
@@ -159,50 +287,42 @@ app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], all
159
  @app.get("/")
160
  @app.get("/health")
161
  async def health():
162
- return JSONResponse({"status": "healthy", "device": str(_DEVICE), "models": list(MODELS_CONFIG.keys()), "gpu": torch.cuda.is_available()})
163
 
164
 
165
  @app.post("/upscale/2x")
166
  async def route_2x(file: UploadFile = File(...)):
167
- result, info = upscale_image(await file.read(), "2x")
168
- return StreamingResponse(BytesIO(result), media_type="image/png", headers={"X-Info": json.dumps(info)})
169
 
170
 
171
  @app.post("/upscale/4x")
172
  async def route_4x(file: UploadFile = File(...)):
173
- result, info = upscale_image(await file.read(), "4x")
174
- return StreamingResponse(BytesIO(result), media_type="image/png", headers={"X-Info": json.dumps(info)})
175
 
176
 
177
  @app.post("/upscale/compare")
178
  async def route_compare(file: UploadFile = File(...), format: Literal["image", "json", "both"] = Query("both")):
179
- img, metrics = generate_comparison(await file.read())
180
- if format == "json":
181
- return JSONResponse(metrics)
182
- if format == "image":
183
- return StreamingResponse(BytesIO(img), media_type="image/png")
184
- return StreamingResponse(BytesIO(img), media_type="image/png", headers={"X-Metrics": json.dumps(metrics)})
185
 
186
 
187
  @app.post("/upscale/metrics")
188
  async def route_metrics(file: UploadFile = File(...)):
189
- _, metrics = generate_comparison(await file.read())
190
- return JSONResponse(metrics)
191
 
192
 
193
  @app.post("/upscale/invsr")
194
  async def route_invsr(
195
  file: UploadFile = File(...),
196
- num_steps: int = Query(1, ge=1, le=5),
197
- tile_size: int = Query(128, ge=64, le=512),
198
  ):
199
- if torch.cuda.is_available():
200
- raise HTTPException(501, detail="InvSR GPU pipeline not bundled in this Space. Use the Colab notebook.")
201
- # Fallback to MewZoom 4X on CPU
202
- logger.info("InvSR endpoint called on CPU β€” falling back to MewZoom 4X")
203
- result, info = upscale_image(await file.read(), "4x")
204
- info["fallback"] = "InvSR not available on CPU, used MewZoom 4X instead"
205
- return StreamingResponse(
206
- BytesIO(result), media_type="image/png",
207
- headers={"X-Info": json.dumps(info)},
208
- )
 
1
+ import json, logging, time, sys, os, tempfile
 
 
2
  from io import BytesIO
3
  from pathlib import Path
4
  from contextlib import asynccontextmanager
 
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import StreamingResponse, JSONResponse
15
 
 
 
 
16
  from mewzoom.model import MewZoom
17
 
18
+ # ── Config ──────────────────────────────────────────────────
19
+ MEWZOOM_MODELS = {"2x": "andrewdalpino/MewZoom-V1-2X-Unet", "4x": "andrewdalpino/MewZoom-V1-4X-Unet"}
20
+ MAX_DIM = {"2x": 2048, "4x": 1024, "invsr": 256}
21
  CACHE_DIR = Path("models")
22
 
23
  logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s")
24
  logger = logging.getLogger(__name__)
25
 
26
+ _DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
27
+ logger.info("Device: %s", _DEVICE)
28
+
29
+ # ── MewZoom Models ──────────────────────────────────────────
30
+ _mz_models: dict[str, MewZoom] = {}
31
 
32
 
33
+ def _load_mewzoom(scale: str) -> MewZoom:
34
+ if scale in _mz_models:
35
+ return _mz_models[scale]
36
+ mid = MEWZOOM_MODELS[scale]
37
+ logger.info("Loading MewZoom %s (%s) ...", scale, mid)
38
  CACHE_DIR.mkdir(exist_ok=True)
39
+ m = MewZoom.from_pretrained(mid, cache_dir=str(CACHE_DIR))
40
+ m.to(_DEVICE).eval()
41
+ _mz_models[scale] = m
42
+ logger.info("MewZoom %s ready (%s params)", scale, f"{sum(p.numel() for p in m.parameters()):,}")
43
+ return m
44
+
45
+
46
+ def _pil_to_tensor(img: Image.Image) -> torch.Tensor:
47
+ arr = np.array(img, dtype=np.float32) / 255.0
48
+ return torch.from_numpy(arr).permute(2, 0, 1)
49
 
50
 
51
  def _resize_if_needed(img: Image.Image, scale: str) -> tuple[Image.Image, bool]:
52
+ max_dim = MAX_DIM.get(scale, 1024)
53
  w, h = img.size
54
  if max(w, h) <= max_dim:
55
  return img, False
56
+ r = max_dim / max(w, h)
57
+ return img.resize((int(w * r), int(h * r)), Image.LANCZOS), True
 
 
 
 
 
58
 
59
 
60
+ def upscale_mewzoom(image_bytes: bytes, scale: str) -> tuple[bytes, dict]:
61
+ model = _load_mewzoom(scale)
62
  factor = int(scale[0])
63
+ pil = Image.open(BytesIO(image_bytes)).convert("RGB")
 
 
 
64
  orig = (pil.width, pil.height)
65
  pil, resized = _resize_if_needed(pil, scale)
66
  out_mp = pil.width * factor * pil.height * factor / 1e6
67
  if out_mp > 64:
68
+ raise HTTPException(400, f"Output too large ({out_mp:.0f}MP)")
69
  x = _pil_to_tensor(pil).unsqueeze(0).to(_DEVICE)
70
  with torch.inference_mode():
71
  y = model.upscale(x)
72
+ result_np = (y.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
73
+ result = Image.fromarray(result_np)
74
+ buf = BytesIO(); result.save(buf, format="PNG"); buf.seek(0)
 
75
  return buf.getvalue(), {"scale": scale, "input": f"{orig[0]}x{orig[1]}", "output": f"{result.width}x{result.height}", "resized": resized}
76
 
77
 
78
+ # ── InvSR Model (Diffusion 4X) ──────────────────────────────
79
+ _INVSR_PATH = Path("/app/InvSR")
80
+ _sampler_invsr = None
81
+
82
+
83
+ def _patch_invsr():
84
+ """Patch InvSR source for CPU/float32 support."""
85
+ p = _INVSR_PATH / "sampler_invsr.py"
86
+ code = p.read_text()
87
+
88
+ # Remove basicsr import chain (not needed for inference)
89
+ code = code.replace("from datapipe.datasets import create_dataset", "")
90
+
91
+ # Add device param to BaseSampler
92
+ old_init = """class BaseSampler:
93
+ def __init__(self, configs):
94
+ '''
95
+ Input:
96
+ configs: config, see the yaml file in folder ./configs/
97
+ configs.sampler_config.{start_timesteps, padding_mod, seed, sf, num_sample_steps}
98
+ seed: int, random seed
99
+ '''
100
+ self.configs = configs
101
+
102
+ self.setup_seed()
103
+
104
+ self.build_model()
105
+
106
+ def setup_seed(self, seed=None):
107
+ seed = self.configs.seed if seed is None else seed
108
+ random.seed(seed)
109
+ np.random.seed(seed)
110
+ torch.manual_seed(seed)
111
+ torch.cuda.manual_seed_all(seed)"""
112
+
113
+ new_init = """class BaseSampler:
114
+ def __init__(self, configs, device='auto'):
115
+ self.configs = configs
116
+ if device == 'auto':
117
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
118
+ self.device = torch.device(device)
119
+ self.dtype = torch.float16 if self.device.type == 'cuda' else torch.float32
120
+ self.setup_seed()
121
+ self.build_model()
122
+
123
+ def setup_seed(self, seed=None):
124
+ seed = self.configs.seed if seed is None else seed
125
+ random.seed(seed)
126
+ np.random.seed(seed)
127
+ torch.manual_seed(seed)
128
+ if torch.cuda.is_available():
129
+ torch.cuda.manual_seed_all(seed)"""
130
+
131
+ code = code.replace(old_init, new_init)
132
+
133
+ # Replace .cuda() and .type(torch.float16) with device-aware versions
134
+ code = code.replace('sd_pipe.to(f"cuda")', "sd_pipe.to(self.device)")
135
+ code = code.replace("model_start.cuda()", "model_start.to(self.device)")
136
+ code = code.replace('map_location=f"cuda"', "map_location=self.device")
137
+ code = code.replace("im_cond.type(torch.float16)", "im_cond.type(self.dtype)")
138
+ code = code.replace(".type(torch.float16)", ".type(self.dtype)")
139
+ code = code.replace("data['lq'].cuda()", "data['lq'].to(self.device)")
140
+ code = code.replace("util_image.img2tensor(im_cond).cuda()", "util_image.img2tensor(im_cond).to(self.device)")
141
+
142
+ # Lazy import create_dataset in inference method
143
+ code = code.replace(
144
+ "if in_path.is_dir():\n data_config",
145
+ "if in_path.is_dir():\n from datapipe.datasets import create_dataset\n data_config",
146
+ )
147
 
148
+ p.write_text(code)
149
+ logger.info("InvSR sampler patched for CPU/float32")
150
 
 
 
 
 
151
 
152
+ def _load_invsr():
153
+ global _sampler_invsr
154
+ if _sampler_invsr is not None:
155
+ return _sampler_invsr
156
 
157
+ _patch_invsr()
158
+ sys.path.insert(0, str(_INVSR_PATH))
159
+ sys.path.insert(0, str(_INVSR_PATH / "src"))
160
+
161
+ from omegaconf import OmegaConf
162
+ from sampler_invsr import InvSamplerSR
163
+
164
+ cfg = OmegaConf.load(str(_INVSR_PATH / "configs" / "sample-sd-turbo.yaml"))
165
+ cfg.sd_pipe.params.torch_dtype = "torch.float32" if _DEVICE == "cpu" else "torch.float16"
166
+ cfg.sd_pipe.params.cache_dir = str(CACHE_DIR / "invsr")
167
+ CACHE_DIR.mkdir(exist_ok=True)
168
+
169
+ # Download noise predictor
170
+ from torch.hub import download_url_to_file
171
+ ckpt = CACHE_DIR / "invsr" / "noise_predictor_sd_turbo_v5.pth"
172
+ ckpt.parent.mkdir(exist_ok=True)
173
+ if not ckpt.exists():
174
+ logger.info("Downloading noise predictor (~800MB)...")
175
+ download_url_to_file(
176
+ "https://huggingface.co/OAOA/InvSR/resolve/main/noise_predictor_sd_turbo_v5.pth",
177
+ str(ckpt), progress=True,
178
+ )
179
+ cfg.model_start.ckpt_path = str(ckpt)
180
+
181
+ cfg.timesteps = [200]; cfg.bs = 1; cfg.tiled_vae = True
182
+ cfg.color_fix = "wavelet"; cfg.basesr.chopping.pch_size = 128
183
+ cfg.basesr.chopping.extra_bs = 8
184
+
185
+ logger.info("Loading InvSR sampler (SD-Turbo ~5GB download on first run)...")
186
+ _sampler_invsr = InvSamplerSR(cfg, device="auto")
187
+ if _DEVICE == "cpu":
188
+ _sampler_invsr.sd_pipe = _sampler_invsr.sd_pipe.to(dtype=torch.float32)
189
+ logger.info("InvSR ready on %s", _DEVICE)
190
+ return _sampler_invsr
191
+
192
+
193
+ def upscale_invsr(image_bytes: bytes, num_steps: int = 1) -> bytes:
194
+ sampler = _load_invsr()
195
+ sys.path.insert(0, str(_INVSR_PATH))
196
+ from utils import util_image
197
+
198
+ # Write bytes to temp file for cv2.imread
199
+ tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
200
+ try:
201
+ tmp.write(image_bytes); tmp.close()
202
+ im = util_image.imread(tmp.name, chn="rgb", dtype="float32")
203
+ finally:
204
+ os.unlink(tmp.name)
205
+
206
+ im_cond = util_image.img2tensor(im).to(sampler.device)
207
 
208
+ steps_map = {1: [200], 2: [200, 100], 3: [200, 100, 50], 4: [200, 150, 100, 50], 5: [250, 200, 150, 100, 50]}
209
+ sampler.configs.timesteps = steps_map.get(num_steps, [200])
210
+ sampler.configs.basesr.chopping.pch_size = 128
211
 
212
+ result = sampler.sample_func(im_cond).squeeze(0)
213
+ result = (result * 255).clip(0, 255).astype(np.uint8)
214
+ img = Image.fromarray(result)
215
+ buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0)
216
+ return buf.getvalue()
217
+
218
+
219
+ # ── Metrics ─────────────────────────────────────────────────
220
  def compute_metrics(img: Image.Image) -> dict:
221
+ arr = np.array(img.convert("L"), dtype=np.float64)
222
+ lap = ndimage.laplace(arr)
223
+ hist = np.histogram(arr, bins=256, range=(0, 256))[0]
224
+ hist = hist[hist > 0] / hist.sum()
225
+ mag = np.hypot(ndimage.sobel(arr, axis=0), ndimage.sobel(arr, axis=1))
226
+ return {
227
+ "size": f"{img.width}x{img.height}",
228
+ "sharpness": round(float(lap.var()), 4),
229
+ "entropy": round(float(-np.sum(hist * np.log2(hist))), 4),
230
+ "edge_density": round(float(np.mean(mag > mag.mean() + mag.std())), 4),
231
+ "contrast_std": round(float(np.array(img).std()), 2),
232
+ }
233
 
234
 
235
  def generate_comparison(image_bytes: bytes) -> tuple[bytes, dict]:
236
  original = Image.open(BytesIO(image_bytes)).convert("RGB")
237
  metrics = {"original": compute_metrics(original)}
238
  upscaled = {}
239
+ for scale in MEWZOOM_MODELS:
240
  t0 = time.perf_counter()
241
+ rb, info = upscale_mewzoom(image_bytes, scale)
242
+ t = time.perf_counter() - t0
243
+ img = Image.open(BytesIO(rb)).convert("RGB")
244
  upscaled[scale] = img
245
+ metrics[scale] = {**compute_metrics(img), "time_s": round(t, 3), **info}
246
  orig_r = original.resize(upscaled["2x"].size, Image.LANCZOS)
247
  images = [orig_r, upscaled["2x"], upscaled["4x"]]
248
  labels = ["Original", "MewZoom 2X", "MewZoom 4X"]
249
+ lh, gap = 30, 8
250
+ mh = max(i.height for i in images)
251
+ tw = sum(i.width for i in images) + gap * (len(images) - 1)
252
+ canvas = Image.new("RGB", (tw, mh + lh), (30, 30, 30))
253
  draw = ImageDraw.Draw(canvas)
254
  try:
255
  font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14)
 
257
  font = ImageFont.load_default()
258
  x = 0
259
  for img, lbl in zip(images, labels):
260
+ canvas.paste(img, (x, lh))
261
+ bb = draw.textbbox((0, 0), lbl, font=font)
262
+ tw2 = bb[2] - bb[0]
263
+ draw.text((x + (img.width - tw2) // 2, 6), lbl, fill=(255, 255, 255), font=font)
264
  x += img.width + gap
265
+ buf = BytesIO(); canvas.save(buf, format="PNG"); buf.seek(0)
 
 
266
  return buf.getvalue(), metrics
267
 
268
 
269
+ # ── FastAPI App ─────────────────────────────────────────────
270
  @asynccontextmanager
271
  async def lifespan(app: FastAPI):
272
+ logger.info("Loading MewZoom models...")
273
+ for s in MEWZOOM_MODELS:
274
+ _load_mewzoom(s)
275
  yield
276
 
277
 
278
  app = FastAPI(
279
  title="Super-Resolution API",
280
+ description="MewZoom 2X/4X + InvSR 4X diffusion + comparison + quality metrics",
281
+ version="2.0.0",
282
  lifespan=lifespan,
283
  )
284
  app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
 
287
  @app.get("/")
288
  @app.get("/health")
289
  async def health():
290
+ return JSONResponse({"status": "healthy", "device": _DEVICE, "models": ["2x", "4x", "invsr"], "gpu": torch.cuda.is_available()})
291
 
292
 
293
  @app.post("/upscale/2x")
294
  async def route_2x(file: UploadFile = File(...)):
295
+ r, i = upscale_mewzoom(await file.read(), "2x")
296
+ return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)})
297
 
298
 
299
  @app.post("/upscale/4x")
300
  async def route_4x(file: UploadFile = File(...)):
301
+ r, i = upscale_mewzoom(await file.read(), "4x")
302
+ return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)})
303
 
304
 
305
  @app.post("/upscale/compare")
306
  async def route_compare(file: UploadFile = File(...), format: Literal["image", "json", "both"] = Query("both")):
307
+ img, m = generate_comparison(await file.read())
308
+ if format == "json": return JSONResponse(m)
309
+ if format == "image": return StreamingResponse(BytesIO(img), media_type="image/png")
310
+ return StreamingResponse(BytesIO(img), media_type="image/png", headers={"X-Metrics": json.dumps(m)})
 
 
311
 
312
 
313
  @app.post("/upscale/metrics")
314
  async def route_metrics(file: UploadFile = File(...)):
315
+ _, m = generate_comparison(await file.read())
316
+ return JSONResponse(m)
317
 
318
 
319
  @app.post("/upscale/invsr")
320
  async def route_invsr(
321
  file: UploadFile = File(...),
322
+ num_steps: int = Query(1, ge=1, le=5, description="1=fast, 5=best quality"),
 
323
  ):
324
+ try:
325
+ result = upscale_invsr(await file.read(), num_steps=num_steps)
326
+ except Exception as e:
327
+ raise HTTPException(500, detail=f"InvSR failed: {e}")
328
+ return StreamingResponse(BytesIO(result), media_type="image/png")
 
 
 
 
 
requirements.txt CHANGED
@@ -7,3 +7,10 @@ torchvision>=0.15.0
7
  Pillow>=10.0.0
8
  scipy>=1.10.0
9
  numpy>=1.23.0
 
 
 
 
 
 
 
 
7
  Pillow>=10.0.0
8
  scipy>=1.10.0
9
  numpy>=1.23.0
10
+ diffusers>=0.28.0
11
+ transformers>=4.37.0
12
+ accelerate>=0.28.0
13
+ omegaconf>=2.3.0
14
+ loguru>=0.7.0
15
+ einops>=0.7.0
16
+ opencv-python-headless>=4.8.0