import json, logging, time, sys, os, tempfile, threading from io import BytesIO from pathlib import Path from contextlib import asynccontextmanager from typing import Literal import numpy as np import torch from PIL import Image, ImageDraw, ImageFont from scipy import ndimage from fastapi import FastAPI, File, UploadFile, Query, HTTPException from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse, JSONResponse from mewzoom.model import MewZoom logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s") logger = logging.getLogger(__name__) MEWZOOM_MODELS = {"2x": "andrewdalpino/MewZoom-V1-2X-Unet", "4x": "andrewdalpino/MewZoom-V1-4X-Unet"} MAX_DIM = {"2x": 2048, "4x": 1024, "invsr": 256} CACHE_DIR = Path("models") _DEVICE = "cuda" if torch.cuda.is_available() else "cpu" logger.info("Device: %s", _DEVICE) # ── MewZoom ────────────────────────────────────────────────── _mz_models: dict[str, MewZoom] = {} def _load_mewzoom(scale: str) -> MewZoom: if scale in _mz_models: return _mz_models[scale] mid = MEWZOOM_MODELS[scale] logger.info("Loading MewZoom %s ...", scale) CACHE_DIR.mkdir(exist_ok=True) m = MewZoom.from_pretrained(mid, cache_dir=str(CACHE_DIR)) m.to(_DEVICE).eval() _mz_models[scale] = m logger.info("MewZoom %s ready (%s params)", scale, f"{sum(p.numel() for p in m.parameters()):,}") return m def _pil_to_tensor(img: Image.Image) -> torch.Tensor: arr = np.array(img, dtype=np.float32) / 255.0 return torch.from_numpy(arr).permute(2, 0, 1) def _resize_if_needed(img: Image.Image, scale: str) -> tuple[Image.Image, bool]: md = MAX_DIM.get(scale, 1024) w, h = img.size if max(w, h) <= md: return img, False r = md / max(w, h) return img.resize((int(w * r), int(h * r)), Image.LANCZOS), True def upscale_mewzoom(image_bytes: bytes, scale: str) -> tuple[bytes, dict]: model = _load_mewzoom(scale) factor = int(scale[0]) pil = Image.open(BytesIO(image_bytes)).convert("RGB") orig = (pil.width, pil.height) pil, resized = _resize_if_needed(pil, scale) out_mp = pil.width * factor * pil.height * factor / 1e6 if out_mp > 64: raise HTTPException(400, f"Output too large ({out_mp:.0f}MP)") x = _pil_to_tensor(pil).unsqueeze(0).to(_DEVICE) with torch.inference_mode(): y = model.upscale(x) result_np = (y.squeeze(0).permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8) result = Image.fromarray(result_np) buf = BytesIO(); result.save(buf, format="PNG"); buf.seek(0) return buf.getvalue(), {"scale": scale, "input": f"{orig[0]}x{orig[1]}", "output": f"{result.width}x{result.height}", "resized": resized} # ── InvSR ──────────────────────────────────────────────────── _INVSR_PATH = Path("/app/InvSR") _sampler_invsr = None _invsr_status = "not_loaded" _invsr_error = None _invsr_jobs: dict[str, dict] = {} _job_counter = 0 def _patch_invsr_source(): p = _INVSR_PATH / "sampler_invsr.py" code = p.read_text() code = code.replace("from datapipe.datasets import create_dataset", "") code = code.replace( "class BaseSampler:\n def __init__(self, configs):\n '''\n Input:\n configs: config", "class BaseSampler:\n def __init__(self, configs, device='auto'):\n '''\n Input:\n configs: config" ) code = code.replace( "self.configs = configs\n\n self.setup_seed()\n\n self.build_model()", "self.configs = configs\n if device == 'auto':\n device = 'cuda' if torch.cuda.is_available() else 'cpu'\n self.device = torch.device(device)\n self.dtype = torch.float16 if self.device.type == 'cuda' else torch.float32\n self.setup_seed()\n self.build_model()" ) code = code.replace( "torch.cuda.manual_seed_all(seed)", "if torch.cuda.is_available():\n torch.cuda.manual_seed_all(seed)" ) code = code.replace('sd_pipe.to(f"cuda")', "sd_pipe.to(self.device)") code = code.replace("model_start.cuda()", "model_start.to(self.device)") code = code.replace('map_location=f"cuda"', "map_location=self.device") code = code.replace("im_cond.type(torch.float16)", "im_cond.type(self.dtype)") code = code.replace(".type(torch.float16)", ".type(self.dtype)") code = code.replace("data['lq'].cuda()", "data['lq'].to(self.device)") code = code.replace("util_image.img2tensor(im_cond).cuda()", "util_image.img2tensor(im_cond).to(self.device)") code = code.replace( "if in_path.is_dir():\n data_config", "if in_path.is_dir():\n from datapipe.datasets import create_dataset\n data_config" ) p.write_text(code) logger.info("InvSR source patched for CPU") def _load_invsr_sync(): global _sampler_invsr, _invsr_status, _invsr_error try: _invsr_status = "patching" _patch_invsr_source() sys.path.insert(0, str(_INVSR_PATH)) sys.path.insert(0, str(_INVSR_PATH / "src")) from omegaconf import OmegaConf from huggingface_hub import snapshot_download, hf_hub_download from sampler_invsr import InvSamplerSR invsr_cache = str(CACHE_DIR / "invsr") CACHE_DIR.mkdir(exist_ok=True) _invsr_status = "downloading_sd_turbo" logger.info("Downloading SD-Turbo (~5GB, one-time, 10-20 min)...") snapshot_download("stabilityai/sd-turbo", cache_dir=invsr_cache, resume_download=True) logger.info("SD-Turbo downloaded") _invsr_status = "downloading_noise_pred" logger.info("Downloading noise predictor...") hf_hub_download("OAOA/InvSR", "noise_predictor_sd_turbo_v5.pth", cache_dir=invsr_cache) ckpt = None for f in Path(invsr_cache).rglob("noise_predictor_sd_turbo_v5.pth"): ckpt = str(f); break if not ckpt: raise FileNotFoundError("Noise predictor not found") _invsr_status = "loading" cfg = OmegaConf.load(str(_INVSR_PATH / "configs" / "sample-sd-turbo.yaml")) cfg.sd_pipe.params.torch_dtype = "torch.float32" cfg.sd_pipe.params.cache_dir = invsr_cache cfg.sd_pipe.params.local_files_only = True cfg.model_start.ckpt_path = ckpt cfg.timesteps = [200]; cfg.bs = 1; cfg.tiled_vae = True cfg.color_fix = "wavelet"; cfg.basesr.chopping.pch_size = 128 cfg.basesr.chopping.extra_bs = 8 logger.info("Loading InvSR into memory...") _sampler_invsr = InvSamplerSR(cfg, device="auto") if _DEVICE == "cpu": _sampler_invsr.sd_pipe = _sampler_invsr.sd_pipe.to(dtype=torch.float32) _invsr_status = "ready" logger.info("InvSR ready on %s", _DEVICE) except Exception as e: _invsr_status = "error" _invsr_error = str(e) logger.error("InvSR load failed: %s", e) def upscale_invsr(image_bytes: bytes, num_steps: int = 1) -> bytes: if _invsr_status == "error": raise HTTPException(500, f"InvSR failed to load: {_invsr_error}") if _sampler_invsr is None: raise HTTPException(503, f"InvSR is still {_invsr_status}. Check /health for progress.") sampler = _sampler_invsr sys.path.insert(0, str(_INVSR_PATH)) from utils import util_image tmp = tempfile.NamedTemporaryFile(suffix=".png", delete=False) try: tmp.write(image_bytes); tmp.close() im = util_image.imread(tmp.name, chn="rgb", dtype="float32") finally: os.unlink(tmp.name) im_cond = util_image.img2tensor(im).to(sampler.device) steps = {1: [200], 2: [200, 100], 3: [200, 100, 50], 4: [200, 150, 100, 50], 5: [250, 200, 150, 100, 50]} sampler.configs.timesteps = steps.get(num_steps, [200]) sampler.configs.basesr.chopping.pch_size = 128 result = sampler.sample_func(im_cond).squeeze(0) result = (result * 255).clip(0, 255).astype(np.uint8) img = Image.fromarray(result) buf = BytesIO(); img.save(buf, format="PNG"); buf.seek(0) return buf.getvalue() # ── Finegrain ESRGAN 4X ───────────────────────────────────── FG_ESRGAN_PATH = CACHE_DIR / "esrgan" _fg_esrgan_model = None _fg_esrgan_loading = False def _conv_block(in_nc, out_nc): return torch.nn.Sequential( torch.nn.Conv2d(in_nc, out_nc, kernel_size=3, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), ) class _ResidualDenseBlock5C(torch.nn.Module): def __init__(self, nf=64, gc=32): super().__init__() self.conv1 = _conv_block(nf, gc) self.conv2 = _conv_block(nf + gc, gc) self.conv3 = _conv_block(nf + 2 * gc, gc) self.conv4 = _conv_block(nf + 3 * gc, gc) self.conv5 = torch.nn.Sequential(torch.nn.Conv2d(nf + 4 * gc, nf, kernel_size=3, padding=1)) def forward(self, x): x1 = self.conv1(x) x2 = self.conv2(torch.cat((x, x1), 1)) x3 = self.conv3(torch.cat((x, x1, x2), 1)) x4 = self.conv4(torch.cat((x, x1, x2, x3), 1)) x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) return x5 * 0.2 + x class _RRDB(torch.nn.Module): def __init__(self, nf): super().__init__() self.RDB1 = _ResidualDenseBlock5C(nf) self.RDB2 = _ResidualDenseBlock5C(nf) self.RDB3 = _ResidualDenseBlock5C(nf) def forward(self, x): out = self.RDB1(x) out = self.RDB2(out) out = self.RDB3(out) return out * 0.2 + x class _SkipBlock(torch.nn.Module): def __init__(self, sub): super().__init__() self.sub = sub def forward(self, x): return x + self.sub(x) class _RRDBNet(torch.nn.Module): def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23): super().__init__() self.model = torch.nn.Sequential( torch.nn.Conv2d(in_nc, nf, kernel_size=3, padding=1), _SkipBlock(torch.nn.Sequential( *(_RRDB(nf) for _ in range(nb)), torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), )), torch.nn.Upsample(scale_factor=2), torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), torch.nn.Upsample(scale_factor=2), torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), torch.nn.Conv2d(nf, nf, kernel_size=3, padding=1), torch.nn.LeakyReLU(negative_slope=0.2, inplace=True), torch.nn.Conv2d(nf, out_nc, kernel_size=3, padding=1), ) def forward(self, x): return self.model(x) def _load_fg_esrgan(): global _fg_esrgan_model, _fg_esrgan_loading if _fg_esrgan_model is not None: return _fg_esrgan_model if _fg_esrgan_loading: return None _fg_esrgan_loading = True try: from huggingface_hub import hf_hub_download logger.info("Downloading ESRGAN 4x-UltraSharp model...") ckpt = hf_hub_download("philz1337x/upscaler", "4x-UltraSharp.pth", cache_dir=str(FG_ESRGAN_PATH)) logger.info("Loading ESRGAN...") state = torch.load(ckpt, map_location="cpu", weights_only=True) model = _RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23) model.load_state_dict(state, strict=False) model.eval() _fg_esrgan_model = model logger.info("ESRGAN 4X ready (CPU)") except Exception as e: logger.error("Failed to load ESRGAN: %s", e) _fg_esrgan_model = None _fg_esrgan_loading = False return _fg_esrgan_model def upscale_finegrain(image_bytes: bytes, use_sd_refinement: bool = False) -> tuple[bytes, dict]: model = _load_fg_esrgan() if model is None: raise HTTPException(503, "ESRGAN model not loaded. Check /health.") img = Image.open(BytesIO(image_bytes)).convert("RGB") in_w, in_h = img.size # ESRGAN upscale (4X with tiling for large images) tile_size = 512 overlap = 64 w, h = img.size out = Image.new("RGB", (w * 4, h * 4)) if w <= tile_size and h <= tile_size: img_np = np.array(img)[:, :, ::-1] img_np = np.transpose(img_np, (2, 0, 1))[np.newaxis, :].astype(np.float32) / 255.0 with torch.no_grad(): result = model(torch.from_numpy(img_np)) result = result.squeeze().clamp(0, 1).numpy() result = np.transpose(result, (1, 2, 0))[:, :, ::-1] out = Image.fromarray((result * 255).astype(np.uint8)) else: # Tiled upscale stride = tile_size - overlap cols = -(-max(0, w - overlap) // stride) if w > tile_size else 1 rows = -(-max(0, h - overlap) // stride) if h > tile_size else 1 out_arr = np.zeros((h * 4, w * 4, 3), dtype=np.float32) weight = np.zeros((h * 4, w * 4, 1), dtype=np.float32) for row in range(rows): y1 = min(row * stride, h - tile_size) if h > tile_size else 0 y2 = min(y1 + tile_size, h) for col in range(cols): x1 = min(col * stride, w - tile_size) if w > tile_size else 0 x2 = min(x1 + tile_size, w) tile = img.crop((x1, y1, x2, y2)) tile_np = np.array(tile)[:, :, ::-1] tile_np = np.transpose(tile_np, (2, 0, 1))[np.newaxis, :].astype(np.float32) / 255.0 with torch.no_grad(): res_tile = model(torch.from_numpy(tile_np)) res_tile = res_tile.squeeze().clamp(0, 1).numpy() res_tile = np.transpose(res_tile, (1, 2, 0)) ys, ye = y1 * 4, y2 * 4 xs, xe = x1 * 4, x2 * 4 out_arr[ys:ye, xs:xe] += res_tile weight[ys:ye, xs:xe] += 1.0 out_arr = out_arr / np.maximum(weight, 1e-8) out = Image.fromarray((out_arr[:, :, ::-1] * 255).astype(np.uint8)) if use_sd_refinement and torch.cuda.is_available(): out = out # SD refinement placeholder — will be full pipeline in future buf = BytesIO(); out.save(buf, format="PNG"); buf.seek(0) info = {"model": "esrgan_4x", "input": f"{in_w}x{in_h}", "output": f"{out.width}x{out.height}"} return buf.getvalue(), info # ── Metrics ────────────────────────────────────────────────── def compute_metrics(img: Image.Image) -> dict: arr = np.array(img.convert("L"), dtype=np.float64) lap = ndimage.laplace(arr) hist = np.histogram(arr, bins=256, range=(0, 256))[0] hist = hist[hist > 0] / hist.sum() mag = np.hypot(ndimage.sobel(arr, axis=0), ndimage.sobel(arr, axis=1)) return {"size": f"{img.width}x{img.height}", "sharpness": round(float(lap.var()), 4), "entropy": round(float(-np.sum(hist * np.log2(hist))), 4), "edge_density": round(float(np.mean(mag > mag.mean() + mag.std())), 4), "contrast_std": round(float(np.array(img).std()), 2)} def generate_comparison(image_bytes: bytes) -> tuple[bytes, dict]: original = Image.open(BytesIO(image_bytes)).convert("RGB") metrics = {"original": compute_metrics(original)} upscaled = {} for scale in MEWZOOM_MODELS: t0 = time.perf_counter() rb, info = upscale_mewzoom(image_bytes, scale) t = time.perf_counter() - t0 img = Image.open(BytesIO(rb)).convert("RGB") upscaled[scale] = img metrics[scale] = {**compute_metrics(img), "time_s": round(t, 3), **info} orig_r = original.resize(upscaled["2x"].size, Image.LANCZOS) images = [orig_r, upscaled["2x"], upscaled["4x"]] labels = ["Original", "MewZoom 2X", "MewZoom 4X"] lh, gap = 30, 8 mh = max(i.height for i in images) tw = sum(i.width for i in images) + gap * (len(images) - 1) canvas = Image.new("RGB", (tw, mh + lh), (30, 30, 30)) draw = ImageDraw.Draw(canvas) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 14) except Exception: font = ImageFont.load_default() x = 0 for img, lbl in zip(images, labels): canvas.paste(img, (x, lh)) bb = draw.textbbox((0, 0), lbl, font=font) draw.text((x + (img.width - (bb[2] - bb[0])) // 2, 6), lbl, fill=(255, 255, 255), font=font) x += img.width + gap buf = BytesIO(); canvas.save(buf, format="PNG"); buf.seek(0) return buf.getvalue(), metrics # ── FastAPI ────────────────────────────────────────────────── @asynccontextmanager async def lifespan(app: FastAPI): logger.info("Loading MewZoom models...") for s in MEWZOOM_MODELS: _load_mewzoom(s) threading.Thread(target=_load_invsr_sync, daemon=True).start() threading.Thread(target=_load_fg_esrgan, daemon=True).start() yield app = FastAPI(title="Super-Resolution API", version="2.0.0", lifespan=lifespan, description="MewZoom 2X/4X + InvSR diffusion 4X + comparison + quality metrics") app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) @app.get("/") @app.get("/health") async def health(): return JSONResponse({ "status": "healthy", "device": _DEVICE, "models": ["2x","4x","invsr","finegrain"], "gpu": torch.cuda.is_available(), "invsr_status": _invsr_status, "invsr_error": _invsr_error, "finegrain_loaded": _fg_esrgan_model is not None, }) @app.post("/upscale/2x") async def route_2x(file: UploadFile = File(...)): r, i = upscale_mewzoom(await file.read(), "2x") return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)}) @app.post("/upscale/4x") async def route_4x(file: UploadFile = File(...)): r, i = upscale_mewzoom(await file.read(), "4x") return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)}) @app.post("/upscale/compare") async def route_compare(file: UploadFile = File(...), format: Literal["image","json","both"] = Query("both")): img, m = generate_comparison(await file.read()) if format == "json": return JSONResponse(m) if format == "image": return StreamingResponse(BytesIO(img), media_type="image/png") return StreamingResponse(BytesIO(img), media_type="image/png", headers={"X-Metrics": json.dumps(m)}) @app.post("/upscale/metrics") async def route_metrics(file: UploadFile = File(...)): _, m = generate_comparison(await file.read()) return JSONResponse(m) @app.post("/upscale/finegrain") async def route_finegrain( file: UploadFile = File(...), sd_refinement: bool = Query(False, description="Use SD1.5 refinement (GPU only)"), ): try: r, i = upscale_finegrain(await file.read(), use_sd_refinement=sd_refinement) except HTTPException: raise except Exception as e: raise HTTPException(500, detail=f"Finegrain failed: {e}") return StreamingResponse(BytesIO(r), media_type="image/png", headers={"X-Info": json.dumps(i)}) @app.post("/upscale/invsr") async def route_invsr(file: UploadFile = File(...), num_steps: int = Query(1, ge=1, le=5)): if _invsr_status == "error": raise HTTPException(500, f"InvSR not loaded: {_invsr_error}") if _sampler_invsr is None: raise HTTPException(503, f"InvSR is {_invsr_status}. Check /health for status.") global _job_counter _job_counter += 1 job_id = str(_job_counter) _invsr_jobs[job_id] = {"status": "queued", "image_bytes": await file.read(), "num_steps": num_steps} threading.Thread(target=_run_invsr_job, args=(job_id,), daemon=True).start() return JSONResponse({"job_id": job_id, "status": "queued", "check": f"/upscale/invsr/{job_id}"}) def _run_invsr_job(job_id: str): job = _invsr_jobs.get(job_id) if not job: return try: job["status"] = "processing" job["result"] = upscale_invsr(job["image_bytes"], job["num_steps"]) job["status"] = "done" except Exception as e: job["status"] = "error" job["error"] = str(e) @app.get("/upscale/invsr/{job_id}") async def route_invsr_status(job_id: str): job = _invsr_jobs.get(job_id) if not job: raise HTTPException(404, "Job not found") if job["status"] == "done": return StreamingResponse(BytesIO(job["result"]), media_type="image/png") return JSONResponse({"job_id": job_id, "status": job["status"], "error": job.get("error")})