File size: 5,242 Bytes
34e750a 7df3ea4 34e750a 7df3ea4 4984743 2389e63 34e750a 2389e63 7df3ea4 2389e63 eb396d8 34e750a 2389e63 7df3ea4 34e750a 7df3ea4 34e750a 3e3c736 34e750a 7df3ea4 c8867e7 3e3c736 7df3ea4 c8867e7 34e750a 7df3ea4 34e750a eb396d8 c8867e7 34e750a c8867e7 34e750a 4984743 c8867e7 4984743 34e750a 4984743 34e750a 4984743 34e750a c8867e7 34e750a 3e3c736 34e750a 3e3c736 34e750a 7df3ea4 34e750a 4984743 c8867e7 34e750a 4984743 c8867e7 4984743 34e750a 4984743 c8867e7 4984743 34e750a 4984743 c8867e7 34e750a 3e3c736 34e750a 7df3ea4 c8867e7 4984743 c8867e7 b51e8ba c8867e7 4984743 7df3ea4 c8867e7 34e750a 4984743 c8867e7 34e750a 2389e63 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import os
import io
import torch
import logging
import base64
import requests
import numpy as np
import cv2
from PIL import Image
from gfpgan import GFPGANer
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact
logger = logging.getLogger(__name__)
class EndpointHandler:
def __init__(self, path="."):
logger.info("π [INIT] GFPGAN + Real-ESRGAN handler starting...")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.half = self.device == "cuda"
self.path = path
# Model URLs (GFPGAN + RealESRGAN)
self.gfpgan_model_url = (
"https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
)
self.realesr_model_url = (
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
)
# Local cache paths
self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")
self.bg_upsampler = None
self.restorer = None
# Ensure model weights exist
self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
self._ensure_model(self.realesr_model_url, self.realesr_model_path)
logger.info(f"π§ Device: {self.device}, half precision: {self.half}")
def _ensure_model(self, url, path):
"""Download model if missing."""
if not os.path.exists(path):
logger.info(f"β¬οΈ Downloading model from {url}")
r = requests.get(url, timeout=60)
r.raise_for_status()
with open(path, "wb") as f:
f.write(r.content)
logger.info(f"β
Model saved to {path}")
else:
logger.info(f"π Found cached model: {path}")
def _init_models(self):
"""Lazy-load ESRGAN + GFPGAN models."""
if self.bg_upsampler is None:
logger.info("π§© Initializing Real-ESRGAN upsampler...")
model = SRVGGNetCompact(
num_in_ch=3, num_out_ch=3, num_feat=64,
num_conv=32, upscale=4, act_type="prelu"
)
self.bg_upsampler = RealESRGANer(
scale=4,
model_path=self.realesr_model_path,
model=model,
tile=400,
tile_pad=10,
pre_pad=0,
half=self.half,
device=self.device,
)
if self.restorer is None:
logger.info("𧬠Initializing GFPGAN restorer...")
self.restorer = GFPGANer(
model_path=self.gfpgan_model_path,
upscale=2,
arch="clean",
channel_multiplier=2,
bg_upsampler=self.bg_upsampler,
)
logger.info("β
Models ready!")
def _load_image(self, data):
"""Accept base64, raw bytes, or URL and return PIL image."""
if isinstance(data, dict) and "inputs" in data:
data = data["inputs"]
if isinstance(data, (bytes, bytearray)):
logger.info("π¦ Received raw bytes input")
return Image.open(io.BytesIO(data)).convert("RGB")
if isinstance(data, str):
if data.startswith("http"):
logger.info(f"π Downloading image from URL: {data}")
resp = requests.get(data)
return Image.open(io.BytesIO(resp.content)).convert("RGB")
else:
# Base64
logger.info("𧬠Decoding base64 image input")
try:
decoded = base64.b64decode(data)
return Image.open(io.BytesIO(decoded)).convert("RGB")
except Exception as e:
logger.error(f"β Failed to decode base64: {e}")
raise ValueError("Invalid base64 image input")
raise ValueError("Unsupported input type")
def __call__(self, data):
logger.info("βοΈ Starting GFPGAN inference pipeline...")
self._init_models()
# Load input
image = self._load_image(data)
input_img = np.array(image, dtype=np.uint8)
logger.info(f"π Input image shape: {input_img.shape}")
# Restore face(s)
cropped_faces, restored_faces, restored_img = self.restorer.enhance(
input_img, has_aligned=False, only_center_face=False, paste_back=True
)
logger.info("πΌοΈ Restoration complete, preparing output...")
# β
Convert color from BGR β RGB (fix hue issue)
restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
restored_img_rgb = np.clip(restored_img_rgb, 0, 255).astype(np.uint8)
# β
Encode output as base64 string for JSON
_, buffer = cv2.imencode(".jpg", restored_img_rgb)
b64_output = base64.b64encode(buffer).decode("utf-8")
logger.info("β
Returning base64-encoded image JSON response")
return {
"image": b64_output,
"status": "success",
"info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3 (RGB fixed)"
}
|