File size: 3,160 Bytes
4adca93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
from io import BytesIO

import torch
from data.dataAccessor import update_db
from data.task import ModelType, Task, TaskType
from pipelines.inpainter import InPainter
from pipelines.prompt_modifier import PromptModifier
from pipelines.remove_background import RemoveBackground
from pipelines.upscaler import Upscaler
from util.cache import clear_cuda
from util.commons import (
    add_code_names,
    construct_default_s3_url,
    upload_image,
    upload_images,
)
from util.slack import Slack

torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True

num_return_sequences = 4
auto_mode = False

slack = Slack()

prompt_modifier = PromptModifier(num_of_sequences=num_return_sequences)
upscaler = Upscaler()
inpainter = InPainter()


@update_db
@slack.auto_send_alert
def remove_bg(task: Task):
    remove_background = RemoveBackground()
    output_image = remove_background.remove(task.get_imageUrl())

    output_key = "crecoAI/{}_rmbg.png".format(task.get_taskId())
    upload_image(output_image, output_key)

    return {"generated_image_url": construct_default_s3_url(output_key)}


@update_db
@slack.auto_send_alert
def inpaint(task: Task):
    prompt = add_code_names(task.get_prompt())
    if task.is_prompt_engineering():
        prompt = prompt_modifier.modify(prompt)
    else:
        prompt = [prompt] * num_return_sequences

    print({"prompts": prompt})

    images = inpainter.process(
        prompt=prompt,
        image_url=task.get_imageUrl(),
        mask_image_url=task.get_maskImageUrl(),
        width=task.get_width(),
        height=task.get_height(),
        seed=task.get_seed(),
        negative_prompt=[task.get_negative_prompt()] * num_return_sequences,
    )
    generated_image_urls = upload_images(images, "_inpaint", task.get_taskId())

    clear_cuda()

    return {"modified_prompts": prompt, "generated_image_urls": generated_image_urls}


@update_db
@slack.auto_send_alert
def upscale_image(task: Task):
    output_key = "crecoAI/{}_upscale.png".format(task.get_taskId())
    out_img = None
    if task.get_modelType() == ModelType.ANIME:
        print("Using Anime model")
        out_img = upscaler.upscale_anime(task.get_imageUrl())
    else:
        print("Using Real model")
        out_img = upscaler.upscale(task.get_imageUrl())

    upload_image(BytesIO(out_img), output_key)
    return {"generated_image_url": construct_default_s3_url(output_key)}


def model_fn(model_dir):
    print("Logs: model loaded .... starts")

    prompt_modifier.load()
    upscaler.load()
    inpainter.load()

    print("Logs: model loaded ....")
    return


def predict_fn(data, pipe):
    task = Task(data)
    print("task is ", data)

    try:
        task_type = task.get_type()

        if task_type == TaskType.REMOVE_BG:
            return remove_bg(task)
        elif task_type == TaskType.INPAINT:
            return inpaint(task)
        elif task_type == TaskType.UPSCALE_IMAGE:
            return upscale_image(task)
        else:
            raise Exception("Invalid task type")
    except Exception as e:
        print(f"Error: {e}")
        slack.error_alert(task, e)
        return None