mastari commited on
Commit
b6af06b
·
1 Parent(s): abb6db1

Fix handler to include EndpointHandler class

Browse files
Files changed (1) hide show
  1. handler.py +35 -22
handler.py CHANGED
@@ -12,19 +12,21 @@ class EndpointHandler:
12
  def __init__(self, path="."):
13
  print("🔹 Initializing Real-ESRGAN x4 model...")
14
 
15
- self.model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
 
 
 
16
  self.model_path = os.path.join(path, "RealESRGAN_x4plus.pth")
17
 
18
- # Download model if not present
19
  if not os.path.exists(self.model_path):
20
- print(f"📥 Downloading RealESRGAN_x4plus.pth from {self.model_url} ...")
21
  r = requests.get(self.model_url)
22
  r.raise_for_status()
23
  with open(self.model_path, "wb") as f:
24
  f.write(r.content)
25
- print(f"✅ Download complete: {self.model_path}")
26
 
27
- # Build Real-ESRGAN model
28
  model = RRDBNet(
29
  num_in_ch=3,
30
  num_out_ch=3,
@@ -45,13 +47,9 @@ class EndpointHandler:
45
  print("✅ Real-ESRGAN model initialized and ready.")
46
 
47
  # ==========================================================
48
- # MAIN CALL METHOD
49
  # ==========================================================
50
  def __call__(self, data):
51
- """
52
- This is called automatically by Hugging Face Inference Toolkit.
53
- It receives the raw image bytes from the POST body.
54
- """
55
  try:
56
  image = self.preprocess(data)
57
  output = self.inference(image)
@@ -60,24 +58,39 @@ class EndpointHandler:
60
  return {"error": str(e)}
61
 
62
  # ==========================================================
63
- # PREPROCESS / INFERENCE / POSTPROCESS
64
  # ==========================================================
65
- def preprocess(self, request):
66
- """Convert raw bytes into a PIL image."""
67
- if isinstance(request, (bytes, bytearray)):
68
- return Image.open(io.BytesIO(request)).convert("RGB")
69
- raise ValueError("Expected raw image bytes.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
  def inference(self, image):
72
- """Run Real-ESRGAN model."""
73
  output, _ = self.upsampler.enhance(image, outscale=4)
74
  return output
75
 
76
  def postprocess(self, output_image):
77
- """Convert output image to base64-encoded PNG for JSON."""
78
- buffer = io.BytesIO()
79
- output_image.save(buffer, format="PNG")
80
- encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
81
- buffer.close()
82
  return {"image": encoded}
83
 
 
12
  def __init__(self, path="."):
13
  print("🔹 Initializing Real-ESRGAN x4 model...")
14
 
15
+ self.model_url = (
16
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/"
17
+ "RealESRGAN_x4plus.pth"
18
+ )
19
  self.model_path = os.path.join(path, "RealESRGAN_x4plus.pth")
20
 
21
+ # Download model weights if missing
22
  if not os.path.exists(self.model_path):
23
+ print(f"📥 Downloading RealESRGAN_x4plus.pth ...")
24
  r = requests.get(self.model_url)
25
  r.raise_for_status()
26
  with open(self.model_path, "wb") as f:
27
  f.write(r.content)
28
+ print(f"✅ Downloaded model to {self.model_path}")
29
 
 
30
  model = RRDBNet(
31
  num_in_ch=3,
32
  num_out_ch=3,
 
47
  print("✅ Real-ESRGAN model initialized and ready.")
48
 
49
  # ==========================================================
50
+ # Main callable
51
  # ==========================================================
52
  def __call__(self, data):
 
 
 
 
53
  try:
54
  image = self.preprocess(data)
55
  output = self.inference(image)
 
58
  return {"error": str(e)}
59
 
60
  # ==========================================================
61
+ # Steps
62
  # ==========================================================
63
+ def preprocess(self, data):
64
+ """Accept raw bytes OR dict-style payloads."""
65
+ # case 1: raw image bytes
66
+ if isinstance(data, (bytes, bytearray)):
67
+ return Image.open(io.BytesIO(data)).convert("RGB")
68
+
69
+ # case 2: dict with "inputs" key (HF default)
70
+ if isinstance(data, dict) and "inputs" in data:
71
+ img_field = data["inputs"]
72
+
73
+ # base64-encoded image
74
+ if isinstance(img_field, str):
75
+ try:
76
+ return Image.open(io.BytesIO(base64.b64decode(img_field))).convert("RGB")
77
+ except Exception:
78
+ raise ValueError("Invalid base64 image string in 'inputs'.")
79
+
80
+ # already bytes
81
+ if isinstance(img_field, (bytes, bytearray)):
82
+ return Image.open(io.BytesIO(img_field)).convert("RGB")
83
+
84
+ raise ValueError("Expected raw image bytes or {'inputs': <bytes/base64>}.")
85
 
86
  def inference(self, image):
 
87
  output, _ = self.upsampler.enhance(image, outscale=4)
88
  return output
89
 
90
  def postprocess(self, output_image):
91
+ buf = io.BytesIO()
92
+ output_image.save(buf, format="PNG")
93
+ encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
94
+ buf.close()
 
95
  return {"image": encoded}
96