File size: 3,160 Bytes
4adca93 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 |
from io import BytesIO
import torch
from data.dataAccessor import update_db
from data.task import ModelType, Task, TaskType
from pipelines.inpainter import InPainter
from pipelines.prompt_modifier import PromptModifier
from pipelines.remove_background import RemoveBackground
from pipelines.upscaler import Upscaler
from util.cache import clear_cuda
from util.commons import (
add_code_names,
construct_default_s3_url,
upload_image,
upload_images,
)
from util.slack import Slack
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
num_return_sequences = 4
auto_mode = False
slack = Slack()
prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
upscaler = Upscaler()
inpainter = InPainter()
@update_db
@slack.auto_send_alert
def remove_bg(task: Task):
remove_background = RemoveBackground()
output_image = remove_background.remove(task.get_imageUrl())
output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
upload_image(output_image, output_key)
return {"generated_image_url": construct_default_s3_url(output_key)}
@update_db
@slack.auto_send_alert
def inpaint(task: Task):
prompt = add_code_names(task.get_prompt())
if task.is_prompt_engineering():
prompt = prompt_modifier.modify(prompt)
else:
prompt = [prompt] * num_return_sequences
print({"prompts": prompt})
images = inpainter.process(
prompt=prompt,
image_url=task.get_imageUrl(),
mask_image_url=task.get_maskImageUrl(),
width=task.get_width(),
height=task.get_height(),
seed=task.get_seed(),
negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
)
generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())
clear_cuda()
return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}
@update_db
@slack.auto_send_alert
def upscale_image(task: Task):
output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
out_img = None
if task.get_modelType() == ModelType.ANIME:
print("Using Anime model")
out_img = upscaler.upscale_anime(task.get_imageUrl())
else:
print("Using Real model")
out_img = upscaler.upscale(task.get_imageUrl())
upload_image(BytesIO(out_img), output_key)
return {"generated_image_url": construct_default_s3_url(output_key)}
def model_fn(model_dir):
print("Logs: model loaded .... starts")
prompt_modifier.load()
upscaler.load()
inpainter.load()
print("Logs: model loaded ....")
return
def predict_fn(data, pipe):
task = Task(data)
print("task is ", data)
try:
task_type = task.get_type()
if task_type == TaskType.REMOVE_BG:
return remove_bg(task)
elif task_type == TaskType.INPAINT:
return inpaint(task)
elif task_type == TaskType.UPSCALE_IMAGE:
return upscale_image(task)
else:
raise Exception("Invalid task type")
except Exception as e:
print(f"Error: {e}")
slack.error_alert(task, e)
return None
|