dikdimon commited on
Commit
8c0bf45
·
verified ·
1 Parent(s): 36646c6

Upload sd-webui-xl_vec using SD-Hub

Browse files
.gitattributes CHANGED
@@ -166,3 +166,6 @@ outputs/txt2img-images/2025-09-13/00014-1035.png filter=lfs diff=lfs merge=lfs -
166
  outputs/txt2img-images/2025-09-13/00015-1323.png filter=lfs diff=lfs merge=lfs -text
167
  outputs/txt2img-images/2025-09-13/00016-1728.png filter=lfs diff=lfs merge=lfs -text
168
  outputs/txt2img-images/2025-09-13/00017-2053.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
166
  outputs/txt2img-images/2025-09-13/00015-1323.png filter=lfs diff=lfs merge=lfs -text
167
  outputs/txt2img-images/2025-09-13/00016-1728.png filter=lfs diff=lfs merge=lfs -text
168
  outputs/txt2img-images/2025-09-13/00017-2053.png filter=lfs diff=lfs merge=lfs -text
169
+ sd-webui-xl_vec/images/crop_top.png filter=lfs diff=lfs merge=lfs -text
170
+ sd-webui-xl_vec/images/mult.png filter=lfs diff=lfs merge=lfs -text
171
+ sd-webui-xl_vec/images/original_size.png filter=lfs diff=lfs merge=lfs -text
sd-webui-xl_vec/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__
2
+ /.vs
sd-webui-xl_vec/LICENSE ADDED
Binary file (24 Bytes). View file
 
sd-webui-xl_vec/README.md ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # XL Vec
2
+
3
+ ## What is this?
4
+
5
+ This is an extension for [stable-diffusion-webui](https://github.com/AUTOMATIC1111/stable-diffusion-webui) which overwrites SDXL's CLIP outputs.
6
+
7
+ ## Usage
8
+
9
+ Input values as you like.
10
+
11
+ ![GUI](images/gui.png)
12
+
13
+ Overwritten values are dumped into stdout.
14
+
15
+ ![stdout](images/log.png)
16
+
17
+ ## Examples
18
+
19
+ ```
20
+ Hassaku XL alphaV0.7 / DPM++ 3M SDE / 30 steps / 576x1024
21
+ Prompt: a cute girl sitting in flower garden, clear anime face, insanely frilled white dress, absurdly long brown hair, small silver tiara, long sleeves highneck dress, looking at viewer
22
+ Negative Prompt: maid
23
+ ```
24
+
25
+ See PNGInfo for details.
26
+
27
+ ### Crop Top
28
+
29
+ ![Crop top](images/crop_top.png)
30
+
31
+ ### Original Width/Height
32
+
33
+ ![Original Width/Height](images/original_size.png)
34
+
35
+ ### Token Multiplier
36
+
37
+ ![Token Multiplier](images/mult.png)
sd-webui-xl_vec/images/crop_top.png ADDED

Git LFS Details

  • SHA256: 8aaf2946fae78c9dd3fff563c639fbf358c555b9c33e53163e6e1a314cfeb2b2
  • Pointer size: 133 Bytes
  • Size of remote file: 22.2 MB
sd-webui-xl_vec/images/gui.png ADDED
sd-webui-xl_vec/images/log.png ADDED
sd-webui-xl_vec/images/mult.png ADDED

Git LFS Details

  • SHA256: c4c72cc491048f76f6064f6e80223a2cc90851664a0819f2a78389532f1550a7
  • Pointer size: 133 Bytes
  • Size of remote file: 23.7 MB
sd-webui-xl_vec/images/original_size.png ADDED

Git LFS Details

  • SHA256: d5ce692a37c760f944f2c217c3a610373a2dd40288c2bc2c0e044cb096f6e5d1
  • Pointer size: 133 Bytes
  • Size of remote file: 23.4 MB
sd-webui-xl_vec/scripts/__pycache__/sdhook.cpython-310.pyc ADDED
Binary file (8.45 kB). View file
 
sd-webui-xl_vec/scripts/__pycache__/xl_clip.cpython-310.pyc ADDED
Binary file (1.34 kB). View file
 
sd-webui-xl_vec/scripts/__pycache__/xl_vec.cpython-310.pyc ADDED
Binary file (13.4 kB). View file
 
sd-webui-xl_vec/scripts/__pycache__/xl_vec_xyz.cpython-310.pyc ADDED
Binary file (5.93 kB). View file
 
sd-webui-xl_vec/scripts/sdhook.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from typing import Any, Callable, Union
3
+
4
+ from torch import nn
5
+ from torch.utils.hooks import RemovableHandle
6
+
7
+ from ldm.modules.diffusionmodules.openaimodel import (
8
+ TimestepEmbedSequential,
9
+ )
10
+ from ldm.modules.attention import (
11
+ SpatialTransformer,
12
+ BasicTransformerBlock,
13
+ CrossAttention,
14
+ MemoryEfficientCrossAttention,
15
+ )
16
+ from ldm.modules.diffusionmodules.openaimodel import (
17
+ ResBlock,
18
+ )
19
+ from modules.processing import StableDiffusionProcessing
20
+ from modules import shared
21
+
22
+ class ForwardHook:
23
+
24
+ def __init__(self, module: nn.Module, fn: Callable[[nn.Module, Callable[..., Any], Any], Any]):
25
+ self.o = module.forward
26
+ self.fn = fn
27
+ self.module = module
28
+ self.module.forward = self.forward
29
+
30
+ def remove(self):
31
+ if self.module is not None and self.o is not None:
32
+ self.module.forward = self.o
33
+ self.module = None
34
+ self.o = None
35
+ self.fn = None
36
+
37
+ def forward(self, *args, **kwargs):
38
+ if self.module is not None and self.o is not None:
39
+ if self.fn is not None:
40
+ return self.fn(self.module, self.o, *args, **kwargs)
41
+ return None
42
+
43
+
44
+ class SDHook:
45
+
46
+ def __init__(self, enabled: bool):
47
+ self._enabled = enabled
48
+ self._handles: list[Union[RemovableHandle,ForwardHook]] = []
49
+
50
+ @property
51
+ def enabled(self):
52
+ return self._enabled
53
+
54
+ @enabled.setter
55
+ def enabled(self, v: bool):
56
+ self._enabled = bool(v)
57
+
58
+ @property
59
+ def batch_num(self):
60
+ return shared.state.job_no
61
+
62
+ @property
63
+ def step_num(self):
64
+ return shared.state.current_image_sampling_step
65
+
66
+ def __enter__(self):
67
+ if self.enabled:
68
+ pass
69
+
70
+ def __exit__(self, exc_type, exc_value, traceback):
71
+ if self.enabled:
72
+ for handle in self._handles:
73
+ handle.remove()
74
+ self._handles.clear()
75
+ self.dispose()
76
+
77
+ def dispose(self):
78
+ pass
79
+
80
+ def setup(
81
+ self,
82
+ p: StableDiffusionProcessing
83
+ ):
84
+ if not self.enabled:
85
+ return
86
+
87
+ wrapper = getattr(p.sd_model, "model", None)
88
+
89
+ unet: Union[nn.Module,None] = getattr(wrapper, "diffusion_model", None) if wrapper is not None else None
90
+ vae: Union[nn.Module,None] = getattr(p.sd_model, "first_stage_model", None)
91
+ clip: Union[nn.Module,None] = getattr(p.sd_model, "cond_stage_model", None)
92
+
93
+ assert unet is not None, "p.sd_model.diffusion_model is not found. broken model???"
94
+ self._do_hook(p, p.sd_model, unet=unet, vae=vae, clip=clip) # type: ignore
95
+ self.on_setup()
96
+
97
+ def on_setup(self):
98
+ pass
99
+
100
+ def _do_hook(
101
+ self,
102
+ p: StableDiffusionProcessing,
103
+ model: Any,
104
+ unet: Union[nn.Module,None],
105
+ vae: Union[nn.Module,None],
106
+ clip: Union[nn.Module,None]
107
+ ):
108
+ assert model is not None, "empty model???"
109
+
110
+ if clip is not None:
111
+ self.hook_clip(p, clip)
112
+
113
+ if unet is not None:
114
+ self.hook_unet(p, unet)
115
+
116
+ if vae is not None:
117
+ self.hook_vae(p, vae)
118
+
119
+ def hook_vae(
120
+ self,
121
+ p: StableDiffusionProcessing,
122
+ vae: nn.Module
123
+ ):
124
+ pass
125
+
126
+ def hook_unet(
127
+ self,
128
+ p: StableDiffusionProcessing,
129
+ unet: nn.Module
130
+ ):
131
+ pass
132
+
133
+ def hook_clip(
134
+ self,
135
+ p: StableDiffusionProcessing,
136
+ clip: nn.Module
137
+ ):
138
+ pass
139
+
140
+ def hook_layer(
141
+ self,
142
+ module: Union[nn.Module,Any],
143
+ fn: Callable[[nn.Module, tuple, Any], Any]
144
+ ):
145
+ if not self.enabled:
146
+ return
147
+
148
+ assert module is not None
149
+ assert isinstance(module, nn.Module)
150
+ self._handles.append(module.register_forward_hook(fn))
151
+
152
+ def hook_layer_pre(
153
+ self,
154
+ module: Union[nn.Module,Any],
155
+ fn: Callable[[nn.Module, tuple], Any]
156
+ ):
157
+ if not self.enabled:
158
+ return
159
+
160
+ assert module is not None
161
+ assert isinstance(module, nn.Module)
162
+ self._handles.append(module.register_forward_pre_hook(fn))
163
+
164
+ def hook_forward(
165
+ self,
166
+ module: Union[nn.Module,Any],
167
+ fn: Callable[[nn.Module, Callable[..., Any], Any], Any]
168
+ ):
169
+ assert module is not None
170
+ assert isinstance(module, nn.Module)
171
+ self._handles.append(ForwardHook(module, fn))
172
+
173
+ def log(self, msg: str):
174
+ print(msg, file=sys.stderr)
175
+
176
+
177
+ # enumerate SpatialTransformer in TimestepEmbedSequential
178
+ def each_transformer(unet_block: TimestepEmbedSequential):
179
+ for block in unet_block.children():
180
+ if isinstance(block, SpatialTransformer):
181
+ yield block
182
+
183
+ # enumerate BasicTransformerBlock in SpatialTransformer
184
+ def each_basic_block(trans: SpatialTransformer):
185
+ for block in trans.transformer_blocks.children():
186
+ if isinstance(block, BasicTransformerBlock):
187
+ yield block
188
+
189
+ # enumerate Attention Layers in TimestepEmbedSequential
190
+ # each_transformer + each_basic_block
191
+ def each_attns(unet_block: TimestepEmbedSequential):
192
+ for n, trans in enumerate(each_transformer(unet_block)):
193
+ for depth, basic_block in enumerate(each_basic_block(trans)):
194
+ # attn1: Union[CrossAttention,MemoryEfficientCrossAttention]
195
+ # attn2: Union[CrossAttention,MemoryEfficientCrossAttention]
196
+
197
+ attn1, attn2 = basic_block.attn1, basic_block.attn2
198
+ assert isinstance(attn1, CrossAttention) or isinstance(attn1, MemoryEfficientCrossAttention)
199
+ assert isinstance(attn2, CrossAttention) or isinstance(attn2, MemoryEfficientCrossAttention)
200
+
201
+ yield n, depth, attn1, attn2
202
+
203
+ def each_unet_attn_layers(unet: nn.Module):
204
+ def get_attns(layer_index: int, block: TimestepEmbedSequential, format: str):
205
+ for n, d, attn1, attn2 in each_attns(block):
206
+ kwargs = {
207
+ 'layer_index': layer_index,
208
+ 'trans_index': n,
209
+ 'block_index': d
210
+ }
211
+ yield format.format(attn_name='sattn', **kwargs), attn1
212
+ yield format.format(attn_name='xattn', **kwargs), attn2
213
+
214
+ def enumerate_all(blocks: nn.ModuleList, format: str):
215
+ for idx, block in enumerate(blocks.children()):
216
+ if isinstance(block, TimestepEmbedSequential):
217
+ yield from get_attns(idx, block, format)
218
+
219
+ inputs: nn.ModuleList = unet.input_blocks # type: ignore
220
+ middle: TimestepEmbedSequential = unet.middle_block # type: ignore
221
+ outputs: nn.ModuleList = unet.output_blocks # type: ignore
222
+
223
+ yield from enumerate_all(inputs, 'IN{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}')
224
+ yield from get_attns(0, middle, 'M{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}')
225
+ yield from enumerate_all(outputs, 'OUT{layer_index:02}_{trans_index:02}_{block_index:02}_{attn_name}')
226
+
227
+
228
+ def each_unet_transformers(unet: nn.Module):
229
+ def get_trans(layer_index: int, block: TimestepEmbedSequential, format: str):
230
+ for n, trans in enumerate(each_transformer(block)):
231
+ kwargs = {
232
+ 'layer_index': layer_index,
233
+ 'block_index': n,
234
+ 'block_name': 'trans',
235
+ }
236
+ yield format.format(**kwargs), trans
237
+
238
+ def enumerate_all(blocks: nn.ModuleList, format: str):
239
+ for idx, block in enumerate(blocks.children()):
240
+ if isinstance(block, TimestepEmbedSequential):
241
+ yield from get_trans(idx, block, format)
242
+
243
+ inputs: nn.ModuleList = unet.input_blocks # type: ignore
244
+ middle: TimestepEmbedSequential = unet.middle_block # type: ignore
245
+ outputs: nn.ModuleList = unet.output_blocks # type: ignore
246
+
247
+ yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}')
248
+ yield from get_trans(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}')
249
+ yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}')
250
+
251
+
252
+ def each_resblock(unet_block: TimestepEmbedSequential):
253
+ for block in unet_block.children():
254
+ if isinstance(block, ResBlock):
255
+ yield block
256
+
257
+ def each_unet_resblock(unet: nn.Module):
258
+ def get_resblock(layer_index: int, block: TimestepEmbedSequential, format: str):
259
+ for n, res in enumerate(each_resblock(block)):
260
+ kwargs = {
261
+ 'layer_index': layer_index,
262
+ 'block_index': n,
263
+ 'block_name': 'resblock',
264
+ }
265
+ yield format.format(**kwargs), res
266
+
267
+ def enumerate_all(blocks: nn.ModuleList, format: str):
268
+ for idx, block in enumerate(blocks.children()):
269
+ if isinstance(block, TimestepEmbedSequential):
270
+ yield from get_resblock(idx, block, format)
271
+
272
+ inputs: nn.ModuleList = unet.input_blocks # type: ignore
273
+ middle: TimestepEmbedSequential = unet.middle_block # type: ignore
274
+ outputs: nn.ModuleList = unet.output_blocks # type: ignore
275
+
276
+ yield from enumerate_all(inputs, 'IN{layer_index:02}_{block_index:02}_{block_name}')
277
+ yield from get_resblock(0, middle, 'M{layer_index:02}_{block_index:02}_{block_name}')
278
+ yield from enumerate_all(outputs, 'OUT{layer_index:02}_{block_index:02}_{block_name}')
279
+
sd-webui-xl_vec/scripts/xl_clip.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import open_clip
3
+
4
+ try:
5
+ from sgm.modules import GeneralConditioner as CLIP_SDXL
6
+ from sgm.modules.encoders.modules import FrozenOpenCLIPEmbedder2
7
+ from modules.sd_hijack_open_clip import FrozenOpenCLIPEmbedder2WithCustomWords
8
+ except:
9
+ print(f"[XL Vec] failed to load `sgm.modules`")
10
+ raise
11
+
12
+
13
+
14
+ def get_pooled(clip: CLIP_SDXL, text: str, layer='last', index=-1):
15
+ # cf. sgm/modules/encoders/modules.py:FrozenOpenCLIPEmbedder2
16
+
17
+ mod = clip.embedders[1]
18
+ if isinstance(mod, FrozenOpenCLIPEmbedder2WithCustomWords):
19
+ mod = mod.wrapped
20
+
21
+ assert isinstance(mod, FrozenOpenCLIPEmbedder2)
22
+
23
+ tokens = open_clip.tokenize([text]).to(mod.device)
24
+
25
+ x = mod.model.token_embedding(tokens) # [batch_size, n_ctx, d_model]
26
+ x = x + mod.model.positional_embedding
27
+ x = x.permute(1, 0, 2) # NLD -> LND
28
+ x = mod.text_transformer_forward(x, attn_mask=mod.model.attn_mask)
29
+
30
+ o = x[layer]
31
+ o = mod.model.ln_final(o)
32
+
33
+ eot = tokens.argmax(dim=-1)
34
+ p = torch.zeros_like(eot)
35
+ if 0 <= index:
36
+ p[0] = index
37
+ else:
38
+ p[0] = eot.item() + index + 1
39
+
40
+ real_index = p.item()
41
+ assert 0 <= real_index < 77, f'index={index}, real_index={real_index}'
42
+
43
+ pooled = (
44
+ o[torch.arange(o.shape[0]), p]
45
+ @ mod.model.text_projection
46
+ )
47
+
48
+ return pooled, real_index
sd-webui-xl_vec/scripts/xl_vec.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ NAME = 'XL Vec'
2
+
3
+ import logging
4
+ import traceback
5
+ from threading import Lock
6
+ from torch import Tensor, FloatTensor, nn
7
+ import gradio as gr
8
+ from modules.processing import StableDiffusionProcessing
9
+ from modules import scripts
10
+
11
+ from scripts.sdhook import SDHook
12
+ from scripts.xl_clip import CLIP_SDXL, get_pooled
13
+ from scripts.xl_vec_xyz import init_xyz
14
+
15
+ # --- LOGGING ---
16
+ logger = logging.getLogger(__name__)
17
+
18
+ # --- CONSTANTS ---
19
+ SDXL_POOLED_DIM = 1280 # Размер pooled embedding вектора SDXL
20
+ AESTHETIC_SCORE_EPS = 0.01 # Допуск для сравнения float значений (aesthetic score)
21
+ DEFAULT_AESTHETIC_SCORE = 6.0
22
+
23
+ # --- PRESETS ---
24
+ PRESETS = {
25
+ "Manual / Custom": None,
26
+ "1:1 Square (1024x1024)": (1024, 1024),
27
+ "4:3 Photo (1152x896)": (1152, 896),
28
+ "3:4 Portrait (896x1152)": (896, 1152),
29
+ "16:9 Cinema (1344x768)": (1344, 768),
30
+ "9:16 Mobile (768x1344)": (768, 1344),
31
+ "21:9 Wide (1536x640)": (1536, 640),
32
+ "2:3 Classic (832x1216)": (832, 1216),
33
+ }
34
+
35
+
36
+ def hook_input(args: 'Hook', mod: nn.Module, inputs: tuple[dict[str, Tensor]]) -> tuple[dict[str, Tensor]]:
37
+ """
38
+ Перехватывает входные данные CLIP модели для подмены параметров conditioning
39
+ (размеры, кроп, эстетическая оценка).
40
+
41
+ Args:
42
+ args: Экземпляр Hook с параметрами
43
+ mod: CLIP модуль
44
+ inputs: Tuple с входными данными
45
+
46
+ Returns:
47
+ Модифицированные входные данные
48
+ """
49
+ if not args.enabled:
50
+ return inputs
51
+
52
+ assert isinstance(mod, CLIP_SDXL), f"Expected CLIP_SDXL, got {type(mod)}"
53
+ input_data = inputs[0]
54
+
55
+ def create(v: list[float], src: FloatTensor) -> FloatTensor:
56
+ """Создает тензор с правильным device и dtype."""
57
+ return FloatTensor(v).to(dtype=src.dtype, device=src.device)
58
+
59
+ def put(name: str, v: list[float]) -> None:
60
+ """Безопасно заменяет значение в input_data."""
61
+ if name in input_data:
62
+ src = input_data[name]
63
+ input_data[name] = create(v, src).reshape(src.shape)
64
+
65
+ # Применяем геометрические параметры
66
+ put('original_size_as_tuple', [args.original_height, args.original_width])
67
+ put('crop_coords_top_left', [args.crop_top, args.crop_left])
68
+ put('target_size_as_tuple', [args.target_height, args.target_width])
69
+
70
+ # Логика определения Positive/Negative промпта через Aesthetic Score
71
+ try:
72
+ current_score = input_data['aesthetic_score'].item()
73
+ if args.is_positive_prompt(current_score):
74
+ put('aesthetic_score', [args.aesthetic_score])
75
+ else:
76
+ put('aesthetic_score', [args.negative_aesthetic_score])
77
+ except (KeyError, AttributeError) as e:
78
+ logger.warning(f"[XL Vec] Cannot access aesthetic_score: {e}")
79
+
80
+ return inputs
81
+
82
+
83
+ def hook_output(args: 'Hook', mod: nn.Module, inputs: tuple[dict[str, Tensor]], output: dict) -> dict:
84
+ """
85
+ Перехватывает выход CLIP модели для замены векторов токенов.
86
+
87
+ Args:
88
+ args: Экземпляр Hook с параметрами
89
+ mod: CLIP модуль
90
+ inputs: Входные данные
91
+ output: Выходные данные с ключом 'vector'
92
+
93
+ Returns:
94
+ Модифицированные выходные данные
95
+ """
96
+ if not args.enabled:
97
+ return output
98
+
99
+ try:
100
+ # Определяем, работаем ли мы с Positive или Negative промптом
101
+ current_score = inputs[0]['aesthetic_score'].item()
102
+ prompt, index, multiplier = args.get_prompt_params(current_score)
103
+
104
+ # Если параметры не заданы пользователем, ничего не делаем
105
+ if (prompt is None or len(prompt) == 0) and (index == -1 and multiplier == 1.0):
106
+ return output
107
+
108
+ # Если текст замены не задан, используем оригинальный промпт
109
+ if prompt is None or len(prompt) == 0:
110
+ prompt = inputs[0]['txt'][0]
111
+
112
+ assert isinstance(mod, CLIP_SDXL), f"Expected CLIP_SDXL, got {type(mod)}"
113
+
114
+ # Получаем новый pooled embedding с защитой от рекурсии
115
+ with args._lock:
116
+ args.enabled = False
117
+ try:
118
+ pooled, token_idx = get_pooled(mod, prompt, index=index)
119
+ finally:
120
+ args.enabled = True
121
+
122
+ # Подмена вектора с проверкой размерности
123
+ if output['vector'].shape[1] >= SDXL_POOLED_DIM:
124
+ output['vector'][:, 0:SDXL_POOLED_DIM] = pooled[:] * multiplier
125
+ logger.info(
126
+ f"[XL Vec] Vector override: '{inputs[0]['txt']}' -> '{prompt}' "
127
+ f"@ token {token_idx} [x{multiplier:.2f}]"
128
+ )
129
+ else:
130
+ logger.error(
131
+ f"[XL Vec] Vector dimension mismatch: expected >={SDXL_POOLED_DIM}, "
132
+ f"got {output['vector'].shape[1]}"
133
+ )
134
+
135
+ except Exception as e:
136
+ logger.error(f"[XL Vec] Error in hook_output: {e}")
137
+ traceback.print_exc()
138
+
139
+ return output
140
+
141
+
142
+ class Hook(SDHook):
143
+ """Хук для модификации CLIP conditioning в SDXL."""
144
+
145
+ def __init__(
146
+ self,
147
+ enabled: bool,
148
+ p: StableDiffusionProcessing,
149
+ crop_left: float, crop_top: float,
150
+ original_width: float, original_height: float,
151
+ target_width: float, target_height: float,
152
+ aesthetic_score: float, negative_aesthetic_score: float,
153
+ extra_prompt: str | None, extra_negative_prompt: str | None,
154
+ token_index: int | float, negative_token_index: int | float,
155
+ eot_multiplier: float, negative_eot_multiplier: float,
156
+ with_hr: bool,
157
+ base_aesthetic_score: float,
158
+ ):
159
+ super().__init__(enabled)
160
+
161
+ # Валидация параметров
162
+ self._validate_params(
163
+ aesthetic_score, negative_aesthetic_score, base_aesthetic_score,
164
+ original_width, original_height, target_width, target_height
165
+ )
166
+
167
+ self.p = p
168
+ self.crop_left = float(crop_left)
169
+ self.crop_top = float(crop_top)
170
+ self.original_width = float(original_width)
171
+ self.original_height = float(original_height)
172
+ self.target_width = float(target_width)
173
+ self.target_height = float(target_height)
174
+ self.aesthetic_score = float(aesthetic_score)
175
+ self.negative_aesthetic_score = float(negative_aesthetic_score)
176
+ self.extra_prompt = extra_prompt
177
+ self.extra_negative_prompt = extra_negative_prompt
178
+ self.token_index = int(token_index)
179
+ self.negative_token_index = int(negative_token_index)
180
+ self.eot_multiplier = float(eot_multiplier)
181
+ self.negative_eot_multiplier = float(negative_eot_multiplier)
182
+ self.with_hr = bool(with_hr)
183
+ self.base_aesthetic_score = float(base_aesthetic_score)
184
+
185
+ # Thread safety для предотвращения race conditions
186
+ self._lock = Lock()
187
+
188
+ @staticmethod
189
+ def _validate_params(
190
+ aesthetic_score: float,
191
+ negative_aesthetic_score: float,
192
+ base_aesthetic_score: float,
193
+ original_width: float,
194
+ original_height: float,
195
+ target_width: float,
196
+ target_height: float
197
+ ) -> None:
198
+ """Валидирует входные параметры."""
199
+ for score, name in [
200
+ (aesthetic_score, "aesthetic_score"),
201
+ (negative_aesthetic_score, "negative_aesthetic_score"),
202
+ (base_aesthetic_score, "base_aesthetic_score")
203
+ ]:
204
+ if not (0 <= score <= 10):
205
+ raise ValueError(f"{name} должен быть в диапазоне [0, 10], получено {score}")
206
+
207
+ for size, name in [
208
+ (original_width, "original_width"),
209
+ (original_height, "original_height"),
210
+ (target_width, "target_width"),
211
+ (target_height, "target_height")
212
+ ]:
213
+ if size < 0:
214
+ raise ValueError(f"{name} не может быть отрицательным, получено {size}")
215
+
216
+ def is_positive_prompt(self, aesthetic_score: float) -> bool:
217
+ """
218
+ Определяет, является ли текущий промпт положительным.
219
+
220
+ Args:
221
+ aesthetic_score: Текущее значение aesthetic score
222
+
223
+ Returns:
224
+ True если это positive prompt, False если negative
225
+ """
226
+ return abs(aesthetic_score - self.base_aesthetic_score) < AESTHETIC_SCORE_EPS
227
+
228
+ def get_prompt_params(self, aesthetic_score: float) -> tuple[str | None, int, float]:
229
+ """
230
+ Возвращает параметры промпта в зависимости от aesthetic_score.
231
+
232
+ Args:
233
+ aesthetic_score: Текущее значение aesthetic score
234
+
235
+ Returns:
236
+ Tuple (prompt, token_index, multiplier)
237
+ """
238
+ if self.is_positive_prompt(aesthetic_score):
239
+ return self.extra_prompt, self.token_index, self.eot_multiplier
240
+ else:
241
+ return self.extra_negative_prompt, self.negative_token_index, self.negative_eot_multiplier
242
+
243
+ def cleanup(self) -> None:
244
+ """Корректно удаляет все хуки."""
245
+ try:
246
+ self.__exit__(None, None, None)
247
+ except Exception as e:
248
+ logger.warning(f"[XL Vec] Error during cleanup: {e}")
249
+
250
+ def hook_clip(self, p: StableDiffusionProcessing, clip: nn.Module) -> None:
251
+ """Устанавливает хуки на CLIP модель."""
252
+ if not hasattr(p.sd_model, 'is_sdxl') or not p.sd_model.is_sdxl:
253
+ logger.debug("[XL Vec] Model is not SDXL, skipping hooks")
254
+ return
255
+
256
+ def inp(*args, **kwargs):
257
+ return hook_input(self, *args, **kwargs)
258
+
259
+ def outp(*args, **kwargs):
260
+ return hook_output(self, *args, **kwargs)
261
+
262
+ self.hook_layer_pre(clip, inp)
263
+ self.hook_layer(clip, outp)
264
+
265
+
266
+ class Script(scripts.Script):
267
+ """Скрипт для управления SDXL conditioning параметрами."""
268
+
269
+ def title(self) -> str:
270
+ return NAME
271
+
272
+ def show(self, is_img2img) -> scripts.AlwaysVisible:
273
+ return scripts.AlwaysVisible
274
+
275
+ def ui(self, is_img2img):
276
+ with gr.Accordion(NAME, open=False):
277
+ with gr.Row():
278
+ enabled = gr.Checkbox(label='Enable XL Vec', value=False)
279
+ with_hr = gr.Checkbox(label='Active on Hires Fix', value=False, visible=False)
280
+
281
+ # --- GEOMETRY SECTION ---
282
+ with gr.Group():
283
+ gr.Markdown("### 📐 SDXL Geometry & Size")
284
+
285
+ preset_dropdown = gr.Dropdown(
286
+ label="⚡ Quick Resolution Preset",
287
+ choices=list(PRESETS.keys()),
288
+ value="Manual / Custom",
289
+ type="value"
290
+ )
291
+
292
+ with gr.Row():
293
+ original_width = gr.Slider(
294
+ minimum=-1, maximum=4096, step=1, value=-1,
295
+ label='Original Width (-1=auto)'
296
+ )
297
+ original_height = gr.Slider(
298
+ minimum=-1, maximum=4096, step=1, value=-1,
299
+ label='Original Height (-1=auto)'
300
+ )
301
+
302
+ with gr.Row():
303
+ target_width = gr.Slider(
304
+ minimum=-1, maximum=4096, step=1, value=-1,
305
+ label='Target Width (-1=auto)'
306
+ )
307
+ target_height = gr.Slider(
308
+ minimum=-1, maximum=4096, step=1, value=-1,
309
+ label='Target Height (-1=auto)'
310
+ )
311
+
312
+ # Callback: Dropdown -> Sliders
313
+ def apply_preset(choice):
314
+ if choice in PRESETS and PRESETS[choice] is not None:
315
+ w, h = PRESETS[choice]
316
+ return w, h, w, h
317
+ return gr.update(), gr.update(), gr.update(), gr.update()
318
+
319
+ preset_dropdown.change(
320
+ fn=apply_preset,
321
+ inputs=[preset_dropdown],
322
+ outputs=[original_width, original_height, target_width, target_height]
323
+ )
324
+
325
+ # Callback: Sliders -> Dropdown (Reset to Manual)
326
+ def reset_dropdown():
327
+ return "Manual / Custom"
328
+
329
+ for slider in [original_width, original_height, target_width, target_height]:
330
+ slider.change(fn=reset_dropdown, inputs=None, outputs=[preset_dropdown])
331
+
332
+ with gr.Accordion("✂️ Crop Settings", open=False):
333
+ with gr.Row():
334
+ crop_left = gr.Slider(
335
+ minimum=-10000, maximum=10000, step=1, value=0,
336
+ label='Crop Left'
337
+ )
338
+ crop_top = gr.Slider(
339
+ minimum=-10000, maximum=10000, step=1, value=0,
340
+ label='Crop Top'
341
+ )
342
+
343
+ # --- AESTHETICS SECTION ---
344
+ with gr.Group():
345
+ gr.Markdown("### 🎨 Aesthetics")
346
+ with gr.Row():
347
+ aesthetic_score = gr.Slider(
348
+ minimum=0.0, maximum=10.0, step=0.1, value=6.0,
349
+ label="Positive Aesthetic Score"
350
+ )
351
+ negative_aesthetic_score = gr.Slider(
352
+ minimum=0.0, maximum=10.0, step=0.1, value=2.5,
353
+ label="Negative Aesthetic Score"
354
+ )
355
+
356
+ with gr.Accordion("⚙️ Detection Threshold (Advanced)", open=False):
357
+ base_aesthetic_score = gr.Slider(
358
+ minimum=0.0, maximum=10.0, step=0.1, value=6.0,
359
+ label="Base Score Threshold"
360
+ )
361
+ gr.Info("Change this ONLY if you modified 'SDXL Aesthetic Score' in WebUI settings.")
362
+
363
+ # --- VECTORS SECTION ---
364
+ with gr.Accordion("🧠 Token & Vector Control", open=False):
365
+ with gr.Row():
366
+ eot_multiplier = gr.Slider(
367
+ minimum=-4.0, maximum=8.0, step=0.05, value=1.0,
368
+ label='Pos. Vector Mult'
369
+ )
370
+ negative_eot_multiplier = gr.Slider(
371
+ minimum=-4.0, maximum=8.0, step=0.05, value=1.0,
372
+ label='Neg. Vector Mult'
373
+ )
374
+ with gr.Row():
375
+ token_index = gr.Slider(
376
+ minimum=-77, maximum=76, step=1, value=-1,
377
+ label='Pos. Token Index'
378
+ )
379
+ negative_token_index = gr.Slider(
380
+ minimum=-77, maximum=76, step=1, value=-1,
381
+ label='Neg. Token Index'
382
+ )
383
+ with gr.Row():
384
+ extra_prompt = gr.Textbox(
385
+ lines=1, label='Extra Prompt',
386
+ placeholder="Override positive prompt text..."
387
+ )
388
+ extra_negative_prompt = gr.Textbox(
389
+ lines=1, label='Extra Negative',
390
+ placeholder="Override negative prompt text..."
391
+ )
392
+
393
+ return [
394
+ enabled, crop_left, crop_top, original_width, original_height,
395
+ target_width, target_height, aesthetic_score, negative_aesthetic_score,
396
+ extra_prompt, extra_negative_prompt, token_index, negative_token_index,
397
+ eot_multiplier, negative_eot_multiplier, with_hr,
398
+ base_aesthetic_score
399
+ ]
400
+
401
+ def process(
402
+ self, p, enabled, crop_left, crop_top, original_width, original_height,
403
+ target_width, target_height, aesthetic_score, negative_aesthetic_score,
404
+ extra_prompt, extra_negative_prompt, token_index, negative_token_index,
405
+ eot_multiplier, negative_eot_multiplier, with_hr,
406
+ base_aesthetic_score=DEFAULT_AESTHETIC_SCORE
407
+ ):
408
+ """Обрабатывает параметры и устанавливает хуки перед генерацией."""
409
+
410
+ # Очистка предыдущего хука (если остался)
411
+ if getattr(self, 'last_hooker', None) is not None:
412
+ self.last_hooker.cleanup()
413
+ self.last_hooker = None
414
+
415
+ if not enabled:
416
+ return
417
+
418
+ # Автозаполнение размеров
419
+ if original_width < 0:
420
+ original_width = p.width
421
+ if original_height < 0:
422
+ original_height = p.height
423
+ if target_width < 0:
424
+ target_width = p.width
425
+ if target_height < 0:
426
+ target_height = p.height
427
+
428
+ try:
429
+ self.last_hooker = Hook(
430
+ enabled=True, p=p,
431
+ crop_left=crop_left, crop_top=crop_top,
432
+ original_width=original_width, original_height=original_height,
433
+ target_width=target_width, target_height=target_height,
434
+ aesthetic_score=aesthetic_score,
435
+ negative_aesthetic_score=negative_aesthetic_score,
436
+ extra_prompt=extra_prompt, extra_negative_prompt=extra_negative_prompt,
437
+ token_index=token_index, negative_token_index=negative_token_index,
438
+ eot_multiplier=eot_multiplier,
439
+ negative_eot_multiplier=negative_eot_multiplier,
440
+ with_hr=with_hr, base_aesthetic_score=base_aesthetic_score
441
+ )
442
+ except ValueError as e:
443
+ logger.error(f"[XL Vec] Invalid parameters: {e}")
444
+ return
445
+
446
+ self.last_hooker.setup(p)
447
+ self.last_hooker.__enter__()
448
+
449
+ # Обновление метаданных (Infotext)
450
+ p.extra_generation_params.update({
451
+ f'[{NAME}] Enabled': enabled,
452
+ f'[{NAME}] Original Size': f"{int(original_width)}x{int(original_height)}",
453
+ f'[{NAME}] Target Size': f"{int(target_width)}x{int(target_height)}",
454
+ f'[{NAME}] Aesthetic Score': aesthetic_score,
455
+ })
456
+
457
+ if crop_left != 0 or crop_top != 0:
458
+ p.extra_generation_params[f'[{NAME}] Crop'] = f"{crop_left},{crop_top}"
459
+
460
+ if abs(base_aesthetic_score - DEFAULT_AESTHETIC_SCORE) > AESTHETIC_SCORE_EPS:
461
+ p.extra_generation_params[f'[{NAME}] Base Score'] = base_aesthetic_score
462
+
463
+ if eot_multiplier != 1.0:
464
+ p.extra_generation_params[f'[{NAME}] Token Mult'] = eot_multiplier
465
+
466
+ # Сброс кэша для применения новых условий
467
+ if hasattr(p, 'cached_c'):
468
+ p.cached_c = [None, None]
469
+ if hasattr(p, 'cached_uc'):
470
+ p.cached_uc = [None, None]
471
+
472
+
473
+ init_xyz(Script, NAME)
sd-webui-xl_vec/scripts/xl_vec_xyz.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Union, List, Callable
3
+
4
+ from modules import scripts
5
+ from modules.processing import StableDiffusionProcessing, StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img
6
+
7
+
8
+ def __set_value(p: StableDiffusionProcessing, script: type, index: int, value):
9
+ args = list(p.script_args)
10
+
11
+ if isinstance(p, StableDiffusionProcessingTxt2Img):
12
+ all_scripts = scripts.scripts_txt2img.scripts
13
+ else:
14
+ all_scripts = scripts.scripts_img2img.scripts
15
+
16
+ froms = [x.args_from for x in all_scripts if isinstance(x, script)]
17
+ for idx in froms:
18
+ assert idx is not None
19
+ args[idx + index] = value
20
+
21
+ p.script_args = type(p.script_args)(args)
22
+
23
+ def __set_values(p: StableDiffusionProcessing, script: type, indices: list[int], values: list):
24
+ args = list(p.script_args)
25
+
26
+ if isinstance(p, StableDiffusionProcessingTxt2Img):
27
+ all_scripts = scripts.scripts_txt2img.scripts
28
+ else:
29
+ all_scripts = scripts.scripts_img2img.scripts
30
+
31
+ froms = [x.args_from for x in all_scripts if isinstance(x, script)]
32
+ for idx in froms:
33
+ assert idx is not None
34
+ for index, value in zip(indices, values):
35
+ args[idx + index] = value
36
+
37
+ p.script_args = type(p.script_args)(args)
38
+
39
+
40
+
41
+ def to_bool(v: str):
42
+ if len(v) == 0: return False
43
+ v = v.lower()
44
+ if 'true' in v: return True
45
+ if 'false' in v: return False
46
+
47
+ try:
48
+ w = int(v)
49
+ return bool(w)
50
+ except:
51
+ acceptable = ['True', 'False', '1', '0']
52
+ s = ', '.join([f'`{v}`' for v in acceptable])
53
+ raise ValueError(f'value must be one of {s}.')
54
+
55
+
56
+ class AxisOptions:
57
+
58
+ def __init__(self, AxisOption: type, axis_options: list):
59
+ self.AxisOption = AxisOption
60
+ self.target = axis_options
61
+ self.options = []
62
+
63
+ def __enter__(self):
64
+ self.options.clear()
65
+ return self
66
+
67
+ def __exit__(self, ex_type, ex_value, trace):
68
+ if ex_type is not None:
69
+ return
70
+
71
+ for opt in self.options:
72
+ self.target.append(opt)
73
+
74
+ self.options.clear()
75
+
76
+ def create(self, name: str, type_fn: Callable, action: Callable, choices: Union[List[str],None]):
77
+ if choices is None or len(choices) == 0:
78
+ opt = self.AxisOption(name, type_fn, action)
79
+ else:
80
+ opt = self.AxisOption(name, type_fn, action, choices=lambda: choices)
81
+ return opt
82
+
83
+ def add(self, axis_option):
84
+ self.target.append(axis_option)
85
+
86
+
87
+ __init = False
88
+
89
+ def init_xyz(script: type, ext_name: str):
90
+ global __init
91
+
92
+ if __init:
93
+ return
94
+
95
+ for data in scripts.scripts_data:
96
+ name = os.path.basename(data.path)
97
+ if name != 'xy_grid.py' and name != 'xyz_grid.py':
98
+ continue
99
+
100
+ if not hasattr(data.module, 'AxisOption'):
101
+ continue
102
+
103
+ if not hasattr(data.module, 'axis_options'):
104
+ continue
105
+
106
+ AxisOption = data.module.AxisOption
107
+ axis_options = data.module.axis_options
108
+
109
+ if not isinstance(AxisOption, type):
110
+ continue
111
+
112
+ if not isinstance(axis_options, list):
113
+ continue
114
+
115
+ try:
116
+ create_options(ext_name, script, AxisOption, axis_options)
117
+ except:
118
+ pass
119
+
120
+ __init = True
121
+
122
+
123
+ def create_options(ext_name: str, script: type, AxisOptionClass: type, axis_options: list):
124
+ with AxisOptions(AxisOptionClass, axis_options) as opts:
125
+ def define(param: str, index: int, type_fn: Callable, choices: List[str] = []):
126
+ def fn(p, x, xs):
127
+ __set_value(p, script, index, x)
128
+
129
+ name = f'[{ext_name}] {param}'
130
+ return opts.create(name, type_fn, fn, choices)
131
+
132
+ def define2(param: str, indices: list[int], type_fn: Callable, choices: List[str] = []):
133
+ def fn(p, x, xs):
134
+ __set_values(p, script, indices, x)
135
+
136
+ name = f'[{ext_name}] {param}'
137
+ return opts.create(name, type_fn, fn, choices)
138
+
139
+ options = [
140
+ define('Enabled', 0, to_bool, choices=['false', 'true']),
141
+ define('Crop Left', 1, float),
142
+ define('Crop Top', 2, float),
143
+ define('Original Width', 3, float),
144
+ define('Original Height', 4, float),
145
+ define('Target Width', 5, float),
146
+ define('Target Height', 6, float),
147
+ define('Aesthetic Score', 7, float),
148
+ define('Negative Aesthetic Score', 8, float),
149
+ define2('Original WxH', [3, 4], lambda s: [float(x) for x in s.split('x')]),
150
+ define2('Target WxH', [5, 6], lambda s: [float(x) for x in s.split('x')]),
151
+ define('Extra Prompt', 9, str),
152
+ define('Extra Negative Prompt', 10, str),
153
+ define('Token Index', 11, int),
154
+ define('Negative Token Index', 12, int),
155
+ define('EOT Multiplier', 13, float),
156
+ define('Negative EOT Multiplier', 14, float),
157
+ ]
158
+
159
+ for opt in options:
160
+ opts.add(opt)