|
|
import os |
|
|
import io |
|
|
import torch |
|
|
import base64 |
|
|
import requests |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
from realesrgan import RealESRGANer |
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet |
|
|
|
|
|
|
|
|
class EndpointHandler: |
|
|
def __init__(self, path="."): |
|
|
print("π [INIT] Starting EndpointHandler initialization...") |
|
|
print(f"π Working directory: {os.getcwd()}") |
|
|
print(f"π Model path root: {path}") |
|
|
|
|
|
self.model_url = ( |
|
|
"https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/" |
|
|
"RealESRGAN_x4plus.pth" |
|
|
) |
|
|
self.model_path = os.path.join(path, "RealESRGAN_x4plus.pth") |
|
|
|
|
|
if not os.path.exists(self.model_path): |
|
|
print(f"π₯ [DOWNLOAD] Fetching model weights from {self.model_url}") |
|
|
r = requests.get(self.model_url) |
|
|
r.raise_for_status() |
|
|
with open(self.model_path, "wb") as f: |
|
|
f.write(r.content) |
|
|
print(f"β
[DOWNLOAD] Saved model to {self.model_path}") |
|
|
else: |
|
|
print(f"β
[CACHE] Model already exists at {self.model_path}") |
|
|
|
|
|
print("π§ [MODEL] Building RRDBNet...") |
|
|
model = RRDBNet( |
|
|
num_in_ch=3, |
|
|
num_out_ch=3, |
|
|
num_feat=64, |
|
|
num_block=23, |
|
|
num_grow_ch=32, |
|
|
scale=4, |
|
|
) |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(f"π» [DEVICE] Using device: {device}") |
|
|
|
|
|
self.upsampler = RealESRGANer( |
|
|
scale=4, |
|
|
model_path=self.model_path, |
|
|
model=model, |
|
|
half=False, |
|
|
device=device, |
|
|
) |
|
|
|
|
|
print("β
[INIT DONE] Real-ESRGAN model initialized and ready.\n\n") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, data): |
|
|
print("π°οΈ [CALL] Endpoint invoked!") |
|
|
print(f"π¦ [CALL] Raw data type: {type(data)}") |
|
|
print(f"π [CALL] Data preview: {str(data)[:300]}...") |
|
|
|
|
|
try: |
|
|
print("β‘οΈ [STEP] Preprocessing input...") |
|
|
image = self.preprocess(data) |
|
|
print(f"β
[STEP] Preprocessing complete! Image size: {image.size}") |
|
|
|
|
|
print("β‘οΈ [STEP] Running inference...") |
|
|
output = self.inference(image) |
|
|
print("β
[STEP] Inference complete!") |
|
|
|
|
|
print("β‘οΈ [STEP] Encoding output image...") |
|
|
result = self.postprocess(output) |
|
|
print("β
[STEP] Postprocessing complete!") |
|
|
|
|
|
return result |
|
|
except Exception as e: |
|
|
print("π₯ [ERROR] Exception during inference:", str(e)) |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def preprocess(self, data): |
|
|
print(f"π§ [PREPROCESS] Type received: {type(data)}") |
|
|
|
|
|
if isinstance(data, dict): |
|
|
print("π§© [PREPROCESS] Detected dict input.") |
|
|
if "inputs" in data: |
|
|
data = data["inputs"] |
|
|
print(f"π¨ [PREPROCESS] Found 'inputs' key: {type(data)}") |
|
|
|
|
|
if isinstance(data, Image.Image): |
|
|
print("πΌοΈ [PREPROCESS] Got PIL.Image.Image directly.") |
|
|
return data.convert("RGB") |
|
|
|
|
|
if isinstance(data, (bytes, bytearray)): |
|
|
print("π§Ύ [PREPROCESS] Treating input as raw bytes.") |
|
|
return Image.open(io.BytesIO(data)).convert("RGB") |
|
|
|
|
|
if isinstance(data, str): |
|
|
print(f"π§Ύ [PREPROCESS] Treating input as base64 string, len={len(data)}") |
|
|
decoded = base64.b64decode(data) |
|
|
return Image.open(io.BytesIO(decoded)).convert("RGB") |
|
|
|
|
|
if isinstance(data, list) and len(data) > 0: |
|
|
item = data[0] |
|
|
if isinstance(item, Image.Image): |
|
|
return item.convert("RGB") |
|
|
if isinstance(item, (bytes, bytearray)): |
|
|
return Image.open(io.BytesIO(item)).convert("RGB") |
|
|
if isinstance(item, str): |
|
|
return Image.open(io.BytesIO(base64.b64decode(item))).convert("RGB") |
|
|
|
|
|
raise ValueError("Unsupported input type. Expected image, bytes, or base64 data.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def inference(self, image): |
|
|
print("π― [INFERENCE] Running ESRGAN upscaling...") |
|
|
print(f"π [INFERENCE] Input image size: {image.size}") |
|
|
|
|
|
|
|
|
img_np = np.array(image)[:, :, ::-1] |
|
|
print(f"π [INFERENCE] Converted to NumPy: shape={img_np.shape}, dtype={img_np.dtype}") |
|
|
|
|
|
output, _ = self.upsampler.enhance(img_np, outscale=4) |
|
|
print(f"β
[INFERENCE] Output NumPy shape: {output.shape}") |
|
|
|
|
|
|
|
|
output_rgb = Image.fromarray(output[:, :, ::-1]) |
|
|
print(f"β
[INFERENCE] Converted back to PIL: size={output_rgb.size}") |
|
|
return output_rgb |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def postprocess(self, output_image): |
|
|
print("π€ [POSTPROCESS] Encoding image to base64...") |
|
|
buf = io.BytesIO() |
|
|
output_image.save(buf, format="PNG") |
|
|
raw_bytes = buf.getvalue() |
|
|
print(f"π [POSTPROCESS] Output byte size: {len(raw_bytes)}") |
|
|
encoded = base64.b64encode(raw_bytes).decode("utf-8") |
|
|
print(f"β
[POSTPROCESS] Encoded base64 length: {len(encoded)}") |
|
|
buf.close() |
|
|
return {"image": encoded} |
|
|
|
|
|
|