from io import BytesIO import torch from data.dataAccessor import update_db from data.task import ModelType, Task, TaskType from pipelines.inpainter import InPainter from pipelines.prompt_modifier import PromptModifier from pipelines.remove_background import RemoveBackground from pipelines.upscaler import Upscaler from util.cache import clear_cuda from util.commons import ( add_code_names, construct_default_s3_url, upload_image, upload_images, ) from util.slack import Slack torch.backends.cudnn.benchmark = True torch.backends.cuda.matmul.allow_tf32 = True num_return_sequences = 4 auto_mode = False slack = Slack() prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences) upscaler = Upscaler() inpainter = InPainter() @update_db @slack.auto_send_alert def remove_bg(task: Task): remove_background = RemoveBackground() output_image = remove_background.remove(task.get_imageUrl()) output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId()) upload_image(output_image, output_key) return {"generated_image_url": construct_default_s3_url(output_key)} @update_db @slack.auto_send_alert def inpaint(task: Task): prompt = add_code_names(task.get_prompt()) if task.is_prompt_engineering(): prompt = prompt_modifier.modify(prompt) else: prompt = [prompt] * num_return_sequences print({"prompts": prompt}) images = inpainter.process( prompt=prompt, image_url=task.get_imageUrl(), mask_image_url=task.get_maskImageUrl(), width=task.get_width(), height=task.get_height(), seed=task.get_seed(), negative_prompt=[task.get_negative_prompt()] * num_return_sequences, ) generated_image_urls = upload_images(images, "_inpaint", task.get_taskId()) clear_cuda() return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls} @update_db @slack.auto_send_alert def upscale_image(task: Task): output_key = "crecoAI/{}_upscale.png".format(task.get_taskId()) out_img = None if task.get_modelType() == ModelType.ANIME: print("Using Anime model") out_img = upscaler.upscale_anime(task.get_imageUrl()) else: print("Using Real model") out_img = upscaler.upscale(task.get_imageUrl()) upload_image(BytesIO(out_img), output_key) return {"generated_image_url": construct_default_s3_url(output_key)} def model_fn(model_dir): print("Logs: model loaded .... starts") prompt_modifier.load() upscaler.load() inpainter.load() print("Logs: model loaded ....") return def predict_fn(data, pipe): task = Task(data) print("task is ", data) try: task_type = task.get_type() if task_type == TaskType.REMOVE_BG: return remove_bg(task) elif task_type == TaskType.INPAINT: return inpaint(task) elif task_type == TaskType.UPSCALE_IMAGE: return upscale_image(task) else: raise Exception("Invalid task type") except Exception as e: print(f"Error: {e}") slack.error_alert(task, e) return None