| | from compel import Compel, ReturnedEmbeddingsType |
| | import logging |
| | from abc import ABC |
| |
|
| | import diffusers |
| | import torch |
| | from diffusers import StableDiffusionXLPipeline |
| |
|
| | import numpy as np |
| | import threading |
| |
|
| | import base64 |
| | from io import BytesIO |
| | from PIL import Image |
| | import numpy as np |
| | import uuid |
| | from tempfile import TemporaryFile |
| | from google.cloud import storage |
| | import sys |
| | from flask import Flask, request, jsonify |
| |
|
| | logger = logging.getLogger(__name__) |
| | logger.info("Diffusers version %s", diffusers.__version__) |
| |
|
| | class DiffusersHandler(ABC): |
| | """ |
| | Diffusers handler class for text to image generation. |
| | """ |
| |
|
| | def __init__(self): |
| | self.initialized = False |
| |
|
| | def initialize(self, properties): |
| | """In this initialize function, the Stable Diffusion model is loaded and |
| | initialized here. |
| | Args: |
| | ctx (context): It is a JSON Object containing information |
| | pertaining to the model artefacts parameters. |
| | """ |
| | |
| | logger.info("Loading diffusion model") |
| | logger.info("I'm totally new and updated") |
| |
|
| |
|
| | device_str = "cuda:" + str(properties.get("gpu_id")) if torch.cuda.is_available() and properties.get("gpu_id") is not None else "cpu" |
| | |
| | print("my device is " + device_str) |
| | self.device = torch.device(device_str) |
| | self.pipe = StableDiffusionXLPipeline.from_pretrained( |
| | sys.argv[1], |
| | torch_dtype=torch.float16, |
| | use_safetensors=True, |
| | ) |
| | |
| | logger.info("moving model to device: %s", device_str) |
| | self.pipe.to(self.device) |
| | |
| | logger.info(self.device) |
| | logger.info("Diffusion model from path %s loaded successfully") |
| |
|
| | self.initialized = True |
| |
|
| | def preprocess(self, raw_requests): |
| | """Basic text preprocessing, of the user's prompt. |
| | Args: |
| | requests (str): The Input data in the form of text is passed on to the preprocess |
| | function. |
| | Returns: |
| | list : The preprocess function returns a list of prompts. |
| | """ |
| | logger.info("Received requests: '%s'", raw_requests) |
| | self.working = True |
| | |
| | processed_request = { |
| | "prompt": raw_requests[0]["prompt"], |
| | "negative_prompt": raw_requests[0].get("negative_prompt"), |
| | "width": raw_requests[0].get("width"), |
| | "height": raw_requests[0].get("height"), |
| | "num_inference_steps": raw_requests[0].get("num_inference_steps", 30), |
| | "guidance_scale": raw_requests[0].get("guidance_scale", 7.5), |
| | } |
| | |
| | logger.info("Processed request: '%s'", processed_request) |
| | return processed_request |
| | |
| |
|
| | def inference(self, request): |
| | """Generates the image relevant to the received text. |
| | Args: |
| | inputs (list): List of Text from the pre-process function is passed here |
| | Returns: |
| | list : It returns a list of the generate images for the input text |
| | """ |
| |
|
| | |
| | compel = Compel(tokenizer=[self.pipe.tokenizer, self.pipe.tokenizer_2] , text_encoder=[self.pipe.text_encoder, self.pipe.text_encoder_2], returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED, requires_pooled=[False, True]) |
| | |
| | self.prompt = request.pop("prompt") |
| | conditioning, pooled = compel(self.prompt) |
| |
|
| | |
| | inferences = self.pipe( |
| | prompt_embeds=conditioning, |
| | pooled_prompt_embeds=pooled, |
| | **request |
| | ).images |
| | |
| | logger.info("Generated image: '%s'", inferences) |
| | return inferences |
| |
|
| | def postprocess(self, inference_outputs): |
| | """Post Process Function converts the generated image into Torchserve readable format. |
| | Args: |
| | inference_outputs (list): It contains the generated image of the input text. |
| | Returns: |
| | (list): Returns a list of the images. |
| | """ |
| | bucket_name = "outputs-storage-prod" |
| | client = storage.Client() |
| | self.working = False |
| | bucket = client.get_bucket(bucket_name) |
| | outputs = [] |
| | for image in inference_outputs: |
| | image_name = str(uuid.uuid4()) |
| |
|
| | blob = bucket.blob(image_name + '.png') |
| |
|
| | with TemporaryFile() as tmp: |
| | image.save(tmp, format="png") |
| | tmp.seek(0) |
| | blob.upload_from_file(tmp, content_type='image/png') |
| |
|
| | |
| | |
| | |
| |
|
| | outputs.append('https://storage.googleapis.com/' + bucket_name + '/' + image_name + '.png') |
| | return outputs |
| |
|
| |
|
| | app = Flask(__name__) |
| |
|
| | |
| | gpu_count = torch.cuda.device_count() |
| | if gpu_count == 0: |
| | raise ValueError("No GPUs available!") |
| |
|
| | handlers = [DiffusersHandler() for i in range(gpu_count)] |
| | for i in range(gpu_count): |
| | handlers[i].initialize({"gpu_id": i}) |
| |
|
| | handler_lock = threading.Lock() |
| | handler_index = 0 |
| |
|
| | @app.route('/generate', methods=['POST']) |
| | def generate_image(): |
| | global handler_index |
| | try: |
| | |
| | raw_requests = request.json |
| |
|
| | with handler_lock: |
| | selected_handler = handlers[handler_index] |
| | handler_index = (handler_index + 1) % gpu_count |
| |
|
| | processed_request = selected_handler.preprocess([raw_requests]) |
| | inferences = selected_handler.inference(processed_request) |
| | outputs = selected_handler.postprocess(inferences) |
| |
|
| | return jsonify({"image_urls": outputs}) |
| | except Exception as e: |
| | logger.error("Error during image generation: %s", str(e)) |
| | return jsonify({"error": "Failed to generate image", "details": str(e)}), 500 |
| |
|
| | if __name__ == '__main__': |
| | app.run(host='0.0.0.0', port=3000, threaded=True) |