mastari commited on
Commit
4ed9f59
·
1 Parent(s): 1851a72

Add verified Real-ESRGAN handler with auto-download

Browse files
Files changed (1) hide show
  1. handler.py +61 -39
handler.py CHANGED
@@ -1,56 +1,78 @@
1
- # handler.py
2
- from base64 import b64encode, b64decode
3
- from io import BytesIO
4
- from pathlib import Path
5
- from PIL import Image
6
  import numpy as np
 
 
 
7
  from realesrgan import RealESRGANer
8
  from basicsr.archs.rrdbnet_arch import RRDBNet
9
 
 
 
 
 
 
 
 
 
 
10
  class EndpointHandler:
11
- def __init__(self, model_dir: str = "", **kwargs):
12
- """
13
- Called once when the endpoint starts.
14
- Loads the Real-ESRGAN model weights.
15
- """
16
  print("🔹 Initializing Real-ESRGAN x4 model...")
17
- model_path = str(Path(model_dir) / "RealESRGAN_x4plus.pth")
18
 
19
- # Build model
20
- rrdbnet = RRDBNet(num_in_ch=3, num_out_ch=3)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  self.upsampler = RealESRGANer(
22
  scale=4,
23
- model_path=model_path,
24
- model=rrdbnet,
25
- tile=0,
 
26
  pre_pad=0,
27
- half=True,
28
  )
29
 
 
 
30
  def __call__(self, data):
31
- """
32
- Called for each request.
33
- Expects a dict with 'inputs' = base64-encoded image or bytes.
34
- Returns base64-encoded upscaled image.
35
- """
36
- image = data.get("inputs")
37
-
38
- if isinstance(image, str):
39
- image = Image.open(BytesIO(b64decode(image)))
40
- elif isinstance(image, bytes):
41
- image = Image.open(BytesIO(image))
42
- else:
43
- raise ValueError("Input must be base64 string or bytes")
44
 
45
- image = np.array(image)
46
- image = image[:, :, ::-1] # RGB→BGR
47
- output, _ = self.upsampler.enhance(image, outscale=4)
48
- output = output[:, :, ::-1] # BGR→RGB
49
- out_img = Image.fromarray(output)
50
 
51
- buf = BytesIO()
52
- out_img.save(buf, format="PNG")
53
- encoded = b64encode(buf.getvalue()).decode("utf-8")
54
 
55
- return {"image": encoded}
56
 
 
1
+ import os
2
+ import io
3
+ import torch
 
 
4
  import numpy as np
5
+ import requests
6
+ import cv2
7
+ from PIL import Image
8
  from realesrgan import RealESRGANer
9
  from basicsr.archs.rrdbnet_arch import RRDBNet
10
 
11
+ # ======================================================
12
+ # CONFIG
13
+ # ======================================================
14
+ MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
15
+ MODEL_PATH = "/repository/RealESRGAN_x4plus.pth"
16
+
17
+ # ======================================================
18
+ # HANDLER CLASS
19
+ # ======================================================
20
  class EndpointHandler:
21
+ def __init__(self, model_dir):
 
 
 
 
22
  print("🔹 Initializing Real-ESRGAN x4 model...")
 
23
 
24
+ # Ensure the model weights exist
25
+ if not os.path.exists(MODEL_PATH):
26
+ print(f"📥 Downloading RealESRGAN_x4plus.pth from {MODEL_URL} ...")
27
+ response = requests.get(MODEL_URL, stream=True)
28
+ response.raise_for_status()
29
+ with open(MODEL_PATH, "wb") as f:
30
+ for chunk in response.iter_content(chunk_size=8192):
31
+ f.write(chunk)
32
+ print("✅ Download complete:", MODEL_PATH)
33
+ else:
34
+ print("✅ Model file already exists:", MODEL_PATH)
35
+
36
+ # Define RRDBNet model
37
+ self.model = RRDBNet(
38
+ num_in_ch=3, num_out_ch=3, num_feat=64,
39
+ num_block=23, num_grow_ch=32, scale=4
40
+ )
41
+
42
+ # Initialize Real-ESRGAN upsampler
43
  self.upsampler = RealESRGANer(
44
  scale=4,
45
+ model_path=MODEL_PATH,
46
+ model=self.model,
47
+ tile=0, # Disable tile mode for simplicity
48
+ tile_pad=10,
49
  pre_pad=0,
50
+ half=False # Disable FP16 for stability
51
  )
52
 
53
+ print("✅ Real-ESRGAN model initialized and ready.")
54
+
55
  def __call__(self, data):
56
+ print("🚀 Received inference request...")
57
+
58
+ # Get input image
59
+ image_bytes = data.get("inputs") or data.get("image")
60
+ if image_bytes is None:
61
+ raise ValueError("❌ No image data found in request payload.")
62
+
63
+ # Convert bytes to RGB numpy array
64
+ if isinstance(image_bytes, list):
65
+ image_bytes = bytes(image_bytes)
66
+ img = Image.open(io.BytesIO(image_bytes)).convert("RGB")
67
+ img = np.array(img)
 
68
 
69
+ # Run Real-ESRGAN enhancement
70
+ output, _ = self.upsampler.enhance(img, outscale=4)
71
+ print("✨ Image enhancement complete.")
 
 
72
 
73
+ # Convert to PNG bytes for return
74
+ _, buffer = cv2.imencode(".png", cv2.cvtColor(output, cv2.COLOR_RGB2BGR))
75
+ print("📦 Returning processed image bytes.")
76
 
77
+ return {"image": buffer.tobytes()}
78