zwww commited on
Commit
7fc2b1a
·
1 Parent(s): 97914a7

Create stable_diffusion_handler.py

Browse files
Files changed (1) hide show
  1. stable_diffusion_handler.py +93 -0
stable_diffusion_handler.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from abc import ABC
3
+
4
+ import diffusers
5
+ import torch
6
+ from diffusers import StableDiffusionPipeline
7
+
8
+ from ts.torch_handler.base_handler import BaseHandler
9
+ import numpy as np
10
+
11
+
12
+ logger = logging.getLogger(__name__)
13
+ logger.info("Diffusers version %s", diffusers.__version__)
14
+
15
+ class DiffusersHandler(BaseHandler, ABC):
16
+ """
17
+ Diffusers handler class for text to image generation.
18
+ """
19
+
20
+ def __init__(self):
21
+ self.initialized = False
22
+
23
+ def initialize(self, ctx):
24
+ """In this initialize function, the Stable Diffusion model is loaded and
25
+ initialized here.
26
+ Args:
27
+ ctx (context): It is a JSON Object containing information
28
+ pertaining to the model artefacts parameters.
29
+ """
30
+
31
+ logger.info("Loading diffusion model")
32
+
33
+ self.manifest = ctx.manifest
34
+ properties = ctx.system_properties
35
+ model_dir = properties.get("model_dir")
36
+
37
+ self.device = torch.device(
38
+ "cuda:" + str(properties.get("gpu_id"))
39
+ if torch.cuda.is_available() and properties.get("gpu_id") is not None
40
+ else "cpu"
41
+ )
42
+
43
+
44
+ self.pipe = StableDiffusionPipeline.from_pretrained("./")
45
+ self.pipe.to(self.device)
46
+ logger.info("Diffusion model from path %s loaded successfully", model_dir)
47
+
48
+ self.initialized = True
49
+
50
+ def preprocess(self, requests):
51
+ """Basic text preprocessing, of the user's prompt.
52
+ Args:
53
+ requests (str): The Input data in the form of text is passed on to the preprocess
54
+ function.
55
+ Returns:
56
+ list : The preprocess function returns a list of prompts.
57
+ """
58
+ logger.info("Received requests: '%s'", requests)
59
+
60
+ text = requests[0]["prompt"]
61
+
62
+ logger.info("pre-processed text: '%s'", text)
63
+
64
+ return [text]
65
+
66
+
67
+ def inference(self, inputs):
68
+ """Generates the image relevant to the received text.
69
+ Args:
70
+ inputs (list): List of Text from the pre-process function is passed here
71
+ Returns:
72
+ list : It returns a list of the generate images for the input text
73
+ """
74
+
75
+ # Handling inference for sequence_classification.
76
+ inferences = self.pipe(
77
+ inputs, guidance_scale=7.5, num_inference_steps=50
78
+ ).images
79
+
80
+ logger.info("Generated image: '%s'", inferences)
81
+ return inferences
82
+
83
+ def postprocess(self, inference_output):
84
+ """Post Process Function converts the generated image into Torchserve readable format.
85
+ Args:
86
+ inference_output (list): It contains the generated image of the input text.
87
+ Returns:
88
+ (list): Returns a list of the images.
89
+ """
90
+ images = []
91
+ for image in inference_output:
92
+ images.append(np.array(image).tolist())
93
+ return images