mastari commited on
Commit
abb6db1
·
1 Parent(s): 283bed4

Fix handler to include EndpointHandler class

Browse files
Files changed (1) hide show
  1. handler.py +70 -90
handler.py CHANGED
@@ -8,96 +8,76 @@ from realesrgan import RealESRGANer
8
  from basicsr.archs.rrdbnet_arch import RRDBNet
9
 
10
 
11
- # ==========================================================
12
- # MODEL INITIALIZATION
13
- # ==========================================================
14
- print("🔹 Initializing Real-ESRGAN x4 model...")
15
-
16
- MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
17
- MODEL_PATH = "/repository/RealESRGAN_x4plus.pth"
18
-
19
- # Download model if missing
20
- if not os.path.exists(MODEL_PATH):
21
- print(f"📥 Downloading RealESRGAN_x4plus.pth from {MODEL_URL} ...")
22
- r = requests.get(MODEL_URL)
23
- r.raise_for_status()
24
- with open(MODEL_PATH, "wb") as f:
25
- f.write(r.content)
26
- print(f"✅ Download complete: {MODEL_PATH}")
27
-
28
- # Create model architecture
29
- model = RRDBNet(
30
- num_in_ch=3,
31
- num_out_ch=3,
32
- num_feat=64,
33
- num_block=23,
34
- num_grow_ch=32,
35
- scale=4
36
- )
37
-
38
- # Load Real-ESRGAN model
39
- upsampler = RealESRGANer(
40
- scale=4,
41
- model_path=MODEL_PATH,
42
- model=model,
43
- half=False,
44
- device="cuda" if torch.cuda.is_available() else "cpu"
45
- )
46
-
47
- print("✅ Real-ESRGAN model initialized and ready.")
48
-
49
-
50
- # ==========================================================
51
- # HANDLER FUNCTIONS
52
- # ==========================================================
53
- def preprocess(request):
54
- """
55
- Converts request body (bytes) into a PIL image.
56
- """
57
- if isinstance(request, (bytes, bytearray)):
58
- image = Image.open(io.BytesIO(request)).convert("RGB")
59
- return image
60
- else:
61
- raise ValueError("Request body must be raw image bytes.")
62
-
63
-
64
- def inference(image: Image.Image):
65
- """
66
- Runs Real-ESRGAN on the input PIL image.
67
- """
68
- try:
69
- output, _ = upsampler.enhance(image, outscale=4)
 
 
 
 
70
  return output
71
- except Exception as e:
72
- raise RuntimeError(f"Inference failed: {e}")
73
-
74
-
75
- def postprocess(output_image: Image.Image):
76
- """
77
- Converts a PIL Image to base64-encoded PNG string for JSON output.
78
- """
79
- buffer = io.BytesIO()
80
- output_image.save(buffer, format="PNG")
81
- image_bytes = buffer.getvalue()
82
- buffer.close()
83
-
84
- encoded = base64.b64encode(image_bytes).decode("utf-8")
85
- return {"image": encoded}
86
-
87
 
88
- # ==========================================================
89
- # MAIN ENTRY POINT
90
- # ==========================================================
91
- def __call__(self, data):
92
- """
93
- This method is called automatically by Hugging Face Inference Toolkit.
94
- `data` is the raw HTTP request body (image bytes).
95
- """
96
- try:
97
- image = preprocess(data)
98
- output = inference(image)
99
- result = postprocess(output)
100
- return result
101
- except Exception as e:
102
- return {"error": str(e)}
103
 
 
8
  from basicsr.archs.rrdbnet_arch import RRDBNet
9
 
10
 
11
+ class EndpointHandler:
12
+ def __init__(self, path="."):
13
+ print("🔹 Initializing Real-ESRGAN x4 model...")
14
+
15
+ self.model_url = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth"
16
+ self.model_path = os.path.join(path, "RealESRGAN_x4plus.pth")
17
+
18
+ # Download model if not present
19
+ if not os.path.exists(self.model_path):
20
+ print(f"📥 Downloading RealESRGAN_x4plus.pth from {self.model_url} ...")
21
+ r = requests.get(self.model_url)
22
+ r.raise_for_status()
23
+ with open(self.model_path, "wb") as f:
24
+ f.write(r.content)
25
+ print(f"✅ Download complete: {self.model_path}")
26
+
27
+ # Build Real-ESRGAN model
28
+ model = RRDBNet(
29
+ num_in_ch=3,
30
+ num_out_ch=3,
31
+ num_feat=64,
32
+ num_block=23,
33
+ num_grow_ch=32,
34
+ scale=4,
35
+ )
36
+
37
+ self.upsampler = RealESRGANer(
38
+ scale=4,
39
+ model_path=self.model_path,
40
+ model=model,
41
+ half=False,
42
+ device="cuda" if torch.cuda.is_available() else "cpu",
43
+ )
44
+
45
+ print("✅ Real-ESRGAN model initialized and ready.")
46
+
47
+ # ==========================================================
48
+ # MAIN CALL METHOD
49
+ # ==========================================================
50
+ def __call__(self, data):
51
+ """
52
+ This is called automatically by Hugging Face Inference Toolkit.
53
+ It receives the raw image bytes from the POST body.
54
+ """
55
+ try:
56
+ image = self.preprocess(data)
57
+ output = self.inference(image)
58
+ return self.postprocess(output)
59
+ except Exception as e:
60
+ return {"error": str(e)}
61
+
62
+ # ==========================================================
63
+ # PREPROCESS / INFERENCE / POSTPROCESS
64
+ # ==========================================================
65
+ def preprocess(self, request):
66
+ """Convert raw bytes into a PIL image."""
67
+ if isinstance(request, (bytes, bytearray)):
68
+ return Image.open(io.BytesIO(request)).convert("RGB")
69
+ raise ValueError("Expected raw image bytes.")
70
+
71
+ def inference(self, image):
72
+ """Run Real-ESRGAN model."""
73
+ output, _ = self.upsampler.enhance(image, outscale=4)
74
  return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
+ def postprocess(self, output_image):
77
+ """Convert output image to base64-encoded PNG for JSON."""
78
+ buffer = io.BytesIO()
79
+ output_image.save(buffer, format="PNG")
80
+ encoded = base64.b64encode(buffer.getvalue()).decode("utf-8")
81
+ buffer.close()
82
+ return {"image": encoded}
 
 
 
 
 
 
 
 
83