File size: 6,077 Bytes
c9d56e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c715c44
 
 
c9d56e0
 
 
 
 
 
 
 
 
 
 
 
c715c44
c9d56e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
import huggingface_hub
# Monkey-patch: if cached_download is missing, alias it to hf_hub_download.
if not hasattr(huggingface_hub, "cached_download"):
    huggingface_hub.cached_download = huggingface_hub.hf_hub_download

print("huggingface_hub version:", huggingface_hub.__version__)

import diffusers
print("diffusers version:", diffusers.__version__)
import numpy
print("numpy version:", numpy.__version__)
import spaces
import gradio as gr
import torch
from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
from PIL import Image
import os
import time

from utils.dl_utils import dl_cn_model, dl_cn_config, dl_tagger_model, dl_lora_model
from utils.image_utils import resize_image_aspect_ratio, base_generation

from utils.prompt_utils import execute_prompt, remove_color, remove_duplicates
from utils.tagger import modelLoad, analysis

# Set up directories
path = os.getcwd()
cn_dir = os.path.join(path, "controlnet")
tagger_dir = os.path.join(path, "tagger")
lora_dir = os.path.join(path, "lora")
os.makedirs(cn_dir, exist_ok=True)
os.makedirs(tagger_dir, exist_ok=True)
os.makedirs(lora_dir, exist_ok=True)

# Download required models and configs
dl_cn_model(cn_dir)
dl_cn_config(cn_dir)
dl_tagger_model(tagger_dir)
dl_lora_model(lora_dir)

def load_model(lora_dir, cn_dir):
    dtype = torch.float16
    vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
    controlnet = ControlNetModel.from_pretrained(cn_dir, torch_dtype=dtype, use_safetensors=True)

    pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
        "cagliostrolab/animagine-xl-3.1", controlnet=controlnet, vae=vae, torch_dtype=torch.float16
    )
    pipe.enable_model_cpu_offload()
    pipe.load_lora_weights(lora_dir, weight_name="lineart.safetensors")
    return pipe

@spaces.GPU(duration=120)
def predict(input_image_path, prompt, negative_prompt, controlnet_scale):
    pipe = load_model(lora_dir, cn_dir) 
    input_image = Image.open(input_image_path)
    base_image = base_generation(input_image.size, (255, 255, 255, 255)).convert("RGB")
    resize_image = resize_image_aspect_ratio(input_image)
    resize_base_image = resize_image_aspect_ratio(base_image)
    generator = torch.manual_seed(0)
    last_time = time.time()
    # Prepend a base prompt to get best results
    prompt = "masterpiece, best quality, monochrome, sharp uniform black lines, vector style, very thick lineart, clean lineart, no shading, solid very thick black lines, no gradients, white background, " + prompt
    execute_tags = ["sketch", "transparent background"]
    prompt = execute_prompt(execute_tags, prompt)
    prompt = remove_duplicates(prompt)        
    prompt = remove_color(prompt)
    print(prompt)




    output_image = pipe(
        image=resize_base_image,
        control_image=resize_image,
        strength=1.0,
        prompt=prompt,
        negative_prompt=negative_prompt,
        controlnet_conditioning_scale=float(controlnet_scale),
        generator=generator,
        num_inference_steps=40,
        eta=1.0,
    ).images[0]
    print(f"Time taken: {time.time() - last_time}")

    output_image = output_image.resize(input_image.size, Image.LANCZOS)
    return output_image

@spaces.GPU(duration=120)
def prompt_analysis(input_image_path):
    """
    Run prompt analysis on the given image.
    Loads the tagger model, runs analysis, cleans the tags, and returns a string.
    """
    # Load the tagger model using the tagger_dir (set earlier in the file)
    tagger_model = modelLoad(tagger_dir)
    tags = analysis(input_image_path, tagger_dir, tagger_model)
    tags_clean = remove_color(tags)
    if isinstance(tags_clean, (list, tuple)):
        return ", ".join(tags_clean)
    return tags_clean


class Img2Img:
    def __init__(self):
        self.demo = self.layout()
        self.tagger_model = None
        self.input_image_path = None
        self.canny_image = None

    def process_prompt_analysis(self, input_image_path):
        if self.tagger_model is None:
            self.tagger_model = modelLoad(tagger_dir)
        tags = analysis(input_image_path, tagger_dir, self.tagger_model)
        tags_list = remove_color(tags)
        return tags_list

    def layout(self):
        css = """
        #intro{
            max-width: 32rem;
            text-align: center;
            margin: 0 auto;
        }
        """
        with gr.Blocks(css=css) as demo:
            with gr.Row():
                with gr.Column():
                    self.input_image_path = gr.Image(label="Input image", type='filepath')
                    self.prompt = gr.Textbox(label="Prompt", lines=3)
                    self.negative_prompt = gr.Textbox(
                        label="Negative prompt", 
                        lines=3, 
                        value="sketch, lowres, error, extra digit, fewer digits, cropped, worst quality,low quality, normal quality, jpeg artifacts, blurry"
                    )
                    # Button to run prompt analysis locally (UI callback)
                    prompt_analysis_button = gr.Button("Prompt analysis")
                    self.controlnet_scale = gr.Slider(
                        minimum=0.5, maximum=1.25, value=1.0, step=0.01, label="Lineart fidelity"
                    )                 
                    generate_button = gr.Button(value="Generate", variant="primary")
                with gr.Column():
                    self.output_image = gr.Image(type="pil", label="Output image")

            prompt_analysis_button.click(
                self.process_prompt_analysis,
                inputs=[self.input_image_path],
                outputs=self.prompt
            )

            generate_button.click(
                fn=predict,
                inputs=[self.input_image_path, self.prompt, self.negative_prompt, self.controlnet_scale],
                outputs=self.output_image
            )
        return demo

img2img = Img2Img()
img2img.demo.queue()
img2img.demo.launch(share=True, show_error=True)