mastari commited on
Commit
2389e63
·
0 Parent(s):

Initial GFPGAN custom handler

Browse files
Files changed (2) hide show
  1. handler.py +115 -0
  2. requirements.txt +11 -0
handler.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import torch
4
+ import base64
5
+ import requests
6
+ import numpy as np
7
+ from PIL import Image
8
+ from gfpgan import GFPGANer
9
+ from realesrgan import RealESRGANer
10
+ from basicsr.archs.rrdbnet_arch import RRDBNet
11
+
12
+
13
+ class EndpointHandler:
14
+ def __init__(self, path="."):
15
+ print("🚀 Initializing GFPGANv1 Face Restoration Pipeline...")
16
+
17
+ # ------------------------------------------------------------
18
+ # Load Real-ESRGAN (background upscaler)
19
+ # ------------------------------------------------------------
20
+ self.esrgan_url = (
21
+ "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/"
22
+ "RealESRGAN_x4plus.pth"
23
+ )
24
+ self.esrgan_path = os.path.join(path, "RealESRGAN_x4plus.pth")
25
+
26
+ if not os.path.exists(self.esrgan_path):
27
+ print("📥 Downloading RealESRGAN_x4plus.pth...")
28
+ r = requests.get(self.esrgan_url)
29
+ r.raise_for_status()
30
+ with open(self.esrgan_path, "wb") as f:
31
+ f.write(r.content)
32
+ print("✅ Downloaded Real-ESRGAN model.")
33
+
34
+ model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=23, num_grow_ch=32, scale=4)
35
+ device = "cuda" if torch.cuda.is_available() else "cpu"
36
+
37
+ self.bg_upsampler = RealESRGANer(
38
+ scale=4,
39
+ model_path=self.esrgan_path,
40
+ model=model,
41
+ half=False,
42
+ device=device,
43
+ )
44
+
45
+ # ------------------------------------------------------------
46
+ # Load GFPGAN model
47
+ # ------------------------------------------------------------
48
+ self.gfpgan_url = (
49
+ "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
50
+ )
51
+ self.gfpgan_path = os.path.join(path, "GFPGANv1.4.pth")
52
+
53
+ if not os.path.exists(self.gfpgan_path):
54
+ print("📥 Downloading GFPGANv1.4.pth...")
55
+ r = requests.get(self.gfpgan_url)
56
+ r.raise_for_status()
57
+ with open(self.gfpgan_path, "wb") as f:
58
+ f.write(r.content)
59
+ print("✅ Downloaded GFPGANv1.4.pth.")
60
+
61
+ self.face_enhancer = GFPGANer(
62
+ model_path=self.gfpgan_path,
63
+ upscale=4,
64
+ arch="clean",
65
+ channel_multiplier=2,
66
+ bg_upsampler=self.bg_upsampler,
67
+ )
68
+
69
+ print("✅ GFPGANv1 Face Restoration Pipeline Ready!")
70
+
71
+ # ------------------------------------------------------------------
72
+ # Main callable
73
+ # ------------------------------------------------------------------
74
+ def __call__(self, data):
75
+ try:
76
+ image = self._load_image(data)
77
+ restored = self._restore_face(image)
78
+ return self._encode_image(restored)
79
+ except Exception as e:
80
+ print("💥 Error:", str(e))
81
+ return {"error": str(e)}
82
+
83
+ # ------------------------------------------------------------------
84
+ # Helper functions
85
+ # ------------------------------------------------------------------
86
+ def _load_image(self, data):
87
+ # Handles raw bytes, base64 JSON, or dict
88
+ if isinstance(data, (bytes, bytearray)):
89
+ return Image.open(io.BytesIO(data)).convert("RGB")
90
+
91
+ if isinstance(data, dict):
92
+ field = data.get("inputs") or data.get("image")
93
+ if isinstance(field, str):
94
+ field = base64.b64decode(field)
95
+ return Image.open(io.BytesIO(field)).convert("RGB")
96
+
97
+ if isinstance(data, str):
98
+ decoded = base64.b64decode(data)
99
+ return Image.open(io.BytesIO(decoded)).convert("RGB")
100
+
101
+ raise ValueError("Expected image bytes or base64 string.")
102
+
103
+ def _restore_face(self, image):
104
+ img_np = np.array(image)[:, :, ::-1] # RGB -> BGR
105
+ cropped_faces, restored_faces, restored_img = self.face_enhancer.enhance(
106
+ img_np, has_aligned=False, only_center_face=False, paste_back=True
107
+ )
108
+ return Image.fromarray(restored_img[:, :, ::-1]) # back to RGB
109
+
110
+ def _encode_image(self, pil_img):
111
+ buf = io.BytesIO()
112
+ pil_img.save(buf, format="PNG")
113
+ encoded = base64.b64encode(buf.getvalue()).decode("utf-8")
114
+ return {"image": encoded}
115
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.1.2
2
+ torchvision==0.16.2
3
+ gfpgan==1.3.8
4
+ realesrgan==0.3.0
5
+ basicsr==1.4.2
6
+ facexlib==0.3.0
7
+ numpy==1.26.4
8
+ Pillow>=10.0.0
9
+ opencv-python
10
+ requests
11
+