recoilme commited on
Commit
3d8b2a1
·
verified ·
1 Parent(s): e8aae27

Update pipeline_sdxs.py

Browse files
Files changed (1) hide show
  1. pipeline_sdxs.py +177 -356
pipeline_sdxs.py CHANGED
@@ -1,372 +1,193 @@
 
 
 
 
1
  import torch
2
- from diffusers import DiffusionPipeline
3
- from diffusers.utils import BaseOutput
4
- from dataclasses import dataclass
5
- from typing import List, Union, Optional, Tuple, Any
6
  from PIL import Image
7
- import numpy as np
8
- from tqdm import tqdm
9
-
10
- @dataclass
11
- class SdxsPipelineOutput(BaseOutput):
12
- images: Union[List[Image.Image], np.ndarray]
13
- refined_prompt: Optional[Union[str, List[str]]] = None
14
-
15
- class SdxsPipeline(DiffusionPipeline):
16
- # НОВОЕ: Константа для токена </think> в Qwen3
17
- END_THINK_TOKEN_ID = 151668
18
-
19
- # Шаблон промпта по умолчанию
20
- DEFAULT_REFINE_TEMPLATE = (
21
- "You are a skilled text-to-image prompt engineer whose sole function is to transform the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
22
- "**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt and MUST be described in rich detail within the first sentence or two.** "
23
- "If the input is short, elaborate the subject using diverse attributes (style, pose, expression, lighting/color palette/mood). **Descriptions must avoid cliches and include diverse options.** "
24
- "If the input is long, concisely pack the core subject and essential details into the final three-sentence format without losing crucial information. "
25
- "Output **only** the final revised prompt in **English**, with absolutely no commentary, thinking text, or surrounding quotes.\n"
26
- "User input prompt: {prompt}"
27
- )
28
- #User input prompt: {prompt}
29
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 192):
30
- super().__init__()
31
- self.register_modules(
32
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
33
- unet=unet, scheduler=scheduler
34
- )
35
- self.vae_scale_factor = 16
36
- self.max_length = max_length
37
 
38
- def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
39
- device = device or self.device
40
- dtype = dtype or next(self.unet.parameters()).dtype
41
-
42
- # Преобразуем в списки
43
- if isinstance(prompt, str):
44
- prompt = [prompt]
45
- if isinstance(negative_prompt, str):
46
- negative_prompt = [negative_prompt]
47
-
48
- # Если промпты не заданы, используем пустые эмбеддинги
49
- if prompt is None and negative_prompt is None:
50
- hidden_dim = 1024 # Размерность эмбеддинга
51
- seq_len = self.max_length
52
- batch_size = 1
53
- # ИЗМЕНЕНО: Возвращаем три элемента: embeds, mask, pooled
54
- empty_embeds = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
55
- empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
56
- empty_pooled = torch.zeros((batch_size, hidden_dim), dtype=dtype, device=device)
57
- return empty_embeds, empty_mask, empty_pooled
58
-
59
- # Токенизация с фиксированным max_length и padding="max_length"
60
- def encode_texts(texts, max_length=self.max_length):
61
- with torch.no_grad():
62
- if isinstance(texts, str):
63
- texts = [texts]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
 
65
- for i, prompt_item in enumerate(texts):
66
- messages = [
67
- {"role": "user", "content": prompt_item},
68
- ]
69
- prompt_item = self.tokenizer.apply_chat_template(
70
- messages,
71
- tokenize=False,
72
- add_generation_prompt=True,
73
- enable_thinking=True,
74
- )
75
- texts[i] = prompt_item
76
-
77
- toks = self.tokenizer(
78
- texts,
79
- return_tensors="pt",
80
- padding="max_length",
81
- truncation=True,
82
- max_length=max_length
83
- ).to(device)
84
- outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
85
-
86
- # Токен-эмбеддинги (для Cross-Attention)
87
- hidden = outs.hidden_states[-2] # Используем last hidden state -2???
88
- # Маска внимания (для Cross-Attention)
89
- attention_mask = toks["attention_mask"]
90
-
91
- # Пулинг-эмбеддинг (для Class/Time Conditioning). Берем эмбеддинг последнего токена без padding.
92
- sequence_lengths = attention_mask.sum(dim=1) - 1
93
- batch_size = hidden.shape[0]
94
- pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
95
-
96
- # --- НОВАЯ ЛОГИКА: ОБЪЕДИНЕНИЕ ДЛЯ КРОСС-ВНИМАНИЯ ---
97
- # 1. Расширяем пулинг-вектор до последовательности [B, 1, 1024]
98
- pooled_expanded = pooled.unsqueeze(1)
99
-
100
- # 2. Объединяем последовательность токенов и пулинг-вектор
101
- # !!! ИЗМЕНЕНИЕ ЗДЕСЬ !!!: Пулинг идет ПЕРВЫМ
102
- # Теперь: [B, 1 + L, 1024]. Пулинг стал токеном в НАЧАЛЕ.
103
- new_encoder_hidden_states = torch.cat([pooled_expanded, hidden], dim=1)
104
-
105
- # 3. Обновляем маску внимания для нового токена
106
- # Маска внимания: [B, 1 + L]. Добавляем 1 в НАЧАЛО.
107
- # torch.ones((batch_size, 1), device=device) создает маску [B, 1] со значениями 1.
108
- new_attention_mask = torch.cat([torch.ones((batch_size, 1), device=device), attention_mask], dim=1)
109
-
110
- return new_encoder_hidden_states, new_attention_mask, pooled
111
-
112
- # Кодируем позитивные и негативные промпты
113
- # ИСПРАВЛЕНИЕ: Теперь возвращаем (None, None, None), чтобы избежать UnboundLocalError
114
- pos_result = encode_texts(prompt) if prompt is not None else (None, None, None)
115
- neg_result = encode_texts(negative_prompt) if negative_prompt is not None else (None, None, None)
116
-
117
- pos_embeddings, pos_mask, pos_pooled = pos_result
118
- neg_embeddings, neg_mask, neg_pooled = neg_result
119
-
120
- # Выравниваем размеры batch_size
121
- batch_size = max(
122
- pos_embeddings.shape[0] if pos_embeddings is not None else 0,
123
- neg_embeddings.shape[0] if neg_embeddings is not None else 0
124
- )
125
-
126
- # Повторяем эмбеддинги, маски и пулинг по batch_size
127
- if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size:
128
- pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1)
129
- pos_mask = pos_mask.repeat(batch_size, 1)
130
- pos_pooled = pos_pooled.repeat(batch_size, 1)
131
-
132
- # ИСПРАВЛЕНИЕ: Проверяем, существует ли neg_embeddings, прежде чем обращаться к его shape[0]
133
- if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size:
134
- neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1)
135
- neg_mask = neg_mask.repeat(batch_size, 1)
136
- neg_pooled = neg_pooled.repeat(batch_size, 1)
137
-
138
- # Конкатенируем для guidance (эмбеддинги и маски)
139
- # Убеждаемся, что все три компонента существуют перед конкатенацией
140
- if pos_embeddings is not None and neg_embeddings is not None:
141
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
142
- attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
143
- pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
144
- elif pos_embeddings is not None:
145
- text_embeddings = pos_embeddings
146
- attention_mask = pos_mask
147
- pooled_embeddings = pos_pooled
148
- else: # Только neg_embeddings
149
- text_embeddings = neg_embeddings
150
- attention_mask = neg_mask
151
- pooled_embeddings = neg_pooled
152
-
153
- # Возвращаем кортеж
154
- return (
155
- text_embeddings.to(device=device, dtype=dtype),
156
- attention_mask.to(device=device, dtype=torch.int64),
157
- pooled_embeddings.to(device=device, dtype=dtype)
158
- )
159
-
160
-
161
- @torch.no_grad()
162
- def generate_latents(
163
- self,
164
- text_embeddings,
165
- attention_mask,
166
- pooled_embeddings,
167
- height: int = 1536,
168
- width: int = 1280,
169
- num_inference_steps: int = 40,
170
- guidance_scale: float = 4.0,
171
- latent_channels: int = 16,
172
- batch_size: int = 1,
173
- generator=None,
174
- ):
175
- device = self.device
176
- dtype = next(self.unet.parameters()).dtype
177
 
178
- self.scheduler.set_timesteps(num_inference_steps, device=device)
179
-
180
- # Разделяем эмбеддинги и маски на условные и безусловные
181
- if guidance_scale > 1:
182
- neg_embeds, pos_embeds = text_embeddings.chunk(2)
183
- neg_mask, pos_mask = attention_mask.chunk(2)
184
- neg_pooled, pos_pooled = pooled_embeddings.chunk(2)
185
-
186
- # Повторяем, если batch_size больше
187
- if batch_size > pos_embeds.shape[0]:
188
- pos_embeds = pos_embeds.repeat(batch_size, 1, 1)
189
- neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
190
- pos_mask = pos_mask.repeat(batch_size, 1)
191
- neg_mask = neg_mask.repeat(batch_size, 1)
192
- pos_pooled = pos_pooled.repeat(batch_size, 1)
193
- neg_pooled = neg_pooled.repeat(batch_size, 1)
194
-
195
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
196
- unet_attention_mask = torch.cat([neg_mask, pos_mask], dim=0)
197
- unet_pooled_embeddings = torch.cat([neg_pooled, pos_pooled], dim=0)
198
- else:
199
- text_embeddings = text_embeddings.repeat(batch_size, 1, 1)
200
- unet_attention_mask = attention_mask.repeat(batch_size, 1)
201
- unet_pooled_embeddings = pooled_embeddings.repeat(batch_size, 1)
202
-
203
- # Инициализация латентов
204
- latent_shape = (
205
- batch_size,
206
- latent_channels,
207
- height // self.vae_scale_factor,
208
- width // self.vae_scale_factor
209
- )
210
- latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
211
-
212
- # Процесс диффузии
213
- for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
214
- latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents
215
-
216
- noise_pred = self.unet(
217
- latent_input,
218
- t,
219
- encoder_hidden_states=text_embeddings,
220
- encoder_attention_mask=unet_attention_mask,
221
- #added_cond_kwargs={'text_embeds': unet_pooled_embeddings}
222
- ).sample
223
-
224
- if guidance_scale > 1:
225
- noise_uncond, noise_text = noise_pred.chunk(2)
226
- noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
227
-
228
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
229
-
230
- return latents
231
-
232
 
233
- def decode_latents(self, latents, output_type="pil"):
234
- """Декодирование латентов в изображения."""
235
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
236
- with torch.no_grad():
237
- images = self.vae.decode(latents).sample
238
- images = (images / 2 + 0.5).clamp(0, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
 
240
- if output_type == "pil":
241
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
242
- images = (images * 255).round().astype("uint8")
243
- return [Image.fromarray(image) for image in images]
244
- return images.cpu().permute(0, 2, 3, 1).float().numpy()
245
 
246
- # ИЗМЕНЕНИЕ: Метод __call__ теперь корректно внутри класса SdxsPipeline
247
- @torch.no_grad()
248
- def __call__(
249
- self,
250
- prompt: Optional[Union[str, List[str]]] = None,
251
- height: int = 1280,
252
- width: int = 1024,
253
- num_inference_steps: int = 40,
254
- guidance_scale: float = 4.0,
255
- latent_channels: int = 16,
256
- output_type: str = "pil",
257
- return_dict: bool = True,
258
- batch_size: int = 1,
259
- seed: Optional[int] = None,
260
- negative_prompt: Optional[Union[str, List[str]]] = None,
261
- text_embeddings: Optional[torch.FloatTensor] = None,
262
- refine_prompt: bool = True,
263
- refine_template: Optional[str] = None,
264
- ):
265
- device = self.device
266
- generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
267
 
268
- refined_prompt_output = None
269
-
270
- # 1. ЛОГИКА УТОЧНЕНИЯ ПРОМПТА
271
- if refine_prompt and prompt is not None and text_embeddings is None:
272
-
273
- is_str_input = isinstance(prompt, str)
274
- original_prompts = [prompt] if is_str_input else prompt
275
-
276
- template = refine_template if refine_template is not None else self.DEFAULT_REFINE_TEMPLATE
277
-
278
- refined_list = []
279
 
280
- for p in original_prompts:
281
- # 1.1. Форматирование промпта по правилам Qwen
282
- messages = [
283
- {"role": "user", "content": template.format(prompt=p)} # Шаблон с промптом пользователя внутри
284
- ]
285
-
286
- # ИЗМЕНЕНИЕ: Используем chat_template для подготовки текста
287
- text = self.tokenizer.apply_chat_template(
288
- messages,
289
- tokenize=False,
290
- add_generation_prompt=True,
291
- enable_thinking=True
292
- )
293
-
294
- model_inputs = self.tokenizer([text], return_tensors="pt", truncation=True).to(device)
295
 
296
- try:
297
- # 1.2. Генерация текста (требует, чтобы self.text_encoder имел метод .generate())
298
- generated_ids = self.text_encoder.generate(
299
- **model_inputs,
300
- max_new_tokens=32768, # Ограничим, чтобы избежать слишком долгой генерации
301
- do_sample=True,
302
- pad_token_id=self.tokenizer.eos_token_id
303
- )
304
-
305
- # 1.3. Обрезка входного промпта
306
- # ИЗМЕНЕНИЕ: Обрезаем сгенерированные токены до тех, что были сгенерированы моделью
307
- output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
308
-
309
- # 1.4. Парсинг содержимого "мышления"
310
- try:
311
- # ИЗМЕНЕНИЕ: Ищем токен END_THINK_TOKEN_ID (151668) с конца
312
- # output_ids[::-1].index(151668) найдет индекс в обратном списке
313
- index = len(output_ids) - output_ids[::-1].index(self.END_THINK_TOKEN_ID)
314
- except ValueError:
315
- # Если токен </think> не найден, начинаем с начала
316
- index = 0
317
 
318
- # ИЗМЕНЕНИЕ: Декодируем контент только после </think>
319
- refined_text = self.tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n")
320
-
321
- # 1.5. Добавление оригинального промпта в конец
322
- # Формат: refined_text + ", " + user_prompt
323
- final_refined_text = f"{refined_text.strip()}"#\n{p}"
324
-
325
- except AttributeError:
326
- print("ВНИМАНИЕ: self.text_encoder не имеет метода .generate(). Уточнение промпта пропущено.")
327
- final_refined_text = p # Используем оригинальный промпт
328
- except Exception as e:
329
- print(f"Ошибка при уточнении промпта: {e}. Используется оригинальный промпт.")
330
- final_refined_text = p
331
-
332
- refined_list.append(final_refined_text)
333
-
334
- # Обновление промпта и сохранение уточненного для вывода
335
- prompt = refined_list[0] if is_str_input else refined_list
336
- refined_prompt_output = prompt # Здесь уже список или строка
337
 
338
- # 2. КОДИРОВАНИЕ ПРОМПТОВ (существующая логика)
339
- if text_embeddings is None:
340
- if prompt is None and negative_prompt is None:
341
- raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
342
-
343
- # ВЫЗОВ СТАНДАРТНОЙ ФУНКЦИИ encode_prompt
344
- text_embeddings, attention_mask, pooled_embeddings = self.encode_prompt(
345
- prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype
346
- )
347
- else:
348
- raise NotImplementedError("Передача text_embeddings напрямую пока не поддерживает передачу маски и пулинга. Используйте prompt/negative_prompt.")
349
 
 
 
 
 
 
 
 
350
 
351
- # 3. ГЕНЕРАЦИЯ ЛАТЕНТОВ (существующая логика)
352
- latents = self.generate_latents(
353
- text_embeddings=text_embeddings,
354
- attention_mask=attention_mask,
355
- pooled_embeddings=pooled_embeddings,
356
- height=height,
357
- width=width,
358
- num_inference_steps=num_inference_steps,
359
- guidance_scale=guidance_scale,
360
- latent_channels=latent_channels,
361
- batch_size=batch_size,
362
- generator=generator
363
- )
 
 
 
 
 
364
 
365
- # 4. ДЕКОДИРОВАНИЕ (существующая логика)
366
- images = self.decode_latents(latents, output_type=output_type)
367
-
368
- # 5. ВОЗВРАТ РЕЗУЛЬТАТА
369
- if not return_dict:
370
- return images
371
- # ИЗМЕНЕНИЕ: Возвращаем уточненный промпт
372
- return SdxsPipelineOutput(images=images, refined_prompt=refined_prompt_output)
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ import random
4
+ import spaces
5
  import torch
6
+ from diffusers import DiffusionPipeline, AutoencoderKL, UNet2DConditionModel, FlowMatchEulerDiscreteScheduler,AsymmetricAutoencoderKL
7
+ from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from typing import Optional, Union, List, Tuple
 
9
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ device = "cuda" if torch.cuda.is_available() else "cpu"
12
+ dtype = torch.float16 if torch.cuda.is_available() else torch.float32
13
+ model_repo_id = "AiArtLab/sdxs-08b"
14
+
15
+ pipe = DiffusionPipeline.from_pretrained(
16
+ model_repo_id,
17
+ torch_dtype=dtype,
18
+ trust_remote_code=True
19
+ ).to(device)
20
+
21
+ # НОВОЕ: Инициализация Qwen3 для рефайнинга
22
+ llm_model_id = "Qwen/Qwen3-0.6B"
23
+ tokenizer = AutoTokenizer.from_pretrained(llm_model_id)
24
+ llm_model = AutoModelForCausalLM.from_pretrained(llm_model_id, torch_dtype="auto", device_map="auto")
25
+
26
+ MAX_SEED = np.iinfo(np.int32).max
27
+ MIN_IMAGE_SIZE = 640
28
+ MAX_IMAGE_SIZE = 1280
29
+ STEP = 64
30
+
31
+ # НОВОЕ: Настройки для LLM
32
+ END_THINK_TOKEN_ID = 151668
33
+ DEFAULT_REFINE_TEMPLATE = (
34
+ "You are a skilled text-to-image prompt engineer whose sole function is to transform the user's input into an aesthetically optimized, detailed, and visually descriptive three-sentence output. "
35
+ "**The primary subject (e.g., 'girl', 'dog', 'house') MUST be the main focus of the revised prompt and MUST be described in rich detail within the first sentence or two.** "
36
+ "Output **only** the final revised prompt in **English**, with absolutely no commentary, thinking text, or surrounding quotes.\n"
37
+ "User input prompt: {prompt}"
38
+ )
39
+
40
+ @spaces.GPU(duration=30)
41
+ def infer(
42
+ prompt: str,
43
+ negative_prompt: str,
44
+ seed: int,
45
+ randomize_seed: bool,
46
+ width: int,
47
+ height: int,
48
+ guidance_scale: float,
49
+ num_inference_steps: int,
50
+ refine_prompt: bool,
51
+ progress=gr.Progress(track_tqdm=True),
52
+ ) -> Tuple[Image.Image, int, str]: # Возвращаем prompt в конце
53
+
54
+ if randomize_seed:
55
+ seed = random.randint(0, MAX_SEED)
56
+
57
+ # НОВОЕ: Логика улучшения промпта
58
+ if refine_prompt and prompt:
59
+ messages = [{"role": "user", "content": DEFAULT_REFINE_TEMPLATE.format(prompt=prompt)}]
60
+ text = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True, enable_thinking=True)
61
+ model_inputs = tokenizer([text], return_tensors="pt").to(llm_model.device)
62
 
63
+ generated_ids = llm_model.generate(**model_inputs, max_new_tokens=2048, do_sample=True, pad_token_id=tokenizer.eos_token_id)
64
+ output_ids = generated_ids[0][len(model_inputs.input_ids[0]):].tolist()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
+ try:
67
+ index = len(output_ids) - output_ids[::-1].index(END_THINK_TOKEN_ID)
68
+ except ValueError:
69
+ index = 0
70
+
71
+ prompt = tokenizer.decode(output_ids[index:], skip_special_tokens=True).strip("\n").strip()
72
+
73
+ output = pipe(
74
+ prompt=prompt,
75
+ negative_prompt=negative_prompt,
76
+ guidance_scale=guidance_scale,
77
+ num_inference_steps=num_inference_steps,
78
+ width=width,
79
+ height=height,
80
+ seed=seed,
81
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ image = output.images[0]
84
+ return image, seed, prompt # Возвращаем измененный промпт
85
+
86
+ examples = [
87
+ "A frozen river, surrounded by snow-covered trees, reflects the clear blue sky, with a warm glow from the setting sun.",
88
+ "A young woman with striking blue eyes and pointed ears, adorned with a floral kimono and a tattoo. Her hair is styled in a braid, and she wears a pair of ears",
89
+ "A volcano explodes, creating a skull face shadow in embers with lightning illuminating the clouds.",
90
+ "There is a young male character standing against a vibrant, colorful graffiti wall. he is wearing a straw hat, a black jacket adorned with gold accents, and black shorts.",
91
+ "A man with dark hair and a beard is meticulously carving an intricate design on a piece of pottery. He is wearing a traditional scarf and a white shirt, and he is focused on his work.",
92
+ "girl, smiling, red eyes, blue hair, white shirt"
93
+ ]
94
+
95
+ css = """
96
+ #col-container {
97
+ margin: 0 auto;
98
+ max-width: 640px;
99
+ }
100
+ """
101
+
102
+ with gr.Blocks(css=css) as demo:
103
+ with gr.Column(elem_id="col-container"):
104
+ gr.Markdown(" # Simple Diffusion (sdxs-08b)")
105
+
106
+ with gr.Row():
107
+ prompt = gr.Text(
108
+ label="Prompt",
109
+ show_label=False,
110
+ max_lines=5,
111
+ placeholder="Enter your prompt",
112
+ container=False,
113
+ )
114
 
115
+ run_button = gr.Button("Run", scale=0, variant="primary")
 
 
 
 
116
 
117
+ result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
+ with gr.Accordion("Advanced Settings", open=False):
120
+ # Изменено value на True
121
+ refine_prompt = gr.Checkbox(label="Refine Prompt with Qwen3", value=True)
 
 
 
 
 
 
 
 
122
 
123
+ negative_prompt = gr.Text(
124
+ label="Negative prompt",
125
+ max_lines=1,
126
+ placeholder="Enter a negative prompt",
127
+ value ="bad quality, low resolution"
128
+ )
 
 
 
 
 
 
 
 
 
129
 
130
+ seed = gr.Slider(
131
+ label="Seed",
132
+ minimum=0,
133
+ maximum=MAX_SEED,
134
+ step=1,
135
+ value=0,
136
+ )
137
+ randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
138
+
139
+ with gr.Row():
140
+ width = gr.Slider(
141
+ label="Width",
142
+ minimum=MIN_IMAGE_SIZE,
143
+ maximum=MAX_IMAGE_SIZE,
144
+ step=STEP,
145
+ value=1024,
146
+ )
 
 
 
 
147
 
148
+ height = gr.Slider(
149
+ label="Height",
150
+ minimum=MIN_IMAGE_SIZE,
151
+ maximum=MAX_IMAGE_SIZE,
152
+ step=STEP,
153
+ value=MAX_IMAGE_SIZE,
154
+ )
 
 
 
 
 
 
 
 
 
 
 
 
155
 
156
+ with gr.Row():
157
+ guidance_scale = gr.Slider(
158
+ label="Guidance scale",
159
+ minimum=0.0,
160
+ maximum=10.0,
161
+ step=0.5,
162
+ value=4.0,
163
+ )
 
 
 
164
 
165
+ num_inference_steps = gr.Slider(
166
+ label="Number of inference steps",
167
+ minimum=1,
168
+ maximum=50,
169
+ step=1,
170
+ value=40,
171
+ )
172
 
173
+ gr.Examples(examples=examples, inputs=[prompt])
174
+
175
+ gr.on(
176
+ triggers=[run_button.click, prompt.submit],
177
+ fn=infer,
178
+ inputs=[
179
+ prompt,
180
+ negative_prompt,
181
+ seed,
182
+ randomize_seed,
183
+ width,
184
+ height,
185
+ guidance_scale,
186
+ num_inference_steps,
187
+ refine_prompt,
188
+ ],
189
+ outputs=[result, seed, prompt], # Добавлен prompt для обновления текста в интерфейсе
190
+ )
191
 
192
+ if __name__ == "__main__":
193
+ demo.launch()