File size: 5,242 Bytes
34e750a
7df3ea4
34e750a
7df3ea4
4984743
2389e63
34e750a
 
2389e63
7df3ea4
2389e63
eb396d8
34e750a
2389e63
7df3ea4
34e750a
7df3ea4
34e750a
3e3c736
34e750a
 
 
7df3ea4
c8867e7
3e3c736
 
 
 
 
 
7df3ea4
c8867e7
34e750a
 
7df3ea4
34e750a
 
eb396d8
c8867e7
34e750a
 
c8867e7
34e750a
 
4984743
c8867e7
4984743
 
34e750a
 
4984743
34e750a
4984743
 
 
34e750a
 
c8867e7
34e750a
3e3c736
34e750a
3e3c736
 
34e750a
 
 
 
 
 
 
 
 
 
7df3ea4
 
34e750a
 
 
 
 
 
 
 
 
 
 
4984743
c8867e7
34e750a
 
 
 
4984743
 
 
 
 
 
 
 
 
c8867e7
4984743
 
 
 
 
 
 
 
 
34e750a
4984743
c8867e7
4984743
 
 
 
34e750a
4984743
 
c8867e7
34e750a
3e3c736
34e750a
7df3ea4
c8867e7
4984743
c8867e7
b51e8ba
c8867e7
 
 
 
4984743
7df3ea4
c8867e7
 
34e750a
4984743
 
c8867e7
34e750a
2389e63
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
import os
import io
import torch
import logging
import base64
import requests
import numpy as np
import cv2
from PIL import Image

from gfpgan import GFPGANer
from realesrgan import RealESRGANer
from realesrgan.archs.srvgg_arch import SRVGGNetCompact

logger = logging.getLogger(__name__)

class EndpointHandler:
    def __init__(self, path="."):
        logger.info("πŸš€ [INIT] GFPGAN + Real-ESRGAN handler starting...")
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.half = self.device == "cuda"
        self.path = path

        # Model URLs (GFPGAN + RealESRGAN)
        self.gfpgan_model_url = (
            "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"
        )
        self.realesr_model_url = (
            "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.5.0/realesr-general-x4v3.pth"
        )

        # Local cache paths
        self.gfpgan_model_path = os.path.join(path, "GFPGANv1.4.pth")
        self.realesr_model_path = os.path.join(path, "realesr-general-x4v3.pth")

        self.bg_upsampler = None
        self.restorer = None

        # Ensure model weights exist
        self._ensure_model(self.gfpgan_model_url, self.gfpgan_model_path)
        self._ensure_model(self.realesr_model_url, self.realesr_model_path)

        logger.info(f"🧠 Device: {self.device}, half precision: {self.half}")

    def _ensure_model(self, url, path):
        """Download model if missing."""
        if not os.path.exists(path):
            logger.info(f"⬇️ Downloading model from {url}")
            r = requests.get(url, timeout=60)
            r.raise_for_status()
            with open(path, "wb") as f:
                f.write(r.content)
            logger.info(f"βœ… Model saved to {path}")
        else:
            logger.info(f"πŸ“ Found cached model: {path}")

    def _init_models(self):
        """Lazy-load ESRGAN + GFPGAN models."""
        if self.bg_upsampler is None:
            logger.info("🧩 Initializing Real-ESRGAN upsampler...")
            model = SRVGGNetCompact(
                num_in_ch=3, num_out_ch=3, num_feat=64,
                num_conv=32, upscale=4, act_type="prelu"
            )
            self.bg_upsampler = RealESRGANer(
                scale=4,
                model_path=self.realesr_model_path,
                model=model,
                tile=400,
                tile_pad=10,
                pre_pad=0,
                half=self.half,
                device=self.device,
            )

        if self.restorer is None:
            logger.info("🧬 Initializing GFPGAN restorer...")
            self.restorer = GFPGANer(
                model_path=self.gfpgan_model_path,
                upscale=2,
                arch="clean",
                channel_multiplier=2,
                bg_upsampler=self.bg_upsampler,
            )
            logger.info("βœ… Models ready!")

    def _load_image(self, data):
        """Accept base64, raw bytes, or URL and return PIL image."""
        if isinstance(data, dict) and "inputs" in data:
            data = data["inputs"]

        if isinstance(data, (bytes, bytearray)):
            logger.info("πŸ“¦ Received raw bytes input")
            return Image.open(io.BytesIO(data)).convert("RGB")

        if isinstance(data, str):
            if data.startswith("http"):
                logger.info(f"🌐 Downloading image from URL: {data}")
                resp = requests.get(data)
                return Image.open(io.BytesIO(resp.content)).convert("RGB")
            else:
                # Base64
                logger.info("🧬 Decoding base64 image input")
                try:
                    decoded = base64.b64decode(data)
                    return Image.open(io.BytesIO(decoded)).convert("RGB")
                except Exception as e:
                    logger.error(f"❌ Failed to decode base64: {e}")
                    raise ValueError("Invalid base64 image input")

        raise ValueError("Unsupported input type")

    def __call__(self, data):
        logger.info("βš™οΈ Starting GFPGAN inference pipeline...")
        self._init_models()

        # Load input
        image = self._load_image(data)
        input_img = np.array(image, dtype=np.uint8)
        logger.info(f"πŸ“ Input image shape: {input_img.shape}")

        # Restore face(s)
        cropped_faces, restored_faces, restored_img = self.restorer.enhance(
            input_img, has_aligned=False, only_center_face=False, paste_back=True
        )

        logger.info("πŸ–ΌοΈ Restoration complete, preparing output...")

        # βœ… Convert color from BGR β†’ RGB (fix hue issue)
        restored_img_rgb = cv2.cvtColor(restored_img, cv2.COLOR_BGR2RGB)
        restored_img_rgb = np.clip(restored_img_rgb, 0, 255).astype(np.uint8)

        # βœ… Encode output as base64 string for JSON
        _, buffer = cv2.imencode(".jpg", restored_img_rgb)
        b64_output = base64.b64encode(buffer).decode("utf-8")

        logger.info("βœ… Returning base64-encoded image JSON response")

        return {
            "image": b64_output,
            "status": "success",
            "info": "Restored with GFPGAN v1.4 + Real-ESRGAN x4v3 (RGB fixed)"
        }