test / benchmark_load_lora.py
varad-simpli's picture
Upload benchmark_load_lora.py
9be23cd verified
#!/usr/bin/env python3
import time
import argparse
import torch
from diffusers import FluxPipeline
def benchmark_load_lora(
base_model: str,
lora_source: str,
weight_name: str = None,
adapter_name: str = None,
dtype = torch.bfloat16,
runs: int = 3,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Benchmarking on device {device}, torch.cuda.device_count()={torch.cuda.device_count()}.")
print(f"1/4. Loading base Flux.1-dev model …")
t0 = time.time()
pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=dtype, use_safetensors=True)
base_load_s = time.time() - t0
print(f" Base model loaded in {base_load_s:.3f} s")
print("2/4. Moving pipeline to GPU …")
t1 = time.time()
pipe = pipe.to(device)
torch.cuda.synchronize(device)
move_s = time.time() - t1
print(f" to('cuda') took {move_s:.3f} s")
# Warm‑up LoRA caching (optional)
for i in range(runs):
print(f"3.{i+1}/4. Running load_lora_weights (run {i+1}/{runs}) …")
start = time.time()
adapter_name = "lora"
pipe.load_lora_weights(lora_source, adapter_name=adapter_name)
torch.cuda.synchronize(device)
duration = time.time() - start
print(f" → run {i+1}: load_lora_weights took {duration:.3f} s")
if i < runs - 1:
print(" Unloading LoRA …")
pipe.unload_lora_weights(reset_to_overwritten_params=True)
torch.cuda.synchronize(device)
print("All runs complete.")
avg = duration # last run
print(f"☆ Final run time: {avg:.3f} s")
print(f"― average over {runs} runs ≈ {avg:.3f} s")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Benchmark Flux.1‑dev load_lora_weights timing"
)
parser.add_argument("--model", default="black-forest-labs/FLUX.1-dev")
parser.add_argument("--lora", required=True, help="LoRA adapter repo ID or local folder / file path")
parser.add_argument("--runs", type=int, default=3)
args = parser.parse_args()
benchmark_load_lora(
base_model=args.model,
lora_source=args.lora,
runs=args.runs
)