File size: 2,803 Bytes
61f066f
 
cf991ea
 
61f066f
cf991ea
61f066f
 
 
 
cf991ea
 
 
 
 
 
6dfcee3
61f066f
cf991ea
61f066f
cf991ea
 
61f066f
6b539a6
61f066f
 
 
 
 
 
9540255
61f066f
 
 
 
 
9540255
61f066f
f041201
61f066f
 
 
 
 
 
 
cf991ea
61f066f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf991ea
61f066f
cf991ea
61f066f
cf991ea
 
61f066f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
from basicsr.archs.rrdbnet_arch import RRDBNet
from realesrgan import RealESRGANer
from diffusers import StableDiffusionPipeline
import base64
from PIL import Image
from io import BytesIO
import torch
from torch.cuda.amp import autocast
from typing import Dict, Any
import numpy as np

# Setting the device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class EndpointHandler():

    def __init__(self, path=""):
        # Load the StableDiffusionPipeline
        self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
        device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        self.pipe = self.pipe.to(device)

        # Load the ESRGAN state dictionary
        checkpoint = torch.load("/repository/RealESRGAN_x4plus_anime_6B.pth")

        # Check if 'params_ema' is in the keys and filter the state_dict
        if "params_ema" in checkpoint:
            state_dict = checkpoint["params_ema"]
        else:
            state_dict = checkpoint

        # Define the ESRGAN model architecture
        self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
        self.model.load_state_dict(state_dict)
        self.model.to(device)
        self.model.eval()

        # Create a RealESRGANer object for inference
        self.upsampler = RealESRGANer(scale=4, model=self.model, tile=0, model_path="/repository/RealESRGAN_x4plus_anime_6B.pth")


    def __call__(self, data: Dict[str, Any], output_size=(512, )) -> Dict[str, str]:
        inputs = data.get("inputs")
        negative_prompt = data.get("negative_prompt", None)

        # Run StableDiffusionPipeline
        with autocast():
            output = self.pipe(inputs, guidance_scale=7.5, negative_prompt=negative_prompt)
            image = output['images'][0]

        # Normalize the image to [0, 1] range if it's not
        image = np.clip(image, 0, 255) / 255.0

        # Convert the StableDiffusionPipeline output to suitable format for ESRGAN
        tensor_image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(device)

        # Process the image with ESRGAN
        with torch.no_grad():
            esrgan_output = self.model(tensor_image)

        # Post-process the ESRGAN output to make it a PIL image
        esrgan_output = esrgan_output.squeeze().permute(1, 2, 0).cpu().numpy()
        esrgan_output = np.clip(esrgan_output, 0, 1)  # Ensure the values are within [0, 1]
        esrgan_image = Image.fromarray((esrgan_output * 255).astype('uint8'))

        # Encoding ESRGAN image as base64
        buffered = BytesIO()
        esrgan_image.save(buffered, format="PNG")
        img_str = base64.b64encode(buffered.getvalue())

        return {"image": img_str.decode()}