Alexmikupro commited on
Commit
84ae75b
·
1 Parent(s): 7c5fa1a

return to the previous handler

Browse files
Files changed (1) hide show
  1. handler.py +19 -51
handler.py CHANGED
@@ -1,71 +1,39 @@
1
- from basicsr.archs.rrdbnet_arch import RRDBNet
2
- from realesrgan import RealESRGANer
 
3
  from diffusers import StableDiffusionPipeline
4
  import base64
5
- from PIL import Image
6
  from io import BytesIO
7
- import torch
8
- from torch.cuda.amp import autocast
9
- from typing import Dict, Any
10
- import numpy as np
11
 
12
  # Setting the device
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
  class EndpointHandler():
16
 
17
- def __init__(self, path="", esrgan_model_path=""):
18
- # Load the StableDiffusionPipeline
19
  self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
20
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  self.pipe = self.pipe.to(device)
22
 
23
- # Load the ESRGAN state dictionary
24
- checkpoint = torch.load(esrgan_model_path)
25
-
26
- # Check if 'params_ema' is in the keys and filter the state_dict
27
- if "params_ema" in checkpoint:
28
- state_dict = checkpoint["params_ema"]
29
- else:
30
- state_dict = checkpoint
31
 
32
- # Define the ESRGAN model architecture
33
- self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
34
- self.model.load_state_dict(state_dict)
35
- self.model.to(device)
36
- self.model.eval()
37
 
38
- # Create a RealESRGANer object for inference
39
- self.upsampler = RealESRGANer(scale=4, model=self.model, tile=0, model_path=esrgan_model_path)
40
-
41
-
42
- def __call__(self, data: Dict[str, Any], output_size=(512, )) -> Dict[str, str]:
43
- inputs = data.get("inputs")
44
- negative_prompt = data.get("negative_prompt", None)
45
-
46
- # Run StableDiffusionPipeline
47
  with autocast():
48
- output = self.pipe(inputs, guidance_scale=7.5, negative_prompt=negative_prompt)
49
- image = output['images'][0]
50
-
51
- # Normalize the image to [0, 1] range if it's not
52
- image = np.clip(image, 0, 255) / 255.0
53
-
54
- # Convert the StableDiffusionPipeline output to suitable format for ESRGAN
55
- tensor_image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(device)
56
-
57
- # Process the image with ESRGAN
58
- with torch.no_grad():
59
- esrgan_output = self.model(tensor_image)
60
-
61
- # Post-process the ESRGAN output to make it a PIL image
62
- esrgan_output = esrgan_output.squeeze().permute(1, 2, 0).cpu().numpy()
63
- esrgan_output = np.clip(esrgan_output, 0, 1) # Ensure the values are within [0, 1]
64
- esrgan_image = Image.fromarray((esrgan_output * 255).astype('uint8'))
65
 
66
- # Encoding ESRGAN image as base64
67
  buffered = BytesIO()
68
- esrgan_image.save(buffered, format="PNG")
69
  img_str = base64.b64encode(buffered.getvalue())
70
 
 
71
  return {"image": img_str.decode()}
 
1
+ from typing import Dict, Any
2
+ import torch
3
+ from torch.cuda.amp import autocast
4
  from diffusers import StableDiffusionPipeline
5
  import base64
 
6
  from io import BytesIO
 
 
 
 
7
 
8
  # Setting the device
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
  class EndpointHandler():
12
 
13
+ def __init__(self, path=""):
14
+ # Load the model
15
  self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
 
16
  self.pipe = self.pipe.to(device)
17
 
18
+ def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
19
+ """
20
+ Args:
21
+ data (dict): Includes the input data for inference.
 
 
 
 
22
 
23
+ Return:
24
+ dict: Base64 encoded image.
25
+ """
26
+ inputs = data.get("inputs") # Getting the inputs from the data dictionary
 
27
 
28
+ # Run inference pipeline
 
 
 
 
 
 
 
 
29
  with autocast():
30
+ output = self.pipe(inputs, guidance_scale=7.5)
31
+ image = output['images'][0] # Accessing the 'images' key in the output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
+ # Encoding image as base 64
34
  buffered = BytesIO()
35
+ image.save(buffered, format="PNG")
36
  img_str = base64.b64encode(buffered.getvalue())
37
 
38
+ # Returning the base64 image as a dictionary
39
  return {"image": img_str.decode()}