Spaces:
Runtime error
Runtime error
bugfix
Browse files
app.py
CHANGED
|
@@ -365,31 +365,40 @@ def run_lora(prompt_bg, character_prompts_json, character_positions_json, lora_s
|
|
| 365 |
with calculateDuration("Set random seed"):
|
| 366 |
seed = random.randint(0, MAX_SEED)
|
| 367 |
|
| 368 |
-
# ็ผ็ ๆ็คบ่ฏ
|
| 369 |
with calculateDuration("Encoding prompts"):
|
| 370 |
# ็ผ็ ่ๆฏๆ็คบ่ฏ
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 371 |
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
|
| 372 |
-
bg_prompt_embeds = pipe.text_encoder_2(bg_text_input.input_ids.to(device))[0]
|
| 373 |
bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
|
| 374 |
|
| 375 |
# ็ผ็ ่ง่ฒๆ็คบ่ฏ
|
| 376 |
character_prompt_embeds = []
|
| 377 |
character_pooled_embeds = []
|
| 378 |
for prompt in character_prompts:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
|
| 380 |
-
char_prompt_embeds = pipe.text_encoder_2(char_text_input.input_ids.to(device))[0]
|
| 381 |
char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
|
|
|
|
| 382 |
character_prompt_embeds.append(char_prompt_embeds)
|
| 383 |
character_pooled_embeds.append(char_pooled_embeds)
|
| 384 |
|
| 385 |
# ็ผ็ ไบๅจ็ป่ๆ็คบ่ฏ
|
|
|
|
|
|
|
|
|
|
| 386 |
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
|
| 387 |
-
details_prompt_embeds = pipe.text_encoder_2(details_text_input.input_ids.to(device))[0]
|
| 388 |
details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
|
| 389 |
|
| 390 |
# ๅๅนถ่ๆฏๅไบๅจ็ป่็ๅตๅ
ฅ
|
| 391 |
prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
|
| 392 |
-
pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=1)
|
| 393 |
|
| 394 |
# ่งฃๆ่ง่ฒไฝ็ฝฎ
|
| 395 |
character_infos = []
|
|
|
|
| 365 |
with calculateDuration("Set random seed"):
|
| 366 |
seed = random.randint(0, MAX_SEED)
|
| 367 |
|
|
|
|
| 368 |
with calculateDuration("Encoding prompts"):
|
| 369 |
# ็ผ็ ่ๆฏๆ็คบ่ฏ
|
| 370 |
+
# ไฝฟ็จ tokenizer_2 ๅ text_encoder_2
|
| 371 |
+
bg_text_input_2 = pipe.tokenizer_2(prompt_bg, return_tensors="pt").to(device)
|
| 372 |
+
bg_prompt_embeds = pipe.text_encoder_2(bg_text_input_2.input_ids.to(device))[0]
|
| 373 |
+
|
| 374 |
+
# ไฝฟ็จ tokenizer ๅ text_encoder
|
| 375 |
bg_text_input = pipe.tokenizer(prompt_bg, return_tensors="pt").to(device)
|
|
|
|
| 376 |
bg_pooled_embeds = pipe.text_encoder(bg_text_input.input_ids.to(device)).pooler_output
|
| 377 |
|
| 378 |
# ็ผ็ ่ง่ฒๆ็คบ่ฏ
|
| 379 |
character_prompt_embeds = []
|
| 380 |
character_pooled_embeds = []
|
| 381 |
for prompt in character_prompts:
|
| 382 |
+
# ไฝฟ็จ tokenizer_2 ๅ text_encoder_2
|
| 383 |
+
char_text_input_2 = pipe.tokenizer_2(prompt, return_tensors="pt").to(device)
|
| 384 |
+
char_prompt_embeds = pipe.text_encoder_2(char_text_input_2.input_ids.to(device))[0]
|
| 385 |
+
# ไฝฟ็จ tokenizer ๅ text_encoder
|
| 386 |
char_text_input = pipe.tokenizer(prompt, return_tensors="pt").to(device)
|
|
|
|
| 387 |
char_pooled_embeds = pipe.text_encoder(char_text_input.input_ids.to(device)).pooler_output
|
| 388 |
+
|
| 389 |
character_prompt_embeds.append(char_prompt_embeds)
|
| 390 |
character_pooled_embeds.append(char_pooled_embeds)
|
| 391 |
|
| 392 |
# ็ผ็ ไบๅจ็ป่ๆ็คบ่ฏ
|
| 393 |
+
details_text_input_2 = pipe.tokenizer_2(prompt_details, return_tensors="pt").to(device)
|
| 394 |
+
details_prompt_embeds = pipe.text_encoder_2(details_text_input_2.input_ids.to(device))[0]
|
| 395 |
+
|
| 396 |
details_text_input = pipe.tokenizer(prompt_details, return_tensors="pt").to(device)
|
|
|
|
| 397 |
details_pooled_embeds = pipe.text_encoder(details_text_input.input_ids.to(device)).pooler_output
|
| 398 |
|
| 399 |
# ๅๅนถ่ๆฏๅไบๅจ็ป่็ๅตๅ
ฅ
|
| 400 |
prompt_embeds = torch.cat([bg_prompt_embeds, details_prompt_embeds], dim=1)
|
| 401 |
+
pooled_prompt_embeds = torch.cat([bg_pooled_embeds, details_pooled_embeds], dim=-1)
|
| 402 |
|
| 403 |
# ่งฃๆ่ง่ฒไฝ็ฝฎ
|
| 404 |
character_infos = []
|