Spaces:
Runtime error
Runtime error
pure load lora
Browse files
app.py
CHANGED
|
@@ -14,9 +14,12 @@ from io import BytesIO
|
|
| 14 |
# from diffusers.models.attention_processor import AttentionProcessor
|
| 15 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 16 |
import torch.nn.functional as F
|
| 17 |
-
|
|
|
|
|
|
|
| 18 |
import re
|
| 19 |
import json
|
|
|
|
| 20 |
# ็ปๅฝ Hugging Face Hub
|
| 21 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 22 |
login(token=HF_TOKEN)
|
|
@@ -49,262 +52,16 @@ class calculateDuration:
|
|
| 49 |
else:
|
| 50 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
| 51 |
|
| 52 |
-
# ๅฎไนไฝ็ฝฎใๅ็งปๅๅบๅ็ๆ ๅฐ
|
| 53 |
-
valid_locations = { # x, y in 90*90
|
| 54 |
-
'in the center': (45, 45),
|
| 55 |
-
'on the left': (15, 45),
|
| 56 |
-
'on the right': (75, 45),
|
| 57 |
-
'on the top': (45, 15),
|
| 58 |
-
'on the bottom': (45, 75),
|
| 59 |
-
'on the top-left': (15, 15),
|
| 60 |
-
'on the top-right': (75, 15),
|
| 61 |
-
'on the bottom-left': (15, 75),
|
| 62 |
-
'on the bottom-right': (75, 75)
|
| 63 |
-
}
|
| 64 |
-
|
| 65 |
-
valid_offsets = { # x, y in 90*90
|
| 66 |
-
'no offset': (0, 0),
|
| 67 |
-
'slightly to the left': (-10, 0),
|
| 68 |
-
'slightly to the right': (10, 0),
|
| 69 |
-
'slightly to the upper': (0, -10),
|
| 70 |
-
'slightly to the lower': (0, 10),
|
| 71 |
-
'slightly to the upper-left': (-10, -10),
|
| 72 |
-
'slightly to the upper-right': (10, -10),
|
| 73 |
-
'slightly to the lower-left': (-10, 10),
|
| 74 |
-
'slightly to the lower-right': (10, 10)
|
| 75 |
-
}
|
| 76 |
-
|
| 77 |
-
valid_areas = { # w, h in 90*90
|
| 78 |
-
"a small square area": (50, 50),
|
| 79 |
-
"a small vertical area": (40, 60),
|
| 80 |
-
"a small horizontal area": (60, 40),
|
| 81 |
-
"a medium-sized square area": (60, 60),
|
| 82 |
-
"a medium-sized vertical area": (50, 80),
|
| 83 |
-
"a medium-sized horizontal area": (80, 50),
|
| 84 |
-
"a large square area": (70, 70),
|
| 85 |
-
"a large vertical area": (60, 90),
|
| 86 |
-
"a large horizontal area": (90, 60)
|
| 87 |
-
}
|
| 88 |
-
|
| 89 |
-
# ่งฃๆ่ง่ฒไฝ็ฝฎ็ๅฝๆฐ
|
| 90 |
-
def parse_character_position(character_position):
|
| 91 |
-
# ๅฎไนๆญฃๅ่กจ่พพๅผๆจกๅผ
|
| 92 |
-
location_pattern = '|'.join(re.escape(key) for key in valid_locations.keys())
|
| 93 |
-
offset_pattern = '|'.join(re.escape(key) for key in valid_offsets.keys())
|
| 94 |
-
area_pattern = '|'.join(re.escape(key) for key in valid_areas.keys())
|
| 95 |
-
|
| 96 |
-
# ๆๅไฝ็ฝฎ
|
| 97 |
-
location_match = re.search(location_pattern, character_position, re.IGNORECASE)
|
| 98 |
-
location = location_match.group(0) if location_match else 'in the center'
|
| 99 |
-
|
| 100 |
-
# ๆๅๅ็งป
|
| 101 |
-
offset_match = re.search(offset_pattern, character_position, re.IGNORECASE)
|
| 102 |
-
offset = offset_match.group(0) if offset_match else 'no offset'
|
| 103 |
-
|
| 104 |
-
# ๆๅๅบๅ
|
| 105 |
-
area_match = re.search(area_pattern, character_position, re.IGNORECASE)
|
| 106 |
-
area = area_match.group(0) if area_match else 'a medium-sized square area'
|
| 107 |
-
|
| 108 |
-
return {
|
| 109 |
-
'location': location,
|
| 110 |
-
'offset': offset,
|
| 111 |
-
'area': area
|
| 112 |
-
}
|
| 113 |
-
|
| 114 |
-
# ๅๅปบๆฉ็ ็ๅฝๆฐ
|
| 115 |
-
def create_attention_mask(image_width, image_height, location, offset, area):
|
| 116 |
-
# ๅพๅๅจ็ๆๆถ้ๅธธไผ่ขซ็ผฉๆพไธบ 90x90๏ผๅ ๆญคๅ
ๅฎไนไธไธชๅบ็กๅฐบๅฏธ
|
| 117 |
-
base_size = 90
|
| 118 |
-
|
| 119 |
-
# ่ทๅไฝ็ฝฎๅๆ
|
| 120 |
-
loc_x, loc_y = valid_locations.get(location, (45, 45))
|
| 121 |
-
# ่ทๅๅ็งป้
|
| 122 |
-
offset_x, offset_y = valid_offsets.get(offset, (0, 0))
|
| 123 |
-
# ่ทๅๅบๅๅคงๅฐ
|
| 124 |
-
area_width, area_height = valid_areas.get(area, (60, 60))
|
| 125 |
-
|
| 126 |
-
# ่ฎก็ฎๆ็ปไฝ็ฝฎ
|
| 127 |
-
final_x = loc_x + offset_x
|
| 128 |
-
final_y = loc_y + offset_y
|
| 129 |
-
|
| 130 |
-
# ๅฐๅๆ ๅๅฐบๅฏธๆ ๅฐๅฐๅฎ้
ๅพๅๅฐบๅฏธ
|
| 131 |
-
scale_x = image_width / base_size
|
| 132 |
-
scale_y = image_height / base_size
|
| 133 |
-
|
| 134 |
-
center_x = final_x * scale_x
|
| 135 |
-
center_y = final_y * scale_y
|
| 136 |
-
width = area_width * scale_x
|
| 137 |
-
height = area_height * scale_y
|
| 138 |
-
|
| 139 |
-
# ่ฎก็ฎๅทฆไธ่งๅๅณไธ่งๅๆ
|
| 140 |
-
x_start = int(max(center_x - width / 2, 0))
|
| 141 |
-
y_start = int(max(center_y - height / 2, 0))
|
| 142 |
-
x_end = int(min(center_x + width / 2, image_width))
|
| 143 |
-
y_end = int(min(center_y + height / 2, image_height))
|
| 144 |
-
|
| 145 |
-
# ๅๅปบๆฉ็
|
| 146 |
-
mask = torch.zeros((image_height, image_width), dtype=torch.float32, device="cuda")
|
| 147 |
-
mask[y_start:y_end, x_start:x_end] = 1.0
|
| 148 |
-
|
| 149 |
-
# ๅฑๅนณๆไธ็ปด
|
| 150 |
-
mask_flat = mask.view(-1) # ๅฝข็ถไธบ (image_height * image_width,)
|
| 151 |
-
return mask_flat
|
| 152 |
-
|
| 153 |
-
# ่ชๅฎไนๆณจๆๅๅค็ๅจ
|
| 154 |
-
|
| 155 |
-
class CustomCrossAttentionProcessor(AttnProcessor2_0):
|
| 156 |
-
def __init__(self, masks, adapter_names):
|
| 157 |
-
super().__init__()
|
| 158 |
-
self.masks = masks # ๅ่กจ๏ผๅ
ๅซๆฏไธช่ง่ฒ็ๆฉ็ (shape: [key_length])
|
| 159 |
-
self.adapter_names = adapter_names # ๅ่กจ๏ผๅ
ๅซๆฏไธช่ง่ฒ็ LoRA ้้
ๅจๅ็งฐ
|
| 160 |
-
|
| 161 |
-
def __call__(
|
| 162 |
-
self,
|
| 163 |
-
attn,
|
| 164 |
-
hidden_states,
|
| 165 |
-
encoder_hidden_states=None,
|
| 166 |
-
attention_mask=None,
|
| 167 |
-
temb=None,
|
| 168 |
-
**kwargs,
|
| 169 |
-
):
|
| 170 |
-
"""
|
| 171 |
-
่ชๅฎไน็ๆณจๆๅๅค็ๅจ๏ผ็จไบๅจๆณจๆๅ่ฎก็ฎไธญๅบ็จ่ง่ฒๆฉ็ ใ
|
| 172 |
-
|
| 173 |
-
ๅๆฐ๏ผ
|
| 174 |
-
attn: ๆณจๆๅๆจกๅๅฎไพใ
|
| 175 |
-
hidden_states: ่พๅ
ฅ็้่็ถๆ (query)ใ
|
| 176 |
-
encoder_hidden_states: ็ผ็ ๅจ็้่็ถๆ (key/value)ใ
|
| 177 |
-
attention_mask: ๆณจๆๅๆฉ็ ใ
|
| 178 |
-
temb: ๆถ้ดๅตๅ
ฅ๏ผๅฏ่ฝไธ้่ฆ๏ผใ
|
| 179 |
-
**kwargs: ๅ
ถไปๅๆฐใ
|
| 180 |
-
|
| 181 |
-
่ฟๅ๏ผ
|
| 182 |
-
๏ฟฝ๏ฟฝ็ๅ็้่็ถๆใ
|
| 183 |
-
"""
|
| 184 |
-
# ่ทๅๅฝๅ็ adapter_name
|
| 185 |
-
adapter_name = getattr(attn, 'adapter_name', None)
|
| 186 |
-
if adapter_name is None or adapter_name not in self.adapter_names:
|
| 187 |
-
# ๅฆๆๆฒกๆ adapter_name๏ผๆ่
ไธๅจๆไปฌ็ๅ่กจไธญ๏ผ็ดๆฅๆง่ก็ถ็ฑป็ __call__ ๆนๆณ
|
| 188 |
-
return super().__call__(attn, hidden_states, encoder_hidden_states, attention_mask, temb, **kwargs)
|
| 189 |
-
|
| 190 |
-
# ๆฅๆพ adapter_name ๅฏนๅบ็็ดขๅผ
|
| 191 |
-
idx = self.adapter_names.index(adapter_name)
|
| 192 |
-
mask = self.masks[idx] # ่ทๅๅฏนๅบ็ๆฉ็ (shape: [key_length])
|
| 193 |
-
|
| 194 |
-
# ไปฅไธๆฏ AttnProcessor2_0 ็ๅฎ็ฐ๏ผๆไปฌๅจ้ๅฝ็ไฝ็ฝฎๅ ๅ
ฅ่ชๅฎไน็ๆฉ็ ้ป่พ
|
| 195 |
-
|
| 196 |
-
residual = hidden_states
|
| 197 |
-
if attn.spatial_norm is not None:
|
| 198 |
-
hidden_states = attn.spatial_norm(hidden_states, temb)
|
| 199 |
-
|
| 200 |
-
input_ndim = hidden_states.ndim
|
| 201 |
-
|
| 202 |
-
if input_ndim == 4:
|
| 203 |
-
batch_size, channel, height, width = hidden_states.shape
|
| 204 |
-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
| 205 |
-
else:
|
| 206 |
-
batch_size, sequence_length, _ = hidden_states.shape
|
| 207 |
-
|
| 208 |
-
if encoder_hidden_states is None:
|
| 209 |
-
encoder_hidden_states = hidden_states
|
| 210 |
-
else:
|
| 211 |
-
# ๅฆๆๆ encoder_hidden_states๏ผ่ทๅๅ
ถๅฝข็ถ
|
| 212 |
-
encoder_batch_size, key_length, _ = encoder_hidden_states.shape
|
| 213 |
-
|
| 214 |
-
if attention_mask is not None:
|
| 215 |
-
# ๅค็ attention_mask๏ผๅฆๆ้่ฆ็่ฏ
|
| 216 |
-
attention_mask = attn.prepare_attention_mask(attention_mask, key_length, batch_size)
|
| 217 |
-
# attention_mask ็ๅฝข็ถๅบไธบ (batch_size, attn.heads, query_length, key_length)
|
| 218 |
-
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
| 219 |
-
else:
|
| 220 |
-
# ๅฆๆๆฒกๆ attention_mask๏ผๆไปฌๅๅปบไธไธชๅ
จ 0 ็ๆฉ็
|
| 221 |
-
attention_mask = torch.zeros(
|
| 222 |
-
batch_size, attn.heads, 1, key_length, device=hidden_states.device, dtype=hidden_states.dtype
|
| 223 |
-
)
|
| 224 |
-
|
| 225 |
-
if attn.group_norm is not None:
|
| 226 |
-
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
| 227 |
-
|
| 228 |
-
query = attn.to_q(hidden_states)
|
| 229 |
-
|
| 230 |
-
if attn.norm_cross:
|
| 231 |
-
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
| 232 |
-
|
| 233 |
-
key = attn.to_k(encoder_hidden_states)
|
| 234 |
-
value = attn.to_v(encoder_hidden_states)
|
| 235 |
-
|
| 236 |
-
inner_dim = key.shape[-1]
|
| 237 |
-
head_dim = inner_dim // attn.heads
|
| 238 |
-
|
| 239 |
-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 240 |
-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 241 |
-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
| 242 |
-
|
| 243 |
-
if attn.norm_q is not None:
|
| 244 |
-
query = attn.norm_q(query)
|
| 245 |
-
if attn.norm_k is not None:
|
| 246 |
-
key = attn.norm_k(key)
|
| 247 |
-
|
| 248 |
-
# ่ฎก็ฎๅๅง็ๆณจๆๅๅพๅ
|
| 249 |
-
# ๆไปฌ้่ฆๅจ่ฎก็ฎๆณจๆๅๅพๅๅๅบ็จๆฉ็
|
| 250 |
-
# ไฝ็ฑไบ PyTorch ็ scaled_dot_product_attention ๆฅๅ attention_mask ๅๆฐ๏ผๆไปฌ้่ฆ่ฐๆดๆไปฌ็ๆฉ็
|
| 251 |
-
|
| 252 |
-
# ๅๅปบ่ชๅฎไน็ attention_mask
|
| 253 |
-
# mask ็ๅฝข็ถไธบ [key_length]๏ผ้่ฆ่ฐๆดไธบ (batch_size, attn.heads, 1, key_length)
|
| 254 |
-
custom_attention_mask = mask.view(1, 1, 1, -1).to(hidden_states.device, dtype=hidden_states.dtype)
|
| 255 |
-
# ๅฐๆๆไฝ็ฝฎ่ฎพไธบ 0๏ผ่ขซๆฉ่ฝ็ไฝ็ฝฎ่ฎพไธบ -1e9๏ผๅฏนไบ float16๏ผไฝฟ็จ -65504๏ผ
|
| 256 |
-
mask_value = -65504.0 if hidden_states.dtype == torch.float16 else -1e9
|
| 257 |
-
custom_attention_mask = (1.0 - custom_attention_mask) * mask_value # ๆๆไฝ็ฝฎไธบ 0๏ผๆ ๆไฝ็ฝฎไธบ -1e9
|
| 258 |
-
|
| 259 |
-
# ๅฐ่ชๅฎไนๆฉ็ ๆทปๅ ๅฐ attention_mask
|
| 260 |
-
attention_mask = attention_mask + custom_attention_mask
|
| 261 |
-
|
| 262 |
-
# ่ฎก็ฎๆณจๆๅ
|
| 263 |
-
hidden_states = F.scaled_dot_product_attention(
|
| 264 |
-
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
| 265 |
-
)
|
| 266 |
-
|
| 267 |
-
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
| 268 |
-
hidden_states = hidden_states.to(query.dtype)
|
| 269 |
-
|
| 270 |
-
# linear proj
|
| 271 |
-
hidden_states = attn.to_out[0](hidden_states)
|
| 272 |
-
# dropout
|
| 273 |
-
hidden_states = attn.to_out[1](hidden_states)
|
| 274 |
-
|
| 275 |
-
if input_ndim == 4:
|
| 276 |
-
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
| 277 |
-
|
| 278 |
-
if attn.residual_connection:
|
| 279 |
-
hidden_states = hidden_states + residual
|
| 280 |
-
|
| 281 |
-
hidden_states = hidden_states / attn.rescale_output_factor
|
| 282 |
-
|
| 283 |
-
return hidden_states
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
# ๆฟๆขๆณจๆๅๅค็ๅจ็ๅฝๆฐ
|
| 287 |
-
def replace_attention_processors(pipe, masks, adapter_names):
|
| 288 |
-
custom_processor = CustomCrossAttentionProcessor(masks, adapter_names)
|
| 289 |
-
for name, module in pipe.transformer.named_modules():
|
| 290 |
-
if hasattr(module, 'attn'):
|
| 291 |
-
module.attn.adapter_name = getattr(module, 'adapter_name', None)
|
| 292 |
-
module.attn.processor = custom_processor
|
| 293 |
-
if hasattr(module, 'cross_attn'):
|
| 294 |
-
module.cross_attn.adapter_name = getattr(module, 'adapter_name', None)
|
| 295 |
-
module.cross_attn.processor = custom_processor
|
| 296 |
-
|
| 297 |
# ็ๆๅพๅ็ๅฝๆฐ
|
| 298 |
-
|
| 299 |
-
|
|
|
|
| 300 |
pipe.to(device)
|
| 301 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 302 |
-
|
| 303 |
with calculateDuration("Generating image"):
|
| 304 |
# Generate image
|
| 305 |
generated_image = pipe(
|
| 306 |
-
|
| 307 |
-
pooled_prompt_embeds=pooled_prompt_embeds,
|
| 308 |
num_inference_steps=steps,
|
| 309 |
guidance_scale=cfg_scale,
|
| 310 |
width=width,
|
|
@@ -315,111 +72,67 @@ def generate_image_with_embeddings(prompt_embeds, pooled_prompt_embeds, steps, s
|
|
| 315 |
progress(99, "Generate success!")
|
| 316 |
return generated_image
|
| 317 |
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
| 336 |
-
|
|
|
|
| 337 |
|
| 338 |
-
|
| 339 |
-
with calculateDuration("Loading LoRA weights"):
|
| 340 |
-
pipe.unload_lora_weights()
|
| 341 |
-
adapter_names = []
|
| 342 |
-
for lora_info in lora_strings:
|
| 343 |
-
lora_repo = lora_info.get("repo")
|
| 344 |
-
weights = lora_info.get("weights")
|
| 345 |
-
adapter_name = lora_info.get("adapter_name")
|
| 346 |
-
if lora_repo and weights and adapter_name:
|
| 347 |
-
# ่ฐ็จ pipe.load_lora_weights() ๆนๆณๅ ่ฝฝๆ้
|
| 348 |
-
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
| 349 |
-
adapter_names.append(adapter_name)
|
| 350 |
-
# ๅฐ adapter_name ่ฎพ็ฝฎไธบๆจกๅ็ๅฑๆง
|
| 351 |
-
setattr(pipe.transformer, 'adapter_name', adapter_name)
|
| 352 |
|
| 353 |
-
|
| 354 |
-
raise ValueError("Invalid LoRA string format. Each item must have 'repo', 'weights', and 'adapter_name' keys.")
|
| 355 |
-
adapter_weights = [lora_scale] * len(adapter_names)
|
| 356 |
-
# ่ฐ็จ pipeline.set_adapters ๆนๆณ่ฎพ็ฝฎ adapter ๅๅฏนๅบๆ้
|
| 357 |
-
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
| 358 |
|
| 359 |
-
#
|
| 360 |
-
if
|
| 361 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 362 |
|
| 363 |
# Set random seed for reproducibility
|
| 364 |
if randomize_seed:
|
| 365 |
with calculateDuration("Set random seed"):
|
| 366 |
seed = random.randint(0, MAX_SEED)
|
| 367 |
|
| 368 |
-
with calculateDuration("Encoding prompts"):
|
| 369 |
-
# ็ผ็ ่ๆฏๆ็คบ่ฏ
|
| 370 |
-
# ไฝฟ็จ tokenizer_2 ๅ text_encoder_2
|
| 371 |
-
bg_text_input_2 = pipe.tokenizer_2(prompt_bg, return_tensors="pt").to(device)
|
| 372 |
-
bg_prompt_embeds = pipe.text_encoder_2(bg_text_input_2.input_ids.to(device))[0]
|
| 373 |
-
|
| 374 |
-
# ไฝฟ็จ tokenizer ๅ text_encoder
|
| 375 |
-
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
|
| 376 |
-
bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
|
| 377 |
-
|
| 378 |
-
# ็ผ็ ่ง่ฒๆ็คบ่ฏ
|
| 379 |
-
character_prompt_embeds = []
|
| 380 |
-
character_pooled_embeds = []
|
| 381 |
-
for prompt in character_prompts:
|
| 382 |
-
# ไฝฟ็จ tokenizer_2 ๅ text_encoder_2
|
| 383 |
-
char_text_input_2 = pipe.tokenizer_2(prompt, return_tensors="pt").to(device)
|
| 384 |
-
char_prompt_embeds = pipe.text_encoder_2(char_text_input_2.input_ids.to(device))[0]
|
| 385 |
-
# ไฝฟ็จ tokenizer ๅ text_encoder
|
| 386 |
-
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
|
| 387 |
-
char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
|
| 388 |
-
|
| 389 |
-
character_prompt_embeds.append(char_prompt_embeds)
|
| 390 |
-
character_pooled_embeds.append(char_pooled_embeds)
|
| 391 |
-
|
| 392 |
-
# ็ผ็ ไบๅจ็ป่ๆ็คบ่ฏ
|
| 393 |
-
details_text_input_2 = pipe.tokenizer_2(prompt_details, return_tensors="pt").to(device)
|
| 394 |
-
details_prompt_embeds = pipe.text_encoder_2(details_text_input_2.input_ids.to(device))[0]
|
| 395 |
-
|
| 396 |
-
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
|
| 397 |
-
details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
|
| 398 |
-
|
| 399 |
-
# ๅๅนถ่ๆฏๅไบๅจ็ป่็ๅตๅ
ฅ
|
| 400 |
-
prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
|
| 401 |
-
pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=-1)
|
| 402 |
-
|
| 403 |
-
# ่งฃๆ่ง่ฒไฝ็ฝฎ
|
| 404 |
-
character_infos = []
|
| 405 |
-
for position_str in character_positions:
|
| 406 |
-
info = parse_character_position(position_str)
|
| 407 |
-
character_infos.append(info)
|
| 408 |
-
|
| 409 |
-
# ๅๅปบ่ง่ฒ็ๆฉ็
|
| 410 |
-
masks = []
|
| 411 |
-
for info in character_infos:
|
| 412 |
-
mask = create_attention_mask(width, height, info['location'], info['offset'], info['area'])
|
| 413 |
-
masks.append(mask)
|
| 414 |
-
|
| 415 |
-
# ๆฟๆขๆณจๆๅๅค็ๅจ
|
| 416 |
-
replace_attention_processors(pipe, masks, adapter_names)
|
| 417 |
-
|
| 418 |
# Generate image
|
| 419 |
-
final_image =
|
| 420 |
|
| 421 |
-
|
| 422 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
|
| 424 |
progress(100, "Completed!")
|
| 425 |
|
|
@@ -439,11 +152,9 @@ with gr.Blocks(css=css) as demo:
|
|
| 439 |
|
| 440 |
with gr.Column():
|
| 441 |
|
| 442 |
-
|
| 443 |
-
character_prompts = gr.Text(label="Character Prompts (JSON List)", placeholder='["Character 1 prompt", "Character 2 prompt"]', lines=5)
|
| 444 |
-
character_positions = gr.Text(label="Character Positions (JSON List)", placeholder='["Character 1 position", "Character 2 position"]', lines=5)
|
| 445 |
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
|
| 446 |
-
|
| 447 |
run_button = gr.Button("Run", scale=0)
|
| 448 |
|
| 449 |
with gr.Accordion("Advanced Settings", open=False):
|
|
@@ -474,11 +185,8 @@ with gr.Blocks(css=css) as demo:
|
|
| 474 |
json_text = gr.Text(label="Result JSON")
|
| 475 |
|
| 476 |
inputs = [
|
| 477 |
-
|
| 478 |
-
character_prompts,
|
| 479 |
-
character_positions,
|
| 480 |
lora_strings_json,
|
| 481 |
-
prompt_details,
|
| 482 |
cfg_scale,
|
| 483 |
steps,
|
| 484 |
randomize_seed,
|
|
|
|
| 14 |
# from diffusers.models.attention_processor import AttentionProcessor
|
| 15 |
from diffusers.models.attention_processor import AttnProcessor2_0
|
| 16 |
import torch.nn.functional as F
|
| 17 |
+
import time
|
| 18 |
+
import boto3
|
| 19 |
+
from io import BytesIO
|
| 20 |
import re
|
| 21 |
import json
|
| 22 |
+
|
| 23 |
# ็ปๅฝ Hugging Face Hub
|
| 24 |
HF_TOKEN = os.environ.get("HF_TOKEN")
|
| 25 |
login(token=HF_TOKEN)
|
|
|
|
| 52 |
else:
|
| 53 |
print(f"Elapsed time: {self.elapsed_time:.6f} seconds")
|
| 54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
# ็ๆๅพๅ็ๅฝๆฐ
|
| 56 |
+
@spaces.GPU
|
| 57 |
+
@torch.inference_mode()
|
| 58 |
+
def generate_image(prompt, steps, seed, cfg_scale, width, height, progress):
|
| 59 |
pipe.to(device)
|
| 60 |
generator = torch.Generator(device=device).manual_seed(seed)
|
|
|
|
| 61 |
with calculateDuration("Generating image"):
|
| 62 |
# Generate image
|
| 63 |
generated_image = pipe(
|
| 64 |
+
prompt=prompt,
|
|
|
|
| 65 |
num_inference_steps=steps,
|
| 66 |
guidance_scale=cfg_scale,
|
| 67 |
width=width,
|
|
|
|
| 72 |
progress(99, "Generate success!")
|
| 73 |
return generated_image
|
| 74 |
|
| 75 |
+
|
| 76 |
+
def upload_image_to_r2(image, account_id, access_key, secret_key, bucket_name):
|
| 77 |
+
print("upload_image_to_r2", account_id, access_key, secret_key, bucket_name)
|
| 78 |
+
connectionUrl = f"https://{account_id}.r2.cloudflarestorage.com"
|
| 79 |
+
|
| 80 |
+
s3 = boto3.client(
|
| 81 |
+
's3',
|
| 82 |
+
endpoint_url=connectionUrl,
|
| 83 |
+
region_name='auto',
|
| 84 |
+
aws_access_key_id=access_key,
|
| 85 |
+
aws_secret_access_key=secret_key
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
current_time = datetime.now().strftime("%Y/%m/%d/%H%M%S")
|
| 89 |
+
image_file = f"generated_images/{current_time}_{random.randint(0, MAX_SEED)}.png"
|
| 90 |
+
buffer = BytesIO()
|
| 91 |
+
image.save(buffer, "PNG")
|
| 92 |
+
buffer.seek(0)
|
| 93 |
+
s3.upload_fileobj(buffer, bucket_name, image_file)
|
| 94 |
+
print("upload finish", image_file)
|
| 95 |
|
| 96 |
+
return image_file
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 97 |
|
| 98 |
+
def run_lora(prompt, lora_strings_json, cfg_scale, steps, randomize_seed, seed, width, height, lora_scale, upload_to_r2, account_id, access_key, secret_key, bucket, progress=gr.Progress(track_tqdm=True)):
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
# Load LoRA weights
|
| 101 |
+
if lora_strings_json:
|
| 102 |
+
try:
|
| 103 |
+
lora_strings_json = json.loads(lora_strings_json)
|
| 104 |
+
except:
|
| 105 |
+
lora_strings_json = None
|
| 106 |
+
if lora_strings_json:
|
| 107 |
+
with calculateDuration("Loading LoRA weights"):
|
| 108 |
+
pipe.unload_lora_weights()
|
| 109 |
+
adapter_names = []
|
| 110 |
+
for lora_info in lora_strings:
|
| 111 |
+
lora_repo = lora_info.get("repo")
|
| 112 |
+
weights = lora_info.get("weights")
|
| 113 |
+
adapter_name = lora_info.get("adapter_name")
|
| 114 |
+
if lora_repo and weights and adapter_name:
|
| 115 |
+
# ๅ ่ฝฝ LoRA ๆ้
|
| 116 |
+
pipe.load_lora_weights(lora_repo, weight_name=weights, adapter_name=adapter_name)
|
| 117 |
+
adapter_names.append(adapter_name)
|
| 118 |
+
adapter_weights = [lora_scale] * len(adapter_names)
|
| 119 |
+
pipe.set_adapters(adapter_names, adapter_weights=adapter_weights)
|
| 120 |
|
| 121 |
# Set random seed for reproducibility
|
| 122 |
if randomize_seed:
|
| 123 |
with calculateDuration("Set random seed"):
|
| 124 |
seed = random.randint(0, MAX_SEED)
|
| 125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
# Generate image
|
| 127 |
+
final_image = generate_image(prompt, steps, seed, cfg_scale, width, height, progress)
|
| 128 |
|
| 129 |
+
if final_image:
|
| 130 |
+
if upload_to_r2:
|
| 131 |
+
with calculateDuration("Upload image"):
|
| 132 |
+
url = upload_image_to_r2(final_image, account_id, access_key, secret_key, bucket)
|
| 133 |
+
result = {"status": "success", "message": "upload image success", "url": url}
|
| 134 |
+
else:
|
| 135 |
+
result = {"status": "success", "message": "Image generated but not uploaded"}
|
| 136 |
|
| 137 |
progress(100, "Completed!")
|
| 138 |
|
|
|
|
| 152 |
|
| 153 |
with gr.Column():
|
| 154 |
|
| 155 |
+
prompt = gr.Text(label="Prompt", placeholder="Enter prompt", lines=2)
|
|
|
|
|
|
|
| 156 |
lora_strings_json = gr.Text(label="LoRA Strings (JSON List)", placeholder='[{"repo": "lora_repo1", "weights": "weights1", "adapter_name": "adapter_name1"}, {"repo": "lora_repo2", "weights": "weights2", "adapter_name": "adapter_name2"}]', lines=5)
|
| 157 |
+
|
| 158 |
run_button = gr.Button("Run", scale=0)
|
| 159 |
|
| 160 |
with gr.Accordion("Advanced Settings", open=False):
|
|
|
|
| 185 |
json_text = gr.Text(label="Result JSON")
|
| 186 |
|
| 187 |
inputs = [
|
| 188 |
+
prompt,
|
|
|
|
|
|
|
| 189 |
lora_strings_json,
|
|
|
|
| 190 |
cfg_scale,
|
| 191 |
steps,
|
| 192 |
randomize_seed,
|