Alex Mikulaniec commited on
Commit
9540255
·
1 Parent(s): cf991ea

Added new handler for the model

Browse files
Files changed (1) hide show
  1. handler.py +54 -22
handler.py CHANGED
@@ -1,39 +1,71 @@
1
- from typing import Dict, Any
2
- import torch
3
- from torch.cuda.amp import autocast
4
  from diffusers import StableDiffusionPipeline
5
  import base64
 
6
  from io import BytesIO
 
 
 
 
7
 
8
  # Setting the device
9
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
10
 
11
  class EndpointHandler():
12
 
13
- def __init__(self, path=""):
14
- # Load the model
15
  self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
 
16
  self.pipe = self.pipe.to(device)
17
 
18
- def __call__(self, data: Dict[str, Any]) -> Dict[str, str]:
19
- """
20
- Args:
21
- data (dict): Includes the input data for inference.
22
-
23
- Return:
24
- dict: Base64 encoded image.
25
- """
26
- inputs = data.get("inputs") # Getting the inputs from the data dictionary
27
-
28
- # Run inference pipeline
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  with autocast():
30
- output = self.pipe(inputs, guidance_scale=7.5)
31
- image = output['images'][0] # Accessing the 'images' key in the output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
 
33
- # Encoding image as base 64
34
  buffered = BytesIO()
35
- image.save(buffered, format="PNG")
36
  img_str = base64.b64encode(buffered.getvalue())
37
 
38
- # Returning the base64 image as a dictionary
39
- return {"image": img_str.decode()}
 
1
+ from basicsr.archs.rrdbnet_arch import RRDBNet
2
+ from realesrgan import RealESRGANer
 
3
  from diffusers import StableDiffusionPipeline
4
  import base64
5
+ from PIL import Image
6
  from io import BytesIO
7
+ import torch
8
+ from torch.cuda.amp import autocast
9
+ from typing import Dict, Any
10
+ import numpy as np
11
 
12
  # Setting the device
13
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
14
 
15
  class EndpointHandler():
16
 
17
+ def __init__(self, path="", esrgan_model_path=""):
18
+ # Load the StableDiffusionPipeline
19
  self.pipe = StableDiffusionPipeline.from_pretrained(path, torch_dtype=torch.float32)
20
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
21
  self.pipe = self.pipe.to(device)
22
 
23
+ # Load the ESRGAN state dictionary
24
+ checkpoint = torch.load(esrgan_model_path)
25
+
26
+ # Check if 'params_ema' is in the keys and filter the state_dict
27
+ if "params_ema" in checkpoint:
28
+ state_dict = checkpoint["params_ema"]
29
+ else:
30
+ state_dict = checkpoint
31
+
32
+ # Define the ESRGAN model architecture
33
+ self.model = RRDBNet(num_in_ch=3, num_out_ch=3, num_feat=64, num_block=6, num_grow_ch=32, scale=4)
34
+ self.model.load_state_dict(state_dict)
35
+ self.model.to(device)
36
+ self.model.eval()
37
+
38
+ # Create a RealESRGANer object for inference
39
+ self.upsampler = RealESRGANer(scale=4, model=self.model, tile=0, model_path=esrgan_model_path)
40
+
41
+
42
+ def __call__(self, data: Dict[str, Any], output_size=(512, )) -> Dict[str, str]:
43
+ inputs = data.get("inputs")
44
+ negative_prompt = data.get("negative_prompt", None)
45
+
46
+ # Run StableDiffusionPipeline
47
  with autocast():
48
+ output = self.pipe(inputs, guidance_scale=7.5, negative_prompt=negative_prompt)
49
+ image = output['images'][0]
50
+
51
+ # Normalize the image to [0, 1] range if it's not
52
+ image = np.clip(image, 0, 255) / 255.0
53
+
54
+ # Convert the StableDiffusionPipeline output to suitable format for ESRGAN
55
+ tensor_image = torch.from_numpy(np.array(image)).float().permute(2, 0, 1).unsqueeze(0).to(device)
56
+
57
+ # Process the image with ESRGAN
58
+ with torch.no_grad():
59
+ esrgan_output = self.model(tensor_image)
60
+
61
+ # Post-process the ESRGAN output to make it a PIL image
62
+ esrgan_output = esrgan_output.squeeze().permute(1, 2, 0).cpu().numpy()
63
+ esrgan_output = np.clip(esrgan_output, 0, 1) # Ensure the values are within [0, 1]
64
+ esrgan_image = Image.fromarray((esrgan_output * 255).astype('uint8'))
65
 
66
+ # Encoding ESRGAN image as base64
67
  buffered = BytesIO()
68
+ esrgan_image.save(buffered, format="PNG")
69
  img_str = base64.b64encode(buffered.getvalue())
70
 
71
+ return {"image": img_str.decode()}