File size: 7,484 Bytes
d7fbbae
 
 
ed76569
d7fbbae
76a5ad1
834fe59
3ee0121
054360a
0074787
d7fbbae
a6e58fb
4f0170e
 
e5c7f2d
3b1dbe2
e1eaf72
 
3845f1e
3b1dbe2
e5c7f2d
 
 
7e11826
c527e57
4f0170e
9acc856
76a5ad1
d7fbbae
 
2ab3a9e
e5c7f2d
d7fbbae
1d3f34b
d7fbbae
3ee0121
 
 
 
 
 
 
 
 
e5c7f2d
 
d7fbbae
e5c7f2d
3ee0121
d7fbbae
 
3ee0121
 
d7fbbae
 
 
 
 
 
dd66595
e5c7f2d
 
 
3ee0121
d7fbbae
3ee0121
 
 
 
d7fbbae
 
96e8dac
 
62dba47
bf7ac25
 
 
d7fbbae
 
 
 
 
 
 
 
 
 
 
971e8d0
d7fbbae
 
 
 
 
739b156
d7fbbae
 
 
 
 
 
 
3ee0121
 
 
 
 
e5c7f2d
3ee0121
 
d7fbbae
 
 
 
 
 
3018664
d7fbbae
 
 
 
 
 
 
 
 
 
 
 
3ee0121
45c7008
 
3ee0121
 
e5c7f2d
 
 
 
 
 
 
 
 
3ee0121
d7fbbae
 
 
a38c82f
d7fbbae
2ab3a9e
 
d7fbbae
 
 
 
a38c82f
d7fbbae
2ab3a9e
 
d7fbbae
 
 
 
 
 
 
fd57b60
96e8dac
d7fbbae
 
 
 
 
 
 
96e8dac
d7fbbae
 
 
3ee0121
e5c7f2d
d7fbbae
 
 
 
 
 
 
 
 
 
 
 
e5c7f2d
31fa6f3
d7fbbae
e5c7f2d
d7fbbae
 
 
3ee0121
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
import gradio as gr
import numpy as np
import random
import spaces 
import torch
from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, FlowMatchEulerDiscreteScheduler
from typing import Optional, Union, List, Tuple
from PIL import Image


device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if torch.cuda.is_available() else torch.float32
model_repo_id = "AiArtLab/sdxs"

DEFAULT_REFINE_TEMPLATE = (
    "You are a skilled text-to-image prompt engineer whose sole function is to transform the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
    "**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt and MUST be described in rich detail within the first sentence or two.** "
    "If the input is short, elaborate the subject using diverse attributes (style, pose, expression, lighting/color palette/mood). **Descriptions must avoid cliches and include diverse options.** "
    "If the input is long, concisely pack the core subject and essential details into the final three-sentence format without losing crucial information. "
    "Output **only** the final revised prompt in **English**, with absolutely no commentary, thinking text, or surrounding quotes.\n"
    "User input prompt: {prompt}"
)

pipe = DiffusionPipeline.from_pretrained(
    model_repo_id,
    torch_dtype=dtype,
    trust_remote_code=True
).to(device)

MAX_SEED = np.iinfo(np.int32).max
MIN_IMAGE_SIZE = 768
MAX_IMAGE_SIZE = 1536 

@spaces.GPU(duration=30)
def infer(
    prompt: str,
    negative_prompt: str,
    seed: int,
    randomize_seed: bool,
    width: int,
    height: int,
    guidance_scale: float,
    num_inference_steps: int,
    refine_prompt: bool,
    # НОВОЕ: Аргумент для шаблона уточнения
    refine_template: str,
    progress=gr.Progress(track_tqdm=True),
) -> Tuple[Image.Image, int, Optional[str]]:
    
    if randomize_seed:
        seed = random.randint(0, MAX_SEED)
        
    output = pipe(
        prompt=prompt,
        negative_prompt=negative_prompt,
        guidance_scale=guidance_scale,
        num_inference_steps=num_inference_steps,
        width=width,
        height=height,
        seed=seed,
        refine_prompt=refine_prompt,
        # НОВОЕ: Передаем шаблон в пайплайн
        refine_template=refine_template
    )

    image = output.images[0]
    refined_prompt = output.refined_prompt if isinstance(output.refined_prompt, str) else None

    return image, seed, refined_prompt

examples = [
    "A frozen river, surrounded by snow-covered trees, reflects the clear blue sky, with a warm glow from the setting sun.",
    "A young woman with striking blue eyes and pointed ears, adorned with a floral kimono and a tattoo. Her hair is styled in a braid, and she wears a pair of ears",
    "A volcano explodes, creating a skull face shadow in embers with lightning illuminating the clouds.",
    "There is a young male character standing against a vibrant, colorful graffiti wall. he is wearing a straw hat, a black jacket adorned with gold accents, and black shorts.",
    "A man with dark hair and a beard is meticulously carving an intricate design on a piece of pottery. He is wearing a traditional scarf and a white shirt, and he is focused on his work.",
    "girl, smiling, red eyes, blue hair, white shirt"
]

css = """
#col-container {
    margin: 0 auto;
    max-width: 640px;
}
"""

with gr.Blocks(css=css) as demo:
    with gr.Column(elem_id="col-container"):
        gr.Markdown(" # Simple Diffusion (sdxs)")

        with gr.Row():
            prompt = gr.Text(
                label="Prompt",
                show_label=False,
                max_lines=5,
                placeholder="Enter your prompt",
                container=False,
            )

            run_button = gr.Button("Run", scale=0, variant="primary")

        result = gr.Image(label="Result", show_label=False)
        
        refined_prompt_output = gr.Text(
            label="Refined Prompt (Уточненный промпт)",
            max_lines=5,
            placeholder="Уточненный промпт появится здесь, если выбрана опция 'Уточнить промпт'",
            interactive=False,
            show_label=True
        )

        with gr.Accordion("Advanced Settings", open=False):
            negative_prompt = gr.Text(
                label="Negative prompt",
                max_lines=1,
                placeholder="Enter a negative prompt",
                value ="bad quality, low resolution, oversaturated, sketch"
            )

            seed = gr.Slider(
                label="Seed",
                minimum=0,
                maximum=MAX_SEED,
                step=1,
                value=0,
            )

            randomize_seed = gr.Checkbox(label="Randomize seed", value=True)

            refine_checkbox = gr.Checkbox(
                label="Refine Prompt (Уточнить промпт)", 
                value=True,
                info="Использует LLM для расширения и детализации введенного промпта перед генерацией изображения."
            )
            
            # НОВОЕ: Поле для редактирования шаблона уточнения
            refine_template_input = gr.Text(
                label="Refine Prompt Template (Шаблон уточнения)",
                value=DEFAULT_REFINE_TEMPLATE, # Устанавливаем значение по умолчанию
                lines=10,
                show_label=True,
                info="Шаблон для LLM. Должен содержать плейсхолдер {prompt}."
            )

            with gr.Row():
                width = gr.Slider(
                    label="Width",
                    minimum=MIN_IMAGE_SIZE,
                    maximum=MAX_IMAGE_SIZE,
                    step=64,
                    value=1280,
                )

                height = gr.Slider(
                    label="Height",
                    minimum=MIN_IMAGE_SIZE,
                    maximum=MAX_IMAGE_SIZE,
                    step=64,
                    value=1536,
                )

            with gr.Row():
                guidance_scale = gr.Slider(
                    label="Guidance scale",
                    minimum=0.0,
                    maximum=10.0,
                    step=0.5,
                    value=4.0,
                )

                num_inference_steps = gr.Slider(
                    label="Number of inference steps",
                    minimum=1,
                    maximum=50,
                    step=1,
                    value=40,
                )

        gr.Examples(examples=examples, inputs=[prompt])
    
    # ИЗМЕНЕНИЕ: Обновлены inputs 
    gr.on(
        triggers=[run_button.click, prompt.submit],
        fn=infer,
        inputs=[
            prompt,
            negative_prompt,
            seed,
            randomize_seed,
            width,
            height,
            guidance_scale,
            num_inference_steps,
            refine_checkbox,
            refine_template_input,
        ],
        outputs=[result, seed, refined_prompt_output],
    )

if __name__ == "__main__":
    demo.launch()