Spaces:
Runtime error
Runtime error
Use LORA
Browse files
app.py
CHANGED
|
@@ -425,6 +425,53 @@ def worker(input_image, image_position, prompts, n_prompt, seed, resolution, tot
|
|
| 425 |
|
| 426 |
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
|
| 427 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 428 |
# Sampling
|
| 429 |
|
| 430 |
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
|
|
|
|
| 425 |
|
| 426 |
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
|
| 427 |
|
| 428 |
+
# Load transformer model
|
| 429 |
+
if model_changed:
|
| 430 |
+
stream.output_queue.push(("progress", (None, "", make_progress_bar_html(0, "Loading transformer ..."))))
|
| 431 |
+
|
| 432 |
+
transformer = None
|
| 433 |
+
time.sleep(1.0) # wait for the previous model to be unloaded
|
| 434 |
+
torch.cuda.empty_cache()
|
| 435 |
+
gc.collect()
|
| 436 |
+
|
| 437 |
+
previous_lora_file = lora_file
|
| 438 |
+
previous_lora_multiplier = lora_multiplier
|
| 439 |
+
previous_fp8_optimization = fp8_optimization
|
| 440 |
+
|
| 441 |
+
transformer = load_transfomer() # bfloat16, on cpu
|
| 442 |
+
|
| 443 |
+
if lora_file is not None or fp8_optimization:
|
| 444 |
+
state_dict = transformer.state_dict()
|
| 445 |
+
|
| 446 |
+
# LoRA should be merged before fp8 optimization
|
| 447 |
+
if lora_file is not None:
|
| 448 |
+
# TODO It would be better to merge the LoRA into the state dict before creating the transformer instance.
|
| 449 |
+
# Use from_config() instead of from_pretrained to make the instance without loading.
|
| 450 |
+
|
| 451 |
+
print(f"Merging LoRA file {os.path.basename(lora_file)} ...")
|
| 452 |
+
state_dict = merge_lora_to_state_dict(state_dict, lora_file, lora_multiplier, device=gpu)
|
| 453 |
+
gc.collect()
|
| 454 |
+
|
| 455 |
+
if fp8_optimization:
|
| 456 |
+
TARGET_KEYS = ["transformer_blocks", "single_transformer_blocks"]
|
| 457 |
+
EXCLUDE_KEYS = ["norm"] # Exclude norm layers (e.g., LayerNorm, RMSNorm) from FP8
|
| 458 |
+
|
| 459 |
+
# inplace optimization
|
| 460 |
+
print("Optimizing for fp8")
|
| 461 |
+
state_dict = optimize_state_dict_with_fp8(state_dict, gpu, TARGET_KEYS, EXCLUDE_KEYS, move_to_device=False)
|
| 462 |
+
|
| 463 |
+
# apply monkey patching
|
| 464 |
+
apply_fp8_monkey_patch(transformer, state_dict, use_scaled_mm=False)
|
| 465 |
+
gc.collect()
|
| 466 |
+
|
| 467 |
+
info = transformer.load_state_dict(state_dict, strict=True, assign=True)
|
| 468 |
+
print(f"LoRA and/or fp8 optimization applied: {info}")
|
| 469 |
+
|
| 470 |
+
if not high_vram:
|
| 471 |
+
DynamicSwapInstaller.install_model(transformer, device=gpu)
|
| 472 |
+
else:
|
| 473 |
+
transformer.to(gpu)
|
| 474 |
+
|
| 475 |
# Sampling
|
| 476 |
|
| 477 |
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
|