File size: 5,256 Bytes
d7fbbae
 
 
ed76569
d7fbbae
173dacf
 
823b6c5
054360a
d7fbbae
a6e58fb
a11bfea
e5c7f2d
7e11826
c527e57
4f0170e
9acc856
76a5ad1
d7fbbae
 
173dacf
71c90fd
1c4ac64
d7fbbae
3d5a83e
d7fbbae
3ee0121
 
 
 
 
 
 
 
d64a422
d7fbbae
62f5035
3ee0121
d7fbbae
 
4cf153b
173dacf
 
 
 
 
3ee0121
d7fbbae
 
 
 
 
 
dd66595
3ee0121
d7fbbae
3ee0121
4cf153b
173dacf
 
d7fbbae
 
96e8dac
a11bfea
bf7ac25
 
 
d7fbbae
 
 
 
 
 
 
 
 
 
 
a11bfea
d7fbbae
 
 
 
 
739b156
d7fbbae
9d1e33c
d7fbbae
 
 
 
 
 
 
 
7e1ba99
035b212
d7fbbae
 
 
 
d3e1560
d7fbbae
 
 
 
 
 
 
 
 
 
 
 
 
 
a38c82f
d7fbbae
1c4ac64
7a6dce9
d7fbbae
 
 
 
a38c82f
d7fbbae
1c4ac64
7d6fd96
d7fbbae
 
 
 
 
 
 
fd57b60
96e8dac
d7fbbae
 
 
 
 
 
 
96e8dac
d7fbbae
 
 
3ee0121
d7fbbae
 
 
 
 
 
 
 
 
 
 
 
d64a422
d7fbbae
62f5035
d7fbbae
 
 
c4543d4
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
import gradio as gr
import numpy as np
import random
import spaces 
import torch
from diffusers import DiffusionPipeline
from typing import Tuple
from PIL import Image 

device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_repo_id = "AiArtLab/sdxs-1b"

pipe = DiffusionPipeline.from_pretrained(
    model_repo_id,
    torch_dtype=dtype,
    trust_remote_code=True
).to(device)

MAX_SEED = np.iinfo(np.int32).max
MIN_IMAGE_SIZE = 768
MAX_IMAGE_SIZE = 1408
STEP = 64

@spaces.GPU(duration=60)
def infer(
    prompt: str,
    negative_prompt: str,
    seed: int,
    randomize_seed: bool,
    width: int,
    height: int,
    guidance_scale: float,
    num_inference_steps: int,
    refine_prompt: bool,
    progress=gr.Progress(track_tqdm=True),
) -> Tuple[Image.Image, int, str]:
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
    
    # Используем новую выделенную функцию улучшения промпта
    if refine_prompt:
        refined_list = pipe.refine_prompts(prompt)
        prompt = refined_list[0] # Метод возвращает список, берем первый элемент
    
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        seed=seed,
    )

    image = output.images[0]
    
    # Возвращаем улучшенный промпт, чтобы он отобразился в интерфейсе
    return image, seed, prompt

examples = [
    "A young woman with striking blue eyes and pointed ears, adorned with a floral kimono and a tattoo. Her hair is styled in a braid, and she wears a pair of ears",
    "A frozen river, surrounded by snow-covered trees, reflects the clear blue sky, with a warm glow from the setting sun.",
    "There is a young male character standing against a vibrant, colorful graffiti wall. he is wearing a straw hat, a black jacket adorned with gold accents, and black shorts.",
    "A man with dark hair and a beard is meticulously carving an intricate design on a piece of pottery. He is wearing a traditional scarf and a white shirt, and he is focused on his work.",
    "girl, smiling, red eyes, blue hair, white shirt"
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(" # Simple Diffusion (sdxs)")

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=5,
                placeholder="Enter your prompt",
                value ="cat",
                container=False,
            )

            run_button = gr.Button("Run", scale=0, variant="primary")

        result = gr.Image(label="Result", show_label=False)

        with gr.Accordion("Advanced Settings", open=False):
            refine_prompt = gr.Checkbox(label="Refine Prompt", value=True)
            
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                value ="bad quality grainy image with low details, incomplete text, despite numerous technical flaws and distorted figures"
            )

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )
            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=MIN_IMAGE_SIZE,
                    maximum=MAX_IMAGE_SIZE,
                    step=STEP,
                    value=1024,
                )

                height = gr.Slider(
                    label="Height",
                    minimum=MIN_IMAGE_SIZE,
                    maximum=MAX_IMAGE_SIZE,
                    step=STEP,
                    value=MAX_IMAGE_SIZE,
                )

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.5,
                    value=4.0,
                )

                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=40,
                )

        gr.Examples(examples=examples, inputs=[prompt])
    
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
            refine_prompt,
        ],
        outputs=[result, seed, prompt],
    )

if __name__ == "__main__":
    demo.launch()